Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update T5 tutorial for 2.0 release #2080

Merged
merged 2 commits into from
Feb 27, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 18 additions & 146 deletions examples/tutorials/t5_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
==========================================================================

**Author**: `Pendo Abbo <pabbo@fb.com>`__
**Author**: `Joe Cummings <jrcummings@fb.com>`__

"""

Expand All @@ -24,7 +25,6 @@
# Common imports
# --------------
import torch
import torch.nn.functional as F

DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Expand All @@ -47,7 +47,7 @@
# the T5 model expects the input to be batched.
#

from torchtext.prototype.models import T5Transform
from torchtext.models import T5Transform

padding_idx = 0
eos_idx = 1
Expand All @@ -66,7 +66,7 @@
#
# ::
#
# from torchtext.prototype.models import T5_BASE_GENERATION
# from torchtext.models import T5_BASE_GENERATION
# transform = T5_BASE_GENERATION.transform()
#

Expand All @@ -81,7 +81,7 @@
# https://pytorch.org/text/main/models.html
#
#
from torchtext.prototype.models import T5_BASE_GENERATION
from torchtext.models import T5_BASE_GENERATION


t5_base = T5_BASE_GENERATION
Expand All @@ -92,146 +92,18 @@


#######################################################################
# Sequence Generator
# GenerationUtils
# ------------------
#
# We can define a sequence generator to produce an output sequence based on the input sequence provided. This calls on the
# We can use torchtext's `GenerationUtils` to produce an output sequence based on the input sequence provided. This calls on the
# model's encoder and decoder, and iteratively expands the decoded sequences until the end-of-sequence token is generated
# for all sequences in the batch. The `generate` method shown below uses a beam search to generate the sequences. Larger
# beam sizes can result in better generation at the cost of computational complexity, and a beam size of 1 is equivalent to
# a greedy decoder.
#

from torch import Tensor
from torchtext.prototype.models import T5Model


def beam_search(
beam_size: int,
step: int,
bsz: int,
decoder_output: Tensor,
decoder_tokens: Tensor,
scores: Tensor,
incomplete_sentences: Tensor,
):
probs = F.log_softmax(decoder_output[:, -1], dim=-1)
top = torch.topk(probs, beam_size)

# N is number of sequences in decoder_tokens, L is length of sequences, B is beam_size
# decoder_tokens has shape (N,L) -> (N,B,L)
# top.indices has shape (N,B) - > (N,B,1)
# x has shape (N,B,L+1)
# note that when step == 1, N = batch_size, and when step > 1, N = batch_size * beam_size
x = torch.cat([decoder_tokens.unsqueeze(1).repeat(1, beam_size, 1), top.indices.unsqueeze(-1)], dim=-1)

# beams are first created for a given sequence
if step == 1:
# x has shape (batch_size, B, L+1) -> (batch_size * B, L+1)
# new_scores has shape (batch_size,B)
# incomplete_sentences has shape (batch_size * B) = (N)
new_decoder_tokens = x.view(-1, step + 1)
new_scores = top.values
new_incomplete_sentences = incomplete_sentences

# beams already exist, want to expand each beam into possible new tokens to add
# and for all expanded beams beloning to the same sequences, choose the top k
else:
# scores has shape (batch_size,B) -> (N,1) -> (N,B)
# top.values has shape (N,B)
# new_scores has shape (N,B) -> (batch_size, B^2)
new_scores = (scores.view(-1, 1).repeat(1, beam_size) + top.values).view(bsz, -1)

# v, i have shapes (batch_size, B)
v, i = torch.topk(new_scores, beam_size)

# x has shape (N,B,L+1) -> (batch_size, B, L+1)
# i has shape (batch_size, B) -> (batch_size, B, L+1)
# new_decoder_tokens has shape (batch_size, B, L+1) -> (N, L)
x = x.view(bsz, -1, step + 1)
new_decoder_tokens = x.gather(index=i.unsqueeze(-1).repeat(1, 1, step + 1), dim=1).view(-1, step + 1)

# need to update incomplete sentences in case one of the beams was kicked out
# y has shape (N) -> (N, 1) -> (N, B) -> (batch_size, B^2)
y = incomplete_sentences.unsqueeze(-1).repeat(1, beam_size).view(bsz, -1)

# now can use i to extract those beams that were selected
# new_incomplete_sentences has shape (batch_size, B^2) -> (batch_size, B) -> (N, 1) -> N
new_incomplete_sentences = y.gather(index=i, dim=1).view(bsz * beam_size, 1).squeeze(-1)

# new_scores has shape (batch_size, B)
new_scores = v

return new_decoder_tokens, new_scores, new_incomplete_sentences


def generate(encoder_tokens: Tensor, eos_idx: int, model: T5Model, beam_size: int) -> Tensor:

# pass tokens through encoder
bsz = encoder_tokens.size(0)
encoder_padding_mask = encoder_tokens.eq(model.padding_idx)
encoder_embeddings = model.dropout1(model.token_embeddings(encoder_tokens))
encoder_output = model.encoder(encoder_embeddings, tgt_key_padding_mask=encoder_padding_mask)[0]

encoder_output = model.norm1(encoder_output)
encoder_output = model.dropout2(encoder_output)

# initialize decoder input sequence; T5 uses padding index as starter index to decoder sequence
decoder_tokens = torch.ones((bsz, 1), dtype=torch.long) * model.padding_idx
scores = torch.zeros((bsz, beam_size))

# mask to keep track of sequences for which the decoder has not produced an end-of-sequence token yet
incomplete_sentences = torch.ones(bsz * beam_size, dtype=torch.long)

# iteratively generate output sequence until all sequences in the batch have generated the end-of-sequence token
for step in range(model.config.max_seq_len):

if step == 1:
# duplicate and order encoder output so that each beam is treated as its own independent sequence
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1)
new_order = new_order.to(encoder_tokens.device).long()
encoder_output = encoder_output.index_select(0, new_order)
encoder_padding_mask = encoder_padding_mask.index_select(0, new_order)

# causal mask and padding mask for decoder sequence
tgt_len = decoder_tokens.shape[1]
decoder_mask = torch.triu(torch.ones((tgt_len, tgt_len), dtype=torch.float64), diagonal=1).bool()
decoder_padding_mask = decoder_tokens.eq(model.padding_idx)

# T5 implemention uses padding idx to start sequence. Want to ignore this when masking
decoder_padding_mask[:, 0] = False

# pass decoder sequence through decoder
decoder_embeddings = model.dropout3(model.token_embeddings(decoder_tokens))
decoder_output = model.decoder(
decoder_embeddings,
memory=encoder_output,
tgt_mask=decoder_mask,
tgt_key_padding_mask=decoder_padding_mask,
memory_key_padding_mask=encoder_padding_mask,
)[0]

decoder_output = model.norm2(decoder_output)
decoder_output = model.dropout4(decoder_output)
decoder_output = decoder_output * (model.config.embedding_dim ** -0.5)
decoder_output = model.lm_head(decoder_output)

decoder_tokens, scores, incomplete_sentences = beam_search(
beam_size, step + 1, bsz, decoder_output, decoder_tokens, scores, incomplete_sentences
)
# ignore newest tokens for sentences that are already complete
decoder_tokens[:, -1] *= incomplete_sentences

# update incomplete_sentences to remove those that were just ended
incomplete_sentences = incomplete_sentences - (decoder_tokens[:, -1] == eos_idx).long()

# early stop if all sentences have been ended
if (incomplete_sentences == 0).all():
break

# take most likely sequence
decoder_tokens = decoder_tokens.view(bsz, beam_size, -1)[:, 0, :]
return decoder_tokens
# for all sequences in the batch. The `generate` method shown below uses greedy search to generate the sequences. Beam search and
# other decoding strategies are also supported.
#
#
from torchtext.prototype.generate import GenerationUtils

sequence_generator = GenerationUtils(model)


#######################################################################
Expand Down Expand Up @@ -343,16 +215,16 @@ def process_labels(labels, x):
# ------------------
#
# We can put all of the components together to generate summaries on the first batch of articles in the CNNDM test set
# using a beam size of 3.
# using a beam size of 1.
#

batch = next(iter(cnndm_dataloader))
input_text = batch["article"]
target = batch["abstract"]
beam_size = 3
beam_size = 1

model_input = transform(input_text)
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
output_text = transform.decode(model_output.tolist())

for i in range(cnndm_batch_size):
Expand Down Expand Up @@ -442,7 +314,7 @@ def process_labels(labels, x):
beam_size = 1

model_input = transform(input_text)
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
output_text = transform.decode(model_output.tolist())

for i in range(imdb_batch_size):
Expand Down Expand Up @@ -536,7 +408,7 @@ def process_labels(labels, x):
beam_size = 4

model_input = transform(input_text)
model_output = generate(model=model, encoder_tokens=model_input, eos_idx=eos_idx, beam_size=beam_size)
model_output = sequence_generator.generate(model_input, eos_idx=eos_idx, beam_size=beam_size)
output_text = transform.decode(model_output.tolist())

for i in range(multi_batch_size):
Expand Down