From 83d118ad369a470527e8ec6cd3b988fba5d4fd3e Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Wed, 13 Mar 2024 17:45:57 -0700 Subject: [PATCH] Fix tests after applying JAX key-reuse checker. See: - https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html#experimental-key-reuse-checking - https://jax.readthedocs.io/en/latest/_autosummary/jax.random.clone.html#jax.random.clone PiperOrigin-RevId: 615600399 --- flax/core/lift.py | 50 ++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/flax/core/lift.py b/flax/core/lift.py index fdd14fa6da..3721b938df 100644 --- a/flax/core/lift.py +++ b/flax/core/lift.py @@ -810,30 +810,35 @@ def find_axis_size(axis, x): raise ValueError('axis_size should be specified manually.') else: d_axis_size = axis_size - split_fn = lambda rng: random.split(rng, d_axis_size) + # random.clone is only available on Jax versions 0.4.26 or newer + # see: https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html + if hasattr(random, 'clone'): + split_fn = lambda rng: random.split(random.clone(rng), d_axis_size) + else: + split_fn = lambda rng: random.split(rng, d_axis_size) rng_groups = tuple( - tree_map_rngs(split_fn, rng_group) if split else rng_group - for rng_group, split in zip(rng_groups, rng_splits) + tree_map_rngs(split_fn, rng_group) if split else rng_group + for rng_group, split in zip(rng_groups, rng_splits) ) new_variable_groups = [] for var_group, axis in zip(variable_groups, variable_in_axes): if axis is not None: new_variable_groups.append( - meta.remove_axis(var_group, axis, metadata_params) + meta.remove_axis(var_group, axis, metadata_params) ) else: new_variable_groups.append(var_group) variable_groups = tuple(new_variable_groups) @functools.partial( - jax.vmap, - in_axes=(variable_in_axes, rng_axes, in_axes), - out_axes=(out_axes, variable_out_axes), - axis_name=axis_name, - axis_size=axis_size, - spmd_axis_name=spmd_axis_name, + jax.vmap, + in_axes=(variable_in_axes, rng_axes, in_axes), + out_axes=(out_axes, variable_out_axes), + axis_name=axis_name, + axis_size=axis_size, + spmd_axis_name=spmd_axis_name, ) @functools.wraps(fn) def mapped(variable_groups, rng_groups, args): @@ -969,27 +974,32 @@ def find_length(axis, x): raise ValueError('length should be specified manually.') else: d_length = length - split_fn = lambda rng: random.split(rng, d_length) + # random.clone is only available on Jax versions 0.4.26 or newer + # see: https://jax.readthedocs.io/en/latest/jax.experimental.key_reuse.html + if hasattr(random, 'clone'): + split_fn = lambda rng: random.split(random.clone(rng), d_length) + else: + split_fn = lambda rng: random.split(rng, d_length) rng_groups = tuple( - tree_map_rngs(split_fn, rng_group) if split else rng_group - for rng_group, split in zip(rng_groups, rng_splits) + tree_map_rngs(split_fn, rng_group) if split else rng_group + for rng_group, split in zip(rng_groups, rng_splits) ) @functools.partial( - axes_scan.scan, - in_axes=(variable_in_axes, rng_axes, in_axes), - out_axes=(out_axes, variable_out_axes), - length=length, - reverse=reverse, - unroll=unroll, + axes_scan.scan, + in_axes=(variable_in_axes, rng_axes, in_axes), + out_axes=(out_axes, variable_out_axes), + length=length, + reverse=reverse, + unroll=unroll, ) def scanned(broadcast_vars, carry, scan_variable_groups, rng_groups, args): carry_vars, c = carry variable_groups = (broadcast_vars, carry_vars) + scan_variable_groups if data_transform is not None: variable_groups, rng_groups = data_transform( - variable_groups, rng_groups + variable_groups, rng_groups ) scope = scope_fn(variable_groups, rng_groups) c, y = fn(scope, c, *args)