Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[TGA] Allow setting prefix tokens #3760

Merged
merged 3 commits into from
Jul 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions parlai/agents/test_agents/transformer_generator_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Test agent which overrides the `get_prefix_tokens` function for Transformer
Generator Agent in order to test its functionality.

All text generated by this agent should begin with '4 3 2 1 '.
"""
import torch
from typing import Optional

from parlai.agents.transformer.transformer import TransformerGeneratorAgent
from parlai.core.torch_agent import Batch


PREFIX_TEXT = '4 3 2 1'


class TransformerGeneratorPrefixAgent(TransformerGeneratorAgent):
def get_prefix_tokens(self, batch: Batch) -> Optional[torch.LongTensor]:
bsz = batch.batchsize
dev = batch.text_vec.device
prefix_toks = self.dict.txt2vec(PREFIX_TEXT)
prefix_toks_batch = [prefix_toks for _ in range(bsz)]
return torch.LongTensor(prefix_toks_batch).to(dev)
29 changes: 20 additions & 9 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,10 @@ def eval_step(self, batch):
warn_once("--skip-generation true produces limited metrics")
else:
maxlen = self.label_truncate or 256
beam_preds_scores, beams = self._generate(batch, self.beam_size, maxlen)
prefix_tokens = self.get_prefix_tokens(batch)
beam_preds_scores, beams = self._generate(
batch, self.beam_size, maxlen, prefix_tokens=prefix_tokens
)
preds, scores = zip(*beam_preds_scores)
self._add_generation_metrics(batch, preds)

Expand Down Expand Up @@ -1039,6 +1042,17 @@ def _get_next_decoder_input(
decoder_input = torch.cat([prev_input, selection], dim=-1)
return decoder_input

def get_prefix_tokens(self, batch: Batch) -> Optional[torch.LongTensor]:
"""
Set prefix tokens to seed decoding at generation time.

By default, we do not utilize prefix tokens, but this is
left overridable by child classes.

Returned tensor should be of dimension bsz x len(prefix)
"""
return None

def _generate(
self,
batch: Batch,
Expand Down Expand Up @@ -1116,15 +1130,12 @@ def _generate(
if prefix_tokens is not None and _ts < prefix_tokens.size(1):
# generate prefix_tokens for every timestep that they exist
# achieve by setting score of all other tokens to be -inf
prefix_toks = prefix_tokens[:, _ts].unsqueeze(-1).repeat(1, beam_size)
prefix_score = score.gather(-1, prefix_toks.unsqueeze(-1))
prefix_mask = prefix_toks.ne(self.NULL_IDX)
prefix_toks = prefix_tokens[:, _ts]
prefix_mask = torch.ones_like(score, dtype=torch.bool)
prefix_mask[
:, :, prefix_toks
] = False # everything except prefix toks should be neginf
Comment on lines +1133 to +1137
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks this is way cleaner

score[prefix_mask] = neginf(score.dtype)
score[prefix_mask] = score[prefix_mask].scatter_(
-1,
prefix_toks[prefix_mask].unsqueeze(-1),
prefix_score[prefix_mask],
)
for i, b in enumerate(beams):
if not b.is_done():
b.advance(score[i])
Expand Down
39 changes: 34 additions & 5 deletions tests/test_tga.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import parlai.utils.testing as testing_utils
from parlai.core.params import ParlaiParser
from parlai.core.torch_generator_agent import TorchGeneratorAgent
from parlai.agents.test_agents.transformer_generator_prefix import PREFIX_TEXT


class TestUpgradeOpt(unittest.TestCase):
class TestTGA(unittest.TestCase):
"""
Test upgrade_opt behavior.
Test various Torch Generator agent behaviors.
"""

def test_inference(self):
def test_upgrade_opt_inference(self):
"""
Test --inference with simple options.
"""
Expand Down Expand Up @@ -97,9 +98,9 @@ def test_block_full_context(self):
self.assertEqual(agent.beam_block_full_context, True)


class TestTreeSearch(unittest.TestCase):
class TestGeneration(unittest.TestCase):
"""
Tests various Tree Search functionalities.
Tests various generation functionalities.

NOTE: Currently incomplete.
"""
Expand Down Expand Up @@ -142,6 +143,34 @@ def test_full_context_block(self):
[5, 4, 6, 7] * 256 + [3] + [5, 4, 6, 7] * 256,
) # 3 is end token.

def test_prefix_tokens(self):
"""
Test functionality of `get_prefix_tokens`.
"""
args = [
'--model-file',
'zoo:unittest/transformer_generator2/model',
'--model',
'test_agents/transformer_generator_prefix',
'--inference',
'beam',
'--truncate',
'1024',
'--beam-size',
'2',
]
pp = ParlaiParser(True, True)
agent = create_agent(pp.parse_args(args), True)
obs = {'text': '1 2 3 4 ' * 256, 'episode_done': False}
agent.observe(obs)
act = agent.act()
beam_texts = [x[0] for x in act['beam_texts']]
for beam in beam_texts:
# check that all beams start with the prefix text
assert beam.startswith(
PREFIX_TEXT
), f"[{beam}] does not start with [{PREFIX_TEXT}]"


if __name__ == '__main__':
unittest.main()