Skip to content

Commit 851c0d6

Browse files
EdAbatiOmarManzoor
authored andcommitted
FIX: accuracy and zero_loss support for multilabel with Array API (#29336)
Co-authored-by: Omar Salman <omar.salman2007@gmail.com> Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent 99d8a32 commit 851c0d6

File tree

4 files changed

+72
-7
lines changed

4 files changed

+72
-7
lines changed

doc/whats_new/v1.5.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ Changelog
4848
instead of implicitly converting those inputs as regular NumPy arrays.
4949
:pr:`29119` by :user:`Olivier Grisel`.
5050

51-
- |Fix| Fix a regression in :func:`metrics.zero_one_loss` causing an error
52-
for Array API dispatch with multilabel inputs.
53-
:pr:`29269` by :user:`Yaroslav Korobko <Tialo>`.
51+
- |Fix| Fix a regression in :func:`metrics.accuracy_score` and in :func:`metrics.zero_one_loss`
52+
causing an error for Array API dispatch with multilabel inputs.
53+
:pr:`29269` by :user:`Yaroslav Korobko <Tialo>` and :pr:`29336` by :user:`Edoardo Abati <EdAbati>`.
5454

5555
:mod:`sklearn.model_selection`
5656
..............................

sklearn/metrics/_classification.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
)
4141
from ..utils._array_api import (
4242
_average,
43+
_count_nonzero,
44+
_is_numpy_namespace,
4345
_union1d,
4446
get_namespace,
4547
get_namespace_and_device,
@@ -97,6 +99,7 @@ def _check_targets(y_true, y_pred):
9799
98100
y_pred : array or indicator matrix
99101
"""
102+
xp, _ = get_namespace(y_true, y_pred)
100103
check_consistent_length(y_true, y_pred)
101104
type_true = type_of_target(y_true, input_name="y_true")
102105
type_pred = type_of_target(y_pred, input_name="y_pred")
@@ -142,8 +145,13 @@ def _check_targets(y_true, y_pred):
142145
y_type = "multiclass"
143146

144147
if y_type.startswith("multilabel"):
145-
y_true = csr_matrix(y_true)
146-
y_pred = csr_matrix(y_pred)
148+
if _is_numpy_namespace(xp):
149+
# XXX: do we really want to sparse-encode multilabel indicators when
150+
# they are passed as a dense arrays? This is not possible for array
151+
# API inputs in general hence we only do it for NumPy inputs. But even
152+
# for NumPy the usefulness is questionable.
153+
y_true = csr_matrix(y_true)
154+
y_pred = csr_matrix(y_pred)
147155
y_type = "multilabel-indicator"
148156

149157
return y_type, y_true, y_pred
@@ -223,7 +231,12 @@ def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
223231
y_type, y_true, y_pred = _check_targets(y_true, y_pred)
224232
check_consistent_length(y_true, y_pred, sample_weight)
225233
if y_type.startswith("multilabel"):
226-
differing_labels = count_nonzero(y_true - y_pred, axis=1)
234+
if _is_numpy_namespace(xp):
235+
differing_labels = count_nonzero(y_true - y_pred, axis=1)
236+
else:
237+
differing_labels = _count_nonzero(
238+
y_true - y_pred, xp=xp, device=device, axis=1
239+
)
227240
score = xp.asarray(differing_labels == 0, device=device)
228241
else:
229242
score = y_true == y_pred

sklearn/utils/_array_api.py

+17
Original file line numberDiff line numberDiff line change
@@ -841,3 +841,20 @@ def indexing_dtype(xp):
841841
# TODO: once sufficiently adopted, we might want to instead rely on the
842842
# newer inspection API: https://github.com/data-apis/array-api/issues/640
843843
return xp.asarray(0).dtype
844+
845+
846+
def _count_nonzero(X, xp, device, axis=None, sample_weight=None):
847+
"""A variant of `sklearn.utils.sparsefuncs.count_nonzero` for the Array API.
848+
849+
It only supports 2D arrays.
850+
"""
851+
assert X.ndim == 2
852+
853+
weights = xp.ones_like(X, device=device)
854+
if sample_weight is not None:
855+
sample_weight = xp.asarray(sample_weight, device=device)
856+
sample_weight = xp.reshape(sample_weight, (sample_weight.shape[0], 1))
857+
weights = xp.astype(weights, sample_weight.dtype) * sample_weight
858+
859+
zero_scalar = xp.asarray(0, device=device, dtype=weights.dtype)
860+
return xp.sum(xp.where(X != 0, weights, zero_scalar), axis=axis)

sklearn/utils/tests/test_array_api.py

+36-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
_atol_for_type,
1414
_average,
1515
_convert_to_numpy,
16+
_count_nonzero,
1617
_estimator_with_converted_arrays,
1718
_is_numpy_namespace,
1819
_nanmax,
@@ -30,7 +31,7 @@
3031
_array_api_for_tests,
3132
skip_if_array_api_compat_not_configured,
3233
)
33-
from sklearn.utils.fixes import _IS_32BIT
34+
from sklearn.utils.fixes import _IS_32BIT, CSR_CONTAINERS
3435

3536

3637
@pytest.mark.parametrize("X", [numpy.asarray([1, 2, 3]), [1, 2, 3]])
@@ -530,3 +531,37 @@ def test_get_namespace_and_device():
530531
assert namespace is xp_torch
531532
assert is_array_api
532533
assert device == some_torch_tensor.device
534+
535+
536+
@pytest.mark.parametrize(
537+
"array_namespace, device, dtype_name", yield_namespace_device_dtype_combinations()
538+
)
539+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
540+
@pytest.mark.parametrize("axis", [0, 1, None, -1, -2])
541+
@pytest.mark.parametrize("sample_weight_type", [None, "int", "float"])
542+
def test_count_nonzero(
543+
array_namespace, device, dtype_name, csr_container, axis, sample_weight_type
544+
):
545+
546+
from sklearn.utils.sparsefuncs import count_nonzero as sparse_count_nonzero
547+
548+
xp = _array_api_for_tests(array_namespace, device)
549+
array = numpy.array([[0, 3, 0], [2, -1, 0], [0, 0, 0], [9, 8, 7], [4, 0, 5]])
550+
if sample_weight_type == "int":
551+
sample_weight = numpy.asarray([1, 2, 2, 3, 1])
552+
elif sample_weight_type == "float":
553+
sample_weight = numpy.asarray([0.5, 1.5, 0.8, 3.2, 2.4], dtype=dtype_name)
554+
else:
555+
sample_weight = None
556+
expected = sparse_count_nonzero(
557+
csr_container(array), axis=axis, sample_weight=sample_weight
558+
)
559+
array_xp = xp.asarray(array, device=device)
560+
561+
with config_context(array_api_dispatch=True):
562+
result = _count_nonzero(
563+
array_xp, xp=xp, device=device, axis=axis, sample_weight=sample_weight
564+
)
565+
566+
assert_allclose(_convert_to_numpy(result, xp=xp), expected)
567+
assert getattr(array_xp, "device", None) == getattr(result, "device", None)

0 commit comments

Comments
 (0)