Skip to content

BUG: dask: asarray should not materialize the graph #235

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 16 additions & 17 deletions array_api_compat/dask/array/_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,24 +144,23 @@ def asarray(
See the corresponding documentation in the array library and/or the array API
specification for more details.
"""
if isinstance(obj, da.Array):
if dtype is not None and dtype != obj.dtype:
if copy is False:
raise ValueError("Unable to avoid copy when changing dtype")
obj = obj.astype(dtype)
return obj.copy() if copy else obj

if copy is False:
# copy=False is not yet implemented in dask
raise NotImplementedError("copy=False is not yet implemented")
elif copy is True:
if isinstance(obj, da.Array) and dtype is None:
return obj.copy()
# Go through numpy, since dask copy is no-op by default
obj = np.array(obj, dtype=dtype, copy=True)
return da.array(obj, dtype=dtype)
else:
if not isinstance(obj, da.Array) or dtype is not None and obj.dtype != dtype:
# copy=True to be uniform across dask < 2024.12 and >= 2024.12
# see https://github.com/dask/dask/pull/11524/
obj = np.array(obj, dtype=dtype, copy=True)
return da.from_array(obj)
return obj

return da.asarray(obj, dtype=dtype, **kwargs)
raise NotImplementedError(
"Unable to avoid copy when converting a non-dask object to dask"
)

# copy=None to be uniform across dask < 2024.12 and >= 2024.12
# see https://github.com/dask/dask/pull/11524/
obj = np.array(obj, dtype=dtype, copy=True)
return da.from_array(obj)


from dask.array import (
# Element wise aliases
Expand Down
6 changes: 4 additions & 2 deletions tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,7 @@ def test_all(library):
all_names = module.__all__

if set(dir_names) != set(all_names):
assert set(dir_names) - set(all_names) == set(), f"Some dir() names not included in __all__ for {mod_name}"
assert set(all_names) - set(dir_names) == set(), f"Some __all__ names not in dir() for {mod_name}"
extra_dir = set(dir_names) - set(all_names)
extra_all = set(all_names) - set(dir_names)
assert not extra_dir, f"Some dir() names not included in __all__ for {mod_name}: {extra_dir}"
assert not extra_all, f"Some __all__ names not in dir() for {mod_name}: {extra_all}"
22 changes: 14 additions & 8 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,17 @@ def test_asarray_copy(library):
all = xp.all if library != 'dask.array' else lambda x: xp.all(x).compute()

if library == 'numpy' and xp.__version__[0] < '2' and not hasattr(xp, '_CopyMode') :
supports_copy_false = False
elif library in ['cupy', 'dask.array']:
supports_copy_false = False
supports_copy_false_other_ns = False
supports_copy_false_same_ns = False
elif library == 'cupy':
supports_copy_false_other_ns = False
supports_copy_false_same_ns = False
elif library == 'dask.array':
supports_copy_false_other_ns = False
supports_copy_false_same_ns = True
else:
supports_copy_false = True
supports_copy_false_other_ns = True
supports_copy_false_same_ns = True

a = asarray([1])
b = asarray(a, copy=True)
Expand All @@ -240,7 +246,7 @@ def test_asarray_copy(library):
assert all(a[0] == 0)

a = asarray([1])
if supports_copy_false:
if supports_copy_false_same_ns:
b = asarray(a, copy=False)
assert is_lib_func(b)
a[0] = 0
Expand All @@ -249,7 +255,7 @@ def test_asarray_copy(library):
pytest.raises(NotImplementedError, lambda: asarray(a, copy=False))

a = asarray([1])
if supports_copy_false:
if supports_copy_false_same_ns:
pytest.raises(ValueError, lambda: asarray(a, copy=False,
dtype=xp.float64))
else:
Expand Down Expand Up @@ -281,7 +287,7 @@ def test_asarray_copy(library):
for obj in [True, 0, 0.0, 0j, [0], [[0]]]:
asarray(obj, copy=True) # No error
asarray(obj, copy=None) # No error
if supports_copy_false:
if supports_copy_false_other_ns:
pytest.raises(ValueError, lambda: asarray(obj, copy=False))
else:
pytest.raises(NotImplementedError, lambda: asarray(obj, copy=False))
Expand All @@ -294,7 +300,7 @@ def test_asarray_copy(library):
assert all(b[0] == 1.0)

a = array.array('f', [1.0])
if supports_copy_false:
if supports_copy_false_other_ns:
b = asarray(a, copy=False)
assert is_lib_func(b)
a[0] = 0.0
Expand Down
108 changes: 108 additions & 0 deletions tests/test_dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from contextlib import contextmanager

import dask
import numpy as np
import pytest
import dask.array as da

from array_api_compat import array_namespace


@pytest.fixture
def xp():
"""Fixture returning the wrapped dask namespace"""
return array_namespace(da.empty(0))


@contextmanager
def assert_no_compute():
"""
Context manager that raises if at any point inside it anything calls compute()
or persist(), e.g. as it can be triggered implicitly by __bool__, __array__, etc.
"""
def get(dsk, *args, **kwargs):
raise AssertionError("Called compute() or persist()")

with dask.config.set(scheduler=get):
yield


def test_assert_no_compute():
"""Test the assert_no_compute context manager"""
a = da.asarray(True)
with pytest.raises(AssertionError, match="Called compute"):
with assert_no_compute():
bool(a)

# Exiting the context manager restores the original scheduler
assert bool(a) is True


# Test no_compute for functions that use generic _aliases with xp=np

def test_unary_ops_no_compute(xp):
with assert_no_compute():
a = xp.asarray([1.5, -1.5])
xp.ceil(a)
xp.floor(a)
xp.trunc(a)
xp.sign(a)


def test_matmul_tensordot_no_compute(xp):
A = da.ones((4, 4), chunks=2)
B = da.zeros((4, 4), chunks=2)
with assert_no_compute():
xp.matmul(A, B)
xp.tensordot(A, B)


# Test no_compute for functions that are fully bespoke for dask

def test_asarray_no_compute(xp):
with assert_no_compute():
a = xp.arange(10)
xp.asarray(a)
xp.asarray(a, dtype=np.int16)
xp.asarray(a, dtype=a.dtype)
xp.asarray(a, copy=True)
xp.asarray(a, copy=True, dtype=np.int16)
xp.asarray(a, copy=True, dtype=a.dtype)
xp.asarray(a, copy=False)
xp.asarray(a, copy=False, dtype=a.dtype)


@pytest.mark.parametrize("copy", [True, False])
def test_astype_no_compute(xp, copy):
with assert_no_compute():
a = xp.arange(10)
xp.astype(a, np.int16, copy=copy)
xp.astype(a, a.dtype, copy=copy)


def test_clip_no_compute(xp):
with assert_no_compute():
a = xp.arange(10)
xp.clip(a)
xp.clip(a, 1)
xp.clip(a, 1, 8)


def test_generators_are_lazy(xp):
"""
Test that generator functions are fully lazy, e.g. that
da.ones(n) is not implemented as da.asarray(np.ones(n))
"""
size = 100_000_000_000 # 800 GB
chunks = size // 10 # 10x 80 GB chunks

with assert_no_compute():
xp.zeros(size, chunks=chunks)
xp.ones(size, chunks=chunks)
xp.empty(size, chunks=chunks)
xp.full(size, fill_value=123, chunks=chunks)
a = xp.arange(size, chunks=chunks)
xp.zeros_like(a)
xp.ones_like(a)
xp.empty_like(a)
xp.full_like(a, fill_value=123)
Loading