Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OSCP] 使用 SPU 实现 AP(average_precision_score) 函数 #801

Merged
merged 8 commits into from
Aug 20, 2024
1 change: 1 addition & 0 deletions sml/metrics/classification/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ py_library(
srcs = ["classification.py"],
deps = [
":auc",
"//sml/preprocessing",
"//spu/ops/groupby",
],
)
Expand Down
7 changes: 5 additions & 2 deletions sml/metrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@
from spu.ops.groupby import groupby_sorted


def binary_clf_curve(sorted_pairs: jnp.array) -> Tuple[jnp.array, jnp.array, jnp.array]:
def binary_clf_curve(
sorted_pairs: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Calculate true and false positives per binary classification
threshold (can be used for roc curve or precision/recall curve).
Results may include trailing zeros.
Args:
sorted_pairs: jnp.array
sorted_pairs: jnp.ndarray
y_true y_score pairs sorted by y_score in decreasing order
Returns:
fps: 1d ndarray
Expand Down Expand Up @@ -57,6 +59,7 @@ def binary_clf_curve(sorted_pairs: jnp.array) -> Tuple[jnp.array, jnp.array, jnp
fps = seg_end_marks * fps
thresholds = seg_end_marks * thresholds
thresholds, fps, tps = jax.lax.sort([-thresholds] + [fps, tps], num_keys=1)

return fps, tps, -thresholds


Expand Down
156 changes: 155 additions & 1 deletion sml/metrics/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@

import jax
import jax.numpy as jnp
from auc import binary_roc_auc

from sml.preprocessing.preprocessing import label_binarize
z0gSh1u marked this conversation as resolved.
Show resolved Hide resolved
from spu.ops.groupby import groupby, groupby_sum

from .auc import binary_clf_curve, binary_roc_auc


def roc_auc_score(y_true, y_pred):
sorted_arr = create_sorted_label_score_pair(y_true, y_pred)
Expand Down Expand Up @@ -222,3 +224,155 @@ def fun_score(
else:
raise ValueError("average should be None or 'binary'")
return fun_result


def precision_recall_curve(
y_true: jnp.ndarray, y_score: jnp.ndarray, pos_label=1, score_eps=1e-5
):
"""Compute precision-recall pairs for different probability thresholds.

Note: this implementation is restricted to the binary classification task.

Parameters
----------
y_true : 1d array-like of shape (n,). True binary labels.

y_score : 1d array-like of shape (n,). Target scores, non-negative.

pos_label : int, default=1. The label of the positive class.

score_eps : float, default=1e-5. The lower bound for y_score.

Returns
-------
precisions : ndarray of shape (n + 1,).
Precision values where element i is the precision s.t.
score >= thresholds[i] and the last element is 1.

recalls : ndarray of shape (n + 1,).
Increasing recall values where element i is the recall s.t.
score >= thresholds[i] and the last element is 0.

thresholds : ndarray of shape (n,).
Decreasing thresholds used to compute precision and recall.
Results might include trailing zeros.
"""

# normalize the input
y_true = jnp.where(y_true == pos_label, 1, 0)
y_score = jnp.where(
y_score < score_eps, score_eps, y_score
) # to avoid messing up trailing zero and score zero

# compute TP and FP
sorted_pairs = create_sorted_label_score_pair(y_true, y_score)
fp, tp, thresholds = binary_clf_curve(sorted_pairs)

# compute precision and recalls
mask = jnp.where(thresholds > 0, 1, 0) # tied value entries have mask=0
precisions = jnp.where(mask, tp / (tp + fp + 1e-5), 0)
max_tp = jnp.max(tp)
recalls = jnp.where(max_tp == 0, jnp.ones_like(tp), tp / max_tp)

return (
jnp.hstack((1, precisions)),
jnp.hstack((0, recalls)),
thresholds,
)


def average_precision_score(
y_true: jnp.ndarray,
y_score: jnp.ndarray,
classes=(0, 1),
average="macro",
pos_label=1,
score_eps=1e-5,
):
"""Compute average precision (AP) from prediction scores.

.. math::
\\text{AP} = \\sum_n (R_n - R_{n-1}) P_n

Parameters
-------
y_true : array-like of shape (n_samples,)
True labels.

y_score : array-like of shape (n_samples,) or (n_samples, n_classes)
Estimated target scores as returned by a classifier, non-negative.

classes : 1d array-like, shape (n_classes,), default=(0,1) as for binary classification
Uniquely holds the label for each class.
SPU cannot support dynamic shape, so this parameter needs to be designated.

average : {'macro', 'micro', None}, default='macro'
This parameter is required for multiclass/multilabel targets and
will be ignored when y_true is binary.

'macro':
Calculate metrics for each label, and find their unweighted mean.
'micro':
Calculate metrics globally by considering each element of the label
indicator matrix as a label.
None:
Scores for each class are returned.

pos_label : int, default=1
The label of the positive class. Only applied to binary y_true.

score_eps : float, default=1e-5. The lower bound for y_score.

Returns
-------
average_precision : float
Average precision score.
"""

assert average in (
'macro',
'micro',
None,
), 'average must be either "macro", "micro" or None'

def binary_average_precision(y_true, y_score, pos_label=1):
"""Compute the average precision for binary classification."""
precisions, recalls, _ = precision_recall_curve(
y_true, y_score, pos_label=pos_label, score_eps=score_eps
)

return jnp.sum(jnp.diff(recalls) * precisions[1:])

n_classes = len(classes)
if n_classes <= 2:
# binary classification
# given y_true all the same is a special case considered as binary classification
return binary_average_precision(y_true, y_score, pos_label=pos_label)
else:
# multi-class classification
# binarize labels using one-vs-all scheme into multilabel-indicator
y_true = label_binarize(y_true, classes=classes, n_classes=n_classes)

if average == "micro":
y_true = y_true.ravel()
y_score = y_score.ravel()
elif average == "macro":
pass

# extend the classes dimension if needed
if y_true.ndim == 1:
y_true = y_true[:, jnp.newaxis]
if y_score.ndim == 1:
y_score = y_score[:, jnp.newaxis]

# compute score for each class
n_classes = y_score.shape[1]
score = jnp.zeros((n_classes,))
for c in range(n_classes):
binary_ap = binary_average_precision(
y_true[:, c].ravel(), y_score[:, c].ravel(), pos_label=pos_label
)
score = score.at[c].set(binary_ap)

# average the scores
return jnp.average(score) if average else score
80 changes: 77 additions & 3 deletions sml/metrics/classification/classification_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
import jax.numpy as jnp
import numpy as np
from sklearn import metrics
from sklearn.metrics import average_precision_score as sk_average_precision_score

# add ops dir to the path
sys.path.append(os.path.join(os.path.dirname(__file__), '../../'))

import sml.utils.emulation as emulation
from sml.metrics.classification.classification import (
accuracy_score,
average_precision_score,
f1_score,
precision_score,
recall_score,
Expand All @@ -42,7 +44,7 @@ def emul_auc(mode: emulation.Mode.MULTIPROCESS):

# Run
result = emulator.run(roc_auc_score)(
y_true, y_pred
*emulator.seal(y_true, y_pred)
) # X, y should be two-dimension array
print(result)

Expand Down Expand Up @@ -97,7 +99,7 @@ def check(spu_result, sk_result):
y_true = jnp.array([0, 1, 1, 0, 1, 1])
y_pred = jnp.array([0, 0, 1, 0, 1, 1])
spu_result = emulator.run(proc, static_argnums=(2, 5))(
y_true, y_pred, 'binary', None, 1, False
*emulator.seal(y_true, y_pred), 'binary', None, 1, False
)
sk_result = sklearn_proc(y_true, y_pred)
check(spu_result, sk_result)
Expand All @@ -106,12 +108,83 @@ def check(spu_result, sk_result):
y_true = jnp.array([0, 1, 1, 0, 2, 1])
y_pred = jnp.array([0, 0, 1, 0, 2, 1])
spu_result = emulator.run(proc, static_argnums=(2, 5))(
y_true, y_pred, None, [0, 1, 2], 1, True
*emulator.seal(y_true, y_pred), None, [0, 1, 2], 1, True
)
sk_result = sklearn_proc(y_true, y_pred, average=None, labels=[0, 1, 2])
check(spu_result, sk_result)


def emul_average_precision_score(mode: emulation.Mode.MULTIPROCESS):
def procBinary(y_true, y_score, **kwargs):
sk_res = sk_average_precision_score(y_true, y_score, **kwargs)
spu_res = emulator.run(average_precision_score)(
*emulator.seal(y_true, y_score), **kwargs
)
return sk_res, spu_res

def check(res1, res2):
return np.testing.assert_allclose(res1, res2, rtol=1e-3, atol=1e-3)

# --- Test binary classification ---
# 0-1 labels, no tied value
y_true = jnp.array([0, 0, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0.1, 0.4, 0.35, 0.8], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# 0-1 labels, with tied value, even length
y_true = jnp.array([0, 0, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.4, 0.4, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# 0-1 labels, with tied value, odd length
y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.4, 0.4, 0.25, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# customized labels
y_true = jnp.array([2, 2, 3, 3], dtype=jnp.int32)
y_score = jnp.array([0.1, 0.2, 0.3, 0.4], dtype=jnp.float32)
check(*procBinary(y_true, y_score, pos_label=3))
# larger random dataset
y_true = jnp.array(np.random.randint(0, 2, 100), dtype=jnp.int32)
y_score = jnp.array(np.hstack((0, 1, np.random.random(98))), dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# single label edge case
y_true = jnp.array([0, 0, 0, 0], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
y_true = jnp.array([1, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# zero score edge case
y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0, 0, 0, 0.25, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))
# score > 1 edge case
y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([1.5, 1.5, 1.5, 0.25, 0.25], dtype=jnp.float32)
check(*procBinary(y_true, y_score))

# --- Test multiclass classification ---
y_true = np.array([0, 0, 1, 1, 2, 2], dtype=jnp.int32)
y_score = np.array(
[
[0.7, 0.2, 0.1],
[0.4, 0.3, 0.3],
[0.1, 0.8, 0.1],
[0.2, 0.3, 0.5],
[0.4, 0.4, 0.2],
[0.1, 0.2, 0.7],
],
dtype=jnp.float32,
)
classes = jnp.unique(y_true)
# test over three supported average options
for average in ["macro", "micro", None]:
sk_res = sk_average_precision_score(y_true, y_score, average=average)
spu_res = emulator.run(average_precision_score, static_argnums=(3,))(
*emulator.seal(y_true, y_score), classes, average
)
check(sk_res, spu_res)


if __name__ == "__main__":
try:
# bandwidth and latency only work for docker mode
Expand All @@ -124,5 +197,6 @@ def check(spu_result, sk_result):
emulator.up()
emul_auc(emulation.Mode.MULTIPROCESS)
emul_Classification(emulation.Mode.MULTIPROCESS)
emul_average_precision_score(emulation.Mode.MULTIPROCESS)
finally:
emulator.down()
Loading
Loading