-
Notifications
You must be signed in to change notification settings - Fork 811
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add generation utils with greedy search and tests
- Loading branch information
1 parent
c008115
commit b699de2
Showing
5 changed files
with
290 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Ensuring the TorchText T5 implementation matches other OSS implementations\n", | ||
"\n", | ||
"> In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 29, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from transformers import T5Model\n", | ||
"from torchtext.prototype.models import T5_BASE\n", | ||
"\n", | ||
"import torch" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 30, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"input_sentence = [\"translate to Spanish: My name is Joe\"]\n", | ||
"output_sentence = [\"Me llamo Joe\"]\n", | ||
"\n", | ||
"transform = T5_BASE.transform()\n", | ||
"tt_t5_model = T5_BASE.get_model()\n", | ||
"\n", | ||
"hf_t5_model = T5Model.from_pretrained(\"t5-base\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 31, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"tokenized_sentence = transform(input_sentence)\n", | ||
"tokenized_output = transform(output_sentence)\n", | ||
"\n", | ||
"tt_output = tt_t5_model(encoder_tokens=tokenized_sentence, decoder_tokens=tokenized_output)\n", | ||
"hf_output = hf_t5_model(input_ids=tokenized_sentence, decoder_input_ids=tokenized_output, return_dict=True)\n", | ||
"\n", | ||
"assert torch.all(tt_output[\"encoder_output\"].eq(hf_output[\"encoder_last_hidden_state\"]))\n", | ||
"assert torch.all(tt_output[\"decoder_output\"].eq(hf_output[\"last_hidden_state\"]))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3.9.13 ('torchtext39')", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.13" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"In order to run this notebook, you will need to install the huggingface library with the following command: `pip install transformers`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/tqdm-4.64.0-py3.9.egg/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | ||
" from .autonotebook import tqdm as notebook_tqdm\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from transformers import T5ForConditionalGeneration, T5Tokenizer, BartForConditionalGeneration, BartTokenizer, GPT2LMHeadModel, GPT2Tokenizer\n", | ||
"from torchtext.prototype.generate import GenerationUtil" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"t5 = T5ForConditionalGeneration.from_pretrained(\"t5-base\")\n", | ||
"bart = BartForConditionalGeneration.from_pretrained(\"facebook/bart-large-cnn\")\n", | ||
"gpt2 = GPT2LMHeadModel.from_pretrained(\"gpt2\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"/opt/homebrew/Caskroom/miniforge/base/envs/torchtext39/lib/python3.9/site-packages/transformers/models/t5/tokenization_t5.py:164: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.\n", | ||
"For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.\n", | ||
"- Be aware that you SHOULD NOT rely on t5-base automatically truncating your input to 512 when padding/encoding.\n", | ||
"- If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.\n", | ||
"- To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.\n", | ||
" warnings.warn(\n" | ||
] | ||
}, | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"['owning a dog is good for you, according to studies. a dog is']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Testing Huggingface's T5\n", | ||
"test_sequence = [\"summarize: studies have shown that owning a dog is good for you\"]\n", | ||
"generative_hf_t5 = GenerationUtil(t5, is_encoder_decoder=True, is_huggingface_model=True)\n", | ||
"t5_tokenizer = T5Tokenizer.from_pretrained(\"t5-base\")\n", | ||
"test_sequence_tk = t5_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n", | ||
"tokens = generative_hf_t5.generate(test_sequence_tk, max_len=20, pad_idx=t5.config.pad_token_id)\n", | ||
"print(t5_tokenizer.batch_decode(tokens, skip_special_tokens=True))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"['PG. PG&E said it scheduled the blackouts in response to forecasts for high winds.']\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Testing Huggingface's BART\n", | ||
"test_sequence = [\"PG&E stated it scheduled the blackouts in response to forecasts for high winds \"\n", | ||
" \"amid dry conditions. The aim is to reduce the risk of wildfires. Nearly 800 thousand customers were \"\n", | ||
" \"scheduled to be affected by the shutoffs which were expected to last through at least midday tomorrow.\"]\n", | ||
"generative_hf_bart = GenerationUtil(bart, is_encoder_decoder=True, is_huggingface_model=True)\n", | ||
"bart_tokenizer = BartTokenizer.from_pretrained(\"facebook/bart-large-cnn\")\n", | ||
"test_sequence_tk = bart_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n", | ||
"tokens = generative_hf_bart.generate(test_sequence_tk, max_len=20, pad_idx=bart.config.pad_token_id)\n", | ||
"print(bart_tokenizer.batch_decode(tokens, skip_special_tokens=True))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"[\"I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to\"]\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Testing Huggingface's GPT2\n", | ||
"test_sequence = [\"I enjoy walking with my cute dog\"]\n", | ||
"generative_hf_gpt2 = GenerationUtil(gpt2, is_encoder_decoder=False, is_huggingface_model=True)\n", | ||
"gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", | ||
"test_sequence_tk = gpt2_tokenizer(test_sequence, return_tensors=\"pt\").input_ids\n", | ||
"tokens = generative_hf_gpt2.generate(test_sequence_tk, max_len=20, pad_idx=gpt2.config.pad_token_id)\n", | ||
"print(gpt2_tokenizer.batch_decode(tokens, skip_special_tokens=True))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3.9.13 ('torchtext39')", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.13" | ||
}, | ||
"orig_nbformat": 4, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "63c8862cb56f124e3ee7674b73de745eeb216416a9b24f78d1fcb7c775bff1b7" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
from unittest.mock import patch | ||
from torchtext.prototype.generate import GenerationUtil | ||
from torchtext.prototype.models import T5_BASE_GENERATION | ||
from torchtext_unittest.common.torchtext_test_case import TorchtextTestCase | ||
import torch | ||
|
||
|
||
class TestGenerationUtil(TorchtextTestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
t5_base = T5_BASE_GENERATION | ||
self.transform = t5_base.transform() | ||
self.model = t5_base.get_model() | ||
self.model.eval() | ||
# Examples taken from T5 Paper and Huggingface | ||
self.inputs = self.transform( | ||
[ | ||
"summarize: studies have shown that owning a dog is good for you", | ||
"translate English to German: That is good.", | ||
"cola sentence: The course is jumping well.", | ||
"stsb sentence1: The rhino grazed on the grass. sentence2: A rhino is grazing in a field.", | ||
"summarize: state authorities dispatched emergency crews tuesday to survey the damage after an onslaught of severe weather in mississippi...", | ||
] | ||
) | ||
torch.manual_seed(0) | ||
|
||
def test_greedy_generate_with_t5(self) -> None: | ||
generation_model = GenerationUtil(self.model) | ||
|
||
tokens = generation_model.generate(self.inputs, num_beams=1, max_len=30) | ||
generated_text = self.transform.decode(tokens.tolist()) | ||
|
||
expected_generated_text = [ | ||
"a dog is good for you, according to studies . owning a dog is good for you, according to studies .", | ||
"Das ist gut.", | ||
"acceptable", | ||
"4.0", | ||
"mississippi authorities dispatch emergency crews to survey damage . severe weather in mississippi has caused extensive damage", | ||
] | ||
|
||
self.assertEqual(generated_text, expected_generated_text) | ||
|
||
def test_generate_errors_with_incorrect_beams(self) -> None: | ||
generation_model = GenerationUtil(self.model, is_encoder_decoder=True) | ||
|
||
with self.assertRaises(ValueError): | ||
generation_model.generate(self.inputs, num_beams=0) | ||
|
||
@patch("logging.Logger.warning") | ||
def test_warns_when_no_max_len_provided(self, mock) -> None: | ||
generation_model = GenerationUtil(self.model) | ||
generation_model.generate(self.inputs) | ||
mock.assert_called_with("`max_len` was not specified. Defaulting to 100 tokens.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters