Skip to content

ENH: jax_autojit #284

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
263 changes: 259 additions & 4 deletions src/array_api_extra/_lib/_utils/_helpers.py
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: everything in this module is a private helper.

Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,24 @@

from __future__ import annotations

import io
import math
from collections.abc import Generator, Iterable
import pickle
import types
from collections.abc import Callable, Generator, Iterable
from functools import wraps
from types import ModuleType
from typing import TYPE_CHECKING, cast
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generic,
Literal,
ParamSpec,
TypeAlias,
TypeVar,
cast,
)

from . import _compat
from ._compat import (
Expand All @@ -19,8 +33,16 @@
from ._typing import Array

if TYPE_CHECKING: # pragma: no cover
# TODO import from typing (requires Python >=3.13)
from typing_extensions import TypeIs
# TODO import from typing (requires Python >=3.12 and >=3.13)
from typing_extensions import TypeIs, override
else:

def override(func):
return func


P = ParamSpec("P")
T = TypeVar("T")


__all__ = [
Expand All @@ -29,8 +51,11 @@
"eager_shape",
"in1d",
"is_python_scalar",
"jax_autojit",
"mean",
"meta_namespace",
"pickle_flatten",
"pickle_unflatten",
]


Expand Down Expand Up @@ -306,3 +331,233 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
out["boolean indexing"] = True
out["data-dependent shapes"] = True
return out


_BASIC_PICKLED_TYPES = frozenset((
bool, int, float, complex, str, bytes, bytearray,
list, tuple, dict, set, frozenset, range, slice,
types.NoneType, types.EllipsisType,
)) # fmt: skip
_BASIC_REST_TYPES = frozenset((
type, types.BuiltinFunctionType, types.FunctionType, types.ModuleType
)) # fmt: skip

FlattenRest: TypeAlias = tuple[object, ...]


def pickle_flatten(
obj: object, cls: type[T] | tuple[type[T], ...]
) -> tuple[list[T], FlattenRest]:
"""
Use the pickle machinery to extract objects out of an arbitrary container.

Unlike regular ``pickle.dumps``, this function always succeeds.

Parameters
----------
obj : object
The object to pickle.
cls : type | tuple[type, ...]
One or multiple classes to extract from the object.
The instances of these classes inside ``obj`` will not be pickled.

Returns
-------
instances : list[cls]
All instances of ``cls`` found inside ``obj`` (not pickled).
rest
Opaque object containing the pickled bytes plus all other objects where
``__reduce__`` / ``__reduce_ex__`` is either not implemented or raised.
These are unpickleable objects, types, modules, and functions.

This object is *typically* hashable save for fairly exotic objects
that are neither pickleable nor hashable.

This object is pickleable if everything except ``instances`` was pickleable
in the input object.

See Also
--------
pickle_unflatten : Reverse function.

Examples
--------
>>> class A:
... def __repr__(self):
... return "<A>"
>>> class NS:
... def __repr__(self):
... return "<NS>"
... def __reduce__(self):
... assert False, "not serializable"
>>> obj = {1: A(), 2: [A(), NS(), A()]}
>>> instances, rest = pickle_flatten(obj, A)
>>> instances
[<A>, <A>, <A>]
>>> pickle_unflatten(instances, rest)
{1: <A>, 2: [<A>, <NS>, <A>]}

This can be also used to swap inner objects; the only constraint is that
the number of objects in and out must be the same:

>>> pickle_unflatten(["foo", "bar", "baz"], rest)
{1: "foo", 2: ["bar", <NS>, "baz"]}
"""
instances: list[T] = []
rest: list[object] = []

class Pickler(pickle.Pickler): # numpydoc ignore=GL08
"""
Use the `pickle.Pickler.persistent_id` hook to extract objects.
"""

@override
def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To say that I am aggrieved by pyright and numpydoc would be a gentle understatement at this point.

if isinstance(obj, cls):
instances.append(obj) # type: ignore[arg-type]
return 0

typ_ = type(obj)
if typ_ in _BASIC_PICKLED_TYPES: # No subclasses!
# If obj is a collection, recursively descend inside it
return None
if typ_ in _BASIC_REST_TYPES:
rest.append(obj)
return 1

try:
# Note: a class that defines __slots__ without defining __getstate__
# cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
_ = obj.__reduce_ex__(pickle.HIGHEST_PROTOCOL)
except Exception: # pylint: disable=broad-exception-caught
rest.append(obj)
return 1

# Object can be pickled. Let the Pickler recursively descend inside it.
return None

f = io.BytesIO()
p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL)
p.dump(obj)
return instances, (f.getvalue(), *rest)


def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any]
"""
Reverse of ``pickle_flatten``.

Parameters
----------
instances : Iterable
Inner objects to be reinserted into the flattened container.
rest : FlattenRest
Extra bits, as returned by ``pickle_flatten``.

Returns
-------
object
The outer object originally passed to ``pickle_flatten`` after a
pickle->unpickle round-trip.

See Also
--------
pickle_flatten : Serializing function.
pickle.loads : Standard unpickle function.

Notes
-----
The `instances` iterable must yield at least the same number of elements as the ones
returned by ``pickle_without``, but the elements do not need to be the same objects
or even the same types of objects. Excess elements, if any, will be left untouched.
"""
iters = iter(instances), iter(rest)
pik = cast(bytes, next(iters[1]))

class Unpickler(pickle.Unpickler): # numpydoc ignore=GL08
"""Mirror of the overridden Pickler in pickle_flatten."""

@override
def persistent_load(self, pid: Literal[0, 1]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
try:
return next(iters[pid])
except StopIteration as e:
msg = "Not enough objects to unpickle"
raise ValueError(msg) from e

f = io.BytesIO(pik)
return Unpickler(f).load()


class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01
"""
Helper of :func:`jax_autojit`.

Wrap arbitrary inputs and outputs of the jitted function and
convert them to/from PyTrees.
"""

obj: T
_registered: ClassVar[bool] = False
__slots__: tuple[str, ...] = ("obj",)

def __init__(self, obj: T) -> None: # numpydoc ignore=GL08
self._register()
self.obj = obj

@classmethod
def _register(cls): # numpydoc ignore=SS06
"""
Register upon first use instead of at import time, to avoid
globally importing JAX.
"""
Comment on lines +509 to +512
Copy link
Contributor Author

@crusaderky crusaderky Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aside note on design: Dask avoids this exact problem by not requiring any decorator and instead duck-type checking for uniquely named dunder methods called __dask_<...>__

if not cls._registered:
import jax

jax.tree_util.register_pytree_node(
cls,
lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType]
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType]
)
cls._registered = True


def jax_autojit(
func: Callable[P, T],
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03
"""
Wrap `func` with ``jax.jit``, with the following differences:

- Python scalar arguments and return values are not automatically converted to
``jax.Array`` objects.
- All non-array arguments are automatically treated as static.
Unlike ``jax.jit``, static arguments must be either hashable or serializable with
``pickle``.
- Unlike ``jax.jit``, non-array arguments and return values are not limited to
tuple/list/dict, but can be any object serializable with ``pickle``.
- Automatically descend into non-array arguments and find ``jax.Array`` objects
inside them, then rebuild the arguments when entering `func`, swapping the JAX
concrete arrays with tracer objects.
- Automatically descend into non-array return values and find ``jax.Array`` objects
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
tracer objects with concrete arrays.

See Also
--------
jax.jit : JAX JIT compilation function.
"""
import jax

@jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator]
def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08
wargs: _AutoJITWrapper[Any],
) -> _AutoJITWrapper[T]:
args, kwargs = wargs.obj
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue]
return _AutoJITWrapper(res)

@wraps(func)
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
wargs = _AutoJITWrapper((args, kwargs))
return inner(wargs).obj

return outer
Loading