Skip to content

Releases: google/flax

Version 0.5.1

10 Jun 20:34
Compare
Choose a tag to compare

What's Changed

New Contributors

Full Changelog: v0.5.0...v0.5.1

Version 0.5.0

23 May 12:52
Compare
Choose a tag to compare

New features:

  • Added flax.jax_utils.ad_shard_unpad() by @lucasb-eyer
  • Implemented default dtype FLIP.
    This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32.
    This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate
    complex numbers to their real component by default. Instead the complex dtype is preserved by default.

Bug fixes:

  • Fix support for JAX's experimental_name_stack.

Breaking changes:

  • In rare cases the dtype of a layer can change due to default dtype FLIP. See the "Backward compatibility" section of the proposal for more information.

Version 0.4.3

05 May 12:47
Compare
Choose a tag to compare

Note

Due to a release error we had to roll out a new release, but this version is exactly the same as v0.4.2.

Version 0.4.2

05 May 12:37
Compare
Choose a tag to compare

What's Changed

  • Canonicalize conv padding by @jheek in #2009
  • Update ScopeParamNotFoundError message. by @melissatan in #2013
  • Set field on dataclass transform decorator by @NeilGirdhar in #1927
  • Don't recommend mixing setup and compact in docs. by @levskaya in #2018
  • Clarifies optim.Adam(weight_decay) parameter. by @copybara-service in #2016
  • Update linear regression example in Jax intro and Flax intro. by @melissatan in #2015
  • Lifted cond by @jheek in #2020
  • Use tree_map instead of deprecated tree_multimap by @jheek in #2024
  • Remove tree_multimap from docs, examples, and tests by @jheek in #2026
  • Fix bug where the linen Module state is reused. by @jheek in #2025
  • Add getattribute with lazy setup trigger. by @levskaya in #2028
  • Better error messages for loading checkpoints. by @copybara-service in #2035
  • Add filterwarning for jax.tree_multimap by @marcvanzee in #2038
  • Adds Flax logo to README by @marcvanzee in #2036
  • Module lifecycle note by @jheek in #1964
  • Fix linter errors in core/scope.py and core/tracers.py. by @copybara-service in #2004
  • Handle edge-case of rate==1.0 in Dropout layer. by @levskaya in #2055
  • Bug fixes and generalizations of nn.partitioning api. by @copybara-service in #2062
  • Add support for JAX dynamic stack-based named_call. by @copybara-service in #2063
  • Updates pooling docstrings by @marcvanzee in #2064
  • Makes annotated_mnist use Optax's xent loss. by @andsteing in #2071

Full Changelog: v0.4.1...v0.4.2

Version 0.4.1

23 Mar 14:33
Compare
Choose a tag to compare

What's Changed

  • Added locally-connected (unshared CNN) layer flax.linen.ConvLocal.
  • Improved seq2seq example: Factored our model and input pipeline code.
  • Added Optax update guide and deprecated flax.optim.
  • Added sep argument to flax.traverse_util.flatten_dict().
  • Implemented Sequential module, in flax.linen.combinators.

Version 0.4.0

27 Jan 14:47
Compare
Choose a tag to compare

What's Changed

  • Add PReLU Activation by @isaaccorley in #1570
  • Fix GroupNorm type hint for param num_groups. by @lkhphuc in #1657
  • Add named_call overrides to docs by @jheek in #1649
  • mission statement by @jheek in #1668
  • Improves Flax Modules for RTD by @marcvanzee in #1416
  • Add clarifying docstring for 'size' argument to prefetch_to_device's by @avital in #1574
  • Add circular padding to flax.linen.Conv and flax.linen.ConvTranspose by @sgrigory in #1661
  • Fix child scope rng reuse. by @jheek in #1692
  • Numerically stable weight norm by @jheek in #1693
  • Remove cyclic refs from scope by @jheek in #1696
  • Add unroll to jax_utils.scan_in_dim by @ptigwe in #1691
  • Removes rng arguments from Dropout's __call__. by @copybara-service in #1689
  • Add error for empty scopes. by @jheek in #1698
  • correct axis resolution in case of repeated axis in the logica axis r… by @ultrons in #1703
  • Fix lost mutation bug in transforms on nested scopes. by @levskaya in #1716
  • Expose put_variable function to Module. by @levskaya in #1710
  • add eq and hash for scopes by @jheek in #1720
  • Fixes a bug in DenseGeneral. by @copybara-service in #1722
  • Add param_dtype argument to linen Modules by @jheek in #1739
  • Implement custom vjp by @jheek in #1738
  • Handle setup with transformed methods taking submodules of self. by @levskaya in #1745
  • validate RNG key shape against jax's default by @copybara-service in #1780
  • Adds optax update guide. by @andsteing in #1774
  • Implement LazyRNG by @jheek in #1723
  • make params_with_axes() work when params_axes is not mutable by @copybara-service in #1811
  • Updates the ensembling HOWTO to Optax. by @andsteing in #1806
  • Adds prominent scenic link to examples/README.md by @copybara-service in #1809
  • Removes PixelCNN++ example. @copybara-service in #1819
  • Add support for non-float32 normalization for linen normalization layers by @jheek in #1804
  • Make Filter a Collection instead of a Container by @NeilGirdhar in #1815
  • Removes deprecated API from RTD by @marcvanzee in #1824

New Contributors

Full Changelog: v0.3.6...v0.4.0

Version 0.3.6

27 Oct 21:00
136f41a
Compare
Choose a tag to compare

Breaking changes:

  • Move flax.nn to flax.deprecated.nn.

New features:

  • Add experimental checkpoint policy argument. See flax.linen.checkpoint
  • Add lifted versions of jvp and vjp.
  • Add lifted transformation for mapping variables. See flax.linen.map_variables.

Version 0.3.5

21 Sep 07:47
Compare
Choose a tag to compare

Breaking changes:

  • You can no longer pass an int as the kernel_size for a flax.linen.Conv. Instead a type error is raised stating that a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not ambiguous when the kernel rank is known.
  • flax.linen.enable_named_call and flax.linen.disable_named_call now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now flax.linen.override_named_call that provided a context manager to locally disable/enable named_call.
  • NamedTuples are no longer converted to tuples on assignment to a linen.Module.
    New features:
  • Flax internal stack frames are now removed from exception state traces.
  • Added flax.linen.nowrap to decorate method that should not be transformed because they are stateful.
  • Flax no longer uses implicit rank broadcasting. Thus, you can now use Flax with --jax_numpy_rank_promotion=raise.

Bugfixes:

  • linen Modules and dataclasses made with flax.struct.dataclass or flax.struct.PyTreeNode are now correctly recognized as dataclasses by static analysis tools like PyLance. Autocomplete of constructors has been verified to work with VSCode.
  • Fixed a bug in FrozenDict which didn't allow copying dicts with reserved names.
  • Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated (bug).
  • Mixed precision training with float16 now works correctly with the attention layers.
  • auto-generated linen Module hash, eq, repr no longer fail by default on non-init attributes.

Version 0.3.4

18 May 11:27
eba1dba
Compare
Choose a tag to compare

Possibly breaking changes:

  • When calling init the 'intermediates' collection is no longer mutable.
    Therefore, intermediates will no longer be returned from initialization by default.
  • Don't update batch statistics during initialization.
  • When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the deterministic argument in MultiHeadDotProductAttention.

Other changes:

  • Rewrote various examples to use Optax instead of Flax optimizers (e.g., Imagenet, SST2).
  • Added an NLP text classification example (on the SST-2 dataset) to
    examples/sst2.
    that uses a bidirectional LSTM (BiLSTM) to encode the input text.
  • Added flax.training.train_state to simplify using Optax optimizers.
  • mutable argument is now available on Module.init and Module.init_with_outputs
  • Bug fix: Correctly handle non-default parameters of Linen Modules with nested inheritance.
  • Expose dot_product_attention_weights, allowing access to attention weights.
  • BatchNorm instances will behave correctly during init when called multiple times.
  • Added a more extensive "how to contribute" guide in contributing.md.
  • Add proper cache behavior for lift.jit,
    fixing cache misses.
  • Fix bug in Embed layer: make sure it behaves correctly when embedding is np.array.
  • Fix linen.Module for deep inheritance chains.
  • Fix bug in DenseGeneral: correctly expand bias to account for batch & noncontracting dimensions.
  • Allow Flax lifted transforms to work on partially applied Modules.
  • Make MultiOptimizer use apply_gradient instead of apply_param_gradient.

version 0.3.3

31 Mar 14:16
174cf70
Compare
Choose a tag to compare

Possible breaking changes:

  • Bug Fix: Disallow modifying attributes in Modules after they are initialized.
  • Raise an error when saving a checkpoint which has a smaller step than the
    latest checkpoint already saved.
  • MultiOptimizer now rejects the case where multiple sub optimizers update the
    same parameter.

Other changes:

  • Added custom error classes to many Linen errors. See:
    https://flax.readthedocs.io/en/latest/flax.errors.html
  • Adds Module.bind for binding variables and RNGs to an interactive Module.
  • Adds nn.apply and nn.init for transforming arbitrary functions that take a linen.Module as their first argument.
  • Add option to overwrite existing checkpoints in save_checkpoint.
  • Remove JAX omnistaging check for forward compatibility.
  • Pathlib compatibility for checkpoint paths.
  • is_leaf argument in traverse_util.flatten_dict