From 4071569038502c0ff85e3b37cca343f78d482415 Mon Sep 17 00:00:00 2001 From: Haiku Contributor Date: Tue, 16 Jan 2024 15:07:39 -0800 Subject: [PATCH] enable state with layer stack PiperOrigin-RevId: 598970379 --- haiku/_src/layer_stack.py | 59 ++++++++++++++++------------------ haiku/_src/layer_stack_test.py | 15 ++++++--- 2 files changed, 38 insertions(+), 36 deletions(-) diff --git a/haiku/_src/layer_stack.py b/haiku/_src/layer_stack.py index 43c34b8ec..426f11b4d 100644 --- a/haiku/_src/layer_stack.py +++ b/haiku/_src/layer_stack.py @@ -31,8 +31,8 @@ class LayerStackStateError(Exception): """Raise if trying to use layer_stack with Haiku state.""" LayerStackCarry = collections.namedtuple("LayerStackCarry", ["x"]) -LayerStackScanned = collections.namedtuple("LayerStackScanned", - ["params", "rng", "args_ys"]) +LayerStackScanned = collections.namedtuple( + "LayerStackScanned", ["params", "rng", "state", "args_ys"]) # WrappedFn should take in arbitrarily nested `jax.Array`, and return the # exact same type. We cannot express this with `typing`. So we just use it @@ -162,7 +162,7 @@ def __init__( def __call__(self, x, *args_ys, reverse=False): count = self._count - init_fn, apply_fn = transform.transform(self._call_wrapped) + init_fn, apply_fn = transform.transform_with_state(self._call_wrapped) def per_layer_init_fn(c, a): c, rng = c @@ -170,34 +170,29 @@ def per_layer_init_fn(c, a): rng, next_rng, apply_rng = jax.random.split(rng, 3) else: rng, next_rng, apply_rng = None, None, None - params = init_fn(rng, c, *a) - c, _ = apply_fn(params, apply_rng, c, *a) - return (c, next_rng), params + params, state = init_fn(rng, c, *a) + (c, _), state = apply_fn(params, state, apply_rng, c, *a) + return (c, next_rng), (params, state) def scanned_init_fn(x, rng): - _, params = jax.lax.scan(per_layer_init_fn, (x, rng), args_ys, - length=self._count) + _, (params, state) = jax.lax.scan(per_layer_init_fn, (x, rng), args_ys, + length=self._count) if self._transparency_map is not None: - return _split_params(params, self._count, self._transparency_map) - else: - return params + return (_split_params(params, self._count, self._transparency_map), + _split_params(state, self._count, self._transparency_map)) + return params, state rng = base.maybe_next_rng_key() - try: - if self._transparency_map is not None: - lifted_init_fn = lift.transparent_lift( - scanned_init_fn, allow_reuse=True - ) - else: - lifted_init_fn = lift.lift( - scanned_init_fn, allow_reuse=True, name=self._name - ) - params = lifted_init_fn(x, rng) - except base.NonEmptyStateError as e: - raise LayerStackStateError("LayerStack can only be used on Haiku " - "functions which do not make use of Haiku " - "state.") from e + if self._transparency_map is not None: + params_and_state_fn, updater = lift.transparent_lift_with_state( + scanned_init_fn, allow_reuse=True + ) + else: + params_and_state_fn, updater = lift.lift_with_state( + scanned_init_fn, allow_reuse=True, name=self._name + ) + params, state = params_and_state_fn(x, rng) # Use scan during apply, threading through random seed so that it's # unique for each layer. @@ -206,26 +201,31 @@ def layer( ) -> tuple[LayerStackCarry, Any]: rng = scanned.rng params = scanned.params + state = scanned.state kwargs = {} if self._pass_reverse_to_layer_fn: kwargs["reverse"] = reverse - out_x, z = apply_fn(params, rng, carry.x, *scanned.args_ys, **kwargs) - return LayerStackCarry(x=out_x), z + (out_x, z), state = apply_fn( + params, state, rng, carry.x, *scanned.args_ys, **kwargs) + return LayerStackCarry(x=out_x), (z, state) rng = _get_rng_stack(count) if self._transparency_map is not None: params = _stack_params(params, self._count, self._transparency_map) + state = _stack_params(state, self._count, self._transparency_map) carry = LayerStackCarry(x=x) scanned = LayerStackScanned(params=params, + state=state, rng=rng, args_ys=args_ys) - carry, zs = jax.lax.scan( + carry, (zs, states) = jax.lax.scan( layer, carry, scanned, length=count, unroll=self._unroll, reverse=reverse) + updater.update(states) return carry.x, zs def _call_wrapped( @@ -313,9 +313,6 @@ def layer_stack( that kwargs are not supported, neither are functions with variable number of parameters (specified by ``*args``). - Note that `layer_stack` cannot at the moment be used with functions that build - Haiku modules with state. - If ``with_per_layer_inputs=False`` then the new, wrapped function can be understood as performing the following: diff --git a/haiku/_src/layer_stack_test.py b/haiku/_src/layer_stack_test.py index 1fb9fc3b5..511ad5c06 100644 --- a/haiku/_src/layer_stack_test.py +++ b/haiku/_src/layer_stack_test.py @@ -172,18 +172,23 @@ def stack_fn(x): ValueError, "The function `f` should not have any `varargs`"): build_and_init_stack(VarArgsModule) - def test_layer_stack_no_state_error(self): + def test_layer_stack_with_state(self): def outer_fn_layer_stack(x): - stack = layer_stack.layer_stack(1)(lambda x: base.set_state("hi", x)) + def simple_stateful_layer(x): + base.set_state("hi", x) + return x + stack = layer_stack.layer_stack( + 1, name="with_state")(simple_stateful_layer) return stack(x) layer_stack_fn = transform.transform_with_state(outer_fn_layer_stack) x = jnp.ones((1,)) - with self.assertRaisesRegex(layer_stack.LayerStackStateError, - "LayerStack.*state"): - layer_stack_fn.init(None, x) + params, state = layer_stack_fn.init(None, x) + _, state = layer_stack_fn.apply(params, state, None, x) + + np.testing.assert_allclose(state["with_state/~"]["hi"], np.array([[1.0]])) @parameterized.parameters([1, 2, 4]) def test_layer_stack_grads(self, unroll):