Releases: google/flax
Releases · google/flax
Version 0.5.1
What's Changed
- Adds flax import to summary.py by @marcvanzee in #2138
- Add options for fallback behavior. by @copybara-service in #2130
- Upgrade to modern python idioms using pyupgrade. by @levskaya in #2132
- Update download_dataset_metadata.sh by @mattiasmar in #1801
- Mark correct minimum jax version requirement by @PhilipVinc in #2136
- Edited contributing.md by @IvyZX in #2151
- Bump tensorflow from 2.8.0 to 2.8.1 in /examples/imagenet by @dependabot in #2143
- Bump tensorflow from 2.8.0 to 2.8.1 in /examples/wmt by @dependabot in #2142
- Add typehint to Module.scope by @cgarciae in #2106
- Correcting Mistakes In Flip Docs by @saiteja13427 in #2140
- Add CAUSAL padding for 1D convolution. by @copybara-service in #2141
- Calculate cumulative number or issues and prs by @cgarciae in #2154
- Improve setup instructions in contributing guide by @cgarciae in #2155
- Forward unroll argument in lifted scan by @jheek in #2158
- Improve tabulate by @cgarciae in #2162
- Remove unused variable from nlp_seq example by @marcvanzee in #2163
- Allow nn.cond, nn.while to act on bound methods. by @levskaya in #2172
- 0.5.1 by @cgarciae in #2180
- Update normalization.py by @yechengxi in #2182
New Contributors
- @mattiasmar made their first contribution in #1801
- @PhilipVinc made their first contribution in #2136
- @IvyZX made their first contribution in #2151
- @saiteja13427 made their first contribution in #2140
- @yechengxi made their first contribution in #2182
Full Changelog: v0.5.0...v0.5.1
Version 0.5.0
New features:
- Added
flax.jax_utils.ad_shard_unpad()
by @lucasb-eyer - Implemented default dtype FLIP.
This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32.
This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate
complex numbers to their real component by default. Instead the complex dtype is preserved by default.
Bug fixes:
- Fix support for JAX's experimental_name_stack.
Breaking changes:
- In rare cases the dtype of a layer can change due to default dtype FLIP. See the "Backward compatibility" section of the proposal for more information.
Version 0.4.3
Note
Due to a release error we had to roll out a new release, but this version is exactly the same as v0.4.2.
Version 0.4.2
What's Changed
- Canonicalize conv padding by @jheek in #2009
- Update ScopeParamNotFoundError message. by @melissatan in #2013
- Set field on dataclass transform decorator by @NeilGirdhar in #1927
- Don't recommend mixing setup and compact in docs. by @levskaya in #2018
- Clarifies
optim.Adam(weight_decay)
parameter. by @copybara-service in #2016 - Update linear regression example in Jax intro and Flax intro. by @melissatan in #2015
- Lifted cond by @jheek in #2020
- Use tree_map instead of deprecated tree_multimap by @jheek in #2024
- Remove tree_multimap from docs, examples, and tests by @jheek in #2026
- Fix bug where the linen Module state is reused. by @jheek in #2025
- Add getattribute with lazy setup trigger. by @levskaya in #2028
- Better error messages for loading checkpoints. by @copybara-service in #2035
- Add filterwarning for jax.tree_multimap by @marcvanzee in #2038
- Adds Flax logo to README by @marcvanzee in #2036
- Module lifecycle note by @jheek in #1964
- Fix linter errors in core/scope.py and core/tracers.py. by @copybara-service in #2004
- Handle edge-case of rate==1.0 in Dropout layer. by @levskaya in #2055
- Bug fixes and generalizations of nn.partitioning api. by @copybara-service in #2062
- Add support for JAX dynamic stack-based named_call. by @copybara-service in #2063
- Updates pooling docstrings by @marcvanzee in #2064
- Makes annotated_mnist use Optax's xent loss. by @andsteing in #2071
Full Changelog: v0.4.1...v0.4.2
Version 0.4.1
What's Changed
- Added locally-connected (unshared CNN) layer
flax.linen.ConvLocal
. - Improved seq2seq example: Factored our model and input pipeline code.
- Added Optax update guide and deprecated
flax.optim
. - Added
sep
argument toflax.traverse_util.flatten_dict()
. - Implemented Sequential module, in
flax.linen.combinators
.
Version 0.4.0
What's Changed
- Add PReLU Activation by @isaaccorley in #1570
- Fix GroupNorm type hint for param num_groups. by @lkhphuc in #1657
- Add named_call overrides to docs by @jheek in #1649
- mission statement by @jheek in #1668
- Improves Flax Modules for RTD by @marcvanzee in #1416
- Add clarifying docstring for 'size' argument to prefetch_to_device's by @avital in #1574
- Add circular padding to flax.linen.Conv and flax.linen.ConvTranspose by @sgrigory in #1661
- Fix child scope rng reuse. by @jheek in #1692
- Numerically stable weight norm by @jheek in #1693
- Remove cyclic refs from scope by @jheek in #1696
- Add
unroll
tojax_utils.scan_in_dim
by @ptigwe in #1691 - Removes
rng
arguments from Dropout's__call__
. by @copybara-service in #1689 - Add error for empty scopes. by @jheek in #1698
- correct axis resolution in case of repeated axis in the logica axis r… by @ultrons in #1703
- Fix lost mutation bug in transforms on nested scopes. by @levskaya in #1716
- Expose put_variable function to Module. by @levskaya in #1710
- add eq and hash for scopes by @jheek in #1720
- Fixes a bug in DenseGeneral. by @copybara-service in #1722
- Add param_dtype argument to linen Modules by @jheek in #1739
- Implement custom vjp by @jheek in #1738
- Handle setup with transformed methods taking submodules of self. by @levskaya in #1745
- validate RNG key shape against jax's default by @copybara-service in #1780
- Adds optax update guide. by @andsteing in #1774
- Implement LazyRNG by @jheek in #1723
- make params_with_axes() work when params_axes is not mutable by @copybara-service in #1811
- Updates the ensembling HOWTO to Optax. by @andsteing in #1806
- Adds prominent
scenic
link toexamples/README.md
by @copybara-service in #1809 - Removes PixelCNN++ example. @copybara-service in #1819
- Add support for non-float32 normalization for linen normalization layers by @jheek in #1804
- Make Filter a Collection instead of a Container by @NeilGirdhar in #1815
- Removes deprecated API from RTD by @marcvanzee in #1824
New Contributors
- @isaaccorley made their first contribution in #1570
- @lkhphuc made their first contribution in #1657
- @sgrigory made their first contribution in #1661
- @ptigwe made their first contribution in #1691
- @ultrons made their first contribution in #1703
- @dependabot made their first contribution in #1749
- @NeilGirdhar made their first contribution in #1699
- @saeta made their first contribution in #1784
- @melissatan made their first contribution in #1793
Full Changelog: v0.3.6...v0.4.0
Version 0.3.6
Breaking changes:
- Move
flax.nn
toflax.deprecated.nn
.
New features:
- Add experimental checkpoint policy argument. See
flax.linen.checkpoint
- Add lifted versions of jvp and vjp.
- Add lifted transformation for mapping variables. See
flax.linen.map_variables
.
Version 0.3.5
Breaking changes:
- You can no longer pass an int as the kernel_size for a
flax.linen.Conv
. Instead a type error is raised stating that a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not ambiguous when the kernel rank is known. - flax.linen.enable_named_call and flax.linen.disable_named_call now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now flax.linen.override_named_call that provided a context manager to locally disable/enable named_call.
- NamedTuples are no longer converted to tuples on assignment to a linen.Module.
New features: - Flax internal stack frames are now removed from exception state traces.
- Added flax.linen.nowrap to decorate method that should not be transformed because they are stateful.
- Flax no longer uses implicit rank broadcasting. Thus, you can now use Flax with
--jax_numpy_rank_promotion=raise
.
Bugfixes:
- linen Modules and dataclasses made with flax.struct.dataclass or flax.struct.PyTreeNode are now correctly recognized as dataclasses by static analysis tools like PyLance. Autocomplete of constructors has been verified to work with VSCode.
- Fixed a bug in FrozenDict which didn't allow copying dicts with reserved names.
- Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated (bug).
- Mixed precision training with float16 now works correctly with the attention layers.
- auto-generated linen Module hash, eq, repr no longer fail by default on non-init attributes.
Version 0.3.4
Possibly breaking changes:
- When calling
init
the 'intermediates' collection is no longer mutable.
Therefore, intermediates will no longer be returned from initialization by default. - Don't update batch statistics during initialization.
- When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the
deterministic
argument inMultiHeadDotProductAttention
.
Other changes:
- Rewrote various examples to use Optax instead of Flax optimizers (e.g., Imagenet, SST2).
- Added an NLP text classification example (on the SST-2 dataset) to
examples/sst2
.
that uses a bidirectional LSTM (BiLSTM) to encode the input text. - Added
flax.training.train_state
to simplify using Optax optimizers. mutable
argument is now available onModule.init
andModule.init_with_outputs
- Bug fix: Correctly handle non-default parameters of Linen Modules with nested inheritance.
- Expose
dot_product_attention_weights
, allowing access to attention weights. BatchNorm
instances will behave correctly during init when called multiple times.- Added a more extensive "how to contribute" guide in
contributing.md
. - Add proper cache behavior for
lift.jit
,
fixing cache misses. - Fix bug in Embed layer: make sure it behaves correctly when embedding is np.array.
- Fix
linen.Module
for deep inheritance chains. - Fix bug in DenseGeneral: correctly expand bias to account for batch & noncontracting dimensions.
- Allow Flax lifted transforms to work on partially applied Modules.
- Make
MultiOptimizer
useapply_gradient
instead ofapply_param_gradient
.
version 0.3.3
Possible breaking changes:
- Bug Fix: Disallow modifying attributes in Modules after they are initialized.
- Raise an error when saving a checkpoint which has a smaller step than the
latest checkpoint already saved. - MultiOptimizer now rejects the case where multiple sub optimizers update the
same parameter.
Other changes:
- Added custom error classes to many Linen errors. See:
https://flax.readthedocs.io/en/latest/flax.errors.html - Adds
Module.bind
for binding variables and RNGs to an interactive Module. - Adds
nn.apply
andnn.init
for transforming arbitrary functions that take alinen.Module
as their first argument. - Add option to overwrite existing checkpoints in
save_checkpoint
. - Remove JAX omnistaging check for forward compatibility.
- Pathlib compatibility for checkpoint paths.
is_leaf
argument intraverse_util.flatten_dict