Skip to content

Commit

Permalink
enable state with layer stack
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 598970379
  • Loading branch information
Haiku Contributor authored and copybara-github committed Jan 16, 2024
1 parent 6339353 commit d5500c2
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 36 deletions.
59 changes: 28 additions & 31 deletions haiku/_src/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ class LayerStackStateError(Exception):
"""Raise if trying to use layer_stack with Haiku state."""

LayerStackCarry = collections.namedtuple("LayerStackCarry", ["x"])
LayerStackScanned = collections.namedtuple("LayerStackScanned",
["params", "rng", "args_ys"])
LayerStackScanned = collections.namedtuple(
"LayerStackScanned", ["params", "rng", "state", "args_ys"])

# WrappedFn should take in arbitrarily nested `jax.Array`, and return the
# exact same type. We cannot express this with `typing`. So we just use it
Expand Down Expand Up @@ -162,42 +162,37 @@ def __init__(

def __call__(self, x, *args_ys, reverse=False):
count = self._count
init_fn, apply_fn = transform.transform(self._call_wrapped)
init_fn, apply_fn = transform.transform_with_state(self._call_wrapped)

def per_layer_init_fn(c, a):
c, rng = c
if rng is not None:
rng, next_rng, apply_rng = jax.random.split(rng, 3)
else:
rng, next_rng, apply_rng = None, None, None
params = init_fn(rng, c, *a)
c, _ = apply_fn(params, apply_rng, c, *a)
return (c, next_rng), params
params, state = init_fn(rng, c, *a)
(c, _), state = apply_fn(params, state, apply_rng, c, *a)
return (c, next_rng), (params, state)

def scanned_init_fn(x, rng):
_, params = jax.lax.scan(per_layer_init_fn, (x, rng), args_ys,
length=self._count)
_, (params, state) = jax.lax.scan(per_layer_init_fn, (x, rng), args_ys,
length=self._count)
if self._transparency_map is not None:
return _split_params(params, self._count, self._transparency_map)
else:
return params
return (_split_params(params, self._count, self._transparency_map),
_split_params(state, self._count, self._transparency_map))
return params, state

rng = base.maybe_next_rng_key()

try:
if self._transparency_map is not None:
lifted_init_fn = lift.transparent_lift(
scanned_init_fn, allow_reuse=True
)
else:
lifted_init_fn = lift.lift(
scanned_init_fn, allow_reuse=True, name=self._name
)
params = lifted_init_fn(x, rng)
except base.NonEmptyStateError as e:
raise LayerStackStateError("LayerStack can only be used on Haiku "
"functions which do not make use of Haiku "
"state.") from e
if self._transparency_map is not None:
params_and_state_fn, updater = lift.transparent_lift_with_state(
scanned_init_fn, allow_reuse=True
)
else:
params_and_state_fn, updater = lift.lift_with_state(
scanned_init_fn, allow_reuse=True, name=self._name
)
params, state = params_and_state_fn(x, rng)

# Use scan during apply, threading through random seed so that it's
# unique for each layer.
Expand All @@ -206,26 +201,31 @@ def layer(
) -> tuple[LayerStackCarry, Any]:
rng = scanned.rng
params = scanned.params
state = scanned.state

kwargs = {}
if self._pass_reverse_to_layer_fn:
kwargs["reverse"] = reverse
out_x, z = apply_fn(params, rng, carry.x, *scanned.args_ys, **kwargs)
return LayerStackCarry(x=out_x), z
(out_x, z), state = apply_fn(
params, state, rng, carry.x, *scanned.args_ys, **kwargs)
return LayerStackCarry(x=out_x), (z, state)

rng = _get_rng_stack(count)

if self._transparency_map is not None:
params = _stack_params(params, self._count, self._transparency_map)
state = _stack_params(state, self._count, self._transparency_map)

carry = LayerStackCarry(x=x)
scanned = LayerStackScanned(params=params,
state=state,
rng=rng,
args_ys=args_ys)

carry, zs = jax.lax.scan(
carry, (zs, states) = jax.lax.scan(
layer, carry, scanned, length=count, unroll=self._unroll,
reverse=reverse)
updater.update(states)
return carry.x, zs

def _call_wrapped(
Expand Down Expand Up @@ -313,9 +313,6 @@ def layer_stack(
that kwargs are not supported, neither are functions with variable number
of parameters (specified by ``*args``).
Note that `layer_stack` cannot at the moment be used with functions that build
Haiku modules with state.
If ``with_per_layer_inputs=False`` then the new, wrapped function can be
understood as performing the following:
Expand Down
15 changes: 10 additions & 5 deletions haiku/_src/layer_stack_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,18 +172,23 @@ def stack_fn(x):
ValueError, "The function `f` should not have any `varargs`"):
build_and_init_stack(VarArgsModule)

def test_layer_stack_no_state_error(self):
def test_layer_stack_with_state(self):
def outer_fn_layer_stack(x):
stack = layer_stack.layer_stack(1)(lambda x: base.set_state("hi", x))
def simple_stateful_layer(x):
base.set_state("hi", x)
return x
stack = layer_stack.layer_stack(
1, name="with_state")(simple_stateful_layer)
return stack(x)

layer_stack_fn = transform.transform_with_state(outer_fn_layer_stack)

x = jnp.ones((1,))

with self.assertRaisesRegex(layer_stack.LayerStackStateError,
"LayerStack.*state"):
layer_stack_fn.init(None, x)
params, state = layer_stack_fn.init(None, x)
_, state = layer_stack_fn.apply(params, state, None, x)

np.testing.assert_allclose(state["with_state/~"]["hi"], np.array([[1.0]]))

@parameterized.parameters([1, 2, 4])
def test_layer_stack_grads(self, unroll):
Expand Down

0 comments on commit d5500c2

Please sign in to comment.