From 712dd9efcbabc35e9451d2f467b22e5cd4e3c4ff Mon Sep 17 00:00:00 2001 From: crusaderky Date: Tue, 22 Apr 2025 11:56:49 +0100 Subject: [PATCH 1/2] ENH: `jax_autojit` --- src/array_api_extra/_lib/_utils/_helpers.py | 263 +++++++++++++++++++- src/array_api_extra/testing.py | 85 ++++--- tests/conftest.py | 41 ++- tests/test_at.py | 22 +- tests/test_funcs.py | 20 +- tests/test_helpers.py | 199 ++++++++++++++- tests/test_lazy.py | 6 +- tests/test_testing.py | 96 +++++-- 8 files changed, 628 insertions(+), 104 deletions(-) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 64006270..18461845 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -2,10 +2,24 @@ from __future__ import annotations +import io import math -from collections.abc import Generator, Iterable +import pickle +import types +from collections.abc import Callable, Generator, Iterable +from functools import wraps from types import ModuleType -from typing import TYPE_CHECKING, cast +from typing import ( + TYPE_CHECKING, + Any, + ClassVar, + Generic, + Literal, + ParamSpec, + TypeAlias, + TypeVar, + cast, +) from . import _compat from ._compat import ( @@ -19,8 +33,16 @@ from ._typing import Array if TYPE_CHECKING: # pragma: no cover - # TODO import from typing (requires Python >=3.13) - from typing_extensions import TypeIs + # TODO import from typing (requires Python >=3.12 and >=3.13) + from typing_extensions import TypeIs, override +else: + + def override(func): + return func + + +P = ParamSpec("P") +T = TypeVar("T") __all__ = [ @@ -29,8 +51,11 @@ "eager_shape", "in1d", "is_python_scalar", + "jax_autojit", "mean", "meta_namespace", + "pickle_flatten", + "pickle_unflatten", ] @@ -306,3 +331,233 @@ def capabilities(xp: ModuleType) -> dict[str, int]: out["boolean indexing"] = True out["data-dependent shapes"] = True return out + + +_BASIC_PICKLED_TYPES = frozenset(( + bool, int, float, complex, str, bytes, bytearray, + list, tuple, dict, set, frozenset, range, slice, + types.NoneType, types.EllipsisType, +)) # fmt: skip +_BASIC_REST_TYPES = frozenset(( + type, types.BuiltinFunctionType, types.FunctionType, types.ModuleType +)) # fmt: skip + +FlattenRest: TypeAlias = tuple[object, ...] + + +def pickle_flatten( + obj: object, cls: type[T] | tuple[type[T], ...] +) -> tuple[list[T], FlattenRest]: + """ + Use the pickle machinery to extract objects out of an arbitrary container. + + Unlike regular ``pickle.dumps``, this function always succeeds. + + Parameters + ---------- + obj : object + The object to pickle. + cls : type | tuple[type, ...] + One or multiple classes to extract from the object. + The instances of these classes inside ``obj`` will not be pickled. + + Returns + ------- + instances : list[cls] + All instances of ``cls`` found inside ``obj`` (not pickled). + rest + Opaque object containing the pickled bytes plus all other objects where + ``__reduce__`` / ``__reduce_ex__`` is either not implemented or raised. + These are unpickleable objects, types, modules, and functions. + + This object is *typically* hashable save for fairly exotic objects + that are neither pickleable nor hashable. + + This object is pickleable if everything except ``instances`` was pickleable + in the input object. + + See Also + -------- + pickle_unflatten : Reverse function. + + Examples + -------- + >>> class A: + ... def __repr__(self): + ... return "" + >>> class NS: + ... def __repr__(self): + ... return "" + ... def __reduce__(self): + ... assert False, "not serializable" + >>> obj = {1: A(), 2: [A(), NS(), A()]} + >>> instances, rest = pickle_flatten(obj, A) + >>> instances + [, , ] + >>> pickle_unflatten(instances, rest) + {1: , 2: [, , ]} + + This can be also used to swap inner objects; the only constraint is that + the number of objects in and out must be the same: + + >>> pickle_unflatten(["foo", "bar", "baz"], rest) + {1: "foo", 2: ["bar", , "baz"]} + """ + instances: list[T] = [] + rest: list[object] = [] + + class Pickler(pickle.Pickler): # numpydoc ignore=GL08 + """ + Use the `pickle.Pickler.persistent_id` hook to extract objects. + """ + + @override + def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08 + if isinstance(obj, cls): + instances.append(obj) # type: ignore[arg-type] + return 0 + + typ_ = type(obj) + if typ_ in _BASIC_PICKLED_TYPES: # No subclasses! + # If obj is a collection, recursively descend inside it + return None + if typ_ in _BASIC_REST_TYPES: + rest.append(obj) + return 1 + + try: + # Note: a class that defines __slots__ without defining __getstate__ + # cannot be pickled with __reduce__(), but can with __reduce_ex__(5) + _ = obj.__reduce_ex__(5) + except Exception: # pylint: disable=broad-exception-caught + rest.append(obj) + return 1 + + # Object can be pickled. Let the Pickler recursively descend inside it. + return None + + f = io.BytesIO() + p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL) + p.dump(obj) + return instances, (f.getvalue(), *rest) + + +def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any] + """ + Reverse of ``pickle_flatten``. + + Parameters + ---------- + instances : Iterable + Inner objects to be reinserted into the flattened container. + rest : FlattenRest + Extra bits, as returned by ``pickle_flatten``. + + Returns + ------- + object + The outer object originally passed to ``pickle_flatten`` after a + pickle->unpickle round-trip. + + See Also + -------- + pickle_flatten : Serializing function. + pickle.loads : Standard unpickle function. + + Notes + ----- + The `instances` iterable must yield at least the same number of elements as the ones + returned by ``pickle_without``, but the elements do not need to be the same objects + or even the same types of objects. Excess elements, if any, will be left untouched. + """ + iters = iter(instances), iter(rest) + pik = cast(bytes, next(iters[1])) + + class Unpickler(pickle.Unpickler): # numpydoc ignore=GL08 + """Mirror of the overridden Pickler in pickle_flatten.""" + + @override + def persistent_load(self, pid: Literal[0, 1]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08 + try: + return next(iters[pid]) + except StopIteration as e: + msg = "Not enough objects to unpickle" + raise ValueError(msg) from e + + f = io.BytesIO(pik) + return Unpickler(f).load() + + +class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01 + """ + Helper of :func:`jax_autojit`. + + Wrap arbitrary inputs and outputs of the jitted function and + convert them to/from PyTrees. + """ + + obj: T + _registered: ClassVar[bool] = False + __slots__: tuple[str, ...] = ("obj",) + + def __init__(self, obj: T) -> None: # numpydoc ignore=GL08 + self._register() + self.obj = obj + + @classmethod + def _register(cls): # numpydoc ignore=SS06 + """ + Register upon first use instead of at import time, to avoid + globally importing JAX. + """ + if not cls._registered: + import jax + + jax.tree_util.register_pytree_node( + cls, + lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType] + lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType] + ) + cls._registered = True + + +def jax_autojit( + func: Callable[P, T], +) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03 + """ + Wrap `func` with ``jax.jit``, with the following differences: + + - Python scalar arguments and return values are not automatically converted to + ``jax.Array`` objects. + - All non-array arguments are automatically treated as static. + Unlike ``jax.jit``, static arguments must be either hashable or serializable with + ``pickle``. + - Unlike ``jax.jit``, non-array arguments and return values are not limited to + tuple/list/dict, but can be any object serializable with ``pickle``. + - Automatically descend into non-array arguments and find ``jax.Array`` objects + inside them, then rebuild the arguments when entering `func`, swapping the JAX + concrete arrays with tracer objects. + - Automatically descend into non-array return values and find ``jax.Array`` objects + inside them, then rebuild them downstream of exiting the JIT, swapping the JAX + tracer objects with concrete arrays. + + See Also + -------- + jax.jit : JAX JIT compilation function. + """ + import jax + + @jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator] + def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08 + wargs: _AutoJITWrapper[Any], + ) -> _AutoJITWrapper[T]: + args, kwargs = wargs.obj + res = func(*args, **kwargs) # pyright: ignore[reportCallIssue] + return _AutoJITWrapper(res) + + @wraps(func) + def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 + wargs = _AutoJITWrapper((args, kwargs)) + return inner(wargs).obj + + return outer diff --git a/src/array_api_extra/testing.py b/src/array_api_extra/testing.py index 37e8e69e..c14e9a22 100644 --- a/src/array_api_extra/testing.py +++ b/src/array_api_extra/testing.py @@ -7,12 +7,15 @@ from __future__ import annotations import contextlib -from collections.abc import Callable, Iterable, Iterator, Sequence +import enum +import warnings +from collections.abc import Callable, Iterator, Sequence from functools import wraps from types import ModuleType from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, cast from ._lib._utils._compat import is_dask_namespace, is_jax_namespace +from ._lib._utils._helpers import jax_autojit, pickle_flatten, pickle_unflatten __all__ = ["lazy_xp_function", "patch_lazy_xp_functions"] @@ -26,7 +29,7 @@ # Sphinx hacks SchedulerGetCallable = object - def override(func: object) -> object: + def override(func): return func @@ -36,13 +39,22 @@ def override(func: object) -> object: _ufuncs_tags: dict[object, dict[str, Any]] = {} # type: ignore[explicit-any] +class Deprecated(enum.Enum): + """Unique type for deprecated parameters.""" + + DEPRECATED = 1 + + +DEPRECATED = Deprecated.DEPRECATED + + def lazy_xp_function( # type: ignore[explicit-any] func: Callable[..., Any], *, allow_dask_compute: bool | int = False, jax_jit: bool = True, - static_argnums: int | Sequence[int] | None = None, - static_argnames: str | Iterable[str] | None = None, + static_argnums: Deprecated = DEPRECATED, + static_argnames: Deprecated = DEPRECATED, ) -> None: # numpydoc ignore=GL07 """ Tag a function to be tested on lazy backends. @@ -82,16 +94,30 @@ def lazy_xp_function( # type: ignore[explicit-any] Default: False, meaning that `func` must be fully lazy and never materialize the graph. jax_jit : bool, optional - Set to True to replace `func` with ``jax.jit(func)`` after calling the - :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. Set to False - if `func` is only compatible with eager (non-jitted) JAX. Default: True. - static_argnums : int | Sequence[int], optional - Passed to jax.jit. Positional arguments to treat as static (compile-time - constant). Default: infer from `static_argnames` using - `inspect.signature(func)`. - static_argnames : str | Iterable[str], optional - Passed to jax.jit. Named arguments to treat as static (compile-time constant). - Default: infer from `static_argnums` using `inspect.signature(func)`. + Set to True to replace `func` with a smart variant of ``jax.jit(func)`` after + calling the :func:`patch_lazy_xp_functions` test helper with ``xp=jax.numpy``. + Set to False if `func` is only compatible with eager (non-jitted) JAX. + + Unlike with vanilla ``jax.jit``, all arguments and return types that are not JAX + arrays are treated as static; the function can accept and return arbitrary + wrappers around JAX arrays. This difference is because, in real life, most users + won't wrap the function directly with ``jax.jit`` but rather they will use it + within their own code, which is itself then wrapped by ``jax.jit``, and + internally consume the function's outputs. + + In other words, the pattern that is being tested is:: + + >>> @jax.jit + ... def user_func(x): + ... y = user_prepares_inputs(x) + ... z = func(y, some_static_arg=True) + ... return user_consumes(z) + + Default: True. + static_argnums : + Deprecated; ignored + static_argnames : + Deprecated; ignored See Also -------- @@ -108,7 +134,7 @@ def lazy_xp_function( # type: ignore[explicit-any] def test_myfunc(xp): a = xp.asarray([1, 2]) - # When xp=jax.numpy, this is the same as `b = jax.jit(myfunc)(a)` + # When xp=jax.numpy, this is similar to `b = jax.jit(myfunc)(a)` # When xp=dask.array, crash on compute() or persist() b = myfunc(a) @@ -168,12 +194,20 @@ def test_myfunc(xp): b = mymodule.myfunc(a) # This is wrapped when xp=jax.numpy or xp=dask.array c = naked.myfunc(a) # This is not """ + if static_argnums is not DEPRECATED or static_argnames is not DEPRECATED: + warnings.warn( + ( + "The `static_argnums` and `static_argnames` parameters are deprecated " + "and ignored. They will be removed in a future version." + ), + DeprecationWarning, + stacklevel=2, + ) tags = { "allow_dask_compute": allow_dask_compute, "jax_jit": jax_jit, - "static_argnums": static_argnums, - "static_argnames": static_argnames, } + try: func._lazy_xp_function = tags # type: ignore[attr-defined] # pylint: disable=protected-access # pyright: ignore[reportFunctionMemberAccess] except AttributeError: # @cython.vectorize @@ -247,19 +281,9 @@ def iter_tagged() -> ( # type: ignore[explicit-any] monkeypatch.setattr(mod, name, wrapped) elif is_jax_namespace(xp): - import jax - for mod, name, func, tags in iter_tagged(): if tags["jax_jit"]: - # suppress unused-ignore to run mypy in -e lint as well as -e dev - wrapped = cast( # type: ignore[explicit-any] - Callable[..., Any], - jax.jit( - func, - static_argnums=tags["static_argnums"], - static_argnames=tags["static_argnames"], - ), - ) + wrapped = jax_autojit(func) monkeypatch.setattr(mod, name, wrapped) @@ -308,6 +332,7 @@ def _dask_wrap( After the function returns, materialize the graph in order to re-raise exceptions. """ import dask + import dask.array as da func_name = getattr(func, "__name__", str(func)) n_str = f"only up to {n}" if n else "no" @@ -327,6 +352,8 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 # Block until the graph materializes and reraise exceptions. This allows # `pytest.raises` and `pytest.warns` to work as expected. Note that this would # not work on scheduler='distributed', as it would not block. - return dask.persist(out, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] + arrays, rest = pickle_flatten(out, da.Array) + arrays = dask.persist(arrays, scheduler="threads")[0] # type: ignore[attr-defined,no-untyped-call,func-returns-value,index] # pyright: ignore[reportPrivateImportUsage] + return pickle_unflatten(arrays, rest) # pyright: ignore[reportUnknownArgumentType] return wrapper diff --git a/tests/conftest.py b/tests/conftest.py index 5676cc0d..372f9960 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -148,19 +148,7 @@ def xp( patch_lazy_xp_functions(request, monkeypatch, xp=xp) if library.like(Backend.JAX): - import jax - - # suppress unused-ignore to run mypy in -e lint as well as -e dev - jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore] - - if library == Backend.JAX_GPU: - try: - device = jax.devices("cuda")[0] - except RuntimeError: - pytest.skip("no CUDA device available") - else: - device = jax.devices("cpu")[0] - jax.config.update("jax_default_device", device) + _setup_jax(library) elif library == Backend.TORCH_GPU: import torch.cuda @@ -175,6 +163,22 @@ def xp( yield xp +def _setup_jax(library: Backend) -> None: + import jax + + # suppress unused-ignore to run mypy in -e lint as well as -e dev + jax.config.update("jax_enable_x64", True) # type: ignore[no-untyped-call,unused-ignore] + + if library == Backend.JAX_GPU: + try: + device = jax.devices("cuda")[0] + except RuntimeError: + pytest.skip("no CUDA device available") + else: + device = jax.devices("cpu")[0] + jax.config.update("jax_default_device", device) + + @pytest.fixture(params=[Backend.DASK]) # Can select the test with `pytest -k dask` def da( request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch @@ -186,6 +190,17 @@ def da( return xp +@pytest.fixture(params=[Backend.JAX, Backend.JAX_GPU]) +def jnp( + request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch +) -> ModuleType: # numpydoc ignore=PR01,RT01 + """Variant of the `xp` fixture that only yields jax.numpy.""" + xp = pytest.importorskip("jax.numpy") + _setup_jax(request.param) + patch_lazy_xp_functions(request, monkeypatch, xp=xp) + return xp + + @pytest.fixture def device( library: Backend, xp: ModuleType diff --git a/tests/test_at.py b/tests/test_at.py index fa9bcdc8..7294a7c4 100644 --- a/tests/test_at.py +++ b/tests/test_at.py @@ -1,5 +1,4 @@ import math -import pickle from collections.abc import Callable, Generator from contextlib import contextmanager from types import ModuleType @@ -41,28 +40,11 @@ def at_op( just a workaround for when one wants to apply jax.jit to `at()` directly, which is not a common use case. """ - if isinstance(idx, (slice | tuple)): - return _at_op(x, None, pickle.dumps(idx), op, y, copy=copy, xp=xp) - return _at_op(x, idx, None, op, y, copy=copy, xp=xp) - - -def _at_op( - x: Array, - idx: SetIndex | None, - idx_pickle: bytes | None, - op: _AtOp, - y: Array | object, - copy: bool | None, - xp: ModuleType | None = None, -) -> Array: - """jitted helper of at_op""" - if idx_pickle: - idx = pickle.loads(idx_pickle) - meth = cast(Callable[..., Array], getattr(at(x, cast(SetIndex, idx)), op.value)) # type: ignore[explicit-any] + meth = cast(Callable[..., Array], getattr(at(x, idx), op.value)) # type: ignore[explicit-any] return meth(y, copy=copy, xp=xp) -lazy_xp_function(_at_op, static_argnames=("op", "idx_pickle", "copy", "xp")) +lazy_xp_function(at_op) @contextmanager diff --git a/tests/test_funcs.py b/tests/test_funcs.py index 0cee0b4d..053e8096 100644 --- a/tests/test_funcs.py +++ b/tests/test_funcs.py @@ -37,17 +37,17 @@ # some xp backends are untyped # mypy: disable-error-code=no-untyped-def -lazy_xp_function(apply_where, static_argnums=(2, 3), static_argnames="xp") -lazy_xp_function(atleast_nd, static_argnames=("ndim", "xp")) -lazy_xp_function(cov, static_argnames="xp") -lazy_xp_function(create_diagonal, static_argnames=("offset", "xp")) -lazy_xp_function(expand_dims, static_argnames=("axis", "xp")) -lazy_xp_function(kron, static_argnames="xp") -lazy_xp_function(nunique, static_argnames="xp") -lazy_xp_function(pad, static_argnames=("pad_width", "mode", "constant_values", "xp")) +lazy_xp_function(apply_where) +lazy_xp_function(atleast_nd) +lazy_xp_function(cov) +lazy_xp_function(create_diagonal) +lazy_xp_function(expand_dims) +lazy_xp_function(kron) +lazy_xp_function(nunique) +lazy_xp_function(pad) # FIXME calls in1d which calls xp.unique_values without size -lazy_xp_function(setdiff1d, jax_jit=False, static_argnames=("assume_unique", "xp")) -lazy_xp_function(sinc, static_argnames="xp") +lazy_xp_function(setdiff1d, jax_jit=False) +lazy_xp_function(sinc) class TestApplyWhere: diff --git a/tests/test_helpers.py b/tests/test_helpers.py index a104e93c..d2068f13 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -1,5 +1,5 @@ from types import ModuleType -from typing import cast +from typing import TYPE_CHECKING, Generic, TypeVar, cast import numpy as np import pytest @@ -13,18 +13,31 @@ capabilities, eager_shape, in1d, + jax_autojit, meta_namespace, ndindex, + pickle_flatten, + pickle_unflatten, ) from array_api_extra._lib._utils._typing import Array, Device, DType from array_api_extra.testing import lazy_xp_function from .conftest import np_compat +if TYPE_CHECKING: + # TODO import from typing (requires Python >=3.12) + from typing_extensions import override +else: + + def override(func): + return func + # mypy: disable-error-code=no-untyped-usage +T = TypeVar("T") + # FIXME calls xp.unique_values without size -lazy_xp_function(in1d, jax_jit=False, static_argnames=("assume_unique", "invert", "xp")) +lazy_xp_function(in1d, jax_jit=False) @pytest.mark.skip_xp_backend(Backend.SPARSE, reason="no unique_inverse") @@ -204,3 +217,185 @@ def test_capabilities(xp: ModuleType): if xp.__array_api_version__ >= "2024.12": expect.add("max dimensions") assert capabilities(xp).keys() == expect + + +class Wrapper(Generic[T]): + """Trivial opaque wrapper. Must be pickleable.""" + + x: T + # __slots__ make this object serializable with __reduce_ex__(5), + # but not with __reduce__ + __slots__: tuple[str, ...] = ("x",) + + def __init__(self, x: T): + self.x = x + + # Note: this makes the object not hashable + @override + def __eq__(self, other: object) -> bool: + return isinstance(other, Wrapper) and self.x == other.x + + +class TestPickleFlatten: + def test_roundtrip(self): + class NotSerializable: + @override + def __reduce__(self) -> tuple[object, ...]: + raise NotImplementedError() + + # Note: NotHashable() instances can be reduced to an + # unserializable local class + class NotHashable: + @override + def __eq__(self, other: object) -> bool: + return isinstance(other, type(self)) and other.__dict__ == self.__dict__ + + with pytest.raises(TypeError): + _ = hash(NotHashable()) + + # Extracted objects need be neither pickleable nor serializable + class C(NotSerializable, NotHashable): + x: int + + def __init__(self, x: int): + self.x = x + + class D(C): + pass + + c1 = C(1) + c2 = C(2) + d3 = D(3) + + # An assorted bunch of opaque containers, standard containers, + # non-serializable objects, and non-hashable objects (but not at the same time) + obj = Wrapper([1, c1, {2: (c2, {NotSerializable()})}, NotHashable(), d3]) + instances, rest = pickle_flatten(obj, C) + + assert instances == [c1, c2, d3] + obj2 = pickle_unflatten(instances, rest) + assert obj2 == obj + + def test_swap_objects(self): + class C: + pass + + obj = [1, C(), {2: (C(), {C()})}] + _, rest = pickle_flatten(obj, C) + obj2 = pickle_unflatten(["foo", "bar", "baz"], rest) + assert obj2 == [1, "foo", {2: ("bar", {"baz"})}] + + def test_multi_class(self): + class C: + pass + + class D: + pass + + c, d = C(), D() + instances, _ = pickle_flatten([c, d], (C, D)) + assert len(instances) == 2 + assert instances[0] is c + assert instances[1] is d + + def test_no_class(self): + obj = {1: "foo", 2: (3, 4)} + instances, rest = pickle_flatten(obj, ()) # type: ignore[var-annotated] + assert instances == [] + obj2 = pickle_unflatten([], rest) + assert obj2 == obj + + def test_flattened_stream(self): + """ + Test that multiple calls to flatten() can feed into the same stream of instances + """ + obj1 = Wrapper(1) + obj2 = [Wrapper(2), Wrapper(3)] + instances1, rest1 = pickle_flatten(obj1, Wrapper) + instances2, rest2 = pickle_flatten(obj2, Wrapper) + it = iter(instances1 + instances2 + [Wrapper(4)]) # pyright: ignore[reportUnknownArgumentType] + assert pickle_unflatten(it, rest1) == obj1 # pyright: ignore[reportUnknownArgumentType] + assert pickle_unflatten(it, rest2) == obj2 # pyright: ignore[reportUnknownArgumentType] + assert list(it) == [Wrapper(4)] # pyright: ignore[reportUnknownArgumentType] + + def test_too_short(self): + obj = [Wrapper(1), Wrapper(2)] + instances, rest = pickle_flatten(obj, Wrapper) + with pytest.raises(ValueError, match="Not enough"): + pickle_unflatten(instances[:1], rest) # pyright: ignore[reportUnknownArgumentType] + + def test_recursion(self): + obj: list[object] = [Wrapper(1)] + obj.append(obj) + + instances, rest = pickle_flatten(obj, Wrapper) + assert instances == [Wrapper(1)] + + obj2 = pickle_unflatten(instances, rest) # pyright: ignore[reportUnknownArgumentType] + assert len(obj2) == 2 + assert obj2[0] is obj[0] + assert obj2[1] is obj2 + + +class TestJAXAutoJIT: + def test_basic(self, jnp: ModuleType): + @jax_autojit + def f(x: Array, k: object = False) -> Array: + return x + 1 if k else x - 1 + + # Basic recognition of static_argnames + xp_assert_equal(f(jnp.asarray([1, 2])), jnp.asarray([0, 1])) + xp_assert_equal(f(jnp.asarray([1, 2]), False), jnp.asarray([0, 1])) + xp_assert_equal(f(jnp.asarray([1, 2]), True), jnp.asarray([2, 3])) + xp_assert_equal(f(jnp.asarray([1, 2]), 1), jnp.asarray([2, 3])) + + # static argument is not an ArrayLike + xp_assert_equal(f(jnp.asarray([1, 2]), "foo"), jnp.asarray([2, 3])) + + # static argument is not hashable, but serializable + xp_assert_equal(f(jnp.asarray([1, 2]), ["foo"]), jnp.asarray([2, 3])) + + def test_wrapper(self, jnp: ModuleType): + @jax_autojit + def f(w: Wrapper[Array]) -> Wrapper[Array]: + return Wrapper(w.x + 1) + + inp = Wrapper(jnp.asarray([1, 2])) + out = f(inp).x + xp_assert_equal(out, jnp.asarray([2, 3])) + + def test_static_hashable(self, jnp: ModuleType): + """Static argument/return value is hashable, but not serializable""" + + class C: + def __reduce__(self) -> object: # type: ignore[explicit-override,override] # pyright: ignore[reportIncompatibleMethodOverride,reportImplicitOverride] + raise Exception() + + @jax_autojit + def f(x: object) -> object: + return x + + inp = C() + out = f(inp) + assert out is inp + + # Serializable opaque input contains non-serializable object plus array + inp = Wrapper((C(), jnp.asarray([1, 2]))) + out = f(inp) + assert isinstance(out, Wrapper) + assert out.x[0] is inp.x[0] + assert out.x[1] is not inp.x[1] + xp_assert_equal(out.x[1], inp.x[1]) # pyright: ignore[reportUnknownArgumentType] + + def test_arraylikes_are_static(self): + pytest.importorskip("jax") + + @jax_autojit + def f(x: list[int]) -> list[int]: + assert isinstance(x, list) + assert x == [1, 2] + return [3, 4] + + out = f([1, 2]) + assert isinstance(out, list) + assert out == [3, 4] diff --git a/tests/test_lazy.py b/tests/test_lazy.py index f40df277..aef73301 100644 --- a/tests/test_lazy.py +++ b/tests/test_lazy.py @@ -15,9 +15,7 @@ from array_api_extra._lib._utils._typing import Array, Device from array_api_extra.testing import lazy_xp_function -lazy_xp_function( - lazy_apply, static_argnames=("func", "shape", "dtype", "as_numpy", "xp") -) +lazy_xp_function(lazy_apply) as_numpy = pytest.mark.parametrize( "as_numpy", @@ -386,7 +384,7 @@ def eager( ) -lazy_xp_function(check_lazy_apply_kwargs, static_argnames=("expect_cls", "as_numpy")) +lazy_xp_function(check_lazy_apply_kwargs) @as_numpy diff --git a/tests/test_testing.py b/tests/test_testing.py index fb9ba581..caba08b4 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -185,32 +185,24 @@ def static_params(x: Array, n: int, flag: bool = False) -> Array: return x * 3.0 -def static_params1(x: Array, n: int, flag: bool = False) -> Array: - return static_params(x, n, flag) +lazy_xp_function(static_params) -def static_params2(x: Array, n: int, flag: bool = False) -> Array: - return static_params(x, n, flag) - - -def static_params3(x: Array, n: int, flag: bool = False) -> Array: - return static_params(x, n, flag) - - -lazy_xp_function(static_params1, static_argnums=(1, 2)) -lazy_xp_function(static_params2, static_argnames=("n", "flag")) -lazy_xp_function(static_params3, static_argnums=1, static_argnames="flag") +def test_lazy_xp_function_static_params(xp: ModuleType): + x = xp.asarray([1.0, 2.0]) + xp_assert_equal(static_params(x, 1), xp.asarray([3.0, 6.0])) + xp_assert_equal(static_params(x, 1, True), xp.asarray([2.0, 4.0])) + xp_assert_equal(static_params(x, 1, False), xp.asarray([3.0, 6.0])) + xp_assert_equal(static_params(x, 0, False), xp.asarray([3.0, 6.0])) + xp_assert_equal(static_params(x, 1, flag=True), xp.asarray([2.0, 4.0])) + xp_assert_equal(static_params(x, n=1, flag=True), xp.asarray([2.0, 4.0])) -@pytest.mark.parametrize("func", [static_params1, static_params2, static_params3]) -def test_lazy_xp_function_static_params(xp: ModuleType, func: Callable[..., Array]): # type: ignore[explicit-any] - x = xp.asarray([1.0, 2.0]) - xp_assert_equal(func(x, 1), xp.asarray([3.0, 6.0])) - xp_assert_equal(func(x, 1, True), xp.asarray([2.0, 4.0])) - xp_assert_equal(func(x, 1, False), xp.asarray([3.0, 6.0])) - xp_assert_equal(func(x, 0, False), xp.asarray([3.0, 6.0])) - xp_assert_equal(func(x, 1, flag=True), xp.asarray([2.0, 4.0])) - xp_assert_equal(func(x, n=1, flag=True), xp.asarray([2.0, 4.0])) +def test_lazy_xp_function_deprecated_static_argnames(): + with pytest.warns(DeprecationWarning, match="static_argnames"): + lazy_xp_function(static_params, static_argnames=["flag"]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] + with pytest.warns(DeprecationWarning, match="static_argnums"): + lazy_xp_function(static_params, static_argnums=[1]) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] try: @@ -273,6 +265,66 @@ def test_lazy_xp_function_eagerly_raises(da: ModuleType): _ = dask_raises(x) +class Wrapper: + """Trivial opaque wrapper. Must be pickleable.""" + + x: Array + + def __init__(self, x: Array): + self.x = x + + +def check_opaque_wrapper(w: Wrapper, xp: ModuleType) -> Wrapper: + assert isinstance(w, Wrapper) + assert array_namespace(w.x) == xp + return Wrapper(w.x + 1) + + +lazy_xp_function(check_opaque_wrapper) + + +def test_lazy_xp_function_opaque_wrappers(xp: ModuleType): + """ + Test that function input and output can be wrapped into arbitrary + serializable Python objects, even if jax.jit does not support them. + """ + x = xp.asarray([1, 2]) + xp2 = array_namespace(x) # Revert NUMPY_READONLY to array_api_compat.numpy + res = check_opaque_wrapper(Wrapper(x), xp2) + xp_assert_equal(res.x, xp.asarray([2, 3])) + + +def test_lazy_xp_function_opaque_wrappers_eagerly_raise(da: ModuleType): + """ + Like `test_lazy_xp_function_eagerly_raises`, but the returned object is + wrapped in an opaque wrapper. + """ + x = da.arange(3) + with pytest.raises(ValueError, match="Hello world"): + _ = Wrapper(dask_raises(x)) + + +def check_recursive(x: list[object]) -> list[object]: + assert isinstance(x, list) + assert x[1] is x + y: list[object] = [cast(Array, x[0]) + 1] + y.append(y) + return y + + +lazy_xp_function(check_recursive) + + +def test_lazy_xp_function_recursive(xp: ModuleType): + """Test that inputs and outputs can be recursive data structures.""" + x: list[object] = [xp.asarray([1, 2])] + x.append(x) + y = check_recursive(x) + assert isinstance(y, list) + xp_assert_equal(cast(Array, y[0]), xp.asarray([2, 3])) + assert y[1] is y + + wrapped = ModuleType("wrapped") naked = ModuleType("naked") From e0f691e0b58ec56dba2243255f3dbffc2ee6a5ca Mon Sep 17 00:00:00 2001 From: Guido Imperiale Date: Mon, 28 Apr 2025 09:50:48 +0100 Subject: [PATCH 2/2] Update src/array_api_extra/_lib/_utils/_helpers.py Co-authored-by: Pearu Peterson --- src/array_api_extra/_lib/_utils/_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/array_api_extra/_lib/_utils/_helpers.py b/src/array_api_extra/_lib/_utils/_helpers.py index 18461845..0dd7f1ed 100644 --- a/src/array_api_extra/_lib/_utils/_helpers.py +++ b/src/array_api_extra/_lib/_utils/_helpers.py @@ -428,7 +428,7 @@ def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[ try: # Note: a class that defines __slots__ without defining __getstate__ # cannot be pickled with __reduce__(), but can with __reduce_ex__(5) - _ = obj.__reduce_ex__(5) + _ = obj.__reduce_ex__(pickle.HIGHEST_PROTOCOL) except Exception: # pylint: disable=broad-exception-caught rest.append(obj) return 1