v0.8.3
What's Changed
- Add git fetch upstream to contributing doc. by @carlosgmartin in #3757
- removed getattr/setattr unboxing magic from
nnx.Pytree
by @chiamp in #3743 - added Einsum layer to NNX by @chiamp in #3741
- Make
TrainState
'sstep
possibly jax.Array. This makesreplicate
valid for type checking. by @copybara-service in #3763 - v0.8.3 by @cgarciae in #3758
- [nnx] fix demo notebook by @cgarciae in #3744
- added nnx api reference by @chiamp in #3762
- updated rng docstring for init, apply and make_rng by @chiamp in #3765
- use note box in make_rng docstring by @cgarciae in #3767
- [nnx] improved graph update mechanism by @cgarciae in #3759
- use note box in docstrings by @chiamp in #3769
- Add reset_gate flag to MGUCell. by @carlosgmartin in #3760
- Access thread_resources via jax.interpreters.pxla instead of jax.experimental.maps by @copybara-service in #3775
- Minor doc improvements by @canyon289 in #3588
- added MGU
reset_gate
test by @chiamp in #3773 - [nnx] Pytrees are Trees by @cgarciae in #3768
- Use short-circuiting access to debug_key_reuse by @copybara-service in #3781
- fix tabulate on norm wrappers by @chiamp in #3772
- Add
kw_only
struct.dataclass test by @chiamp in #3651 - extended
PyTreeNode
to take dataclass kwargs by @chiamp in #3785 - [nnx] Arrays are state by @cgarciae in #3791
- [nnx] add GraphNode base class by @cgarciae in #3790
- [nnx] jit accepts many Modules by @cgarciae in #3783
- Exposing the experimental _split_transpose JAX scan parameter in Flax. by @copybara-service in #3795
- Expose
nnx.GraphNode
by @chiamp in #3796 - [nnx] Rngs and RngStream inherit from GraphNode by @cgarciae in #3793
- [nnx] TrainState uses struct by @cgarciae in #3788
- [nnx] split returns graphdef first by @cgarciae in #3794
- Remove the uninitialized field "embedding" in nn.Embed by @copybara-service in #3801
- Add
nnx.training
by @chiamp in #3782 - [nnx] non-str State keys by @cgarciae in #3802
- [nnx] allow all jit kwargs in nnx.jit by @cgarciae in #3809
- [nnx] simplify readme by @cgarciae in #3805
- [nnx] Fix nnx basics by @cgarciae in #3812
- [nnx] grad accepts argnums by @cgarciae in #3798
- [nnx] improve toy examples by @cgarciae in #3813
- [nnx] expose Sequential by @cgarciae in #3814
- [nnx] Rng Variable tags by @cgarciae in #3807
- [nnx] remove copy in graph unflatten by @cgarciae in #3804
- fixed optax guide links and docstring typos by @chiamp in #3789
- added dropout broadcast test by @chiamp in #3776
- relaxed
grads
kwarg forOptimizer.update
by @chiamp in #3818 - added
tree_map
deprecation warning filter by @chiamp in #3828 - updated
tree_map
by @chiamp in #3823 - added NNX vs JAX transformations guide by @chiamp in #3819
- Updated NNX MNIST tutorial by @chiamp in #3810
- [nnx] add Dropout.rngs by @cgarciae in #3815
- removed autosummary from linen docs by @chiamp in #3792
- Fix cloudpickle sentinel cloning by @cgarciae in #3825
- [nnx] remove pytreelib by @cgarciae in #3816
- [nnx] fix nnx_basics by @cgarciae in #3839
- [linen] fix DenseGeneral init by @cgarciae in #3834
- [nnx] jit constrain object state by @cgarciae in #3817
- Copybara import of the project: by @copybara-service in #3857
- Add example of unbox() and replace_boxed() to the jit guide by @IvyZX in #3843
- RNNCellBase refactor FLIP by @cgarciae in #3099
- [nnx] Some small documentation suggestions. by @gnecula in #3861
- updated nnx dropout by @chiamp in #3841
- Fix LogicalRules type annotation. (Tuple[str] is a tuple with single element string, by @copybara-service in #3877
- Add option to skip float32 promotion when computing means and variances for normalization. by @copybara-service in #3873
- added nnx api reference link by @chiamp in #3871
- option of forcing the input of softmax to be fp32 for better numerical stability in mixed-precision training. by @copybara-service in #3874
- allow custom dot_general for einsum. by @copybara-service in #3884
- [NVIDIA] Extend the custom fp8 accumulate dtype in non-jit scenarios by @kaixih in #3827
- updated
robots.txt
by @chiamp in #3886 - fixed autosummary links by @chiamp in #3887
- Fix jax.tree_util.register_dataclass in older JAX versions. by @copybara-service in #3885
- [nnx] v0.1 by @cgarciae in #3876
Full Changelog: v0.8.2...v0.8.3