Skip to content

Commit

Permalink
[OSCP] 使用 SPU 实现 AP 函数 - 调整实现逻辑并使用更严格的测试
Browse files Browse the repository at this point in the history
  • Loading branch information
z0gSh1u committed Aug 16, 2024
1 parent c1f8e9a commit 8ce3dba
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 28 deletions.
13 changes: 3 additions & 10 deletions sml/metrics/classification/auc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,15 @@
from spu.ops.groupby import groupby_sorted


def binary_clf_curve(sorted_pairs: jnp.ndarray, return_seg_end_marks=False) -> Union[
Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
]:
def binary_clf_curve(
sorted_pairs: jnp.ndarray, return_seg_end_marks=False
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Calculate true and false positives per binary classification
threshold (can be used for roc curve or precision/recall curve).
Results may include trailing zeros.
Args:
sorted_pairs: jnp.ndarray
y_true y_score pairs sorted by y_score in decreasing order
return_seg_end_marks: bool
If true, the seg_end_marks array will be returned at the end
Returns:
fps: 1d ndarray
False positives counts, index i records the number
Expand All @@ -47,8 +44,6 @@ def binary_clf_curve(sorted_pairs: jnp.ndarray, return_seg_end_marks=False) -> U
tps[-1] (thus false negatives are given by tps[-1] - tps)
thresholds : 1d ndarray
predicted score sorted in decreasing order
seg_end_marks: 1d ndarray
marking the end of segment in result arrays
References:
Github: scikit-learn _binary_clf_curve.
"""
Expand All @@ -65,8 +60,6 @@ def binary_clf_curve(sorted_pairs: jnp.ndarray, return_seg_end_marks=False) -> U
thresholds = seg_end_marks * thresholds
thresholds, fps, tps = jax.lax.sort([-thresholds] + [fps, tps], num_keys=1)

if return_seg_end_marks:
return fps, tps, -thresholds, seg_end_marks
return fps, tps, -thresholds


Expand Down
36 changes: 22 additions & 14 deletions sml/metrics/classification/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import Tuple

import jax
Expand Down Expand Up @@ -226,7 +227,9 @@ def fun_score(
return fun_result


def precision_recall_curve(y_true: jnp.ndarray, y_score: jnp.ndarray, pos_label=1):
def precision_recall_curve(
y_true: jnp.ndarray, y_score: jnp.ndarray, pos_label=1, score_eps=1e-3
):
"""Compute precision-recall pairs for different probability thresholds.
Note: this implementation is restricted to the binary classification task.
Expand All @@ -235,10 +238,12 @@ def precision_recall_curve(y_true: jnp.ndarray, y_score: jnp.ndarray, pos_label=
----------
y_true : 1d array-like of shape (n,). True binary labels.
y_score : 1d array-like of shape (n,). Target scores.
y_score : 1d array-like of shape (n,). Target scores, non-negative.
pos_label : int, default=1. The label of the positive class.
score_eps : float, default=1e-3. The lower bound for y_score.
Returns
-------
precisions : ndarray of shape (n + 1,).
Expand All @@ -253,23 +258,23 @@ def precision_recall_curve(y_true: jnp.ndarray, y_score: jnp.ndarray, pos_label=
Decreasing thresholds used to compute precision and recall.
Results might include trailing zeros.
"""
# normalize the labels

# normalize the input
y_true = jnp.where(y_true == pos_label, 1, 0)
y_score = jnp.where(y_score < score_eps, score_eps, y_score) # to avoid wrong mask extraction

# compute TP and FP
pairs = jnp.stack([y_true, y_score], axis=1)
sorted_pairs = pairs[jnp.argsort(pairs[:, 1], descending=True, stable=True)]
fp, tp, thresholds, marks = binary_clf_curve(
sorted_pairs, return_seg_end_marks=True
)
fp, tp, thresholds = binary_clf_curve(sorted_pairs)

# compute precision and recalls
precisions = tp / (tp + fp + 1e-10)
# determine the last index where from that on holds trailing zeros in TP because of tied values
last_index = jnp.max(
jnp.where(marks == 0, size=len(marks), fill_value=-1)[0]
) # jnp.argwhere(marks == 0)[-1]
recalls = jnp.where(tp[last_index] == 0, jnp.ones_like(tp), tp / tp[last_index])
mask = jnp.where(thresholds > 0, 1, 0) # tied value entries have mask=0
last = jnp.max(
jnp.where(mask, size=len(mask), fill_value=-1)[0]
) # equivalent of jnp.argwhere(mask)[-1], last index before tied value section
precisions = jnp.where(mask, tp / (tp + fp + 1e-5), 0)
recalls = jnp.where(tp[last] == 0, jnp.ones_like(tp), tp / tp[last])

return (
jnp.hstack((1, precisions)),
Expand All @@ -284,6 +289,7 @@ def average_precision_score(
classes=(0, 1),
average="macro",
pos_label=1,
score_eps=1e-3,
):
"""Compute average precision (AP) from prediction scores.
Expand All @@ -296,7 +302,7 @@ def average_precision_score(
True labels.
y_score : array-like of shape (n_samples,) or (n_samples, n_classes)
Estimated target scores as returned by a classifier.
Estimated target scores as returned by a classifier, non-negative.
classes : 1d array-like, shape (n_classes,), default=(0,1) as for binary classification
Uniquely holds the label for each class.
Expand All @@ -317,6 +323,8 @@ def average_precision_score(
pos_label : int, default=1
The label of the positive class. Only applied to binary y_true.
score_eps : float, default=1e-3. The lower bound for y_score.
Returns
-------
average_precision : float
Expand All @@ -332,7 +340,7 @@ def average_precision_score(
def binary_average_precision(y_true, y_score, pos_label=1):
"""Compute the average precision for binary classification."""
precisions, recalls, _ = precision_recall_curve(
y_true, y_score, pos_label=pos_label
y_true, y_score, pos_label=pos_label, score_eps=score_eps
)

return jnp.sum(jnp.diff(recalls) * precisions[1:])
Expand Down
12 changes: 10 additions & 2 deletions sml/metrics/classification/classification_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def proc(y_true, y_score, **kwargs):
return sk_res, spu_res

def check(res1, res2):
return np.testing.assert_allclose(res1, res2, rtol=1, atol=1e-5)
return np.testing.assert_allclose(res1, res2, rtol=1e-3, atol=1e-3)

# --- Test binary classification ---
# 0-1 labels, no tied value
Expand All @@ -144,13 +144,21 @@ def check(res1, res2):
y_true = jnp.array(np.random.randint(0, 2, 100), dtype=jnp.int32)
y_score = jnp.array(np.hstack((0, 1, np.random.random(98))), dtype=jnp.float32)
check(*proc(y_true, y_score))
# single label edge case, with tied value
# single label edge case
y_true = jnp.array([0, 0, 0, 0], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32)
check(*proc(y_true, y_score))
y_true = jnp.array([1, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32)
check(*proc(y_true, y_score))
# zero score edge case
y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0, 0, 0, 0.25, 0.25], dtype=jnp.float32)
check(*proc(y_true, y_score))
# score > 1 edge case
y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([1.5, 1.5, 1.5, 0.25, 0.25], dtype=jnp.float32)
check(*proc(y_true, y_score))

# --- Test multiclass classification ---
y_true = np.array([0, 0, 1, 1, 2, 2], dtype=jnp.int32)
Expand Down
12 changes: 10 additions & 2 deletions sml/metrics/classification/classification_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def proc(y_true, y_score, **kwargs):
return sk_res, spu_res

def check(res1, res2):
return np.testing.assert_allclose(res1, res2, rtol=1, atol=1e-5)
return np.testing.assert_allclose(res1, res2, rtol=1e-3, atol=1e-3)

# --- Test binary classification ---
# 0-1 labels, no tied value
Expand All @@ -193,13 +193,21 @@ def check(res1, res2):
y_true = jnp.array(np.random.randint(0, 2, 100), dtype=jnp.int32)
y_score = jnp.array(np.hstack((0, 1, np.random.random(98))), dtype=jnp.float32)
check(*proc(y_true, y_score))
# single label edge case, with tied value
# single label edge case
y_true = jnp.array([0, 0, 0, 0], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32)
check(*proc(y_true, y_score))
y_true = jnp.array([1, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0.4, 0.25, 0.4, 0.25], dtype=jnp.float32)
check(*proc(y_true, y_score))
# zero score edge case
y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([0, 0, 0, 0.25, 0.25], dtype=jnp.float32)
check(*proc(y_true, y_score))
# score > 1 edge case
y_true = jnp.array([0, 0, 1, 1, 1], dtype=jnp.int32)
y_score = jnp.array([1.5, 1.5, 1.5, 0.25, 0.25], dtype=jnp.float32)
check(*proc(y_true, y_score))

# --- Test multiclass classification ---
y_true = np.array([0, 0, 1, 1, 2, 2], dtype=jnp.int32)
Expand Down

0 comments on commit 8ce3dba

Please sign in to comment.