Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[TGA] Option to block full context #2928

Merged
merged 6 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,6 +1363,7 @@ def _set_text_vec(self, obs, history, truncate):
obs['full_text'] = history_string
if history_string:
obs['text_vec'] = history.get_history_vec()
obs['full_text_vec'] = history.get_history_vec()

# check truncation
if obs.get('text_vec') is not None:
Expand Down
24 changes: 23 additions & 1 deletion parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,14 @@ def upgrade_opt(cls, opt_from_disk: Opt):
]
del opt_from_disk['beam_blacklist_filename']

# 2020-08-04: Introduce full context beam blocking
# Previous, specifying --beam-context-block-ngram > 1 would block
# from generating ngrams from model's context, which is limited
# by truncation parameters. Now, we block on full dialogue history.
if 'beam_block_full_context' not in opt_from_disk:
warn_once('Loading model with `--beam-block-full-context false`')
opt_from_disk['beam_block_full_context'] = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm should we perhaps print a message about this so that people know?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, added


return opt_from_disk

@classmethod
Expand Down Expand Up @@ -389,6 +397,13 @@ def add_cmdline_args(cls, argparser):
default=-1,
help='Size n-grams to block in beam search. val <= 0 implies no blocking',
)
agent.add_argument(
'--beam-block-full-context',
type='bool',
default=True,
help='Block n-grams from the *full* history context. Specify False to block '
'up to m tokens in the past, where m is truncation parameter for agent',
)
agent.add_argument(
'--beam-length-penalty',
type=float,
Expand Down Expand Up @@ -448,6 +463,7 @@ def __init__(self, opt: Opt, shared=None):
self.beam_min_length = opt.get('beam_min_length', 1)
self.beam_block_ngram = opt.get('beam_block_ngram', -1)
self.beam_context_block_ngram = opt.get('beam_context_block_ngram', -1)
self.beam_block_full_context = opt.get('beam_block_full_context', False)
self.temperature = opt.get('temperature', 1.0)
assert self.temperature > 0, '--temperature must be greater than 0'
self.output_token_losses = opt.get('verbose', False)
Expand Down Expand Up @@ -967,7 +983,13 @@ def _get_context(self, batch, batch_idx):

Intentionally overridable for more complex model histories.
"""
return batch.text_vec[batch_idx]
ctxt = batch.text_vec[batch_idx]
if self.beam_block_full_context:
full_ctxt = batch.observations[batch_idx].get('full_text_vec', ctxt)
if not isinstance(full_ctxt, torch.LongTensor):
full_ctxt = torch.LongTensor(full_ctxt).to(ctxt)
ctxt = full_ctxt
return ctxt

def _get_initial_decoder_input(self, bsz: int, beam_size: int, dev: torch.device):
"""
Expand Down
65 changes: 64 additions & 1 deletion tests/test_tga.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
"""
Test TorchGeneratorAgent.
"""

import unittest
from parlai.core.agents import create_agent
import parlai.utils.testing as testing_utils
Expand Down Expand Up @@ -78,6 +77,70 @@ def test_file_inference(self):
agent = create_agent(opt, True)
self.assertEqual(agent.opt['inference'], 'beam')

def test_block_full_context(self):
"""
Test --beam-block-full-context with older model files.
"""
# old model file == beam block full context false
pp = ParlaiParser(True, True)
opt = pp.parse_args(
['--model-file', 'zoo:unittest/transformer_generator2/model']
)
agent = create_agent(opt, True)
self.assertEqual(agent.opt['beam_block_full_context'], False)
self.assertEqual(agent.beam_block_full_context, False)

# brand new model == beam block full context true
pp = ParlaiParser(True, True)
opt = pp.parse_args(['--model', 'transformer/generator'])
agent = create_agent(opt, True)
self.assertEqual(agent.opt['beam_block_full_context'], True)
self.assertEqual(agent.beam_block_full_context, True)


class TestTreeSearch(unittest.TestCase):
"""
Tests various Tree Search functionalities.

NOTE: Currently incomplete.
"""

def test_full_context_block(self):
args = [
'--model-file',
'zoo:unittest/transformer_generator2/model',
'--inference',
'beam',
'--truncate',
'1024',
]
pp = ParlaiParser(True, True)
agent = create_agent(pp.parse_args(args), True)
obs = {'text': '1 2 3 4 ' * 256, 'episode_done': False}
agent.observe(obs)
batch = agent.batchify([agent.observation])
self.assertEqual(agent._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256)

# observe 1 more obs, context is the same (truncation)
agent.observe(obs)
batch = agent.batchify([agent.observation])
self.assertEqual(agent._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256)

# Now, set agent's beam_block_full_context
args += ['--beam-block-full-context', 'true']
agent2 = create_agent(pp.parse_args(args), True)
agent2.observe(obs)
batch = agent2.batchify([agent2.observation])
self.assertEqual(agent2._get_context(batch, 0).tolist(), [5, 4, 6, 7] * 256)

# observe 1 more obs, context is larger now
agent2.observe(obs)
batch = agent2.batchify([agent2.observation])
self.assertEqual(
agent2._get_context(batch, 0).tolist(),
[5, 4, 6, 7] * 256 + [3] + [5, 4, 6, 7] * 256,
) # 3 is end token.


if __name__ == '__main__':
unittest.main()