Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[T5] T5 in ParlAI #3519

Merged
merged 26 commits into from
Mar 19, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions parlai/agents/t5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# T5

"Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"

See https://arxiv.org/abs/1910.10683.


## Implementation

The T5 model in ParlAI is based on the `T5ForConditionalGeneration` provided by the [HuggingFace Transformers](https://github.com/huggingface/transformers) library. The model can be instantiated with any of the provided architectures there:

- `t5-small`: 60 million parameters
- `t5-base`: 220 million parameters
- `t5-large`: 770 million parameters
- `t5-3b`: 3 billion parameters
- `t5-11b`: 11 billion parameters

**Model Parallel**: HuggingFace has implemented model parallel for T5, however it is an experimental feature, so proceed at your own risk; you can use model parallel by simply specifying `--t5-model-parallel`.

## Basic Examples

### Train t5 large on convai2.
```bash
parlai train_model -m t5 -mf /tmp/model_file -t convai2 -bs 24 --fp16 true -eps 1 -lr 1e-5 --optimizer adam --t5-model-arch t5-large
```
Empty file added parlai/agents/t5/__init__.py
Empty file.
127 changes: 127 additions & 0 deletions parlai/agents/t5/dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrapped HF Tokenizer as a ParlAI DictionaryAgent.
"""
from abc import ABC, abstractmethod, abstractproperty
from collections import defaultdict
from transformers import T5TokenizerFast

from parlai.core.dict import DictionaryAgent
from parlai.core.opt import Opt


class HFTokenizerDictionaryAgent(DictionaryAgent, ABC):
"""
Handle Dict Agent responsibilities using a Tokenizer from HF.
"""

def __init__(self, opt: Opt, shared=None):
if not shared:
self.hf_tokenizer = self.build_hf_tokenizer(opt)
self.tok2ind = self.hf_tokenizer.get_vocab()
self.ind2tok = {v: k for k, v in self.tok2ind.items()}
else:
self.hf_tokenizer = shared['hf_tokenizer']
self.tok2ind = shared['tok2ind']
self.ind2tok = shared['ind2tok']

self.freq = defaultdict(int)
self.minfreq = opt.get('dict_minfreq', DictionaryAgent.default_minfreq)

self.start_token = self.hf_tokenizer.cls_token
self.end_token = self.hf_tokenizer.sep_token
self.null_token = self.hf_tokenizer.pad_token
self.unk_token = self.hf_tokenizer.unk_token

self._unk_token_idx = self.hf_tokenizer.unk_token_id
self.lower = opt.get('dict_lower', DictionaryAgent.default_lower)
self.tokenizer = 'bert'
self.opt = opt
self.max_length = (
self.opt['text_truncate'] or self.hf_tokenizer.model_max_length
)

def is_prebuilt(self):
return True

@abstractmethod
def build_hf_tokenizer(self, opt):
"""
Return hf tokenizer.
"""

@abstractmethod
def format_text(self, text: str) -> str:
"""
Format text prior to encoding with tokenizer.
"""

@abstractproperty
def add_special_tokens(self) -> bool:
"""
Whether to add special tokens when tokenizing.
"""

def share(self):
shared = super().share()
shared['hf_tokenizer'] = self.hf_tokenizer
shared['ind2tok'] = self.ind2tok
shared['tok2ind'] = self.tok2ind
return shared

def __len__(self):
if hasattr(self, 'hf_tokenizer'):
return self.hf_tokenizer.vocab_size
else:
return super().__len__()

def txt2vec(self, text, vec_type=list):
return self.hf_tokenizer.encode(
self.format_text(text),
add_special_tokens=self.add_special_tokens,
max_length=self.max_length,
pad_to_max_length=False,
truncation='longest_first',
)

def vec2txt(self, vec, **kwargs):
return self.hf_tokenizer.decode(vec, skip_special_tokens=True, **kwargs)

def act(self):
return {}


class T5TokenizerDictionaryAgent(HFTokenizerDictionaryAgent):
"""
Handle Dict Agent responsibilities using a BERT Tokenizer from HF.
"""

def __init__(self, opt: Opt, shared=None):
super().__init__(opt, shared)

self.start_token = self.hf_tokenizer.pad_token
self.end_token = self.hf_tokenizer.eos_token
self.null_token = self.hf_tokenizer.pad_token
self.unk_token = self.hf_tokenizer.unk_token

self._unk_token_idx = self.hf_tokenizer.unk_token_id

def build_hf_tokenizer(self, opt):
return T5TokenizerFast.from_pretrained(opt['t5_model_arch'], truncation=True)

def format_text(self, text: str) -> str:
"""
Format text prior to encoding with tokenizer.
"""
return text

@property
def add_special_tokens(self) -> bool:
"""
Whether to add special tokens when tokenizing.
"""
return True
192 changes: 192 additions & 0 deletions parlai/agents/t5/modules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Wrapped Encoders for ParlAI Use.
"""
import torch
from transformers import T5ForConditionalGeneration

try:
from transformers.models.t5.modeling_t5 import T5Stack
except ModuleNotFoundError:
# Prior versions of transformers package do not have T5Stack
T5Stack = object
from typing import Optional, Dict, Any, Tuple

from parlai.core.opt import Opt
from parlai.core.torch_generator_agent import TorchGeneratorModel

from parlai.agents.t5.dict import T5TokenizerDictionaryAgent


def build_t5(opt: Opt) -> T5ForConditionalGeneration:
return T5ForConditionalGeneration.from_pretrained(
opt['t5_model_arch'], dropout_rate=opt['t5_dropout']
)


def set_device(func):
"""
Decorator for setting device.

HF's model parallel uses `torch.cuda.set_device`, which does not vibe well with
ParlAI.
klshuster marked this conversation as resolved.
Show resolved Hide resolved
"""

def wrap(*args, **kwargs):
if torch.cuda.is_available():
torch.cuda.set_device('cuda:0')
ret = func(*args, **kwargs)
if torch.cuda.is_available():
torch.cuda.set_device('cuda:0')
return ret

return wrap


class ParlaiT5Encoder(torch.nn.Module):
def __init__(
self, opt: Opt, encoder: T5Stack, dictionary: T5TokenizerDictionaryAgent
):
super().__init__()
self.stack = encoder
self.padding_idx = dictionary[dictionary.null_token]
self.paralleled = not opt[
't5_model_parallel'
] # need to parallel in forward; bug in HF

@set_device
def forward(
self,
input: torch.LongTensor,
positions: Optional[torch.LongTensor] = None,
segments: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, torch.BoolTensor]:
"""
Forward pass.

:param LongTensor[batch,seqlen] input:
The input IDs
:param LongTensor[batch,seqlen] positions:
Positions for input IDs
:param LongTensor[batch,seqlen] segments:
If provided, additionally adds ``segments`` as extra embedding features.
"""
if not self.paralleled:
self.stack.parallelize()
mask = input != self.padding_idx
outputs = self.stack(input, attention_mask=mask, output_hidden_states=False)
for k in outputs:
if torch.is_tensor(outputs[k]):
outputs[k] = outputs[k].to(input.device)
return outputs[0], mask


class ParlaiT5Decoder(torch.nn.Module):
def __init__(
self, opt: Opt, decoder: T5Stack, dictionary: T5TokenizerDictionaryAgent
):
super().__init__()
self.stack = decoder
self.padding_idx = dictionary[dictionary.null_token]
self.paralleled = not opt[
't5_model_parallel'
] # need to parallel in forward; bug in HF

@set_device
def forward(
self, input: torch.LongTensor, encoder_state: Tuple[Any], incr_state=None
):
"""
Forward pass.

:param LongTensor[batch,seqlen] input:
The decoder inputs (partial or full decoded token IDs).
:param encoder_state:
Output from the encoder module forward pass.
:param incr_state:
The incremental state: a dictionary whose keys index the layers and whose
values contain the incremental state for each layer.
"""
if not self.paralleled:
self.stack.parallelize()
encoder_output, encoder_mask = encoder_state

mask = input != self.padding_idx
mask[:, 0] = True # first token is pad

outputs = self.stack(
input_ids=input,
attention_mask=mask,
encoder_hidden_states=encoder_output.to(input.device),
encoder_attention_mask=encoder_mask.to(input.device),
)
return outputs[0].to(input.device), incr_state


class ParlaiT5Model(TorchGeneratorModel):
"""
Wrap T5 in ParlAI.
"""

def __init__(self, opt, dictionary):
self.pad_idx = dictionary[dictionary.null_token]
self.start_idx = self.pad_idx
self.end_idx = dictionary[dictionary.end_token]
super().__init__(self.pad_idx, self.start_idx, self.end_idx)
self.t5 = build_t5(opt)
self.encoder = ParlaiT5Encoder(opt, self.t5.get_encoder(), dictionary)
self.decoder = ParlaiT5Decoder(opt, self.t5.get_decoder(), dictionary)

@set_device
def _get_initial_forced_decoder_input(self, bsz: int, inputs: torch.LongTensor):
"""
Return initial input to the decoder.

:param bsz:
batchsize
:param inputs:
inputs to decode

:return initial_input:
initial input for the decoder.
"""
inputs = torch.cat([self.START.detach().expand(bsz, 1), inputs], 1)
return inputs

@set_device
def reorder_encoder_states(self, encoder_states, indices):
"""
Reorder the encoder states.

See ``TorchGeneratorModel.reorder_encoder_states`` for a description.
"""
enc, mask = encoder_states
if not torch.is_tensor(indices):
indices = torch.LongTensor(indices).to(enc.device)
enc = torch.index_select(enc, 0, indices)
mask = torch.index_select(mask, 0, indices)
return enc, mask

def reorder_decoder_incremental_state(
self, incremental_state: Dict[int, dict], inds: torch.Tensor
) -> Dict[int, dict]:
"""
Not *quite* sure how to reconcile this with HF.
"""
return {}

@set_device
def output(self, tensor):
"""
Compute output logits.
"""
# Taken directly from HuggingFace
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
tensor = tensor * (self.t5.model_dim ** -0.5)
lm_logits = self.t5.lm_head(tensor)
return lm_logits
Loading