Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Decoder-Only Transformer (#4329)
Browse files Browse the repository at this point in the history
* quick and dirty decoder-only implementation

* fix decoder_only incremental decoding

* remove unused code, add some comments, propogate func signature change

* consolidate code in decoder.py

* unify encoder_state

* export PassThroughEncoder

* add missing build_ functions

* defaults in TransformerDecoderLayer __init__

* comments, consolidating more logic, simplified forward_layers args

* resize token embeddings and unit test

* attempt to suppress some unused import warnings

* padded_tensor fp16 friendly

* autoformat

* decoder_only -> decoder

* more documentation

* update name in test

* add missing dict args

* more argument massaging

* update TestBartDistillation::test_narrow_distillation_losses numbers

* update TestTransformerDistillation::test_narrow_distillation_losses numbers

* fix _pad_tensor in seeker

Co-authored-by: klshuster <kshuster@fb.com>
  • Loading branch information
spencerp and klshuster authored May 4, 2022
1 parent a2cc9f4 commit ecdfbd0
Show file tree
Hide file tree
Showing 18 changed files with 857 additions and 300 deletions.
2 changes: 1 addition & 1 deletion parlai/agents/hugging_face/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def build_model(self, states=None):
def _encoder_input(self, batch):
return (batch.text_vec,)

def _pad_tensor(self, items):
def _pad_tensor(self, items, is_label=False):
"""
Override to always set fp16friendly to False and left_pad to True.
"""
Expand Down
9 changes: 7 additions & 2 deletions parlai/agents/rag/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def __init__(self, opt, dictionary, retriever_shared=None):
padding_idx=self.pad_idx,
)
self.seq2seq_decoder = self.build_decoder(
opt, embedding=self.embeddings, padding_idx=self.pad_idx
opt,
embedding=self.embeddings,
dictionary=dictionary,
padding_idx=self.pad_idx,
)

@classmethod
Expand Down Expand Up @@ -121,7 +124,9 @@ def build_decoder(
**kwargs,
):
if decoder_class is None:
return RagDecoder(opt=opt, embedding=embedding, n_positions=n_positions)
return RagDecoder(
opt=opt, embedding=embedding, n_positions=n_positions, **kwargs
)
else:
return decoder_class(opt, *args, **kwargs)

Expand Down
95 changes: 95 additions & 0 deletions parlai/agents/transformer/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional

from parlai.agents.transformer.transformer import add_common_cmdline_args
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.torch_generator_agent import TorchGeneratorAgent
from parlai.utils.logging import logging
from parlai.utils.misc import recursive_getattr
from parlai.utils.torch import padded_tensor

from .modules import (
PassThroughEncoder,
TransformerDecoderOnly,
TransformerGeneratorModel,
)


class DecoderAgent(TorchGeneratorAgent):
"""
DecoderOnlyAgent.
Implementation of TorchGeneratorAgent, where the model is a Decoder-Only
Transformer.
"""

@classmethod
def add_cmdline_args(
cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
) -> ParlaiParser:
"""
Add command-line arguments specifically for this agent.
"""
agent = parser.add_argument_group('Decoder-Only Transformer Arguments')
add_common_cmdline_args(agent)
cls.dictionary_class().add_cmdline_args(parser, partial_opt=partial_opt)

super().add_cmdline_args(parser, partial_opt=partial_opt)
return agent

def build_model(self, states=None):
"""
Override of ``TorchAgent.build_model``.
"""
assert (
self.opt['n_encoder_layers'] == -1
), "Decoder-only model cannot have encoder layers."
wrapped_class = TransformerGeneratorModel.with_components(
encoder=PassThroughEncoder, decoder=TransformerDecoderOnly
)
return wrapped_class(self.opt, self.dict)

def _pad_tensor(self, items, is_label=False):
"""
Override of ``TorchAgent._pad_tensor``.
Pads context tensor on the left and label tensor on the right, such that when
they are concatenated the example meets in the middle to form a continuous
sequence.
"""
return padded_tensor(
items,
pad_idx=self.NULL_IDX,
left_padded=(not is_label),
fp16friendly=self.fp16,
)

def _resize_token_embeddings(self, state_dict, msg=None):
"""
Resize the token embeddings when adding extra special tokens.
"""
# map extra special tokens carefully
new_size = self.model.embeddings.weight.size()[0]
orig_size = state_dict['embeddings.weight'].size()[0]
logging.info(f'Resizing token embeddings from {orig_size} to {new_size}')
if new_size <= orig_size:
# new size should be greater than original size,
# as we are adding special tokens
raise RuntimeError(msg)

for emb_weights in ['embeddings.weight', 'decoder.embeddings.weight']:
# get new_embs
old_embs = state_dict[emb_weights]
new_embs = recursive_getattr(self.model, emb_weights).to(old_embs.device)
# copy over old weights
new_embs.data[:orig_size, :] = old_embs.data[:orig_size, :]
# reset in state dict
state_dict[emb_weights] = new_embs

return state_dict
13 changes: 11 additions & 2 deletions parlai/agents/transformer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,17 @@
)
from .attention import BasicAttention, MultiHeadAttention # noqa: F401
from .ffn import TransformerFFN # noqa: F401
from .encoder import TransformerEncoder, TransformerEncoderLayer # noqa: F401
from .decoder import TransformerDecoder, TransformerDecoderLayer # noqa: F401
from .encoder import ( # noqa: F401
PassThroughEncoder,
TransformerEncoder,
TransformerEncoderLayer,
)
from .decoder import ( # noqa: F401
TransformerDecoder,
TransformerDecoderLayer,
TransformerDecoderOnly,
TransformerDecoderOnlyLayer,
)
from .generator import TransformerGeneratorModel # noqa: F401
from .wrappers import TransformerLinearWrapper, TransformerResponseWrapper # noqa: F401
from .mem_net import TransformerMemNetModel # noqa: F401
21 changes: 10 additions & 11 deletions parlai/agents/transformer/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch.nn.functional as F

from parlai.core.opt import Opt
from parlai.core.params import default
from parlai.utils.torch import neginf


Expand Down Expand Up @@ -98,14 +99,8 @@ def __init__(
):
super(MultiHeadAttention, self).__init__()

def _default(val, default):
"""
shorthand for explicit None check for optional arguments.
"""
return val if val is not None else default

n_heads = _default(n_heads, opt['n_heads'])
dim = _default(dim, opt['embedding_size'])
n_heads = default(n_heads, opt['n_heads'])
dim = default(dim, opt['embedding_size'])

self.n_heads = n_heads
self.dim = dim
Expand Down Expand Up @@ -224,9 +219,13 @@ def prepare_head(tensor):
if static_kv:
mask = incr_state['prev_mask']
else:
mask = torch.cat([incr_state['prev_mask'], mask], dim=2)
# Prepend along the key_len dimension (analogous to
# incr_state['prev_key'])
# Mask will be of size (B x query_len x key_len)
# During incremental decoding the query will only represent the next token,
# whereas the key/value will represent the entire sequence thus far.
# In such a case, we only want to look at the last element of the mask in the query dimension.
prev_mask = incr_state['prev_mask'][:, -query_len:, :]
mask = torch.cat([prev_mask, mask], dim=2)
# Prepend along the key_len dimension (analogous to incr_state['prev_key'])

# Save new incremental states. We reshape to allow for reordering along batch
# dimension.
Expand Down
Loading

0 comments on commit ecdfbd0

Please sign in to comment.