-
Notifications
You must be signed in to change notification settings - Fork 10
ENH: jax_autojit
#284
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
ENH: jax_autojit
#284
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,10 +2,24 @@ | |
|
||
from __future__ import annotations | ||
|
||
import io | ||
import math | ||
from collections.abc import Generator, Iterable | ||
import pickle | ||
import types | ||
from collections.abc import Callable, Generator, Iterable | ||
from functools import wraps | ||
from types import ModuleType | ||
from typing import TYPE_CHECKING, cast | ||
from typing import ( | ||
TYPE_CHECKING, | ||
Any, | ||
ClassVar, | ||
Generic, | ||
Literal, | ||
ParamSpec, | ||
TypeAlias, | ||
TypeVar, | ||
cast, | ||
) | ||
|
||
from . import _compat | ||
from ._compat import ( | ||
|
@@ -19,8 +33,16 @@ | |
from ._typing import Array | ||
|
||
if TYPE_CHECKING: # pragma: no cover | ||
# TODO import from typing (requires Python >=3.13) | ||
from typing_extensions import TypeIs | ||
# TODO import from typing (requires Python >=3.12 and >=3.13) | ||
from typing_extensions import TypeIs, override | ||
else: | ||
|
||
def override(func): | ||
return func | ||
|
||
|
||
P = ParamSpec("P") | ||
T = TypeVar("T") | ||
|
||
|
||
__all__ = [ | ||
|
@@ -29,8 +51,11 @@ | |
"eager_shape", | ||
"in1d", | ||
"is_python_scalar", | ||
"jax_autojit", | ||
"mean", | ||
"meta_namespace", | ||
"pickle_flatten", | ||
"pickle_unflatten", | ||
] | ||
|
||
|
||
|
@@ -306,3 +331,233 @@ def capabilities(xp: ModuleType) -> dict[str, int]: | |
out["boolean indexing"] = True | ||
out["data-dependent shapes"] = True | ||
return out | ||
|
||
|
||
_BASIC_PICKLED_TYPES = frozenset(( | ||
bool, int, float, complex, str, bytes, bytearray, | ||
list, tuple, dict, set, frozenset, range, slice, | ||
types.NoneType, types.EllipsisType, | ||
)) # fmt: skip | ||
_BASIC_REST_TYPES = frozenset(( | ||
type, types.BuiltinFunctionType, types.FunctionType, types.ModuleType | ||
)) # fmt: skip | ||
|
||
FlattenRest: TypeAlias = tuple[object, ...] | ||
|
||
|
||
def pickle_flatten( | ||
obj: object, cls: type[T] | tuple[type[T], ...] | ||
) -> tuple[list[T], FlattenRest]: | ||
""" | ||
Use the pickle machinery to extract objects out of an arbitrary container. | ||
|
||
Unlike regular ``pickle.dumps``, this function always succeeds. | ||
|
||
Parameters | ||
---------- | ||
obj : object | ||
The object to pickle. | ||
cls : type | tuple[type, ...] | ||
One or multiple classes to extract from the object. | ||
The instances of these classes inside ``obj`` will not be pickled. | ||
|
||
Returns | ||
------- | ||
instances : list[cls] | ||
All instances of ``cls`` found inside ``obj`` (not pickled). | ||
rest | ||
Opaque object containing the pickled bytes plus all other objects where | ||
``__reduce__`` / ``__reduce_ex__`` is either not implemented or raised. | ||
These are unpickleable objects, types, modules, and functions. | ||
|
||
This object is *typically* hashable save for fairly exotic objects | ||
that are neither pickleable nor hashable. | ||
|
||
This object is pickleable if everything except ``instances`` was pickleable | ||
in the input object. | ||
|
||
See Also | ||
-------- | ||
pickle_unflatten : Reverse function. | ||
|
||
Examples | ||
-------- | ||
>>> class A: | ||
... def __repr__(self): | ||
... return "<A>" | ||
>>> class NS: | ||
... def __repr__(self): | ||
... return "<NS>" | ||
... def __reduce__(self): | ||
... assert False, "not serializable" | ||
>>> obj = {1: A(), 2: [A(), NS(), A()]} | ||
>>> instances, rest = pickle_flatten(obj, A) | ||
>>> instances | ||
[<A>, <A>, <A>] | ||
>>> pickle_unflatten(instances, rest) | ||
{1: <A>, 2: [<A>, <NS>, <A>]} | ||
|
||
This can be also used to swap inner objects; the only constraint is that | ||
the number of objects in and out must be the same: | ||
|
||
>>> pickle_unflatten(["foo", "bar", "baz"], rest) | ||
{1: "foo", 2: ["bar", <NS>, "baz"]} | ||
""" | ||
instances: list[T] = [] | ||
rest: list[object] = [] | ||
|
||
class Pickler(pickle.Pickler): # numpydoc ignore=GL08 | ||
""" | ||
Use the `pickle.Pickler.persistent_id` hook to extract objects. | ||
""" | ||
|
||
@override | ||
def persistent_id(self, obj: object) -> Literal[0, 1, None]: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To say that I am aggrieved by pyright and numpydoc would be a gentle understatement at this point. |
||
if isinstance(obj, cls): | ||
instances.append(obj) # type: ignore[arg-type] | ||
return 0 | ||
|
||
typ_ = type(obj) | ||
if typ_ in _BASIC_PICKLED_TYPES: # No subclasses! | ||
# If obj is a collection, recursively descend inside it | ||
return None | ||
if typ_ in _BASIC_REST_TYPES: | ||
rest.append(obj) | ||
return 1 | ||
|
||
try: | ||
# Note: a class that defines __slots__ without defining __getstate__ | ||
# cannot be pickled with __reduce__(), but can with __reduce_ex__(5) | ||
_ = obj.__reduce_ex__(pickle.HIGHEST_PROTOCOL) | ||
except Exception: # pylint: disable=broad-exception-caught | ||
rest.append(obj) | ||
return 1 | ||
|
||
# Object can be pickled. Let the Pickler recursively descend inside it. | ||
return None | ||
|
||
f = io.BytesIO() | ||
p = Pickler(f, protocol=pickle.HIGHEST_PROTOCOL) | ||
p.dump(obj) | ||
return instances, (f.getvalue(), *rest) | ||
|
||
|
||
def pickle_unflatten(instances: Iterable[object], rest: FlattenRest) -> Any: # type: ignore[explicit-any] | ||
""" | ||
Reverse of ``pickle_flatten``. | ||
|
||
Parameters | ||
---------- | ||
instances : Iterable | ||
Inner objects to be reinserted into the flattened container. | ||
rest : FlattenRest | ||
Extra bits, as returned by ``pickle_flatten``. | ||
|
||
Returns | ||
------- | ||
object | ||
The outer object originally passed to ``pickle_flatten`` after a | ||
pickle->unpickle round-trip. | ||
|
||
See Also | ||
-------- | ||
pickle_flatten : Serializing function. | ||
pickle.loads : Standard unpickle function. | ||
|
||
Notes | ||
----- | ||
The `instances` iterable must yield at least the same number of elements as the ones | ||
returned by ``pickle_without``, but the elements do not need to be the same objects | ||
or even the same types of objects. Excess elements, if any, will be left untouched. | ||
""" | ||
iters = iter(instances), iter(rest) | ||
pik = cast(bytes, next(iters[1])) | ||
|
||
class Unpickler(pickle.Unpickler): # numpydoc ignore=GL08 | ||
"""Mirror of the overridden Pickler in pickle_flatten.""" | ||
|
||
@override | ||
def persistent_load(self, pid: Literal[0, 1]) -> object: # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08 | ||
try: | ||
return next(iters[pid]) | ||
except StopIteration as e: | ||
msg = "Not enough objects to unpickle" | ||
raise ValueError(msg) from e | ||
|
||
f = io.BytesIO(pik) | ||
return Unpickler(f).load() | ||
|
||
|
||
class _AutoJITWrapper(Generic[T]): # numpydoc ignore=PR01 | ||
""" | ||
Helper of :func:`jax_autojit`. | ||
|
||
Wrap arbitrary inputs and outputs of the jitted function and | ||
convert them to/from PyTrees. | ||
""" | ||
|
||
obj: T | ||
_registered: ClassVar[bool] = False | ||
__slots__: tuple[str, ...] = ("obj",) | ||
|
||
def __init__(self, obj: T) -> None: # numpydoc ignore=GL08 | ||
self._register() | ||
self.obj = obj | ||
|
||
@classmethod | ||
def _register(cls): # numpydoc ignore=SS06 | ||
""" | ||
Register upon first use instead of at import time, to avoid | ||
globally importing JAX. | ||
""" | ||
Comment on lines
+509
to
+512
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Aside note on design: Dask avoids this exact problem by not requiring any decorator and instead duck-type checking for uniquely named dunder methods called |
||
if not cls._registered: | ||
import jax | ||
|
||
jax.tree_util.register_pytree_node( | ||
cls, | ||
lambda obj: pickle_flatten(obj, jax.Array), # pyright: ignore[reportUnknownArgumentType] | ||
lambda aux_data, children: pickle_unflatten(children, aux_data), # pyright: ignore[reportUnknownArgumentType] | ||
) | ||
cls._registered = True | ||
|
||
|
||
def jax_autojit( | ||
func: Callable[P, T], | ||
) -> Callable[P, T]: # numpydoc ignore=PR01,RT01,SS03 | ||
""" | ||
Wrap `func` with ``jax.jit``, with the following differences: | ||
|
||
- Python scalar arguments and return values are not automatically converted to | ||
``jax.Array`` objects. | ||
- All non-array arguments are automatically treated as static. | ||
Unlike ``jax.jit``, static arguments must be either hashable or serializable with | ||
``pickle``. | ||
- Unlike ``jax.jit``, non-array arguments and return values are not limited to | ||
tuple/list/dict, but can be any object serializable with ``pickle``. | ||
- Automatically descend into non-array arguments and find ``jax.Array`` objects | ||
inside them, then rebuild the arguments when entering `func`, swapping the JAX | ||
concrete arrays with tracer objects. | ||
- Automatically descend into non-array return values and find ``jax.Array`` objects | ||
inside them, then rebuild them downstream of exiting the JIT, swapping the JAX | ||
tracer objects with concrete arrays. | ||
|
||
See Also | ||
-------- | ||
jax.jit : JAX JIT compilation function. | ||
""" | ||
import jax | ||
|
||
@jax.jit # type: ignore[misc] # pyright: ignore[reportUntypedFunctionDecorator] | ||
def inner( # type: ignore[decorated-any,explicit-any] # numpydoc ignore=GL08 | ||
wargs: _AutoJITWrapper[Any], | ||
) -> _AutoJITWrapper[T]: | ||
args, kwargs = wargs.obj | ||
res = func(*args, **kwargs) # pyright: ignore[reportCallIssue] | ||
return _AutoJITWrapper(res) | ||
|
||
@wraps(func) | ||
def outer(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08 | ||
wargs = _AutoJITWrapper((args, kwargs)) | ||
return inner(wargs).obj | ||
|
||
return outer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: everything in this module is a private helper.