Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

modify layer_stack transparency map #707

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 11 additions & 22 deletions haiku/_src/layer_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,16 @@ def _split_params(
name_map: LayerStackTransparencyMapping,
) -> base.Params:
"""Splits the stacked parameters."""

def _split(x):
return [jnp.squeeze(s, axis=0) for s in jnp.split(x, x.shape[0], axis=0)]

params = {}
for mod_name, mod_params in stacked_params.items():
split_mod_params = {k: _split(v) for k, v in mod_params.items()}
for i in range(num_layers):
new_mod_name = name_map.stacked_to_flat(mod_name, i)
if new_mod_name in params:
raise ValueError(
f"Found conflicting unstacked module name for {mod_name} at"
f" {new_mod_name}."
)
params[new_mod_name] = {k: v[i] for k, v in split_mod_params.items()}

params[new_mod_name] = jax.tree_map(lambda x: x[i], mod_params) # pylint:disable=cell-var-from-loop
return params


Expand All @@ -114,32 +108,27 @@ def _stack_params(
name_map: LayerStackTransparencyMapping,
) -> base.Params:
"""Stacks the split parameters."""
params = {}
make_empty_param_stack = lambda: ([None] * num_layers)

# Construct a separate tree for each loop iteration, which we will then
# multimap over in a call to jnp.stack. This formulation preserves custom
# pytree node types.
param_trees = [{} for _ in range(num_layers)]
for mod_name, mod_params in split_params.items():
stacked_name_idx = name_map.flat_to_stacked(mod_name)
# If the transparency map returns None, this param is not part of the stack.
if stacked_name_idx is None:
continue
stacked_mod_name, idx = stacked_name_idx
if stacked_mod_name not in params:
params[stacked_mod_name] = collections.defaultdict(make_empty_param_stack)

if stacked_mod_name not in param_trees[idx]:
param_trees[idx][stacked_mod_name] = {}
for k, v in mod_params.items():
if params[stacked_mod_name][k][idx] is not None:
if k in param_trees[idx][stacked_mod_name]:
raise ValueError(
f"Found conflicting values for param {stacked_mod_name}/{k} at"
f" index {idx}."
)
params[stacked_mod_name][k][idx] = v

for mod_name, mod_params in params.items():
for k, v in mod_params.items():
if None in v:
raise ValueError(f"Couldn't find all params for {mod_name}/{k}: {v}")
mod_params[k] = jnp.stack(v, axis=0)
param_trees[idx][stacked_mod_name][k] = v

return params
return jax.tree_map(lambda *args: jnp.stack(args, axis=0), *param_trees)


class _LayerStack:
Expand Down
Loading