Skip to content

Commit b34e220

Browse files
committed
[WIP] ENH: dask+cupy, dask+sparse etc. namespaces
1 parent 52e01be commit b34e220

File tree

4 files changed

+81
-7
lines changed

4 files changed

+81
-7
lines changed

array_api_compat/common/_helpers.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,9 @@ def is_dask_namespace(xp: Namespace) -> bool:
397397
"""
398398
Returns True if `xp` is a Dask namespace.
399399
400-
This includes both ``dask.array`` itself and the version wrapped by array-api-compat.
400+
This includes ``dask.array`` itself, the version wrapped by array-api-compat,
401+
and the bespoke namespaces generated by
402+
``array_api_compat.dask.array.wrap_namespace``.
401403
402404
See Also
403405
--------
@@ -411,7 +413,13 @@ def is_dask_namespace(xp: Namespace) -> bool:
411413
is_pydata_sparse_namespace
412414
is_array_api_strict_namespace
413415
"""
414-
return xp.__name__ in {"dask.array", _compat_module_name() + ".dask.array"}
416+
da_compat_name = _compat_module_name() + '.dask.array'
417+
name = xp.__name__
418+
return (
419+
name in {'dask.array', da_compat_name}
420+
or name.startswith(da_compat_name + '.')
421+
and name[len(da_compat_name) + 1:] not in ("linalg", "fft")
422+
)
415423

416424

417425
def is_jax_namespace(xp: Namespace) -> bool:
@@ -597,9 +605,16 @@ def your_function(x, y):
597605
elif is_dask_array(x):
598606
if _use_compat:
599607
_check_api_version(api_version)
600-
from ..dask import array as dask_namespace
601-
602-
namespaces.add(dask_namespace)
608+
from ..dask.array import wrap_namespace
609+
610+
# The meta-namespace is only used to generate the meta-array, so it
611+
# would be useless to create a namespace such as e.g.
612+
# array_api_compat.dask.array.array_api_compat.cupy.
613+
# It would get worse once you vendor array-api-compat!
614+
# So keep it clean with array_api_compat.dask.array.cupy.
615+
mxp = array_namespace(x._meta, use_compat=False)
616+
xp = wrap_namespace(mxp)
617+
namespaces.add(xp)
603618
else:
604619
import dask.array as da
605620

array_api_compat/dask/array/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
# These imports may overwrite names from the import * above.
66
from ._aliases import * # noqa: F403
7+
from ._meta import wrap_namespace # noqa: F401
78

89
__array_api_version__: Final = "2024.12"
910

array_api_compat/dask/array/_aliases.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ def asarray(
152152
dtype: DType | None = None,
153153
device: Device | None = None,
154154
copy: py_bool | None = None,
155+
like: Array | None = None,
155156
**kwargs: object,
156157
) -> Array:
157158
"""
@@ -168,7 +169,11 @@ def asarray(
168169
if copy is False:
169170
raise ValueError("Unable to avoid copy when changing dtype")
170171
obj = obj.astype(dtype)
171-
return obj.copy() if copy else obj # pyright: ignore[reportAttributeAccessIssue]
172+
if copy:
173+
obj = obj.copy()
174+
if like is not None:
175+
obj = da.asarray(obj, like=like)
176+
return obj
172177

173178
if copy is False:
174179
raise NotImplementedError(
@@ -177,7 +182,11 @@ def asarray(
177182

178183
# copy=None to be uniform across dask < 2024.12 and >= 2024.12
179184
# see https://github.com/dask/dask/pull/11524/
180-
obj = np.array(obj, dtype=dtype, copy=True)
185+
if like is not None:
186+
mxp = array_namespace(like)
187+
obj = mxp.asarray(obj, dtype=dtype, copy=True)
188+
else:
189+
obj = np.array(obj, dtype=dtype, copy=True)
181190
return da.from_array(obj)
182191

183192

array_api_compat/dask/array/_meta.py

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import functools
2+
import sys
3+
import types
4+
5+
from ...common._helpers import is_numpy_namespace
6+
from ...common._typing import Namespace
7+
8+
__all__ = ['wrap_namespace']
9+
10+
11+
def wrap_namespace(xp: Namespace) -> Namespace:
12+
"""Create a bespoke Dask namespace that wraps around another namespace.
13+
14+
Parameters
15+
----------
16+
xp : namespace
17+
Namespace to be wrapped by Dask
18+
19+
Returns
20+
-------
21+
namespace :
22+
A module object that duplicates array_api_compat.dask.array, with the
23+
difference that all creation functions will create an array with the same
24+
meta namespace as the input.
25+
"""
26+
from .. import array as da_compat
27+
28+
if is_numpy_namespace(xp):
29+
return da_compat
30+
31+
mod_name = f'{da_compat.__name__}.{xp.__name__}'
32+
try:
33+
return sys.modules[mod_name]
34+
except KeyError:
35+
pass
36+
37+
mod = types.ModuleType(mod_name)
38+
sys.modules[mod_name] = mod
39+
40+
meta = xp.empty(())
41+
for name, v in da_compat.__dict__.items():
42+
if name.startswith('_'):
43+
continue
44+
if name in {'arange', 'asarray', 'empty', 'eye', 'from_dlpack',
45+
'full', 'linspace', 'ones', 'zeros'}:
46+
v = functools.wraps(v)(functools.partial(v, like=meta))
47+
setattr(mod, name, v)
48+
49+
return mod

0 commit comments

Comments
 (0)