diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index d7e78fd..0a15e89 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -108,6 +108,13 @@ def asarray( return cp.array(obj, dtype=dtype, **kwargs) +def sign(x: ndarray, /) -> ndarray: + # CuPy sign() does not propagate nans. See + # https://github.com/data-apis/array-api-compat/issues/136 + out = cp.sign(x) + out[cp.isnan(x)] = cp.nan + return out + # These functions are completely new here. If the library already has them # (i.e., numpy 2.0), use the library version instead of our wrapper. if hasattr(cp, 'vecdot'): @@ -122,6 +129,6 @@ def asarray( __all__ = _aliases.__all__ + ['asarray', 'bool', 'acos', 'acosh', 'asin', 'asinh', 'atan', 'atan2', 'atanh', 'bitwise_left_shift', 'bitwise_invert', - 'bitwise_right_shift', 'concat', 'pow'] + 'bitwise_right_shift', 'concat', 'pow', 'sign'] _all_ignore = ['cp', 'get_xp'] diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index a6d642e..899d94f 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -706,6 +706,21 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - axis = 0 return torch.index_select(x, axis, indices, **kwargs) +def sign(x: array, /) -> array: + # torch sign() does not support complex numbers and does not propagate + # nans. See https://github.com/data-apis/array-api-compat/issues/136 + if x.dtype.is_complex: + out = x/torch.abs(x) + # sign(0) = 0 but the above formula would give nan + out[x == 0+0j] = 0+0j + return out + else: + out = torch.sign(x) + if x.dtype.is_floating_point: + out[torch.isnan(x)] = torch.nan + return out + + __all__ = ['result_type', 'can_cast', 'permute_dims', 'bitwise_invert', 'newaxis', 'conj', 'add', 'atan2', 'bitwise_and', 'bitwise_left_shift', 'bitwise_or', 'bitwise_right_shift', @@ -719,6 +734,6 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None, **kwargs) - 'broadcast_arrays', 'UniqueAllResult', 'UniqueCountsResult', 'UniqueInverseResult', 'unique_all', 'unique_counts', 'unique_inverse', 'unique_values', 'matmul', 'matrix_transpose', - 'vecdot', 'tensordot', 'isdtype', 'take'] + 'vecdot', 'tensordot', 'isdtype', 'take', 'sign'] _all_ignore = ['torch', 'get_xp'] diff --git a/cupy-xfails.txt b/cupy-xfails.txt index a1878e9..9c865fa 100644 --- a/cupy-xfails.txt +++ b/cupy-xfails.txt @@ -160,7 +160,6 @@ array_api_tests/test_special_cases.py::test_unary[expm1(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[floor(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[log1p(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[round(x_i is -0) -> -0] -array_api_tests/test_special_cases.py::test_unary[sign(x_i is NaN) -> NaN] array_api_tests/test_special_cases.py::test_unary[sin(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[sinh(x_i is -0) -> -0] array_api_tests/test_special_cases.py::test_unary[sqrt(x_i is -0) -> -0] diff --git a/torch-xfails.txt b/torch-xfails.txt index 577f464..aedbc4a 100644 --- a/torch-xfails.txt +++ b/torch-xfails.txt @@ -169,7 +169,6 @@ array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity array_api_tests/test_special_cases.py::test_iop[__ifloordiv__(x1_i is -infinity and isfinite(x2_i) and x2_i > 0) -> -infinity] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is +0 and x2_i < 0) -> -0] array_api_tests/test_special_cases.py::test_iop[__imod__(x1_i is -0 and x2_i > 0) -> +0] -array_api_tests/test_special_cases.py::test_unary[sign(x_i is NaN) -> NaN] # Float correction is not supported by pytorch # (https://github.com/data-apis/array-api-tests/issues/168) @@ -186,7 +185,6 @@ array_api_tests/test_statistical_functions.py::test_sum array_api_tests/test_statistical_functions.py::test_prod # These functions do not yet support complex numbers -array_api_tests/test_operators_and_elementwise_functions.py::test_sign array_api_tests/test_operators_and_elementwise_functions.py::test_expm1 array_api_tests/test_operators_and_elementwise_functions.py::test_round array_api_tests/test_set_functions.py::test_unique_counts