From bd536d629af1a1fbce5dbaf516e410d1641c2e52 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Tue, 11 Jun 2024 16:53:58 -0600 Subject: [PATCH 1/2] Allow Python scalars in array_namespace They are just ignored. This makes array_namespace easier to use for functions that accept either arrays or scalars. I'm not sure if I should have this behavior by default, or if it should be enabled by a flag. --- array_api_compat/common/_helpers.py | 7 +++++-- tests/test_array_namespace.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 32fb0e70..e5b4133a 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -235,7 +235,8 @@ def array_namespace(*xs, api_version=None, use_compat=None): Parameters ---------- xs: arrays - one or more arrays. + one or more arrays. xs can also be Python scalars (bool, int, float, + or complex), which are ignored. api_version: str The newest version of the spec that you need support for (currently @@ -298,7 +299,9 @@ def your_function(x, y): namespaces = set() for x in xs: - if is_numpy_array(x): + if isinstance(x, (bool, int, float, complex)): + continue + elif is_numpy_array(x): from .. import numpy as numpy_namespace import numpy as np if use_compat is True: diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index 1f83a473..c2e09520 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -92,3 +92,17 @@ def test_api_version(): def test_get_namespace(): # Backwards compatible wrapper assert array_api_compat.get_namespace is array_api_compat.array_namespace + +def test_python_scalars(): + a = torch.asarray([1, 2]) + xp = import_("torch", wrapper=True) + + pytest.raises(TypeError, lambda: array_namespace(1)) + pytest.raises(TypeError, lambda: array_namespace(1.0)) + pytest.raises(TypeError, lambda: array_namespace(1j)) + pytest.raises(TypeError, lambda: array_namespace(True)) + + assert array_namespace(a, 1) == xp + assert array_namespace(a, 1.0) == xp + assert array_namespace(a, 1j) == xp + assert array_namespace(a, True) == xp From 0a2160b1bc97be0f92cd1829c3c5e6719eb938f3 Mon Sep 17 00:00:00 2001 From: Aaron Meurer Date: Wed, 26 Jun 2024 12:09:44 -0600 Subject: [PATCH 2/2] Also ignore None in array_namespace --- array_api_compat/common/_helpers.py | 4 ++-- tests/test_array_namespace.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index e5b4133a..bafe991a 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -236,7 +236,7 @@ def array_namespace(*xs, api_version=None, use_compat=None): ---------- xs: arrays one or more arrays. xs can also be Python scalars (bool, int, float, - or complex), which are ignored. + complex, or None), which are ignored. api_version: str The newest version of the spec that you need support for (currently @@ -299,7 +299,7 @@ def your_function(x, y): namespaces = set() for x in xs: - if isinstance(x, (bool, int, float, complex)): + if isinstance(x, (bool, int, float, complex, type(None))): continue elif is_numpy_array(x): from .. import numpy as numpy_namespace diff --git a/tests/test_array_namespace.py b/tests/test_array_namespace.py index c2e09520..af0ac244 100644 --- a/tests/test_array_namespace.py +++ b/tests/test_array_namespace.py @@ -101,8 +101,10 @@ def test_python_scalars(): pytest.raises(TypeError, lambda: array_namespace(1.0)) pytest.raises(TypeError, lambda: array_namespace(1j)) pytest.raises(TypeError, lambda: array_namespace(True)) + pytest.raises(TypeError, lambda: array_namespace(None)) assert array_namespace(a, 1) == xp assert array_namespace(a, 1.0) == xp assert array_namespace(a, 1j) == xp assert array_namespace(a, True) == xp + assert array_namespace(a, None) == xp