Skip to content

Releases: jax-ml/jax

JAX v0.4.34

04 Oct 14:51
Compare
Choose a tag to compare
  • New Functionality

    • This release includes wheels for Python 3.13. Free-threading mode is not yet
      supported.
    • jax.errors.JaxRuntimeError has been added as a public alias for the
      formerly private XlaRuntimeError type.
  • Breaking changes

    • jax_pmap_no_rank_reduction flag is set to True by default.
      • array[0] on a pmap result now introduces a reshape (use array[0:1]
        instead).
      • The per-shard shape (accessable via jax_array.addressable_shards or
        jax_array.addressable_data(0)) now has a leading (1, ...). Update code
        that directly accesses shards accordingly. The rank of the per-shard-shape
        now matches that of the global shape which is the same behavior as jit.
        This avoids costly reshapes when passing results from pmap into jit.
    • jax.experimental.host_callback has been deprecated since March 2024, with
      JAX version 0.4.26. Now we set the default value of the
      --jax_host_callback_legacy configuration value to True, which means that
      if your code uses jax.experimental.host_callback APIs, those API calls
      will be implemented in terms of the new jax.experimental.io_callback API.
      If this breaks your code, for a very limited time, you can set the
      --jax_host_callback_legacy to True. Soon we will remove that
      configuration option, so you should instead transition to using the
      new JAX callback APIs. See #20385 for a discussion.
  • Deprecations

    • In jax.numpy.trim_zeros, non-arraylike arguments or arraylike
      arguments with ndim != 1 are now deprecated, and in the future will result
      in an error.
    • Internal pretty-printing tools jax.core.pp_* have been removed, after
      being deprecated in JAX v0.4.30.
    • jax.lib.xla_client.Device is deprecated; use jax.Device instead.
    • jax.lib.xla_client.XlaRuntimeError has been deprecated. Use
      jax.errors.JaxRuntimeError instead.
  • Deletion:

    • jax.xla_computation is deleted. It has been 3 months since its deprecation
      in 0.4.30 JAX release.
      Please use the AOT APIs to get the same functionality as jax.xla_computation.
      • jax.xla_computation(fn)(*args, **kwargs) can be replaced with
        jax.jit(fn).lower(*args, **kwargs).compiler_ir('hlo').
      • You can also use .out_info property of jax.stages.Lowered to get the
        output information (like tree structure, shape and dtype).
      • For cross-backend lowering, you can replace
        jax.xla_computation(fn, backend='tpu')(*args, **kwargs) with
        jax.jit(fn).trace(*args, **kwargs).lower(lowering_platforms=('tpu',)).compiler_ir('hlo').
    • jax.ShapeDtypeStruct no longer accepts the named_shape argument.
      The argument was only used by xmap which was removed in 0.4.31.
    • jax.tree.map(f, None, non-None), which previously emitted a
      DeprecationWarning, now raises an error. None
      is only a tree-prefix of itself. To preserve the current behavior, you can
      ask jax.tree.map to treat None as a leaf value by writing:
      jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).
    • jax.sharding.XLACompatibleSharding has been removed. Please use
      jax.sharding.Sharding.
  • Bug fixes

    • Fixed a bug where jax.numpy.cumsum would produce incorrect outputs
      if a non-boolean input was provided and dtype=bool was specified.
    • Edit implementation of jax.numpy.ldexp to get correct gradient.

JAX release v0.4.33

16 Sep 18:42
Compare
Choose a tag to compare

This is a patch release on top of jax 0.4.32, that fixes two bugs found in that
release.

A TPU-only data corruption bug was found in the version of libtpu pinned by
JAX 0.4.32, which manifested only if multiple TPU slices were present in the
same job, for example, if training on multiple v5e slices.

This release fixes that issue by pinning a fixed version of libtpu-nightly.

This release also fixes an inaccurate result for F64 tanh on CPU (#23590).

Jaxlib release v0.4.32

11 Sep 20:03
Compare
Choose a tag to compare

WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job

JAX release v0.4.32

11 Sep 20:05
Compare
Choose a tag to compare

WARNING: This release has been yanked from PyPI because of a data corruption bug on TPU if there are multiple TPU slices in the job

Jaxlib release v0.4.31

30 Jul 00:09
Compare
Choose a tag to compare
jaxlib-v0.4.31

jaxlib version 0.4.31

JAX release v0.4.31

30 Jul 00:10
Compare
Choose a tag to compare
jax-v0.4.31

jax version 0.4.31

Jaxlib release v0.4.30

18 Jun 15:07
Compare
Choose a tag to compare
jaxlib-v0.4.30

jaxlib version 0.4.30

Jax release v0.4.30

18 Jun 15:07
Compare
Choose a tag to compare
jax-v0.4.30

jax version 0.4.30

Jaxlib release v0.4.29

10 Jun 18:31
Compare
Choose a tag to compare
  • Bug fixes

    • Fixed a bug where XLA sharded some concatenation operations incorrectly,
      which manifested as an incorrect output for cumulative reductions (#21403).
    • Fixed a bug where XLA:CPU miscompiled certain matmul fusions
      (openxla/xla#13301).
    • Fixes a compiler crash on GPU (#21396).
  • Deprecations

    • jax.tree.map(f, None, non-None) now emits a DeprecationWarning, and will
      raise an error in a future version of jax. None is only a tree-prefix of
      itself. To preserve the current behavior, you can ask jax.tree.map to
      treat None as a leaf value by writing:
      jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None).

JAX v0.4.29

10 Jun 18:31
Compare
Choose a tag to compare
  • Changes

    • We anticipate that this will be the last release of JAX and jaxlib
      supporting a monolithic CUDA jaxlib. Future releases will use the CUDA
      plugin jaxlib (e.g. pip install jax[cuda12]).
    • JAX now requires ml_dtypes version 0.4.0 or newer.
    • Removed backwards-compatibility support for old usage of the
      jax.experimental.export API. It is not possible anymore to use
      from jax.experimental.export import export, and instead you should use
      from jax.experimental import export.
      The removed functionality has been deprecated since 0.4.24.
  • Deprecations

    • jax.sharding.XLACompatibleSharding is deprecated. Please use
      jax.sharding.Sharding.
    • jax.experimental.Exported.in_shardings has been renamed as
      jax.experimental.Exported.in_shardings_hlo. Same for out_shardings.
      The old names will be removed after 3 months.
    • Removed a number of previously-deprecated APIs:
      • from {mod}jax.core: non_negative_dim, DimSize, Shape
      • from {mod}jax.lax: tie_in
      • from {mod}jax.nn: normalize
      • from {mod}jax.interpreters.xla: backend_specific_translations,
        translations, register_translation, xla_destructure,
        TranslationRule, TranslationContext, XlaOp.
    • The tol argument of {func}jax.numpy.linalg.matrix_rank is being
      deprecated and will soon be removed. Use rtol instead.
    • The rcond argument of {func}jax.numpy.linalg.pinv is being
      deprecated and will soon be removed. Use rtol instead.
    • The deprecated jax.config submodule has been removed. To configure JAX
      use import jax and then reference the config object via jax.config.
    • {mod}jax.random APIs no longer accept batched keys, where previously
      some did unintentionally. Going forward, we recommend explicit use of
      {func}jax.vmap in such cases.
  • New Functionality

    • Added {func}jax.experimental.Exported.in_shardings_jax to construct
      shardings that can be used with the JAX APIs from the HloShardings
      that are stored in the Exported objects.