Skip to content

Commit

Permalink
Merge pull request #2519 from daxiongshu/fea-precision-recall-curve-cupy
Browse files Browse the repository at this point in the history
[Review] Precision recall curve using cupy
  • Loading branch information
dantegd authored Jul 23, 2020
2 parents 88b3fa3 + b1d4f81 commit b6ca7a9
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 47 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
- PR #2394: Adding cosine & correlation distance for KNN
- PR #2392: PCA can accept sparse inputs, and sparse prim for computing covariance
- PR #2465: Support pandas 1.0+
- PR #2519: Precision recall curve using cupy
- PR #2500: Replace UMAP functionality dependency on nvgraph with RAFT Spectral Clustering
- PR #2520: TfidfVectorizer estimator
- PR #2461: Add KNN Sparse Output Functionality
Expand Down
1 change: 1 addition & 0 deletions python/cuml/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cuml.metrics.accuracy import accuracy_score
from cuml.metrics.cluster.adjustedrandindex import adjusted_rand_score
from cuml.metrics._ranking import roc_auc_score
from cuml.metrics._ranking import precision_recall_curve
from cuml.metrics._classification import log_loss
from cuml.metrics.cluster.homogeneity_score import homogeneity_score
from cuml.metrics.cluster.completeness_score import completeness_score
Expand Down
203 changes: 157 additions & 46 deletions python/cuml/metrics/_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,106 @@
import math


@with_cupy_rmm
def precision_recall_curve(y_true, probs_pred):
"""
Compute precision-recall pairs for different probability thresholds
Note: this implementation is restricted to the binary classification task.
The precision is the ratio ``tp / (tp + fp)`` where ``tp`` is the number of
true positives and ``fp`` the number of false positives. The precision is
intuitively the ability of the classifier not to label as positive a sample
that is negative.
The recall is the ratio ``tp / (tp + fn)`` where ``tp`` is the number of
true positives and ``fn`` the number of false negatives. The recall is
intuitively the ability of the classifier to find all the positive samples.
The last precision and recall values are 1. and 0. respectively and do not
have a corresponding threshold. This ensures that the graph starts on the
y axis.
Read more in the :ref:`User Guide <precision_recall_f_measure_metrics>`.
Parameters
----------
y_true : array, shape = [n_samples]
True binary labels, {0, 1}.
probas_pred : array, shape = [n_samples]
Estimated probabilities or decision function.
Returns
-------
precision : array, shape = [n_thresholds + 1]
Precision values such that element i is the precision of
predictions with score >= thresholds[i] and the last element is 1.
recall : array, shape = [n_thresholds + 1]
Decreasing recall values such that element i is the recall of
predictions with score >= thresholds[i] and the last element is 0.
thresholds : array, shape = [n_thresholds <= len(np.unique(probas_pred))]
Increasing thresholds on the decision function used to compute
precision and recall.
Examples
--------
.. code-block:: python
import numpy as np
from cuml.metrics import precision_recall_curve
y_true = np.array([0, 0, 1, 1])
y_scores = np.array([0.1, 0.4, 0.35, 0.8])
precision, recall, thresholds = precision_recall_curve(
y_true, y_scores)
print(precision)
print(recall)
print(thresholds)
Output:
.. code-block:: python
array([0.66666667, 0.5 , 1. , 1. ])
array([1. , 0.5, 0.5, 0. ])
array([0.35, 0.4 , 0.8 ])
"""
y_true, n_rows, n_cols, ytype = \
input_to_cuml_array(y_true, check_dtype=[np.int32, np.int64,
np.float32, np.float64])

y_score, _, _, _ = \
input_to_cuml_array(probs_pred, check_dtype=[np.int32, np.int64,
np.float32, np.float64],
check_rows=n_rows, check_cols=n_cols)

y_true = y_true.to_output('cupy')
y_score = y_score.to_output('cupy')

if cp.any(y_true) == 0:
raise ValueError("precision_recall_curve cannot be used when "
"y_true is all zero.")

fps, tps, thresholds = _binary_clf_curve(y_true, y_score)
precision = cp.flip(tps/(tps+fps), axis=0)
recall = cp.flip(tps/tps[-1], axis=0)
n = (recall == 1).sum()

if n > 1:
precision = precision[n-1:]
recall = recall[n-1:]
thresholds = thresholds[n-1:]
precision = cp.concatenate([precision, cp.ones(1)])
recall = cp.concatenate([recall, cp.zeros(1)])

return precision, recall, thresholds


@with_cupy_rmm
def roc_auc_score(y_true, y_score):
"""
Compute Area Under the Receiver Operating Characteristic Curve
(ROC AUC) from prediction scores. Note -- this implementation can
only be used with binary classification.
Compute Area Under the Receiver Operating Characteristic Curve (ROC AUC)
from prediction scores.
Note: this implementation can only be used with binary classification.
Parameters
----------
Expand All @@ -46,15 +140,19 @@ def roc_auc_score(y_true, y_score):
Examples
--------
>>> import numpy as np
>>> from cuml.metrics import roc_auc_score
>>> y_true = np.array([0, 0, 1, 1])
>>> y_scores = np.array([0.1, 0.4, 0.35, 0.8])
>>> roc_auc_score(y_true, y_scores)
0.75
.. code-block:: python
"""
import numpy as np
from cuml.metrics import roc_auc_score
y_true = np.array([0, 0, 1, 1])
y_scores = np.array([0.1, 0.4, 0.35, 0.8])
print(roc_auc_score(y_true, y_scores))
Output:
.. code-block:: python
0.75
"""
y_true, n_rows, n_cols, ytype = \
input_to_cuml_array(y_true, check_dtype=[np.int32, np.int64,
np.float32, np.float64])
Expand All @@ -63,70 +161,83 @@ def roc_auc_score(y_true, y_score):
input_to_cuml_array(y_score, check_dtype=[np.int32, np.int64,
np.float32, np.float64],
check_rows=n_rows, check_cols=n_cols)

return _binary_roc_auc_score(y_true, y_score)


def _binary_clf_curve(y_true, y_score):

if y_true.dtype.kind == 'f' and np.any(y_true != y_true.astype(int)):
raise ValueError("Continuous format of y_true "
"is not supported.")

ids = cp.argsort(-y_score)
sorted_score = y_score[ids]

ones = y_true[ids].astype('float32') # for calculating true positives
zeros = 1 - ones # for calculating predicted positives

# calculate groups
group = _group_same_scores(sorted_score)
num = int(group[-1])

tps = cp.zeros(num, dtype='float32')
fps = cp.zeros(num, dtype='float32')

tps = _addup_x_in_group(group, ones, tps)
fps = _addup_x_in_group(group, zeros, fps)

tps = cp.cumsum(tps)
fps = cp.cumsum(fps)
thresholds = cp.unique(y_score)
return fps, tps, thresholds


def _binary_roc_auc_score(y_true, y_score):
"""Compute binary roc_auc_score using cupy"""
y_true = y_true.to_output()
y_score = y_score.to_output()
y_true = y_true.to_output('cupy')
y_score = y_score.to_output('cupy')

if cp.unique(y_true).shape[0] == 1:
raise ValueError("roc_auc_score cannot be used when "
"only one class present in y_true. ROC AUC score "
"is not defined in that case.")

if y_true.dtype.kind == 'f' and np.any(y_true != y_true.astype(int)):
raise ValueError("Continuous format of y_true "
"is not supported by roc_auc_score")

if cp.unique(y_score).shape[0] == 1:
return 0.5

y_true = y_true.astype('float32')
ids = cp.argsort(-y_score) # we want descedning order

sorted_score = y_score[ids]
ones = y_true[ids]
zeros = 1 - ones

mask = cp.empty(sorted_score.shape, dtype=cp.bool_)
mask[0] = True
mask[1:] = sorted_score[1:] != sorted_score[:-1]
group = cp.cumsum(mask, dtype=cp.int32)

sum_ones = cp.sum(ones)
sum_zeros = cp.sum(zeros)
fps, tps, thresholds = _binary_clf_curve(y_true, y_score)
tpr = tps/tps[-1]
fpr = fps/fps[-1]

num = int(group[-1])
return _calculate_area_under_curve(fpr, tpr).item()

tps = cp.zeros(num, dtype='float32') # true positives
fps = cp.zeros(num, dtype='float32') # false positives

update_counter_kernel = cp.RawKernel(r'''
def _addup_x_in_group(group, x, result):
addup_x_in_group_kernel = cp.RawKernel(r'''
extern "C" __global__
void update_counter(const int* group, const float* truth,
float* counter, int N)
void addup_x_in_group(const int* group, const float* x,
float* result, int N)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if(tid<N){
atomicAdd(counter + group[tid] - 1, truth[tid]);
atomicAdd(result + group[tid] - 1, x[tid]);
}
}
''', 'update_counter')
''', 'addup_x_in_group')

N = ones.shape[0]
N = x.shape[0]
tpb = 256
bpg = math.ceil(N/tpb)
update_counter_kernel((bpg,), (tpb,),
(group, ones, tps, N)) # grid, block and arguments
update_counter_kernel((bpg,), (tpb,), (group, zeros, fps, N))
addup_x_in_group_kernel((bpg,), (tpb,), (group, x, result, N))
return result

tpr = cp.cumsum(tps)/sum_ones
fpr = cp.cumsum(fps)/sum_zeros

return _calculate_area_under_curve(fpr, tpr)
def _group_same_scores(sorted_score):
mask = cp.empty(sorted_score.shape, dtype=cp.bool_)
mask[0] = True
mask[1:] = sorted_score[1:] != sorted_score[:-1]
group = cp.cumsum(mask, dtype=cp.int32)
return group


def _calculate_area_under_curve(fpr, tpr):
Expand Down
62 changes: 61 additions & 1 deletion python/cuml/test/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@
from cuml.common import has_scipy

from cuml.metrics import roc_auc_score
from cuml.metrics import precision_recall_curve
from cuml.metrics import log_loss
from sklearn.metrics import roc_auc_score as sklearn_roc_auc_score
from sklearn.metrics import precision_recall_curve \
as sklearn_precision_recall_curve


@pytest.mark.parametrize('datatype', [np.float32, np.float64])
Expand Down Expand Up @@ -641,12 +644,69 @@ def test_roc_auc_score_at_limits():
y_pred = np.array([0., 0.5, 1.], dtype=np.float)

err_msg = ("Continuous format of y_true "
"is not supported by roc_auc_score")
"is not supported.")

with pytest.raises(ValueError, match=err_msg):
roc_auc_score(y_true, y_pred)


def test_precision_recall_curve():
y_true = np.array([0, 0, 1, 1])
y_score = np.array([0.1, 0.4, 0.35, 0.8])
precision_using_sk, recall_using_sk, thresholds_using_sk = \
sklearn_precision_recall_curve(
y_true, y_score)

precision, recall, thresholds = precision_recall_curve(
y_true, y_score)

assert array_equal(precision, precision_using_sk)
assert array_equal(recall, recall_using_sk)
assert array_equal(thresholds, thresholds_using_sk)


def test_precision_recall_curve_at_limits():
y_true = np.array([0., 0., 0.], dtype=np.float)
y_pred = np.array([0., 0.5, 1.], dtype=np.float)

err_msg = ("precision_recall_curve cannot be used when "
"y_true is all zero.")

with pytest.raises(ValueError, match=err_msg):
precision_recall_curve(y_true, y_pred)

y_true = np.array([0., 0.5, 1.0], dtype=np.float)
y_pred = np.array([0., 0.5, 1.], dtype=np.float)

err_msg = ("Continuous format of y_true "
"is not supported.")

with pytest.raises(ValueError, match=err_msg):
precision_recall_curve(y_true, y_pred)


@pytest.mark.parametrize('n_samples', [50, 500000])
@pytest.mark.parametrize('dtype', [np.int32, np.int64, np.float32, np.float64])
def test_precision_recall_curve_random(n_samples, dtype):

y_true, _, _, _ = generate_random_labels(
lambda rng: rng.randint(0, 2, n_samples).astype(dtype))

y_score, _, _, _ = generate_random_labels(
lambda rng: rng.randint(0, 1000, n_samples).astype(dtype))

precision_using_sk, recall_using_sk, thresholds_using_sk = \
sklearn_precision_recall_curve(
y_true, y_score)

precision, recall, thresholds = precision_recall_curve(
y_true, y_score)

assert array_equal(precision, precision_using_sk)
assert array_equal(recall, recall_using_sk)
assert array_equal(thresholds, thresholds_using_sk)


def test_log_loss():
y_true = np.array([0, 0, 1, 1])
y_pred = np.array([0.1, 0.4, 0.35, 0.8])
Expand Down

0 comments on commit b6ca7a9

Please sign in to comment.