Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
[TGA] Allow setting prefix tokens (#3760)
Browse files Browse the repository at this point in the history
* prefix tokens tga

* oops

* fix test
  • Loading branch information
Emily Dinan authored Jul 6, 2021
1 parent 0b9afe8 commit dd4fd1f
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 14 deletions.
29 changes: 29 additions & 0 deletions parlai/agents/test_agents/transformer_generator_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Test agent which overrides the `get_prefix_tokens` function for Transformer
Generator Agent in order to test its functionality.
All text generated by this agent should begin with '4 3 2 1 '.
"""
import torch
from typing import Optional

from parlai.agents.transformer.transformer import TransformerGeneratorAgent
from parlai.core.torch_agent import Batch


PREFIX_TEXT = '4 3 2 1'


class TransformerGeneratorPrefixAgent(TransformerGeneratorAgent):
def get_prefix_tokens(self, batch: Batch) -> Optional[torch.LongTensor]:
bsz = batch.batchsize
dev = batch.text_vec.device
prefix_toks = self.dict.txt2vec(PREFIX_TEXT)
prefix_toks_batch = [prefix_toks for _ in range(bsz)]
return torch.LongTensor(prefix_toks_batch).to(dev)
29 changes: 20 additions & 9 deletions parlai/core/torch_generator_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,10 @@ def eval_step(self, batch):
warn_once("--skip-generation true produces limited metrics")
else:
maxlen = self.label_truncate or 256
beam_preds_scores, beams = self._generate(batch, self.beam_size, maxlen)
prefix_tokens = self.get_prefix_tokens(batch)
beam_preds_scores, beams = self._generate(
batch, self.beam_size, maxlen, prefix_tokens=prefix_tokens
)
preds, scores = zip(*beam_preds_scores)
self._add_generation_metrics(batch, preds)

Expand Down Expand Up @@ -1045,6 +1048,17 @@ def _get_next_decoder_input(
decoder_input = torch.cat([prev_input, selection], dim=-1)
return decoder_input

def get_prefix_tokens(self, batch: Batch) -> Optional[torch.LongTensor]:
"""
Set prefix tokens to seed decoding at generation time.
By default, we do not utilize prefix tokens, but this is
left overridable by child classes.
Returned tensor should be of dimension bsz x len(prefix)
"""
return None

def _generate(
self,
batch: Batch,
Expand Down Expand Up @@ -1122,15 +1136,12 @@ def _generate(
if prefix_tokens is not None and _ts < prefix_tokens.size(1):
# generate prefix_tokens for every timestep that they exist
# achieve by setting score of all other tokens to be -inf
prefix_toks = prefix_tokens[:, _ts].unsqueeze(-1).repeat(1, beam_size)
prefix_score = score.gather(-1, prefix_toks.unsqueeze(-1))
prefix_mask = prefix_toks.ne(self.NULL_IDX)
prefix_toks = prefix_tokens[:, _ts]
prefix_mask = torch.ones_like(score, dtype=torch.bool)
prefix_mask[
:, :, prefix_toks
] = False # everything except prefix toks should be neginf
score[prefix_mask] = neginf(score.dtype)
score[prefix_mask] = score[prefix_mask].scatter_(
-1,
prefix_toks[prefix_mask].unsqueeze(-1),
prefix_score[prefix_mask],
)
for i, b in enumerate(beams):
if not b.is_done():
b.advance(score[i])
Expand Down
39 changes: 34 additions & 5 deletions tests/test_tga.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@
import parlai.utils.testing as testing_utils
from parlai.core.params import ParlaiParser
from parlai.core.torch_generator_agent import TorchGeneratorAgent
from parlai.agents.test_agents.transformer_generator_prefix import PREFIX_TEXT


class TestUpgradeOpt(unittest.TestCase):
class TestTGA(unittest.TestCase):
"""
Test upgrade_opt behavior.
Test various Torch Generator agent behaviors.
"""

def test_inference(self):
def test_upgrade_opt_inference(self):
"""
Test --inference with simple options.
"""
Expand Down Expand Up @@ -97,9 +98,9 @@ def test_block_full_context(self):
self.assertEqual(agent.beam_block_full_context, True)


class TestTreeSearch(unittest.TestCase):
class TestGeneration(unittest.TestCase):
"""
Tests various Tree Search functionalities.
Tests various generation functionalities.
NOTE: Currently incomplete.
"""
Expand Down Expand Up @@ -142,6 +143,34 @@ def test_full_context_block(self):
[5, 4, 6, 7] * 256 + [3] + [5, 4, 6, 7] * 256,
) # 3 is end token.

def test_prefix_tokens(self):
"""
Test functionality of `get_prefix_tokens`.
"""
args = [
'--model-file',
'zoo:unittest/transformer_generator2/model',
'--model',
'test_agents/transformer_generator_prefix',
'--inference',
'beam',
'--truncate',
'1024',
'--beam-size',
'2',
]
pp = ParlaiParser(True, True)
agent = create_agent(pp.parse_args(args), True)
obs = {'text': '1 2 3 4 ' * 256, 'episode_done': False}
agent.observe(obs)
act = agent.act()
beam_texts = [x[0] for x in act['beam_texts']]
for beam in beam_texts:
# check that all beams start with the prefix text
assert beam.startswith(
PREFIX_TEXT
), f"[{beam}] does not start with [{PREFIX_TEXT}]"


if __name__ == '__main__':
unittest.main()

0 comments on commit dd4fd1f

Please sign in to comment.