-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Conversation
LGTM!!! I tried swapping in my custom encoder and works perfect 💯 🚀 Great to have this!! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for some reason i'm not 100% grasping the interplay between Manifests and ComponentSpecs, but here's my current interpretation:
- A
Manifest
defines the components of aTComponent
. So like the encoder manifest defines the layers. - A
ComponentSpec
defines the module class itself, and the respective manifest for the class.
I like how it's quite customizable if you know what you're doing. It still seems to suffer from a the downsides of a hierarchical setup (you need to traverse quite a few things to find out the different manifests, etc.), but it might be a better tradeoff than what we have already
Forward pass. | ||
""" | ||
residual = tensor | ||
if self.variant == 'prelayernorm': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one could also view the different variants as custom variations, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For sure! I didn't do anything like that as part of this PR to keep it simple, but I could imagine a pretty simple way to implement that without this if
statement:
class PreLayerNormTransformerEncoderLayer(TransformerEncoderLayer):
def forward(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
residual = tensor
tensor = self.norm1(tensor)
attended_tensor = self.attention(tensor, mask=mask)[0]
tensor = residual + self.dropout(attended_tensor)
residual = tensor
tensor = self.norm2(tensor)
tensor = residual + self.dropout(self.ffn(tensor))
tensor *= mask.unsqueeze(-1).type_as(tensor)
return tensor
manifest.encoder_layer = ComponentSpec(
PreLayerNormTransformerEncoderLayer,
TransformerEncoderLayer.Manifest(),
)
To me the code feels simple enough as is for now, but if we end up with a ton of variations then maybe we'll want to split them out into different implementations like this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes agreed, not in the scope of this current PR but could be an interesting application from it
Thank you both so much for your comments!!
Your summary is accurate. A
This PR is just meant to solve the problem of "how do I swap out a piece of the Transformer without copying and pasting the whole dang file??". I think getting rid of the hierarchy traversal altogether would require a radical refactor such that none of our Maybe another option to get around the traversal issue without a radical refactor could be making the |
Agreed, let's deal with that separately. I do think this is more flexible than before
Verbosity, if it improves clarity, may not be the worst thing, but if it requires more work from the user it may not be super desirable |
Take a look at |
Yes, I like that a lot actually; obviously it may get a bit more complicated with more complicated models but that might be a good thing! Since you have everything laid out in one specification |
@@ -141,7 +218,7 @@ def _default(val, default): | |||
self.layers = nn.ModuleList() | |||
for _ in range(self.n_layers): | |||
self.layers.append( | |||
TransformerEncoderLayer( | |||
template.layer.build( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is really hard to read.. can't it be called encoderlayer_class or something (assuming that's what it is)? template is so generic...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's pretty much what it was called in the previous commit lol. The more I stare at and play with the code the less sure I am about what's intuitive and what's confusing.
def __init__( | ||
self, klass: Type[MC], template: Optional[ModularComponent.Template] = None | ||
) -> None: | ||
self._klass = klass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is 'klass' (sorry)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's type-annotated in __init__
for reference, but it's a class. Spelled with a k
so it doesn't clash with the python reserved word class
. So TransformerEncoderLayer
, for example.
self.pad_idx = dictionary[dictionary.null_token] | ||
self.start_idx = dictionary[dictionary.start_token] | ||
self.end_idx = dictionary[dictionary.end_token] | ||
super().__init__(self.pad_idx, self.start_idx, self.end_idx) | ||
template = template or self.Template() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
|
||
def build_model(self, states=None): | ||
wrapped_class = TransformerGeneratorModel.with_components( | ||
encoder=TransformerEncoder.with_components( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes this is the one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed, this looks nice
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs a tutorial in website docs showing essentially the same as examples, but with more prose.
(I like this final form a lot) |
@@ -0,0 +1,60 @@ | |||
# Swapping Out Model Subcomponents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
say Transformer explicitly, since this only works for Transformers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It technically can be added to any model (or any class at all, really). But I'll take your advice to make it easier to think about for now.
Problem
Right now it's cumbersome to override some specific part of a transformer. If you want to significantly modify the encoder self attention, for example, you either need to add a bunch of conditional branching in
MultiHeadAttention
or subclassMultiHeadAttention
andTransformerGeneratorModel
andTransformerEncoder
andTransformerEncoderLayer
.Purpose of this PR
Allow swapping out of any
nn.Module
inTransformerGeneratorModel
while maintaining backwards compatibility with existing call sites.NOT the Purpose of this PR
Reducing subclassing to improve readability, nor making architectures infinitely composable.
Overview
The main files to look at are:
parlai/agents/examples/transformer_variant.py
: demonstrates swapping out various components in TransformerGeneratorModelparlai/agents/transformer/modules/modular.py
: defines and explains the new classes introducedA brief summary of the pattern introduced:
nn.Module
s in aModularComponent
. Replace any class in the subcomponents with one that shares the same__init__
andforward
signatures and it should just work.TransformerEncoder.forward_embedding
).Subcomponents
(e.g. replacingMultiHeadAttention
inTransformerDecoderLayer
).ModularComponent
along with aSubcomponents
for building its subcomponents.Module
/ModularComponent
instead.Some downsides:
TransformerDecoderLayer
is easier to understand thanswappables.layer
.Testing
Added a print statement in
MyCustomEncoder.forward
and verified it showed in stdout:CircleCI