Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM] ENH: CuPy multi-device support #293

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
46 changes: 23 additions & 23 deletions array_api_compat/common/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ._typing import Array, Device, DType, Namespace
from ._helpers import (
array_namespace,
_check_device,
_device_ctx,
device as _get_device,
is_cupy_namespace as _is_cupy_namespace
)
Expand All @@ -31,8 +31,8 @@ def arange(
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.arange(start, stop=stop, step=step, dtype=dtype, **kwargs)

def empty(
shape: Union[int, Tuple[int, ...]],
Expand All @@ -42,8 +42,8 @@ def empty(
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.empty(shape, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.empty(shape, dtype=dtype, **kwargs)

def empty_like(
x: Array,
Expand All @@ -54,8 +54,8 @@ def empty_like(
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.empty_like(x, dtype=dtype, **kwargs)
with _device_ctx(xp, device, like=x):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my education: why can the user pass a device= argument to a *_like function? Naively I'd have expected that the _like implies that the device matches that of x. But then you can also pass a dtype= which overrides the dtype of x, so by symmetry allowing a device= makes sense?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Matching device of the input is what the user wants 99% of the time.
Using empty_like etc. on a different device can make some sense when preparing an output vessel that is filled from different sources. TBH, though, the main difference between empty and empty_like, besides convenience, is that the latter can easily duplicate a lazy (unknown) shape. Which frequently prevents masked updates, but that's a separate problem.

return xp.empty_like(x, dtype=dtype, **kwargs)

def eye(
n_rows: int,
Expand All @@ -68,8 +68,8 @@ def eye(
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.eye(n_rows, M=n_cols, k=k, dtype=dtype, **kwargs)

def full(
shape: Union[int, Tuple[int, ...]],
Expand All @@ -80,8 +80,8 @@ def full(
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.full(shape, fill_value, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.full(shape, fill_value, dtype=dtype, **kwargs)

def full_like(
x: Array,
Expand All @@ -93,8 +93,8 @@ def full_like(
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)
with _device_ctx(xp, device, like=x):
return xp.full_like(x, fill_value, dtype=dtype, **kwargs)

def linspace(
start: Union[int, float],
Expand All @@ -108,8 +108,8 @@ def linspace(
endpoint: bool = True,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)
with _device_ctx(xp, device):
return xp.linspace(start, stop, num, dtype=dtype, endpoint=endpoint, **kwargs)

def ones(
shape: Union[int, Tuple[int, ...]],
Expand All @@ -119,8 +119,8 @@ def ones(
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.ones(shape, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.ones(shape, dtype=dtype, **kwargs)

def ones_like(
x: Array,
Expand All @@ -131,8 +131,8 @@ def ones_like(
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.ones_like(x, dtype=dtype, **kwargs)
with _device_ctx(xp, device, like=x):
return xp.ones_like(x, dtype=dtype, **kwargs)

def zeros(
shape: Union[int, Tuple[int, ...]],
Expand All @@ -142,8 +142,8 @@ def zeros(
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.zeros(shape, dtype=dtype, **kwargs)
with _device_ctx(xp, device):
return xp.zeros(shape, dtype=dtype, **kwargs)

def zeros_like(
x: Array,
Expand All @@ -154,8 +154,8 @@ def zeros_like(
device: Optional[Device] = None,
**kwargs,
) -> Array:
_check_device(xp, device)
return xp.zeros_like(x, dtype=dtype, **kwargs)
with _device_ctx(xp, device, like=x):
return xp.zeros_like(x, dtype=dtype, **kwargs)

# np.unique() is split into four functions in the array API:
# unique_all, unique_counts, unique_inverse, and unique_values (this is done
Expand Down
93 changes: 59 additions & 34 deletions array_api_compat/common/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
"""
from __future__ import annotations

import contextlib
import sys
import math
import inspect
import warnings
from collections.abc import Generator
from typing import Optional, Union, Any

from ._typing import Array, Device, Namespace
Expand Down Expand Up @@ -595,10 +597,6 @@ def your_function(x, y):
# backwards compatibility alias
get_namespace = array_namespace

def _check_device(xp, device):
if xp == sys.modules.get('numpy'):
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")

# Placeholder object to represent the dask device
# when the array backend is not the CPU.
Expand All @@ -609,6 +607,7 @@ def __repr__(self):

_DASK_DEVICE = _dask_device()


# device() is not on numpy.ndarray or dask.array and to_device() is not on numpy.ndarray
# or cupy.ndarray. They are not included in array objects of this library
# because this library just reuses the respective ndarray classes without
Expand Down Expand Up @@ -685,50 +684,39 @@ def device(x: Array, /) -> Device:
# Prevent shadowing, used below
_device = device


# Based on cupy.array_api.Array.to_device
def _cupy_to_device(x, device, /, stream=None):
import cupy as cp
from cupy.cuda import Device as _Device
from cupy.cuda import stream as stream_module
from cupy_backends.cuda.api import runtime

if device == x.device:
return x
elif device == "cpu":
if device == "cpu":
# allowing us to use `to_device(x, "cpu")`
# is useful for portable test swapping between
# host and device backends
return x.get()
elif not isinstance(device, _Device):
raise ValueError(f"Unsupported device {device!r}")
else:
# see cupy/cupy#5985 for the reason how we handle device/stream here
prev_device = runtime.getDevice()
prev_stream: stream_module.Stream = None
if stream is not None:
prev_stream = stream_module.get_current_stream()
# stream can be an int as specified in __dlpack__, or a CuPy stream
if isinstance(stream, int):
stream = cp.cuda.ExternalStream(stream)
elif isinstance(stream, cp.cuda.Stream):
pass
else:
raise ValueError('the input stream is not recognized')
stream.use()
try:
runtime.setDevice(device.id)
arr = x.copy()
finally:
runtime.setDevice(prev_device)
if stream is not None:
prev_stream.use()
return arr
if not isinstance(device, cp.cuda.Device):
raise TypeError(f"Unsupported device {device!r}")

# see cupy/cupy#5985 for the reason how we handle device/stream here

# stream can be an int as specified in __dlpack__, or a CuPy stream
if isinstance(stream, int):
stream = cp.cuda.ExternalStream(stream)
elif stream is None:
stream = contextlib.nullcontext()
elif not isinstance(stream, cp.cuda.Stream):
raise TypeError('the input stream is not recognized')

with device, stream:
return cp.asarray(x)


def _torch_to_device(x, device, /, stream=None):
if stream is not None:
raise NotImplementedError
return x.to(device)


def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]] = None) -> Array:
"""
Copy the array from the device on which it currently resides to the specified ``device``.
Expand Down Expand Up @@ -811,6 +799,43 @@ def to_device(x: Array, device: Device, /, *, stream: Optional[Union[int, Any]]
return x.to_device(device, stream=stream)


def _device_ctx(
bare_xp: Namespace, device: Device, like: Array | None = None
) -> Generator[None]:
"""Context manager which changes the current device in CuPy.

Used internally by array creation functions in common._aliases.
"""
if device is None:
if like is None:
return contextlib.nullcontext()
device = _device(like)

if bare_xp is sys.modules.get('numpy'):
if device != "cpu":
raise ValueError(f"Unsupported device for NumPy: {device!r}")
return contextlib.nullcontext()

if bare_xp is sys.modules.get('dask.array'):
if device not in ("cpu", _DASK_DEVICE):
raise ValueError(f"Unsupported device for Dask: {device!r}")
return contextlib.nullcontext()

if bare_xp is sys.modules.get('cupy'):
if not isinstance(device, bare_xp.cuda.Device):
raise TypeError(f"device is not a cupy.cuda.Device: {device!r}")
return device

# PyTorch doesn't have a "current device" context manager and you
# can't use array creation functions from common._aliases.
raise AssertionError("unreachable") # pragma: nocover


def _check_device(bare_xp: Namespace, device: Device) -> None:
with _device_ctx(bare_xp, device):
pass


def size(x: Array) -> int | None:
"""
Return the total number of elements of x.
Expand Down
33 changes: 10 additions & 23 deletions array_api_compat/cupy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@
tensordot = get_xp(cp)(_aliases.tensordot)
sign = get_xp(cp)(_aliases.sign)

_copy_default = object()


# asarray also adds the copy keyword, which is not present in numpy 1.0.
def asarray(
Expand All @@ -77,7 +75,7 @@ def asarray(
*,
dtype: Optional[DType] = None,
device: Optional[Device] = None,
copy: Optional[bool] = _copy_default,
copy: Optional[bool] = None,
**kwargs,
) -> Array:
"""
Expand All @@ -86,26 +84,15 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
with cp.cuda.Device(device):
# cupy is like NumPy 1.26 (except without _CopyMode). See the comments
# in asarray in numpy/_aliases.py.
if copy is not _copy_default:
# A future version of CuPy will change the meaning of copy=False
# to mean no-copy. We don't know for certain what version it will
# be yet, so to avoid breaking that version, we use a different
# default value for copy so asarray(obj) with no copy kwarg will
# always do the copy-if-needed behavior.

# This will still need to be updated to remove the
# NotImplementedError for copy=False, but at least this won't
# break the default or existing behavior.
if copy is None:
copy = False
elif copy is False:
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")
kwargs['copy'] = copy

return cp.array(obj, dtype=dtype, **kwargs)
if copy is False:
raise NotImplementedError("asarray(copy=False) is not yet supported in cupy")

like = obj if _helpers.is_cupy_array(obj) else None
with _helpers._device_ctx(cp, device, like=like):
if copy is None:
return cp.asarray(obj, dtype=dtype, **kwargs)
else:
return cp.array(obj, dtype=dtype, copy=True, **kwargs)


def astype(
Expand Down
5 changes: 4 additions & 1 deletion array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
import dask.array as da

from ...common import _aliases, array_namespace
from ...common import _aliases, _helpers, array_namespace
from ...common._typing import (
Array,
Device,
Expand Down Expand Up @@ -56,6 +56,7 @@ def astype(
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)

if not copy and dtype == x.dtype:
return x
Expand Down Expand Up @@ -86,6 +87,7 @@ def arange(
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)

args = [start]
if stop is not None:
Expand Down Expand Up @@ -155,6 +157,7 @@ def asarray(
specification for more details.
"""
# TODO: respect device keyword?
_helpers._check_device(da, device)

if isinstance(obj, da.Array):
if dtype is not None and dtype != obj.dtype:
Expand Down
6 changes: 3 additions & 3 deletions array_api_compat/numpy/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Union

from .._internal import get_xp
from ..common import _aliases
from ..common import _aliases, _helpers
from ..common._typing import NestedSequence, SupportsBufferProtocol
from ._info import __array_namespace_info__
from ._typing import Array, Device, DType
Expand Down Expand Up @@ -95,8 +95,7 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
if device not in ["cpu", None]:
raise ValueError(f"Unsupported device for NumPy: {device!r}")
_helpers._check_device(np, device)

if hasattr(np, '_CopyMode'):
if copy is None:
Expand All @@ -122,6 +121,7 @@ def astype(
copy: bool = True,
device: Optional[Device] = None,
) -> Array:
_helpers._check_device(np, device)
return x.astype(dtype=dtype, copy=copy)


Expand Down
Loading