From 3d3386caa9a1334a5846ae49b6878e3c28f4990d Mon Sep 17 00:00:00 2001 From: Kurt Shuster Date: Tue, 4 Aug 2020 17:48:27 -0400 Subject: [PATCH] [TGA] Option to block full context (#2928) * blcok full context * black * unskip test * change tensor to longtensor * default is true, use upgrade opt * add warning --- parlai/core/torch_agent.py | 1 + parlai/core/torch_generator_agent.py | 24 +++++++++- tests/test_tga.py | 65 +++++++++++++++++++++++++++- 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 48d029658d5..4984af5af08 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1363,6 +1363,7 @@ def _set_text_vec(self, obs, history, truncate): obs['full_text'] = history_string if history_string: obs['text_vec'] = history.get_history_vec() + obs['full_text_vec'] = history.get_history_vec() # check truncation if obs.get('text_vec') is not None: diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 0a17321f668..045eb026664 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -368,6 +368,14 @@ def upgrade_opt(cls, opt_from_disk: Opt): ] del opt_from_disk['beam_blacklist_filename'] + # 2020-08-04: Introduce full context beam blocking + # Previous, specifying --beam-context-block-ngram > 1 would block + # from generating ngrams from model's context, which is limited + # by truncation parameters. Now, we block on full dialogue history. + if 'beam_block_full_context' not in opt_from_disk: + warn_once('Loading model with `--beam-block-full-context false`') + opt_from_disk['beam_block_full_context'] = False + return opt_from_disk @classmethod @@ -403,6 +411,13 @@ def add_cmdline_args(cls, argparser): default=-1, help='Size n-grams to block in beam search. val <= 0 implies no blocking', ) + agent.add_argument( + '--beam-block-full-context', + type='bool', + default=True, + help='Block n-grams from the *full* history context. Specify False to block ' + 'up to m tokens in the past, where m is truncation parameter for agent', + ) agent.add_argument( '--beam-length-penalty', type=float, @@ -462,6 +477,7 @@ def __init__(self, opt: Opt, shared=None): self.beam_min_length = opt.get('beam_min_length', 1) self.beam_block_ngram = opt.get('beam_block_ngram', -1) self.beam_context_block_ngram = opt.get('beam_context_block_ngram', -1) + self.beam_block_full_context = opt.get('beam_block_full_context', False) self.temperature = opt.get('temperature', 1.0) assert self.temperature > 0, '--temperature must be greater than 0' self.output_token_losses = opt.get('verbose', False) @@ -981,7 +997,13 @@ def _get_context(self, batch, batch_idx): Intentionally overridable for more complex model histories. """ - return batch.text_vec[batch_idx] + ctxt = batch.text_vec[batch_idx] + if self.beam_block_full_context: + full_ctxt = batch.observations[batch_idx].get('full_text_vec', ctxt) + if not isinstance(full_ctxt, torch.LongTensor): + full_ctxt = torch.LongTensor(full_ctxt).to(ctxt) + ctxt = full_ctxt + return ctxt def _get_initial_decoder_input( self, bsz: int, beam_size: int, dev: torch.device diff --git a/tests/test_tga.py b/tests/test_tga.py index 388f606e5b2..d39e01eac47 100644 --- a/tests/test_tga.py +++ b/tests/test_tga.py @@ -6,7 +6,6 @@ """ Test TorchGeneratorAgent. """ - import unittest from parlai.core.agents import create_agent import parlai.utils.testing as testing_utils @@ -78,6 +77,70 @@ def test_file_inference(self): agent = create_agent(opt, True) self.assertEqual(agent.opt['inference'], 'beam') + def test_block_full_context(self): + """ + Test --beam-block-full-context with older model files. + """ + # old model file == beam block full context false + pp = ParlaiParser(True, True) + opt = pp.parse_args( + ['--model-file', 'zoo:unittest/transformer_generator2/model'] + ) + agent = create_agent(opt, True) + self.assertEqual(agent.opt['beam_block_full_context'], False) + self.assertEqual(agent.beam_block_full_context, False) + + # brand new model == beam block full context true + pp = ParlaiParser(True, True) + opt = pp.parse_args(['--model', 'transformer/generator']) + agent = create_agent(opt, True) + self.assertEqual(agent.opt['beam_block_full_context'], True) + self.assertEqual(agent.beam_block_full_context, True) + + +class TestTreeSearch(unittest.TestCase): + """ + Tests various Tree Search functionalities. + + NOTE: Currently incomplete. + """ + + def test_full_context_block(self): + args = [ + '--model-file', + 'zoo:unittest/transformer_generator2/model', + '--inference', + 'beam', + '--truncate', + '1024', + ] + pp = ParlaiParser(True, True) + agent = create_agent(pp.parse_args(args), True) + obs = {'text': '1 2 3 4 ' * 256, 'episode_done': False} + agent.observe(obs) + batch = agent.batchify([agent.observation]) + self.assertEqual(agent._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256) + + # observe 1 more obs, context is the same (truncation) + agent.observe(obs) + batch = agent.batchify([agent.observation]) + self.assertEqual(agent._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256) + + # Now, set agent's beam_block_full_context + args += ['--beam-block-full-context', 'true'] + agent2 = create_agent(pp.parse_args(args), True) + agent2.observe(obs) + batch = agent2.batchify([agent2.observation]) + self.assertEqual(agent2._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256) + + # observe 1 more obs, context is larger now + agent2.observe(obs) + batch = agent2.batchify([agent2.observation]) + self.assertEqual( + agent2._get_context(batch, 0).tolist(), + [5, 4, 6, 7] * 256 + [3] + [5, 4, 6, 7] * 256, + ) # 3 is end token. + if __name__ == '__main__': unittest.main()