Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade jax and fix deprecation warnings #915

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion equinox/_ad.py
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@
import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.numpy as jnp
import jax.tree_util as jtu
@@ -598,7 +599,7 @@ class _ClosureConvert(Module):
# Important that `jaxpr` be a leaf (and not static), so that it is a tuple element
# when passing through `filter_primitive_bind` and thus visible to
# `jax.core.subjaxprs`
jaxpr: jax.core.Jaxpr
jaxpr: jax.extend.core.Jaxpr
consts: PyTree[ArrayLike] # Captured in the PyTree structure of _ClosureConvert
in_dynamic_struct: _FlatPyTree[jax.ShapeDtypeStruct] = field(static=True)
out_dynamic_struct: _FlatPyTree[jax.ShapeDtypeStruct] = field(static=True)
4 changes: 2 additions & 2 deletions equinox/_make_jaxpr.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@

import jax
import jax._src.traceback_util as traceback_util
import jax.core
import jax.extend.core
import jax.tree_util as jtu
from jaxtyping import PyTree

@@ -49,7 +49,7 @@ def _fn(*_dynamic_flat):
def filter_make_jaxpr(
fun: Callable[_P, Any],
) -> Callable[
_P, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]
_P, tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]
]:
"""As `jax.make_jaxpr`, but accepts arbitrary PyTrees as input and output.

7 changes: 4 additions & 3 deletions equinox/_unvmap.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@

import jax
import jax.core
import jax.extend.core
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
import jax.numpy as jnp
@@ -10,7 +11,7 @@

# unvmap_all

unvmap_all_p = jax.core.Primitive("unvmap_all")
unvmap_all_p = jax.extend.core.Primitive("unvmap_all")


def unvmap_all(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]:
@@ -41,7 +42,7 @@ def _unvmap_all_batch(x, batch_axes):

# unvmap_any

unvmap_any_p = jax.core.Primitive("unvmap_any")
unvmap_any_p = jax.extend.core.Primitive("unvmap_any")


def unvmap_any(x: Bool[ArrayLike, "..."]) -> Bool[Array, ""]:
@@ -72,7 +73,7 @@ def _unvmap_any_batch(x, batch_axes):

# unvmap_max

unvmap_max_p = jax.core.Primitive("unvmap_max")
unvmap_max_p = jax.extend.core.Primitive("unvmap_max")


def unvmap_max(x: Int[ArrayLike, "..."]) -> Int[Array, ""]:
4 changes: 2 additions & 2 deletions equinox/debug/_announce_transform.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
from typing import Any

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
@@ -124,7 +124,7 @@ def _mlir(*x, stack, name, intermediates, announce):
return x


announce_jaxpr_p = jax.core.Primitive("announce_jaxpr")
announce_jaxpr_p = jax.extend.core.Primitive("announce_jaxpr")
announce_jaxpr_p.multiple_results = True
announce_jaxpr_p.def_impl(_impl)
announce_jaxpr_p.def_abstract_eval(_abstract)
40 changes: 22 additions & 18 deletions equinox/internal/_finalise_jaxpr.py
Original file line number Diff line number Diff line change
@@ -22,6 +22,7 @@
import jax
import jax.core
import jax.custom_derivatives
import jax.extend.core
import jax.tree_util as jtu
from jaxtyping import PyTree

@@ -36,13 +37,13 @@ def _safe_map(f, *args):

def _maybe_finalise_jaxpr(val: Any):
is_open_jaxpr = False
if isinstance(val, jax.core.Jaxpr):
if isinstance(val, jax.extend.core.Jaxpr):
if len(val.constvars) == 0:
is_open_jaxpr = True
val = jax.core.ClosedJaxpr(val, [])
val = jax.extend.core.ClosedJaxpr(val, [])
else:
return val
if isinstance(val, jax.core.ClosedJaxpr):
if isinstance(val, jax.extend.core.ClosedJaxpr):
val = finalise_jaxpr(val)
if is_open_jaxpr:
val = val.jaxpr
@@ -60,33 +61,33 @@ def _finalise_jaxprs_in_params(params):
return new_params


def _default_finalisation(prim: jax.core.Primitive, *args, **kwargs):
def _default_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs):
return prim.bind(*args, **kwargs)


def _impl_finalisation(prim: jax.core.Primitive, *args, **kwargs):
def _impl_finalisation(prim: jax.extend.core.Primitive, *args, **kwargs):
return prim.impl(*args, **kwargs)


primitive_finalisations = {}


def register_impl_finalisation(prim: jax.core.Primitive):
def register_impl_finalisation(prim: jax.extend.core.Primitive):
primitive_finalisations[prim] = ft.partial(_impl_finalisation, prim)


def finalise_eval_jaxpr(jaxpr: jax.core.Jaxpr, consts, *args):
def finalise_eval_jaxpr(jaxpr: jax.extend.core.Jaxpr, consts, *args):
"""As jax.core.eval_jaxpr, but finalises (typically by calling `impl` rather than
`bind` for custom primitives).
"""

def read(v: jax.core.Atom) -> Any:
return v.val if isinstance(v, jax.core.Literal) else env[v]
return v.val if isinstance(v, jax.extend.core.Literal) else env[v]

def write(v: jax.core.Var, val: Any) -> None:
def write(v: jax.extend.core.Var, val: Any) -> None:
env[v] = val

env: dict[jax.core.Var, Any] = {}
env: dict[jax.extend.core.Var, Any] = {}
_safe_map(write, jaxpr.constvars, consts)
_safe_map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
@@ -104,18 +105,18 @@ def write(v: jax.core.Var, val: Any) -> None:
return _safe_map(read, jaxpr.outvars)


def finalise_jaxpr_as_fn(jaxpr: jax.core.ClosedJaxpr):
def finalise_jaxpr_as_fn(jaxpr: jax.extend.core.ClosedJaxpr):
"""As `jax.core.jaxpr_as_fn`, but the result is finalised."""
return ft.partial(finalise_eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)


def finalise_jaxpr(jaxpr: jax.core.ClosedJaxpr) -> jax.core.ClosedJaxpr:
def finalise_jaxpr(jaxpr: jax.extend.core.ClosedJaxpr) -> jax.extend.core.ClosedJaxpr:
"""A jaxpr-to-jaxpr transformation that performs finalisation."""
fn = finalise_jaxpr_as_fn(jaxpr)
args = [
jax.ShapeDtypeStruct(x.aval.shape, x.aval.dtype) for x in jaxpr.jaxpr.invars
]
return cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args))
return cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(*args))


def finalise_fn(fn):
@@ -136,13 +137,15 @@ def _finalise_fn(*args):
@overload
def finalise_make_jaxpr(
fn, *, return_shape: Literal[False] = False
) -> Callable[..., jax.core.ClosedJaxpr]: ...
) -> Callable[..., jax.extend.core.ClosedJaxpr]: ...


@overload
def finalise_make_jaxpr(
fn, *, return_shape: Literal[True] = True
) -> Callable[..., tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]]: ...
) -> Callable[
..., tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]
]: ...


@overload
@@ -151,7 +154,8 @@ def finalise_make_jaxpr(
) -> Callable[
...,
Union[
jax.core.ClosedJaxpr, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]]
jax.extend.core.ClosedJaxpr,
tuple[jax.extend.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct]],
],
]: ...

@@ -164,12 +168,12 @@ def _finalise_make_jaxpr(*args):
*args
)
if return_shape:
jaxpr_struct = cast(tuple[jax.core.ClosedJaxpr, Any], jaxpr_struct)
jaxpr_struct = cast(tuple[jax.extend.core.ClosedJaxpr, Any], jaxpr_struct)
jaxpr, struct = jaxpr_struct
jaxpr = finalise_jaxpr(jaxpr)
return jaxpr, struct
else:
jaxpr_struct = cast(jax.core.ClosedJaxpr, jaxpr_struct)
jaxpr_struct = cast(jax.extend.core.ClosedJaxpr, jaxpr_struct)
jaxpr = finalise_jaxpr(jaxpr_struct)
return jaxpr

4 changes: 2 additions & 2 deletions equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@
from typing import Any, TYPE_CHECKING, Union

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
@@ -105,7 +105,7 @@ def _select_if_vmap_batch(axis_size, axis_name, trace, inputs, batch_axes):
return out, out_axis


select_if_vmap_p = jax.core.Primitive("select_if_vmap")
select_if_vmap_p = jax.extend.core.Primitive("select_if_vmap")
select_if_vmap_p.def_impl(_select_if_vmap_impl)
select_if_vmap_p.def_abstract_eval(_select_if_vmap_abstract)
ad.primitive_jvps[select_if_vmap_p] = _select_if_vmap_jvp
3 changes: 2 additions & 1 deletion equinox/internal/_noinline.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
@@ -330,7 +331,7 @@ def _noinline_mlir(ctx, *dynamic, treedef, static, flatten, **kwargs):
return result


noinline_p = jax.core.Primitive("noinline")
noinline_p = jax.extend.core.Primitive("noinline")
noinline_p.multiple_results = True
noinline_p.def_impl(_noinline_impl)
noinline_p.def_abstract_eval(_noinline_abstract)
8 changes: 4 additions & 4 deletions equinox/internal/_nontraceable.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
from typing import Optional

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
@@ -29,7 +29,7 @@ def _error(*args, name):
return _error


nontraceable_p = jax.core.Primitive("nontraceable")
nontraceable_p = jax.extend.core.Primitive("nontraceable")
nontraceable_p.def_impl(_nontraceable_impl)
nontraceable_p.def_abstract_eval(_nontraceable_impl)
ad.primitive_jvps[nontraceable_p] = _make_error("differentiation")
@@ -53,7 +53,7 @@ def nontraceable(x, *, name="nontraceable operation"):
return combine(dynamic, static)


nondifferentiable_backward_p = jax.core.Primitive("nondifferentiable_backward")
nondifferentiable_backward_p = jax.extend.core.Primitive("nondifferentiable_backward")


def _nondifferentiable_backward_batch(x, batch_axes, *, msg, symbolic):
@@ -137,7 +137,7 @@ def _cannot_batch(x, b, *, msg, allow_constant_across_batch):
raise ValueError(msg)


nonbatchable_p = jax.core.Primitive("nonbatchable")
nonbatchable_p = jax.extend.core.Primitive("nonbatchable")
nonbatchable_p.def_impl(lambda x, *, msg, allow_constant_across_batch: x)
nonbatchable_p.def_abstract_eval(lambda x, *, msg, allow_constant_across_batch: x)
batching.primitive_batchers[nonbatchable_p] = _cannot_batch
5 changes: 3 additions & 2 deletions equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

import jax
import jax.core
import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
@@ -255,7 +256,7 @@ def _wrapper(dynamic, batch_axes, *, treedef, static, flatten):
return _wrapper


def filter_primitive_bind(prim: jax.core.Primitive, *args) -> PyTree:
def filter_primitive_bind(prim: jax.extend.core.Primitive, *args) -> PyTree:
"""Calls a primitive that has had its rules defined using the filter
functions above.
"""
@@ -301,7 +302,7 @@ def materialise_zeros(primal, tangent, allow_struct=False):


def create_vprim(name: str, impl, abstract_eval, jvp, transpose):
prim = jax.core.Primitive(name)
prim = jax.extend.core.Primitive(name)
prim.multiple_results = True

def batch_rule(axis_size, axis_name, trace_type, inputs, batch_axes, **params):
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
@@ -81,8 +81,9 @@ plugins:
- import jaxtyping
- jaxtyping.set_array_name_format("array")
- import jax
- import jax.extend.core
- jax.ShapeDtypeStruct.__module__ = "jax"
- jax.core.ClosedJaxpr.__module__ = "jax.core"
- jax.extend.core.ClosedJaxpr.__module__ = "jax.core"

selection:
inherited_members: true # Allow looking up inherited methods
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@ name = "equinox"
version = "0.11.10"
description = "Elegant easy-to-use neural networks in JAX."
readme = "README.md"
requires-python =">=3.9"
requires-python =">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "Patrick Kidger", email = "contact@kidger.site"},
@@ -23,7 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/equinox" }
dependencies = ["jax>=0.4.13,!=0.4.27", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"]
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.20", "typing_extensions>=4.5.0"]

[build-system]
requires = ["hatchling"]
27 changes: 16 additions & 11 deletions tests/test_finalise_jaxpr.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@
import equinox.internal as eqxi
import jax
import jax.core
import jax.extend.core
import jax.lax as lax
import jax.numpy as jnp

@@ -23,7 +24,9 @@ def _assert_vars_equal(obj1, obj2, varnames):
assert a.aval.strip_weak_type() == b.aval.strip_weak_type()


def _assert_jaxpr_equal(jaxpr1: jax.core.ClosedJaxpr, jaxpr2: jax.core.ClosedJaxpr):
def _assert_jaxpr_equal(
jaxpr1: jax.extend.core.ClosedJaxpr, jaxpr2: jax.extend.core.ClosedJaxpr
):
assert jaxpr1.consts == jaxpr2.consts
jaxpr1 = jaxpr1.jaxpr
jaxpr2 = jaxpr2.jaxpr
@@ -41,7 +44,7 @@ def fn(x):
x = x * 2
return x

jaxpr = cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(1))
jaxpr = cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(1))
jaxpr2 = eqxi.finalise_jaxpr(jaxpr)
_assert_jaxpr_equal(jaxpr, jaxpr2)

@@ -53,13 +56,13 @@ def fn(x):
x = jnp.invert(x)
return x

jaxpr = cast(jax.core.ClosedJaxpr, jax.make_jaxpr(fn)(True))
jaxpr = cast(jax.extend.core.ClosedJaxpr, jax.make_jaxpr(fn)(True))
jaxpr2 = eqxi.finalise_jaxpr(jaxpr)
jaxpr3 = eqxi.finalise_jaxpr(jaxpr2)
_assert_jaxpr_equal(jaxpr2, jaxpr3)

jaxpr = jax.make_jaxpr(jax.vmap(fn))(jnp.array([True, False]))
jaxpr = cast(jax.core.ClosedJaxpr, jaxpr)
jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr)
jaxpr2 = eqxi.finalise_jaxpr(jaxpr)
jaxpr3 = eqxi.finalise_jaxpr(jaxpr2)
_assert_jaxpr_equal(jaxpr2, jaxpr3)
@@ -78,9 +81,9 @@ def fn(x):
assert tree_allclose(fn(-1), finalised_fn(-1))

jaxpr = jax.make_jaxpr(fn)(1)
jaxpr = cast(jax.core.ClosedJaxpr, jaxpr)
jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr)
finalised_jaxpr = jax.make_jaxpr(finalised_fn)(1)
finalised_jaxpr = cast(jax.core.ClosedJaxpr, finalised_jaxpr)
finalised_jaxpr = cast(jax.extend.core.ClosedJaxpr, finalised_jaxpr)
_assert_jaxpr_equal(finalised_jaxpr, jaxpr)


@@ -96,9 +99,11 @@ def fn(x):
assert tree_allclose(fn(True), finalised_fn(True))

finalised_jaxpr = jax.make_jaxpr(finalised_fn)(True)
finalised_jaxpr = cast(jax.core.ClosedJaxpr, finalised_jaxpr)
finalised_jaxpr = cast(jax.extend.core.ClosedJaxpr, finalised_jaxpr)
finalised_finalised_jaxpr = jax.make_jaxpr(eqxi.finalise_fn(finalised_fn))(True)
finalised_finalised_jaxpr = cast(jax.core.ClosedJaxpr, finalised_finalised_jaxpr)
finalised_finalised_jaxpr = cast(
jax.extend.core.ClosedJaxpr, finalised_finalised_jaxpr
)
_assert_jaxpr_equal(finalised_jaxpr, finalised_finalised_jaxpr)
for eqn in finalised_jaxpr.eqns:
assert eqn.primitive != eqxi.unvmap_any_p
@@ -114,19 +119,19 @@ def fn(x):
assert tree_allclose(vmap_fn(arg), finalised_vmap_fn(arg))

finalised_vmap_jaxpr = jax.make_jaxpr(finalised_vmap_fn)(jnp.array([False, False]))
finalised_vmap_jaxpr = cast(jax.core.ClosedJaxpr, finalised_vmap_jaxpr)
finalised_vmap_jaxpr = cast(jax.extend.core.ClosedJaxpr, finalised_vmap_jaxpr)
finalised_finalised_vmap_jaxpr = jax.make_jaxpr(
eqxi.finalise_fn(finalised_vmap_fn)
)(jnp.array([False, False]))
finalised_finalised_vmap_jaxpr = cast(
jax.core.ClosedJaxpr, finalised_finalised_vmap_jaxpr
jax.extend.core.ClosedJaxpr, finalised_finalised_vmap_jaxpr
)
for eqn in finalised_vmap_jaxpr.eqns:
assert eqn.primitive != eqxi.unvmap_any_p
_assert_jaxpr_equal(finalised_vmap_jaxpr, finalised_finalised_vmap_jaxpr)


def _assert_no_unvmap(jaxpr: jax.core.Jaxpr):
def _assert_no_unvmap(jaxpr: jax.extend.core.Jaxpr):
for eqn in jaxpr.eqns:
assert eqn.primitive not in (eqxi.unvmap_any_p, eqxi.unvmap_all_p)
for subjaxpr in jax.core.subjaxprs(jaxpr):
3 changes: 2 additions & 1 deletion tests/test_nontraceable.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
import equinox.internal as eqxi
import jax
import jax.core
import jax.extend.core
import jax.numpy as jnp
import jax.tree_util as jtu
import pytest
@@ -73,7 +74,7 @@ def run(dynamic, static):
jax.vmap(run, in_axes=(0, None))(dynamic_batch, static)

jaxpr = jax.make_jaxpr(run, static_argnums=1)(dynamic, static)
jaxpr = cast(jax.core.ClosedJaxpr, jaxpr)
jaxpr = cast(jax.extend.core.ClosedJaxpr, jaxpr)
run2 = jax.core.jaxpr_as_fun(jaxpr)

run2(*dynamic_flat) # pyright: ignore
3 changes: 2 additions & 1 deletion tests/test_primitive.py
Original file line number Diff line number Diff line change
@@ -4,6 +4,7 @@
import equinox.internal as eqxi
import jax
import jax.core
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the fix!

By the way, do we need to keep all of the import jax.core around? I'm guessing most of these can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems you failed to push after removing the unused imports?

Copy link
Contributor Author

@DrJessop DrJessop Dec 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@garymm No, a fair amount of files still need jax.core, hence why it's not removed everywhere.

import jax.extend.core
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
@@ -14,7 +15,7 @@


def test_call():
newprim_p = jax.core.Primitive("newprim")
newprim_p = jax.extend.core.Primitive("newprim")
newprim_p.multiple_results = True

newprim = ft.partial(eqxi.filter_primitive_bind, newprim_p)