forked from data-apis/array-api-compat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_helpers.py
260 lines (220 loc) · 8.88 KB
/
_helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
"""
Various helper functions which are not part of the spec.
Functions which start with an underscore are for internal use only but helpers
that are in __all__ are intended as additional helper functions for use by end
users of the compat library.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Union, Any
from ._typing import Array, Device
import sys
import math
def _is_numpy_array(x):
# Avoid importing NumPy if it isn't already
if 'numpy' not in sys.modules:
return False
import numpy as np
# TODO: Should we reject ndarray subclasses?
return isinstance(x, (np.ndarray, np.generic))
def _is_cupy_array(x):
# Avoid importing NumPy if it isn't already
if 'cupy' not in sys.modules:
return False
import cupy as cp
# TODO: Should we reject ndarray subclasses?
return isinstance(x, (cp.ndarray, cp.generic))
def _is_torch_array(x):
# Avoid importing torch if it isn't already
if 'torch' not in sys.modules:
return False
import torch
# TODO: Should we reject ndarray subclasses?
return isinstance(x, torch.Tensor)
def _is_dask_array(x):
# Avoid importing dask if it isn't already
if 'dask.array' not in sys.modules:
return False
import dask.array
return isinstance(x, dask.array.Array)
def is_array_api_obj(x):
"""
Check if x is an array API compatible array object.
"""
return _is_numpy_array(x) \
or _is_cupy_array(x) \
or _is_torch_array(x) \
or _is_dask_array(x) \
or hasattr(x, '__array_namespace__')
def _check_api_version(api_version):
if api_version is not None and api_version != '2021.12':
raise ValueError("Only the 2021.12 version of the array API specification is currently supported")
def array_namespace(*xs, api_version=None, _use_compat=True):
"""
Get the array API compatible namespace for the arrays `xs`.
`xs` should contain one or more arrays.
Typical usage is
def your_function(x, y):
xp = array_api_compat.array_namespace(x, y)
# Now use xp as the array library namespace
return xp.mean(x, axis=0) + 2*xp.std(y, axis=0)
api_version should be the newest version of the spec that you need support
for (currently the compat library wrapped APIs only support v2021.12).
"""
namespaces = set()
for x in xs:
if _is_numpy_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import numpy as numpy_namespace
namespaces.add(numpy_namespace)
else:
import numpy as np
namespaces.add(np)
elif _is_cupy_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import cupy as cupy_namespace
namespaces.add(cupy_namespace)
else:
import cupy as cp
namespaces.add(cp)
elif _is_torch_array(x):
_check_api_version(api_version)
if _use_compat:
from .. import torch as torch_namespace
namespaces.add(torch_namespace)
else:
import torch
namespaces.add(torch)
elif _is_dask_array(x):
_check_api_version(api_version)
if _use_compat:
from ..dask import array as dask_namespace
namespaces.add(dask_namespace)
else:
raise TypeError("_use_compat cannot be False if input array is a dask array!")
elif hasattr(x, '__array_namespace__'):
namespaces.add(x.__array_namespace__(api_version=api_version))
else:
# TODO: Support Python scalars?
raise TypeError(f"{type(x).__name__} is not a supported array type")
if not namespaces:
raise TypeError("Unrecognized array input")
if len(namespaces) != 1:
raise TypeError(f"Multiple namespaces for array inputs: {namespaces}")
xp, = namespaces
return xp
# backwards compatibility alias
get_namespace = array_namespace
def _check_device(xp, device):
if xp == sys.modules.get('numpy'):
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")
# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray
# or cupy.ndarray. They are not included in array objects of this library
# because this library just reuses the respective ndarray classes without
# wrapping or subclassing them. These helper functions can be used instead of
# the wrapper functions for libraries that need to support both NumPy/CuPy and
# other libraries that use devices.
def device(x: Array, /) -> Device:
"""
Hardware device the array data resides on.
Parameters
----------
x: array
array instance from NumPy or an array API compatible library.
Returns
-------
out: device
a ``device`` object (see the "Device Support" section of the array API specification).
"""
if _is_numpy_array(x):
return "cpu"
return x.device
# Based on cupy.array_api.Array.to_device
def _cupy_to_device(x, device, /, stream=None):
import cupy as cp
from cupy.cuda import Device as _Device
from cupy.cuda import stream as stream_module
from cupy_backends.cuda.api import runtime
if device == x.device:
return x
elif device == "cpu":
# allowing us to use `to_device(x, "cpu")`
# is useful for portable test swapping between
# host and device backends
return x.get()
elif not isinstance(device, _Device):
raise ValueError(f"Unsupported device {device!r}")
else:
# see cupy/cupy#5985 for the reason how we handle device/stream here
prev_device = runtime.getDevice()
prev_stream: stream_module.Stream = None
if stream is not None:
prev_stream = stream_module.get_current_stream()
# stream can be an int as specified in __dlpack__, or a CuPy stream
if isinstance(stream, int):
stream = cp.cuda.ExternalStream(stream)
elif isinstance(stream, cp.cuda.Stream):
pass
else:
raise ValueError('the input stream is not recognized')
stream.use()
try:
runtime.setDevice(device.id)
arr = x.copy()
finally:
runtime.setDevice(prev_device)
if stream is not None:
prev_stream.use()
return arr
def _torch_to_device(x, device, /, stream=None):
if stream is not None:
raise NotImplementedError
return x.to(device)
def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
Parameters
----------
x: array
array instance from NumPy or an array API compatible library.
device: device
a ``device`` object (see the "Device Support" section of the array API specification).
stream: Optional[Union[int, Any]]
stream object to use during copy. In addition to the types supported in ``array.__dlpack__``, implementations may choose to support any library-specific stream object with the caveat that any code using such an object would not be portable.
Returns
-------
out: array
an array with the same data and data type as ``x`` and located on the specified ``device``.
.. note::
If ``stream`` is given, the copy operation should be enqueued on the provided ``stream``; otherwise, the copy operation should be enqueued on the default stream/queue. Whether the copy is performed synchronously or asynchronously is implementation-dependent. Accordingly, if synchronization is required to guarantee data safety, this must be clearly explained in a conforming library's documentation.
"""
if _is_numpy_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
elif _is_cupy_array(x):
# cupy does not yet have to_device
return _cupy_to_device(x, device, stream=stream)
elif _is_torch_array(x):
return _torch_to_device(x, device, stream=stream)
elif _is_dask_array(x):
if stream is not None:
raise ValueError("The stream argument to to_device() is not supported")
# TODO: What if our array is on the GPU already?
if device == 'cpu':
return x
raise ValueError(f"Unsupported device {device!r}")
return x.to_device(device, stream=stream)
def size(x):
"""
Return the total number of elements of x
"""
if None in x.shape:
return None
return math.prod(x.shape)