From d532d70144a38bc8d577eb243a81595e156abc78 Mon Sep 17 00:00:00 2001 From: Emily Dinan Date: Wed, 30 Jun 2021 14:59:52 -0400 Subject: [PATCH 1/3] prefix tokens tga --- .../transformer_generator_prefix.py | 43 +++++++++++++++++++ parlai/core/torch_generator_agent.py | 29 +++++++++---- tests/test_tga.py | 37 +++++++++++++--- 3 files changed, 95 insertions(+), 14 deletions(-) create mode 100644 parlai/agents/test_agents/transformer_generator_prefix.py diff --git a/parlai/agents/test_agents/transformer_generator_prefix.py b/parlai/agents/test_agents/transformer_generator_prefix.py new file mode 100644 index 00000000000..9fc1d543659 --- /dev/null +++ b/parlai/agents/test_agents/transformer_generator_prefix.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test agent which counts its number of unique items. +""" + +from __future__ import annotations +from typing import Tuple +from collections import Counter + +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Test agent which overrides the `get_prefix_tokens` function for Transformer +Generator Agent in order to test its functionality. + +All text generated by this agent should begin with '4 3 2 1 '. +""" +import torch +from typing import Optional + +from parlai.agents.transformer.transformer import TransformerGeneratorAgent +from parlai.core.torch_agent import Batch + + +PREFIX_TEXT = '4 3 2 1 ' + + +class TransformerGeneratorPrefixAgent(TransformerGeneratorAgent): + def get_prefix_tokens(self, batch: Batch) -> Optional[torch.LongTensor]: + bsz = batch.batchsize + dev = batch.text_vec.device + prefix_toks = self.dict.txt2vec(PREFIX_TEXT) + prefix_toks_batch = [prefix_toks for _ in range(bsz)] + return torch.LongTensor(prefix_toks_batch).to(dev) diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 88c40ac121e..faf7752b040 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -865,7 +865,10 @@ def eval_step(self, batch): warn_once("--skip-generation true produces limited metrics") else: maxlen = self.label_truncate or 256 - beam_preds_scores, beams = self._generate(batch, self.beam_size, maxlen) + prefix_tokens = self.get_prefix_tokens(batch) + beam_preds_scores, beams = self._generate( + batch, self.beam_size, maxlen, prefix_tokens=prefix_tokens + ) preds, scores = zip(*beam_preds_scores) self._add_generation_metrics(batch, preds) @@ -1039,6 +1042,17 @@ def _get_next_decoder_input( decoder_input = torch.cat([prev_input, selection], dim=-1) return decoder_input + def get_prefix_tokens(self, batch: Batch) -> Optional[torch.LongTensor]: + """ + Set prefix tokens to seed decoding at generation time. + + By default, we do not utilize prefix tokens, but this is + left overridable by child classes. + + Returned tensor should be of dimension bsz x len(prefix) + """ + return None + def _generate( self, batch: Batch, @@ -1116,15 +1130,12 @@ def _generate( if prefix_tokens is not None and _ts < prefix_tokens.size(1): # generate prefix_tokens for every timestep that they exist # achieve by setting score of all other tokens to be -inf - prefix_toks = prefix_tokens[:, _ts].unsqueeze(-1).repeat(1, beam_size) - prefix_score = score.gather(-1, prefix_toks.unsqueeze(-1)) - prefix_mask = prefix_toks.ne(self.NULL_IDX) + prefix_toks = prefix_tokens[:, _ts] + prefix_mask = torch.ones_like(score, dtype=torch.bool) + prefix_mask[ + :, :, prefix_toks + ] = False # everything except prefix toks should be neginf score[prefix_mask] = neginf(score.dtype) - score[prefix_mask] = score[prefix_mask].scatter_( - -1, - prefix_toks[prefix_mask].unsqueeze(-1), - prefix_score[prefix_mask], - ) for i, b in enumerate(beams): if not b.is_done(): b.advance(score[i]) diff --git a/tests/test_tga.py b/tests/test_tga.py index 8fcf29f7a93..1a2f73fb923 100644 --- a/tests/test_tga.py +++ b/tests/test_tga.py @@ -11,14 +11,15 @@ import parlai.utils.testing as testing_utils from parlai.core.params import ParlaiParser from parlai.core.torch_generator_agent import TorchGeneratorAgent +from parlai.agents.test_agents.transformer_generator_prefix import PREFIX_TEXT -class TestUpgradeOpt(unittest.TestCase): +class TestTGA(unittest.TestCase): """ - Test upgrade_opt behavior. + Test various Torch Generator agent behaviors. """ - def test_inference(self): + def test_upgrade_opt_inference(self): """ Test --inference with simple options. """ @@ -97,9 +98,9 @@ def test_block_full_context(self): self.assertEqual(agent.beam_block_full_context, True) -class TestTreeSearch(unittest.TestCase): +class TestGeneration(unittest.TestCase): """ - Tests various Tree Search functionalities. + Tests various generation functionalities. NOTE: Currently incomplete. """ @@ -142,6 +143,32 @@ def test_full_context_block(self): [5, 4, 6, 7] * 256 + [3] + [5, 4, 6, 7] * 256, ) # 3 is end token. + def test_prefix_tokens(self): + """ + Test functionality of `get_prefix_tokens`. + """ + args = [ + '--model-file', + 'zoo:unittest/transformer_generator2/model', + '--model', + 'test_agents/transformer_generator_prefix', + '--inference', + 'beam', + '--truncate', + '1024', + '--beam-size', + '2', + ] + pp = ParlaiParser(True, True) + agent = create_agent(pp.parse_args(args), True) + obs = {'text': '1 2 3 4 ' * 256, 'episode_done': False} + agent.observe(obs) + act = agent.act() + beam_texts = [x[0] for x in act['beam_texts']] + for beam in beam_texts: + # check that all beams start with the prefix text + assert beam.startswith(PREFIX_TEXT) + if __name__ == '__main__': unittest.main() From 09c0306906cdee8ee38aad3e3ba61a1292a289b4 Mon Sep 17 00:00:00 2001 From: Emily Dinan Date: Wed, 30 Jun 2021 15:01:09 -0400 Subject: [PATCH 2/3] oops --- .../test_agents/transformer_generator_prefix.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/parlai/agents/test_agents/transformer_generator_prefix.py b/parlai/agents/test_agents/transformer_generator_prefix.py index 9fc1d543659..2cd66acc983 100644 --- a/parlai/agents/test_agents/transformer_generator_prefix.py +++ b/parlai/agents/test_agents/transformer_generator_prefix.py @@ -4,20 +4,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -""" -Test agent which counts its number of unique items. -""" - -from __future__ import annotations -from typing import Tuple -from collections import Counter - -#!/usr/bin/env python3 - -# Copyright (c) Facebook, Inc. and its affiliates. -# This source code is licensed under the MIT license found in the -# LICENSE file in the root directory of this source tree. - """ Test agent which overrides the `get_prefix_tokens` function for Transformer Generator Agent in order to test its functionality. From 18e04b73221833d00901893b3592a7de15fe1d80 Mon Sep 17 00:00:00 2001 From: Emily Dinan Date: Tue, 6 Jul 2021 11:25:11 -0400 Subject: [PATCH 3/3] fix test --- parlai/agents/test_agents/transformer_generator_prefix.py | 2 +- tests/test_tga.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/parlai/agents/test_agents/transformer_generator_prefix.py b/parlai/agents/test_agents/transformer_generator_prefix.py index 2cd66acc983..9bceadb831f 100644 --- a/parlai/agents/test_agents/transformer_generator_prefix.py +++ b/parlai/agents/test_agents/transformer_generator_prefix.py @@ -17,7 +17,7 @@ from parlai.core.torch_agent import Batch -PREFIX_TEXT = '4 3 2 1 ' +PREFIX_TEXT = '4 3 2 1' class TransformerGeneratorPrefixAgent(TransformerGeneratorAgent): diff --git a/tests/test_tga.py b/tests/test_tga.py index 1a2f73fb923..5c6ddd38e86 100644 --- a/tests/test_tga.py +++ b/tests/test_tga.py @@ -167,7 +167,9 @@ def test_prefix_tokens(self): beam_texts = [x[0] for x in act['beam_texts']] for beam in beam_texts: # check that all beams start with the prefix text - assert beam.startswith(PREFIX_TEXT) + assert beam.startswith( + PREFIX_TEXT + ), f"[{beam}] does not start with [{PREFIX_TEXT}]" if __name__ == '__main__':