diff --git a/array_api_compat/common/_aliases.py b/array_api_compat/common/_aliases.py index 46cbb359..e3adda54 100644 --- a/array_api_compat/common/_aliases.py +++ b/array_api_compat/common/_aliases.py @@ -10,7 +10,7 @@ from ._typing import Array, Device, DType, Namespace from ._helpers import ( array_namespace, - _check_device, + _device_ctx, device as _get_device, is_cupy_namespace as _is_cupy_namespace ) @@ -31,8 +31,8 @@ def arange( device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs) def empty( shape: Union[int, Tuple[int, ...]], @@ -42,8 +42,8 @@ def empty( device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.empty(shape, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.empty(shape, dtype=dtype, **kwargs) def empty_like( x: Array, @@ -54,8 +54,8 @@ def empty_like( device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.empty_like(x, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.empty_like(x, dtype=dtype, **kwargs) def eye( n_rows: int, @@ -68,8 +68,8 @@ def eye( device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs) def full( shape: Union[int, Tuple[int, ...]], @@ -80,8 +80,8 @@ def full( device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.full(shape, fill_value, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.full(shape, fill_value, dtype=dtype, **kwargs) def full_like( x: Array, @@ -93,8 +93,8 @@ def full_like( device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.full_like(x, fill_value, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.full_like(x, fill_value, dtype=dtype, **kwargs) def linspace( start: Union[int, float], @@ -108,8 +108,8 @@ def linspace( endpoint: bool = True, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) + with _device_ctx(xp, device): + return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs) def ones( shape: Union[int, Tuple[int, ...]], @@ -119,8 +119,8 @@ def ones( device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.ones(shape, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.ones(shape, dtype=dtype, **kwargs) def ones_like( x: Array, @@ -131,8 +131,8 @@ def ones_like( device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.ones_like(x, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.ones_like(x, dtype=dtype, **kwargs) def zeros( shape: Union[int, Tuple[int, ...]], @@ -142,8 +142,8 @@ def zeros( device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.zeros(shape, dtype=dtype, **kwargs) + with _device_ctx(xp, device): + return xp.zeros(shape, dtype=dtype, **kwargs) def zeros_like( x: Array, @@ -154,8 +154,8 @@ def zeros_like( device: Optional[Device] = None, **kwargs, ) -> Array: - _check_device(xp, device) - return xp.zeros_like(x, dtype=dtype, **kwargs) + with _device_ctx(xp, device, like=x): + return xp.zeros_like(x, dtype=dtype, **kwargs) # np.unique() is split into four functions in the array API: # unique_all, unique_counts, unique_inverse, and unique_values (this is done diff --git a/array_api_compat/common/_helpers.py b/array_api_compat/common/_helpers.py index 67c619b8..9c931614 100644 --- a/array_api_compat/common/_helpers.py +++ b/array_api_compat/common/_helpers.py @@ -7,10 +7,12 @@ """ from __future__ import annotations +import contextlib import sys import math import inspect import warnings +from collections.abc import Generator from typing import Optional, Union, Any from ._typing import Array, Device, Namespace @@ -596,26 +598,42 @@ def your_function(x, y): get_namespace = array_namespace -def _check_device(bare_xp, device): - """ - Validate dummy device on device-less array backends. +def _device_ctx( + bare_xp: Namespace, device: Device, like: Array | None = None +) -> Generator[None]: + """Context manager which changes the current device in CuPy. - Notes - ----- - This function is also invoked by CuPy, which does have multiple devices - if there are multiple GPUs available. - However, CuPy multi-device support is currently impossible - without using the global device or a context manager: - - https://github.com/data-apis/array-api-compat/pull/293 + Used internally by array creation functions in common._aliases. """ + if device is None: + if like is None: + return contextlib.nullcontext() + device = _device(like) + if bare_xp is sys.modules.get('numpy'): - if device not in ("cpu", None): + if device != "cpu": raise ValueError(f"Unsupported device for NumPy: {device!r}") + return contextlib.nullcontext() - elif bare_xp is sys.modules.get('dask.array'): - if device not in ("cpu", _DASK_DEVICE, None): + if bare_xp is sys.modules.get('dask.array'): + if device not in ("cpu", _DASK_DEVICE): raise ValueError(f"Unsupported device for Dask: {device!r}") + return contextlib.nullcontext() + + if bare_xp is sys.modules.get('cupy'): + if not isinstance(device, bare_xp.cuda.Device): + raise TypeError(f"device is not a cupy.cuda.Device: {device!r}") + return device + + # PyTorch doesn't have a "current device" context manager and you + # can't use array creation functions from common._aliases. + raise AssertionError("unreachable") # pragma: nocover + + +def _check_device(bare_xp: Namespace, device: Device) -> None: + """Validate dummy device on device-less array backends.""" + with _device_ctx(bare_xp, device): + pass # Placeholder object to represent the dask device @@ -703,50 +721,39 @@ def device(x: Array, /) -> Device: # Prevent shadowing, used below _device = device + # Based on cupy.array_api.Array.to_device def _cupy_to_device(x, device, /, stream=None): import cupy as cp - from cupy.cuda import Device as _Device - from cupy.cuda import stream as stream_module - from cupy_backends.cuda.api import runtime - if device == x.device: - return x - elif device == "cpu": + if device == "cpu": # allowing us to use `to_device(x, "cpu")` # is useful for portable test swapping between # host and device backends return x.get() - elif not isinstance(device, _Device): - raise ValueError(f"Unsupported device {device!r}") - else: - # see cupy/cupy#5985 for the reason how we handle device/stream here - prev_device = runtime.getDevice() - prev_stream: stream_module.Stream = None - if stream is not None: - prev_stream = stream_module.get_current_stream() - # stream can be an int as specified in __dlpack__, or a CuPy stream - if isinstance(stream, int): - stream = cp.cuda.ExternalStream(stream) - elif isinstance(stream, cp.cuda.Stream): - pass - else: - raise ValueError('the input stream is not recognized') - stream.use() - try: - runtime.setDevice(device.id) - arr = x.copy() - finally: - runtime.setDevice(prev_device) - if stream is not None: - prev_stream.use() - return arr + if not isinstance(device, cp.cuda.Device): + raise TypeError(f"Unsupported device {device!r}") + + # see cupy/cupy#5985 for the reason how we handle device/stream here + + # stream can be an int as specified in __dlpack__, or a CuPy stream + if isinstance(stream, int): + stream = cp.cuda.ExternalStream(stream) + elif stream is None: + stream = contextlib.nullcontext() + elif not isinstance(stream, cp.cuda.Stream): + raise TypeError('the input stream is not recognized') + + with device, stream: + return cp.asarray(x) + def _torch_to_device(x, device, /, stream=None): if stream is not None: raise NotImplementedError return x.to(device) + def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array: """ Copy the array from the device on which it currently resides to the specified ``device``. diff --git a/array_api_compat/cupy/_aliases.py b/array_api_compat/cupy/_aliases.py index fd1460ae..2fd0d5cd 100644 --- a/array_api_compat/cupy/_aliases.py +++ b/array_api_compat/cupy/_aliases.py @@ -64,8 +64,6 @@ finfo = get_xp(cp)(_aliases.finfo) iinfo = get_xp(cp)(_aliases.iinfo) -_copy_default = object() - # asarray also adds the copy keyword, which is not present in numpy 1.0. def asarray( @@ -79,7 +77,7 @@ def asarray( *, dtype: Optional[DType] = None, device: Optional[Device] = None, - copy: Optional[bool] = _copy_default, + copy: Optional[bool] = None, **kwargs, ) -> Array: """ @@ -88,26 +86,15 @@ def asarray( See the corresponding documentation in the array library and/or the array API specification for more details. """ - with cp.cuda.Device(device): - # cupy is like NumPy 1.26 (except without _CopyMode). See the comments - # in asarray in numpy/_aliases.py. - if copy is not _copy_default: - # A future version of CuPy will change the meaning of copy=False - # to mean no-copy. We don't know for certain what version it will - # be yet, so to avoid breaking that version, we use a different - # default value for copy so asarray(obj) with no copy kwarg will - # always do the copy-if-needed behavior. - - # This will still need to be updated to remove the - # NotImplementedError for copy=False, but at least this won't - # break the default or existing behavior. - if copy is None: - copy = False - elif copy is False: - raise NotImplementedError("asarray(copy=False) is not yet supported in cupy") - kwargs['copy'] = copy - - return cp.array(obj, dtype=dtype, **kwargs) + if copy is False: + raise NotImplementedError("asarray(copy=False) is not yet supported in cupy") + + like = obj if isinstance(obj, cp.ndarray) else None + with _helpers._device_ctx(cp, device, like=like): + if copy is None: + return cp.asarray(obj, dtype=dtype, **kwargs) + else: + return cp.array(obj, dtype=dtype, copy=True, **kwargs) def astype(