-
Notifications
You must be signed in to change notification settings - Fork 648
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[linen] add share_scope #4102
[linen] add share_scope #4102
Conversation
flax/linen/combinators.py
Outdated
class Transparent(Module, Generic[M]): | ||
"""A Module that shares its scope with an inner Module. This combinator is useful | ||
when you want to wrap a Module and extend itsfunctionality without changing the | ||
parameter structure. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also highlight the difference between
class DenseLora(Dense):
...
vs
class DenseLora(Transparent[Dense]):
...
(The important distinction is that Transparent replaces inheritance with composition, that allows to swap out implementation of inner module with Dense implementation (as long as it has compatible interface), without having to create separate DenseLora for each implementation.
@marksandler2 wondering if we should just expose class DenseLoRA(nn.Module):
inner: nn.Dense
rank: int
def setup(self):
nn.share_scope(self, self.inner)
@nn.compact
def __call__(self, x: jax.Array):
din, dout = x.shape[-1], self.inner.features
A = self.param('A', nn.zeros_init(), (din, self.rank))
B = self.param('B', nn.zeros_init(), (self.rank, dout))
return self.inner(x) + x @ A @ B
class Model(nn.Module):
@nn.compact
def __call__(self, x: jax.Array):
return DenseLoRA(nn.Dense(10), rank=2)(x) |
Hmm, i like this new proposed function a lot actually. It certainly feels a lot less heavy weight (e.g. we don't need to muck around with Transparent[...] inheritance, and the logic could be made a lot more flexible. |
62f2f4a
to
a52ffdc
Compare
a52ffdc
to
36952ea
Compare
Can we merge this? |
What does this PR do?
Adds
nn.share_scope
can be used to share a scope between two Modules. This means that any parameters created by any of the Modules will be added to their common scope. This is useful when you want to wrap a Module and extend its functionality without changing the parameter structure.