Releases: jax-ml/jax
JAX v0.4.34
-
New Functionality
- This release includes wheels for Python 3.13. Free-threading mode is not yet
supported. jax.errors.JaxRuntimeError
has been added as a public alias for the
formerly privateXlaRuntimeError
type.
- This release includes wheels for Python 3.13. Free-threading mode is not yet
-
Breaking changes
jax_pmap_no_rank_reduction
flag is set toTrue
by default.array[0]
on a pmap result now introduces a reshape (usearray[0:1]
instead).- The per-shard shape (accessable via
jax_array.addressable_shards
or
jax_array.addressable_data(0))
now has a leading(1, ...)
. Update code
that directly accesses shards accordingly. The rank of the per-shard-shape
now matches that of the global shape which is the same behavior as jit.
This avoids costly reshapes when passing results from pmap into jit.
jax.experimental.host_callback
has been deprecated since March 2024, with
JAX version 0.4.26. Now we set the default value of the
--jax_host_callback_legacy
configuration value toTrue
, which means that
if your code usesjax.experimental.host_callback
APIs, those API calls
will be implemented in terms of the newjax.experimental.io_callback
API.
If this breaks your code, for a very limited time, you can set the
--jax_host_callback_legacy
toTrue
. Soon we will remove that
configuration option, so you should instead transition to using the
new JAX callback APIs. See #20385 for a discussion.
-
Deprecations
- In
jax.numpy.trim_zeros
, non-arraylike arguments or arraylike
arguments withndim != 1
are now deprecated, and in the future will result
in an error. - Internal pretty-printing tools
jax.core.pp_*
have been removed, after
being deprecated in JAX v0.4.30. jax.lib.xla_client.Device
is deprecated; usejax.Device
instead.jax.lib.xla_client.XlaRuntimeError
has been deprecated. Use
jax.errors.JaxRuntimeError
instead.
- In
-
Deletion:
jax.xla_computation
is deleted. It has been 3 months since its deprecation
in 0.4.30 JAX release.
Please use the AOT APIs to get the same functionality asjax.xla_computation
.jax.xla_computation(fn)(*args, **kwargs)
can be replaced with
jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo')
.- You can also use
.out_info
property ofjax.stages.Lowered
to get the
output information (like tree structure, shape and dtype). - For cross-backend lowering, you can replace
jax.xla_computation(fn, backend='tpu')(*args, **kwargs)
with
jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo')
.
jax.ShapeDtypeStruct
no longer accepts thenamed_shape
argument.
The argument was only used byxmap
which was removed in 0.4.31.jax.tree.map(f, None, non-None)
, which previously emitted a
DeprecationWarning
, now raises an error.None
is only a tree-prefix of itself. To preserve the current behavior, you can
askjax.tree.map
to treatNone
as a leaf value by writing:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.jax.sharding.XLACompatibleSharding
has been removed. Please use
jax.sharding.Sharding
.
-
Bug fixes
- Fixed a bug where
jax.numpy.cumsum
would produce incorrect outputs
if a non-boolean input was provided anddtype=bool
was specified. - Edit implementation of
jax.numpy.ldexp
to get correct gradient.
- Fixed a bug where
JAX release v0.4.33
This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.
A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.
This release fixes that issue by pinning a fixed version of libtpu-nightly
.
This release also fixes an inaccurate result for F64 tanh on CPU (#23590).
Jaxlib release v0.4.32
WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job
JAX release v0.4.32
WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job
Jaxlib release v0.4.31
jaxlib-v0.4.31 jaxlib version 0.4.31
JAX release v0.4.31
jax-v0.4.31 jax version 0.4.31
Jaxlib release v0.4.30
jaxlib-v0.4.30 jaxlib version 0.4.30
Jax release v0.4.30
jax-v0.4.30 jax version 0.4.30
Jaxlib release v0.4.29
-
Bug fixes
- Fixed a bug where XLA sharded some concatenation operations incorrectly,
which manifested as an incorrect output for cumulative reductions (#21403). - Fixed a bug where XLA:CPU miscompiled certain matmul fusions
(openxla/xla#13301). - Fixes a compiler crash on GPU (#21396).
- Fixed a bug where XLA sharded some concatenation operations incorrectly,
-
Deprecations
jax.tree.map(f, None, non-None)
now emits aDeprecationWarning
, and will
raise an error in a future version of jax.None
is only a tree-prefix of
itself. To preserve the current behavior, you can askjax.tree.map
to
treatNone
as a leaf value by writing:
jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)
.
JAX v0.4.29
-
Changes
- We anticipate that this will be the last release of JAX and jaxlib
supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
plugin jaxlib (e.g.pip install jax[cuda12]
). - JAX now requires ml_dtypes version 0.4.0 or newer.
- Removed backwards-compatibility support for old usage of the
jax.experimental.export
API. It is not possible anymore to use
from jax.experimental.export import export
, and instead you should use
from jax.experimental import export
.
The removed functionality has been deprecated since 0.4.24.
- We anticipate that this will be the last release of JAX and jaxlib
-
Deprecations
jax.sharding.XLACompatibleSharding
is deprecated. Please use
jax.sharding.Sharding
.jax.experimental.Exported.in_shardings
has been renamed as
jax.experimental.Exported.in_shardings_hlo
. Same forout_shardings
.
The old names will be removed after 3 months.- Removed a number of previously-deprecated APIs:
- from {mod}
jax.core
:non_negative_dim
,DimSize
,Shape
- from {mod}
jax.lax
:tie_in
- from {mod}
jax.nn
:normalize
- from {mod}
jax.interpreters.xla
:backend_specific_translations
,
translations
,register_translation
,xla_destructure
,
TranslationRule
,TranslationContext
,XlaOp
.
- from {mod}
- The
tol
argument of {func}jax.numpy.linalg.matrix_rank
is being
deprecated and will soon be removed. Usertol
instead. - The
rcond
argument of {func}jax.numpy.linalg.pinv
is being
deprecated and will soon be removed. Usertol
instead. - The deprecated
jax.config
submodule has been removed. To configure JAX
useimport jax
and then reference the config object viajax.config
. - {mod}
jax.random
APIs no longer accept batched keys, where previously
some did unintentionally. Going forward, we recommend explicit use of
{func}jax.vmap
in such cases.
-
New Functionality
- Added {func}
jax.experimental.Exported.in_shardings_jax
to construct
shardings that can be used with the JAX APIs from the HloShardings
that are stored in theExported
objects.
- Added {func}