Skip to content

Commit

Permalink
Add generation utils with greedy search and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings committed Dec 20, 2022
1 parent c008115 commit b699de2
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 3 deletions.
83 changes: 83 additions & 0 deletions notebooks/hf_vs_tt_t5.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Ensuring the TorchText T5 implementation matches other OSS implementations\n",
"\n",
"> In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"from transformers import T5Model\n",
"from torchtext.prototype.models import T5_BASE\n",
"\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"input_sentence = [\"translate to Spanish: My name is Joe\"]\n",
"output_sentence = [\"Me llamo Joe\"]\n",
"\n",
"transform = T5_BASE.transform()\n",
"tt_t5_model = T5_BASE.get_model()\n",
"\n",
"hf_t5_model = T5Model.from_pretrained(\"t5-base\")"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"tokenized_sentence = transform(input_sentence)\n",
"tokenized_output = transform(output_sentence)\n",
"\n",
"tt_output = tt_t5_model(encoder_tokens=tokenized_sentence, decoder_tokens=tokenized_output)\n",
"hf_output = hf_t5_model(input_ids=tokenized_sentence, decoder_input_ids=tokenized_output, return_dict=True)\n",
"\n",
"assert torch.all(tt_output[\"encoder_output\"].eq(hf_output[\"encoder_last_hidden_state\"]))\n",
"assert torch.all(tt_output[\"decoder_output\"].eq(hf_output[\"last_hidden_state\"]))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('torchtext39')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
151 changes: 151 additions & 0 deletions notebooks/hf_with_torchtext_gen.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"from transformers import T5ForConditionalGeneration, T5Tokenizer, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer\n",
"from torchtext.prototype.generate import GenerationUtil"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"t5 = T5ForConditionalGeneration.from_pretrained(\"t5-base\")\n",
"bart = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")\n",
"gpt2 = GPT2LMHeadModel.from_pretrained(\"gpt2\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n",
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n",
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n",
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n",
"- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n",
" warnings.warn(\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"['owning a dog is good for you, according to studies. a dog is']\n"
]
}
],
"source": [
"# Testing Huggingface's T5\n",
"test_sequence = [\"summarize: studies have shown that owning a dog is good for you\"]\n",
"generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)\n",
"t5_tokenizer = T5Tokenizer.from_pretrained(\"t5-base\")\n",
"test_sequence_tk = t5_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=20, pad_idx=t5.config.pad_token_id)\n",
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"['PG. PG&E said it scheduled the blackouts in response to forecasts for high winds.']\n"
]
}
],
"source": [
"# Testing Huggingface's BART\n",
"test_sequence = [\"PG&E stated it scheduled the blackouts in response to forecasts for high winds \"\n",
" \"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were \"\n",
" \"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.\"]\n",
"generative_hf_bart = GenerationUtil(bart, is_encoder_decoder=True, is_huggingface_model=True)\n",
"bart_tokenizer = BartTokenizer.from_pretrained(\"facebook/bart-large-cnn\")\n",
"test_sequence_tk = bart_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id)\n",
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[\"I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to\"]\n"
]
}
],
"source": [
"# Testing Huggingface's GPT2\n",
"test_sequence = [\"I enjoy walking with my cute dog\"]\n",
"generative_hf_gpt2 = GenerationUtil(gpt2, is_encoder_decoder=False, is_huggingface_model=True)\n",
"gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
"test_sequence_tk = gpt2_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n",
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n",
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.13 ('torchtext39')",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
53 changes: 53 additions & 0 deletions test/torchtext_unittest/prototype/test_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from unittest.mock import patch
from torchtext.prototype.generate import GenerationUtil
from torchtext.prototype.models import T5_BASE_GENERATION
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase
import torch


class TestGenerationUtil(TorchtextTestCase):
def setUp(self) -> None:
super().setUp()
t5_base = T5_BASE_GENERATION
self.transform = t5_base.transform()
self.model = t5_base.get_model()
self.model.eval()
# Examples taken from T5 Paper and Huggingface
self.inputs = self.transform(
[
"summarize: studies have shown that owning a dog is good for you",
"translate English to German: That is good.",
"cola sentence: The course is jumping well.",
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.",
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...",
]
)
torch.manual_seed(0)

def test_greedy_generate_with_t5(self) -> None:
generation_model = GenerationUtil(self.model)

tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30)
generated_text = self.transform.decode(tokens.tolist())

expected_generated_text = [
"a dog is good for you, according to studies . owning a dog is good for you, according to studies .",
"Das ist gut.",
"acceptable",
"4.0",
"mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage",
]

self.assertEqual(generated_text, expected_generated_text)

def test_generate_errors_with_incorrect_beams(self) -> None:
generation_model = GenerationUtil(self.model, is_encoder_decoder=True)

with self.assertRaises(ValueError):
generation_model.generate(self.inputs, num_beams=0)

@patch("logging.Logger.warning")
def test_warns_when_no_max_len_provided(self, mock) -> None:
generation_model = GenerationUtil(self.model)
generation_model.generate(self.inputs)
mock.assert_called_with("`max_len` was not specified. Defaulting to 100 tokens.")
5 changes: 3 additions & 2 deletions torchtext/prototype/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class GenerationUtil:
More examples can be found in the `notebooks` directory of this repository.
"""

def __init__(self, model: nn.Module, is_encoder_decoder: bool = True, is_huggingface_model: bool = False) -> None:
self.model = model
self.is_encoder_decoder = is_encoder_decoder
Expand All @@ -53,7 +54,7 @@ def greedy_search(
eos_idx (int): End of sequence index.
pad_idx (int): Padding index.
**model_kwargs
Returns:
Batch of sequences decoded by greedy search.
"""
Expand Down Expand Up @@ -125,7 +126,7 @@ def generate(
encoder = self.model.get_encoder()
model_kwargs["encoder_outputs"] = encoder(inputs)
inputs = self._prepare_decoder_ids_for_generation(len(inputs), device=inputs.device, **model_kwargs)

if max_len is None:
# Too hard to try to figure out the exact max_seq_length for each model
logger.warning("`max_len` was not specified. Defaulting to 256 tokens.")
Expand Down
1 change: 0 additions & 1 deletion torchtext/prototype/models/t5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,6 @@ def __init__(
for p in self.parameters():
p.requires_grad = False

@torch.jit.ignore
def prepare_inputs_for_generation(self, input_ids, encoder_outputs):
return {"decoder_tokens": input_ids, "encoder_outputs": encoder_outputs}

Expand Down

0 comments on commit b699de2

Please sign in to comment.