JAX v0.4.36
-
Breaking Changes
-
This release lands "stackless", an internal change to JAX's tracing
machinery. We made trace dispatch purely a function of context rather than a
function of both context and data. This let us delete a lot of machinery for
managing data-dependent tracing: levels, sublevels,post_process_call
,
new_base_main
,custom_bind
, and so on. The change should only affect
users that use JAX internals.If you do use JAX internals then you may need to
update your code (see
c36e1f7
for clues about how to do this). There might also be version skew
issues with JAX libraries that do this. If you find this change breaks your
non-JAX-internals-using code then try the
config.jax_data_dependent_tracing_fallback
flag as a workaround, and if
you need help updating your code then please file a bug. -
jax.experimental.jax2tf.convert
withnative_serialization=False
or withenable_xla=False
have been deprecated since July 2024, with
JAX version 0.4.31. Now we removed support for these use cases.jax2tf
with native serialization will still be supported. -
In
jax.interpreters.xla
, thexb
,xc
, andxe
symbols have been removed
after being deprecated in JAX v0.4.31. Instead usexb = jax.lib.xla_bridge
,
xc = jax.lib.xla_client
, andxe = jax.lib.xla_extension
. -
The deprecated module
jax.experimental.export
has been removed. It was replaced
byjax.export
in JAX v0.4.30. See the migration guide
for information on migrating to the new API. -
The
initial
argument tojax.nn.softmax
andjax.nn.log_softmax
has been removed, after being deprecated in v0.4.27. -
Calling
np.asarray
on typed PRNG keys (i.e. keys produced byjax.random.key
)
now raises an error. Previously, this returned a scalar object array. -
The following deprecated methods and functions in
jax.export
have
been removed:jax.export.DisabledSafetyCheck.shape_assertions
: it had no effect
already.jax.export.Exported.lowering_platforms
: useplatforms
.jax.export.Exported.mlir_module_serialization_version
:
usecalling_convention_version
.jax.export.Exported.uses_shape_polymorphism
:
useuses_global_constants
.- the
lowering_platforms
kwarg forjax.export.export
: use
platforms
instead.
-
The kwargs
symbolic_scope
andsymbolic_constraints
from
jax.export.symbolic_args_specs
have been removed. They were
deprecated in June 2024. Usescope
andconstraints
instead. -
Hashing of tracers, which has been deprecated since version 0.4.30, now
results in aTypeError
. -
Refactor: JAX build CLI (build/build.py) now uses a subcommand structure and
replaces previous build.py usage. Runpython build/build.py --help
for
more details. Brief overview of the new subcommand options:build
: Builds JAX wheel packages. For e.g.,python build/build.py build --wheels=jaxlib,jax-cuda-pjrt
requirements_update
: Updates requirements_lock.txt files.
-
jax.scipy.linalg.toeplitz
now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can calljax.numpy.ravel
on the function inputs. -
jax.scipy.special.gamma
andjax.scipy.special.gammasgn
now
return NaN for negative integer inputs, to match the behavior of SciPy from
scipy/scipy#21827. -
jax.clear_backends
was removed after being deprecated in v0.4.26. -
We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
call that we guarantee export stability. This is because this custom call
relies on Triton IR, which is not guaranteed to be stable. If you need
to export code that uses this custom call, you can use thedisabled_checks
parameter. See more details in the documentation.
-
-
New Features
jax.jit
got a newcompiler_options: dict[str, Any]
argument, for
passing compilation options to XLA. For the moment it's undocumented and
may be in flux.jax.tree_util.register_dataclass
now allows metadata fields to be
declared inline viadataclasses.field
. See the function documentation
for examples.- Added
jax.numpy.put_along_axis
. jax.lax.linalg.eig
and the relatedjax.numpy
functions
(jax.numpy.linalg.eig
andjax.numpy.linalg.eigvals
) are now
supported on GPU. See #24663 for more details.- Added two new configuration flags,
jax_exec_time_optimization_effort
andjax_memory_fitting_effort
, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
-
Bug fixes
- Fixed a bug where the GPU implementations of LU and QR decomposition would
result in an indexing overflow for batch sizes close to int32 max. See
#24843 for more details.
- Fixed a bug where the GPU implementations of LU and QR decomposition would
-
Deprecations
jax.lib.xla_extension.ArrayImpl
andjax.lib.xla_client.ArrayImpl
are deprecated;
usejax.Array
instead.jax.lib.xla_extension.XlaRuntimeError
is deprecated; usejax.errors.JaxRuntimeError
instead.