Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Swappable Transformer Components #3567

Merged
merged 23 commits into from
Apr 29, 2021
Merged

Swappable Transformer Components #3567

merged 23 commits into from
Apr 29, 2021

Conversation

spencerp
Copy link
Contributor

@spencerp spencerp commented Apr 1, 2021

Problem

Right now it's cumbersome to override some specific part of a transformer. If you want to significantly modify the encoder self attention, for example, you either need to add a bunch of conditional branching in MultiHeadAttention or subclass MultiHeadAttention and TransformerGeneratorModel and TransformerEncoder and TransformerEncoderLayer.

Purpose of this PR

Allow swapping out of any nn.Module in TransformerGeneratorModel while maintaining backwards compatibility with existing call sites.

NOT the Purpose of this PR

Reducing subclassing to improve readability, nor making architectures infinitely composable.

Overview

The main files to look at are:

  • parlai/agents/examples/transformer_variant.py: demonstrates swapping out various components in TransformerGeneratorModel
  • parlai/agents/transformer/modules/modular.py: defines and explains the new classes introduced

A brief summary of the pattern introduced:

  • ModularComponent: Any component with subcomponents that can be swapped out.
  • Subcomponents: A catalog of all swappable nn.Modules in a ModularComponent. Replace any class in the subcomponents with one that shares the same __init__ and forward signatures and it should just work.
  • To modify some business logic in a model: subclass (e.g. modifying TransformerEncoder.forward_embedding).
  • To replace an entire component: use the Subcomponents (e.g. replacing MultiHeadAttention in TransformerDecoderLayer).
  • ModularComponentSpec: A convenience object for specifying a ModularComponent along with a Subcomponents for building its subcomponents.
  • Together these can be thought of as a graph/tree that fully defines an architecture. If you find yourself adding new components to the template, you might want to create a new Module/ModularComponent instead.

Some downsides:

  • It does not explicitly allow for fancier, dynamic architecture changes like using a different layer class for for each layer. But this could still be accomplished by passing in a custom layer class that behaviors differently depending on the layer. If we find ourselves wanting super dynamic composable architectures often we can revisit this.
  • The code that initializes each component gets a little harder to read. TransformerDecoderLayer is easier to understand than swappables.layer.

Testing

Added a print statement in MyCustomEncoder.forward and verified it showed in stdout:

λ parlai train_model --model examples/transformer_variant --task convai2 --model-file /tmp/testtransformer --beam-size 5 --batchsize 16
...
10:42:49 | training...                                                                                                         
12:54:54 | Custom encoder called!                                                                                              
12:54:55 | Custom attention called!                                                                                                                                                                                                                           
12:54:55 | Custom attention called!                                                                                            
12:54:55 | Custom encoder called!
...

CircleCI

@spencerp spencerp changed the title Transformer Manifest MVP Swappable Transformer Components Apr 1, 2021
@jxmsML
Copy link
Contributor

jxmsML commented Apr 6, 2021

LGTM!!! I tried swapping in my custom encoder and works perfect 💯 🚀 Great to have this!!

@jxmsML jxmsML self-requested a review April 6, 2021 21:57
Copy link
Contributor

@klshuster klshuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some reason i'm not 100% grasping the interplay between Manifests and ComponentSpecs, but here's my current interpretation:

  • A Manifest defines the components of a TComponent. So like the encoder manifest defines the layers.
  • A ComponentSpec defines the module class itself, and the respective manifest for the class.

I like how it's quite customizable if you know what you're doing. It still seems to suffer from a the downsides of a hierarchical setup (you need to traverse quite a few things to find out the different manifests, etc.), but it might be a better tradeoff than what we have already

parlai/agents/examples/transformer_variant.py Outdated Show resolved Hide resolved
parlai/agents/transformer/modules/decoder.py Outdated Show resolved Hide resolved
parlai/agents/transformer/modules/decoder.py Outdated Show resolved Hide resolved
parlai/agents/transformer/modules/encoder.py Outdated Show resolved Hide resolved
Forward pass.
"""
residual = tensor
if self.variant == 'prelayernorm':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one could also view the different variants as custom variations, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure! I didn't do anything like that as part of this PR to keep it simple, but I could imagine a pretty simple way to implement that without this if statement:

class PreLayerNormTransformerEncoderLayer(TransformerEncoderLayer):
  def forward(self, tensor: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    residual = tensor
    tensor = self.norm1(tensor)
    attended_tensor = self.attention(tensor, mask=mask)[0]
    tensor = residual + self.dropout(attended_tensor)
    residual = tensor
    tensor = self.norm2(tensor)
    tensor = residual + self.dropout(self.ffn(tensor))
    tensor *= mask.unsqueeze(-1).type_as(tensor)
    return tensor

manifest.encoder_layer = ComponentSpec(
  PreLayerNormTransformerEncoderLayer,
  TransformerEncoderLayer.Manifest(),
)

To me the code feels simple enough as is for now, but if we end up with a ton of variations then maybe we'll want to split them out into different implementations like this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes agreed, not in the scope of this current PR but could be an interesting application from it

@spencerp
Copy link
Contributor Author

spencerp commented Apr 7, 2021

Thank you both so much for your comments!!

for some reason i'm not 100% grasping the interplay between Manifests and ComponentSpecs, but here's my current interpretation:

Your summary is accurate. A TComponent is like a template, a Manifest is like a form to fill out with all of the slots in that template, and a ComponentSpec combines these two to fully define a component.

It still seems to suffer from a the downsides of a hierarchical setup (you need to traverse quite a few things to find out the different manifests, etc.), but it might be a better tradeoff than what we have already

This PR is just meant to solve the problem of "how do I swap out a piece of the Transformer without copying and pasting the whole dang file??". I think getting rid of the hierarchy traversal altogether would require a radical refactor such that none of our nn.Modules initialize new nn.Modules in their __init__.

Maybe another option to get around the traversal issue without a radical refactor could be making the Manifest immutable and final (not subclassable). That would enforce that anyone who customizes a component has to write out the whole manifest, not just their incremental change. That would be quite verbose, though.

@klshuster
Copy link
Contributor

This PR is just meant to solve the problem of "how do I swap out a piece of the Transformer without copying and pasting the whole dang file??". I think getting rid of the hierarchy traversal altogether would require a radical refactor such that none of our nn.Modules initialize new nn.Modules in their init.

Agreed, let's deal with that separately. I do think this is more flexible than before

Maybe another option to get around the traversal issue without a radical refactor could be making the Manifest immutable and final (not subclassable). That would enforce that anyone who customizes a component has to write out the whole manifest, not just their incremental change. That would be quite verbose, though.

Verbosity, if it improves clarity, may not be the worst thing, but if it requires more work from the user it may not be super desirable

@spencerp
Copy link
Contributor Author

spencerp commented Apr 9, 2021

Verbosity, if it improves clarity, may not be the worst thing, but if it requires more work from the user it may not be super desirable

Take a look at examples/transformer_variant.py now. I tried a more verbose way of specifying the components, and I think it makes it easier to get a wholistic view of the architecture of the model.

@klshuster
Copy link
Contributor

Verbosity, if it improves clarity, may not be the worst thing, but if it requires more work from the user it may not be super desirable

Take a look at examples/transformer_variant.py now. I tried a more verbose way of specifying the components, and I think it makes it easier to get a wholistic view of the architecture of the model.

Yes, I like that a lot actually; obviously it may get a bit more complicated with more complicated models but that might be a good thing! Since you have everything laid out in one specification

@@ -141,7 +218,7 @@ def _default(val, default):
self.layers = nn.ModuleList()
for _ in range(self.n_layers):
self.layers.append(
TransformerEncoderLayer(
template.layer.build(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is really hard to read.. can't it be called encoderlayer_class or something (assuming that's what it is)? template is so generic...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's pretty much what it was called in the previous commit lol. The more I stare at and play with the code the less sure I am about what's intuitive and what's confusing.

def __init__(
self, klass: Type[MC], template: Optional[ModularComponent.Template] = None
) -> None:
self._klass = klass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is 'klass' (sorry)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's type-annotated in __init__ for reference, but it's a class. Spelled with a k so it doesn't clash with the python reserved word class. So TransformerEncoderLayer, for example.

self.pad_idx = dictionary[dictionary.null_token]
self.start_idx = dictionary[dictionary.start_token]
self.end_idx = dictionary[dictionary.end_token]
super().__init__(self.pad_idx, self.start_idx, self.end_idx)
template = template or self.Template()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto


def build_model(self, states=None):
wrapped_class = TransformerGeneratorModel.with_components(
encoder=TransformerEncoder.with_components(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes this is the one

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed, this looks nice

Copy link
Contributor

@stephenroller stephenroller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a tutorial in website docs showing essentially the same as examples, but with more prose.

parlai/agents/transformer/modules/interfaces.py Outdated Show resolved Hide resolved
parlai/agents/examples/transformer_variant.py Show resolved Hide resolved
@stephenroller
Copy link
Contributor

(I like this final form a lot)

@@ -0,0 +1,60 @@
# Swapping Out Model Subcomponents
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

say Transformer explicitly, since this only works for Transformers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It technically can be added to any model (or any class at all, really). But I'll take your advice to make it easier to think about for now.

@spencerp spencerp merged commit 6040079 into master Apr 29, 2021
@spencerp spencerp deleted the transformer-manifest-mvp branch April 29, 2021 14:08
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants