Skip to content

Commit

Permalink
Allow for different parameters to be created when reuse=True.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 579170505
  • Loading branch information
tomhennigan authored and copybara-github committed Nov 3, 2023
1 parent f1a2f8c commit e920309
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
3 changes: 2 additions & 1 deletion haiku/_src/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def pack_into_dict(src: hk.Params,
value = dict(value)
if state:
value = {k: base.StatePair(v, v) for k, v in value.items()}
dst[new_key] = value
dst.setdefault(new_key, {})
dst[new_key].update(value)


def unpack_from_dict(src: hk.Params, prefix: str) -> MutableBundle:
Expand Down
14 changes: 12 additions & 2 deletions haiku/_src/lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def wrapped(*a, **k):
return wrapped


def with_transparent_lift(f):
def with_transparent_lift(f, **kwargs):
def wrapped(*a, **k):
init, apply = transform.transform(f)
params = lift.transparent_lift(init)(None, *a, **k)
params = lift.transparent_lift(init, **kwargs)(None, *a, **k)
return apply(params, None, *a, **k)
return wrapped

Expand Down Expand Up @@ -472,6 +472,16 @@ def fn(x):
ValueError, "close over a module.*transparent_lift"):
fn.init(None, jnp.ones((10, 10)))

def test_transparent_lift_reuse_and_define_new(self):
f = lambda: base.get_parameter("w1", [], init=jnp.zeros)
g = lambda: base.get_parameter("w2", [], init=jnp.ones)
f = with_transparent_lift(f, allow_reuse=True)
g = with_transparent_lift(g, allow_reuse=True)

h = transform.transform(lambda: [f(), g()])
params = h.init(None)
self.assertEqual(params, {"~": {"w1": 0, "w2": 1}})

def test_same_name_across_transforms_no_closed_error(self):
init1, _ = transform.transform(lambda x: Bias()(x)) # pylint: disable=unnecessary-lambda
init2, _ = transform.transform(lambda x: Bias()(x)) # pylint: disable=unnecessary-lambda
Expand Down

0 comments on commit e920309

Please sign in to comment.