Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixup generation utils for prototype release #2065

Merged
merged 3 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
from torchtext.models import T5_BASE_GENERATION
from torchtext.prototype.generate import DEFAULT_MAX_SEQ_LEN, GenerationUtil
from torchtext.prototype.generate import DEFAULT_MAX_SEQ_LEN, GenerationUtils
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase


Expand All @@ -26,9 +26,9 @@ def setUp(self) -> None:
torch.manual_seed(0)

def test_greedy_generate_with_t5(self) -> None:
generation_model = GenerationUtil(self.model)
generation_model = GenerationUtils(self.model)

tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30)
tokens = generation_model.generate(self.inputs, num_beams=1, max_length=30)
generated_text = self.transform.decode(tokens.tolist())

expected_generated_text = [
Expand All @@ -42,13 +42,13 @@ def test_greedy_generate_with_t5(self) -> None:
self.assertEqual(generated_text, expected_generated_text)

def test_generate_errors_with_incorrect_beams(self) -> None:
generation_model = GenerationUtil(self.model, is_encoder_decoder=True)
generation_model = GenerationUtils(self.model, is_encoder_decoder=True)

with self.assertRaises(ValueError):
generation_model.generate(self.inputs, num_beams=0)

@patch("logging.Logger.warning")
def test_warns_when_no_max_len_provided(self, mock) -> None:
generation_model = GenerationUtil(self.model)
generation_model = GenerationUtils(self.model)
generation_model.generate(self.inputs)
mock.assert_called_with(f"`max_len` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")
mock.assert_called_with(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")
41 changes: 22 additions & 19 deletions torchtext/prototype/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
DEFAULT_MAX_SEQ_LEN = 256


class GenerationUtil:
class GenerationUtils:
"""Wrapper to provide generation utils for encoder/decoder models and decoder models.

Example:
>>> model = T5_BASE_GENERATION.get_model()
>>> generative_model = GenerationUtil(model=model)
>>> generative_model.generate(input_ids, num_beams=1, max_len=100)
>>> generative_model = GenerationUtils(model=model)
>>> generative_model.generate(input_ids, num_beams=1, max_length=100)

The wrapper can work with *any* model as long as it meets the following requirements:
1. Is an encoder/decoder or decoder based model.
Expand All @@ -26,15 +26,18 @@ class GenerationUtil:
>>> from transformers import T5Model
>>> model = T5Model.from_pretrained("t5-base")
>>> generative_model = GenerationUtils(model=model, is_huggingface_model=True)
>>> generative_model.generate(input_ids, num_beams=1, max_len=100)
>>> generative_model.generate(input_ids, num_beams=1, max_length=100)

`Note`: We cannot make any claims about the stability of APIs from HuggingFace so all models used from the `transformers`
library are marked 'experimental.'

More examples can be found in the `notebooks` directory of this repository.
"""

def __init__(self, model: nn.Module, is_encoder_decoder: bool = True, is_huggingface_model: bool = False) -> None:
def __init__(self, model: nn.Module, **kwargs) -> None:
self.model = model
self.is_encoder_decoder = is_encoder_decoder
self.is_huggingface_model = is_huggingface_model
self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", True)
self.is_huggingface_model = kwargs.pop("is_huggingface_model", False)

def _prepare_decoder_ids_for_generation(
self, batch_size: int, pad_idx: int = 0, device: Optional[torch.device] = None, **model_kwargs
Expand All @@ -45,13 +48,13 @@ def _prepare_decoder_ids_for_generation(
return torch.ones((batch_size, 1), dtype=torch.long, device=device) * pad_idx

def greedy_search(
self, input_ids: torch.Tensor, max_len: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs
self, input_ids: torch.Tensor, max_length: int, eos_idx: int, pad_idx: Optional[int] = None, **model_kwargs
) -> torch.Tensor:
"""Greedy search decoding for text generation. Takes the most likely next token every time.

Inputs:
input_ids (Tensor): Text prompt(s) for greedy generation.
max_len (int): Max length to generate responses.
max_length (int): Max length to generate responses.
eos_idx (int): End of sequence index.
pad_idx (int): Padding index.
**model_kwargs
Expand Down Expand Up @@ -87,20 +90,20 @@ def greedy_search(
if eos_idx is not None:
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_idx).long())

# Stop iterating once all sequences are finished or exceed the max_len
if unfinished_sequences.max() == 0 or len(input_ids[0]) >= max_len:
# Stop iterating once all sequences are finished or exceed the max_length
if unfinished_sequences.max() == 0 or len(input_ids[0]) >= max_length:
break

return input_ids

def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_len: Optional[int]) -> torch.Tensor:
def beam_search(self, input_ids: torch.Tensor, num_beams: int, max_length: Optional[int]) -> torch.Tensor:
raise NotImplementedError()

def generate(
self,
inputs: Optional[torch.Tensor] = None,
num_beams: Optional[int] = None,
max_len: Optional[int] = None,
max_length: Optional[int] = None,
pad_idx: int = 0,
eos_idx: int = 1,
) -> torch.Tensor:
Expand All @@ -112,7 +115,7 @@ def generate(
Args:
input_ids (Tensor): Ids of tokenized input tokens. The 'seed' text for generation.
num_beams (int): If provided, specifies the number of beams to use in beam search generation.
max_len (int): Max length to generate responses.
max_length (int): Max length to generate responses.
pad_idx (int): Padding index. Defaults to 0.
eos_idx (int): End of sequence index. Defaults to 1.

Expand All @@ -128,14 +131,14 @@ def generate(
model_kwargs["encoder_outputs"] = encoder(inputs)
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs)

if max_len is None:
if max_length is None:
# Too hard to try to figure out the exact max_seq_length for each model
logger.warning(f"`max_len` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")
max_len = DEFAULT_MAX_SEQ_LEN
logger.warning(f"`max_length` was not specified. Defaulting to {DEFAULT_MAX_SEQ_LEN} tokens.")
max_length = DEFAULT_MAX_SEQ_LEN

if num_beams == 1 or num_beams is None:
return self.greedy_search(inputs, max_len, eos_idx, pad_idx=pad_idx, **model_kwargs)
return self.greedy_search(inputs, max_length, eos_idx, pad_idx=pad_idx, **model_kwargs)
elif num_beams > 1:
return self.beam_search(inputs, num_beams, max_len)
return self.beam_search(inputs, num_beams, max_length)
else:
raise ValueError("`num_beams` must be >= 1.")