Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[TGA] Option to block full context (#2928)
Browse files Browse the repository at this point in the history
* blcok full context

* black

* unskip test

* change tensor to longtensor

* default is true, use upgrade opt

* add warning
  • Loading branch information
klshuster authored Aug 4, 2020
1 parent 21463cc commit 3d3386c
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 2 deletions.
1 change: 1 addition & 0 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,7 @@ def _set_text_vec(self, obs, history, truncate):
obs['full_text'] = history_string
if history_string:
obs['text_vec'] = history.get_history_vec()
obs['full_text_vec'] = history.get_history_vec()

# check truncation
if obs.get('text_vec') is not None:
Expand Down
24 changes: 23 additions & 1 deletion parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,14 @@ def upgrade_opt(cls, opt_from_disk: Opt):
]
del opt_from_disk['beam_blacklist_filename']

# 2020-08-04: Introduce full context beam blocking
# Previous, specifying --beam-context-block-ngram > 1 would block
# from generating ngrams from model's context, which is limited
# by truncation parameters. Now, we block on full dialogue history.
if 'beam_block_full_context' not in opt_from_disk:
warn_once('Loading model with `--beam-block-full-context false`')
opt_from_disk['beam_block_full_context'] = False

return opt_from_disk

@classmethod
Expand Down Expand Up @@ -403,6 +411,13 @@ def add_cmdline_args(cls, argparser):
default=-1,
help='Size n-grams to block in beam search. val <= 0 implies no blocking',
)
agent.add_argument(
'--beam-block-full-context',
type='bool',
default=True,
help='Block n-grams from the *full* history context. Specify False to block '
'up to m tokens in the past, where m is truncation parameter for agent',
)
agent.add_argument(
'--beam-length-penalty',
type=float,
Expand Down Expand Up @@ -462,6 +477,7 @@ def __init__(self, opt: Opt, shared=None):
self.beam_min_length = opt.get('beam_min_length', 1)
self.beam_block_ngram = opt.get('beam_block_ngram', -1)
self.beam_context_block_ngram = opt.get('beam_context_block_ngram', -1)
self.beam_block_full_context = opt.get('beam_block_full_context', False)
self.temperature = opt.get('temperature', 1.0)
assert self.temperature > 0, '--temperature must be greater than 0'
self.output_token_losses = opt.get('verbose', False)
Expand Down Expand Up @@ -981,7 +997,13 @@ def _get_context(self, batch, batch_idx):
Intentionally overridable for more complex model histories.
"""
return batch.text_vec[batch_idx]
ctxt = batch.text_vec[batch_idx]
if self.beam_block_full_context:
full_ctxt = batch.observations[batch_idx].get('full_text_vec', ctxt)
if not isinstance(full_ctxt, torch.LongTensor):
full_ctxt = torch.LongTensor(full_ctxt).to(ctxt)
ctxt = full_ctxt
return ctxt

def _get_initial_decoder_input(
self, bsz: int, beam_size: int, dev: torch.device
Expand Down
65 changes: 64 additions & 1 deletion tests/test_tga.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""
Test TorchGeneratorAgent.
"""

import unittest
from parlai.core.agents import create_agent
import parlai.utils.testing as testing_utils
Expand Down Expand Up @@ -78,6 +77,70 @@ def test_file_inference(self):
agent = create_agent(opt, True)
self.assertEqual(agent.opt['inference'], 'beam')

def test_block_full_context(self):
"""
Test --beam-block-full-context with older model files.
"""
# old model file == beam block full context false
pp = ParlaiParser(True, True)
opt = pp.parse_args(
['--model-file', 'zoo:unittest/transformer_generator2/model']
)
agent = create_agent(opt, True)
self.assertEqual(agent.opt['beam_block_full_context'], False)
self.assertEqual(agent.beam_block_full_context, False)

# brand new model == beam block full context true
pp = ParlaiParser(True, True)
opt = pp.parse_args(['--model', 'transformer/generator'])
agent = create_agent(opt, True)
self.assertEqual(agent.opt['beam_block_full_context'], True)
self.assertEqual(agent.beam_block_full_context, True)


class TestTreeSearch(unittest.TestCase):
"""
Tests various Tree Search functionalities.
NOTE: Currently incomplete.
"""

def test_full_context_block(self):
args = [
'--model-file',
'zoo:unittest/transformer_generator2/model',
'--inference',
'beam',
'--truncate',
'1024',
]
pp = ParlaiParser(True, True)
agent = create_agent(pp.parse_args(args), True)
obs = {'text': '1 2 3 4 ' * 256, 'episode_done': False}
agent.observe(obs)
batch = agent.batchify([agent.observation])
self.assertEqual(agent._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256)

# observe 1 more obs, context is the same (truncation)
agent.observe(obs)
batch = agent.batchify([agent.observation])
self.assertEqual(agent._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256)

# Now, set agent's beam_block_full_context
args += ['--beam-block-full-context', 'true']
agent2 = create_agent(pp.parse_args(args), True)
agent2.observe(obs)
batch = agent2.batchify([agent2.observation])
self.assertEqual(agent2._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256)

# observe 1 more obs, context is larger now
agent2.observe(obs)
batch = agent2.batchify([agent2.observation])
self.assertEqual(
agent2._get_context(batch, 0).tolist(),
[5, 4, 6, 7] * 256 + [3] + [5, 4, 6, 7] * 256,
) # 3 is end token.


if __name__ == '__main__':
unittest.main()

0 comments on commit 3d3386c

Please sign in to comment.