Releases: google-deepmind/dm-haiku
Releases Β· google-deepmind/dm-haiku
Haiku 0.0.13
Haiku 0.0.12
Haiku 0.0.11
hk.layer_stack
now allows transparent application (no prefix on module names).hk.MultiHeadAttention
allows bias initializer to be configured or biases to be removed.hk.DepthwiseConvND
now supportsdilation
.hk.dropout
supportsbroadcast_dims
.hk.BatchApply
avoids an unnecessary h2d copy during tracing.hk.experimental.profiler_name_scopes
has been removed, these are on by default.- Added
hk.map
mirroringjax.lax.map
.
Haiku 0.0.10
- Added
hk.mixed_precision.push_policy
. - Added
hk.experimental.{get_params,get_initial_state,get_current_state}
. - Added
hk.experimental.{maybe_get_rng_sequence_state,maybe_replace_rng_sequence_state}
. hk.switch
now supports multiple operands.hk.get_parameter
now supportsinit=None
.hk.MethodContext
now includesorig_class
.hk.GetterContext
now includeslifted_prefix_name
.hk.layer_stack
now allows parameter reuse.- Haiku is now compatible with
jax.enable_custom_prng
. TruncatedNormal
now exports lower and upper bounds.- Haiku init/apply functions now return
dict
rather thanMapping
. hk.dropout
now supportsbroadcast_dims
.
Haiku 0.0.9
What's Changed
- Support vmap where in_axes is a list rather than a tuple in 307cf7d
- Pass pmap axis specs optionally to make_model_info in d0ba451
- Remove use of jax_experimental_name_stack flag in dbc0b1f
- Add param_axis argument to RMSNorm to allow setting scale param shape in a4998a0
- Add documentation and error messages for w_init and w_init_scale to avoid confusion in #541
- Fix hk.while_loop carrying state when reserving variable sizes of rng keys. by @copybara-service in #551
- Add ensemble example to hk.lift documentation. by @copybara-service in #556
Full Changelog: v0.0.8...v0.0.9
Haiku 0.0.8
- Added
experimental.force_name
. - Added ability to simulate a method name in
experimental.name_scope
. - Added a config option for PRNG key block size.
- Added
unroll
parameter todynamic_unroll
. - Remove use of deprecated
jax.tree_*
functions. - Many improvements to our examples.
- Improve error messages in
vmap
. - Support
jax_experimental_name_stack
in jaxpr_info. transform_and_run
now supports a map on PRNG keys.remat
now uses the new JAX remat implementation.- Scale parameter is now optional in
RMSNorm
.
Haiku 0.0.7
- Bug fix: modules with leading zeros (e.g.
linear_007
) are now correctly handled. 7632aff - Breaking change:
hk.vmap(..)
now requiressplit_rng
to be passed. - Breaking change:
hk.jit
was removed from the public API. - Feature: we always add profiler name scopes to Haiku modules with the latest version of JAX.
- Added a tutorial on parameter sharing.
- Added
hk.ModuleProtocol
andhk.SupportsCall
. - Added
cross_replica_axis
toVectorQuantiser
. - Added
allow_reuse
argument tohk.lift
. - Added
fan_in_axes
toVarianceScaling
initialiser. - Added
hk.custom_setter(..)
to intercepthk.set_state(..)
. - Added
hk.Deferred
. - Added
hk.experimental.transparent_lift(..)
andhk.experimental.transparent_lift_with_state(..)
. - Added
hk.experimental.fast_eval_shape(..)
. - Added
hk.experimental.current_name()
. - Added
hk.experimental.DO_NOT_STORE
. 2a6c034 - Added config APIs.
- Added support for new
jax.named_call
implementation. - The
HAIKU_FLATMAPPING
env var is no longer used. hk.dropout(..)
now supports dynamicrate
.hk.without_apply_rng(..)
now supports multi transformed functions.
Haiku 0.0.6
- Haiku now returns plain nested
dict
s rather thanFlatMap
from all APIs. hk.vmap(..)
now optionally takessplit_rng
, this argument will be required in the next version of Haiku.hk.LayerNorm
now acceptsparam_axis
in the constructor, this argument will be required in the next version of Haiku.hk.get_channel_index(..)
was added.hk.experimental.lift_with_state(..)
was added.hk.experimental.layer_stack(..)
was added.hk.DepthwiseConv{1,3}D
were added.hk.BatchNorm
now supports sequences incross_replica_axis
.hk.experimental.check_jax_usage()
makes Haiku check that JAX control flow/transforms are used correctly.hk.dynamic_unroll(..)
now supportsreturn_all_states
.hk.cond(..)
supports N operands.hk.experimental.module_auto_repr(False)
can be used to speed up init.hk.data_structures.merge(..)
now supportscheck_duplicates
.TruncatedNormal
initialiser now supports complex dtypes.transform(jit(f))
now provides a nice error message.hk.multinomial(..)
now usesjax.random.categorical
.- Added
hk.mixed_precision.{current,get}_policy(..)
for introspection. - Mixed precision policies now support reloaded modules.
Haiku 0.0.5
- Added support for mixed precision training (dba1fd9) via jmp
- Added
hk.with_empty_state(..)
. - Added
hk.multi_transform(..)
(#137), supporting transforming multiple functions that share parameters. - Added
hk.data_structures.is_subset(..)
to test whether parameters are a subset of another. - Minimum Python version is now 3.7.
- Multiple changes in preparation for a future version of Haiku changing to plain
dict
s. hk.next_rng_keys(..)
now returns a stacked array rather than a collection.hk.MultiHeadAttention
now supports distinct sequence lengths in query and key/value.hk.LayerNorm
now optionally supports faster (but less stable) variance computation.hk.nets.MLP
now has an output_shape property.hk.nets.ResNet
now supports changing strides.UnexpectedTracerError
inside a Haiku transform now has a more useful error message.hk.{lift,custom_creator,custom_getter}
are no longer experimental.- Haiku now supports JAX's pluggable RNGs.
- We have made multiple improvements to our docs an error messages.
Any many other small fixes and improvements.
Haiku 0.0.4
Changelog:
- (Important Fix) Fixed strides in basic block (300e6a4).
- Added map, partition_n and traverse to data_structures.
- Added "build your own Haiku" to the docs.
- Added summarise utility to Haiku.
- Added visualisation section to docs.
- Added precision arg to Linear, Conv and ConvTranspose.
- Added RMSNorm.
- Added module_name and name to GetterContext.
- Added hk.eval_shape.
- Improved performance of non cross-replica BN variance.
- Haiku branch functions are only traced once (mirroring JAX).
- Attention logits are rescaled before the softmax now.
- ModuleMetaclass now inherits from Protocol.
- Removed "dot access" to FlatMapping.
- Removed query_size from MultiHeadAttention constructor.
Any many other small fixes and improvements.