diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 628951dd573..dba102e8816 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -425,7 +425,7 @@ def add_cmdline_args( ) agent.add_argument( '--inference', - choices={'beam', 'greedy', 'topk', 'nucleus', 'delayedbeam'}, + choices={'beam', 'greedy', 'topk', 'nucleus', 'delayedbeam', 'delayednucleusbeam'}, default='greedy', help='Generation algorithm', ) @@ -994,6 +994,22 @@ def _treesearch_factory(self, device, verbose=False): verbose=verbose, gpu_beam_blocking=self.opt.get('gpu_beam_blocking', False), ) + elif method == 'delayednucleusbeam': + return DelayedNucleusBeamSearch( + self.opt['topp'], + self.opt['beam_delay'], + beam_size, + min_length=self.beam_min_length, + block_ngram=self.beam_block_ngram, + context_block_ngram=self.beam_context_block_ngram, + length_penalty=self.opt.get('beam_length_penalty', 0.65), + padding_token=self.NULL_IDX, + bos_token=self.START_IDX, + eos_token=self.END_IDX, + device=device, + verbose=verbose, + gpu_beam_blocking=self.opt.get('gpu_beam_blocking', False), + ) elif method == 'topk': return TopKSampling( self.opt['topk'], @@ -1862,6 +1878,19 @@ def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection return BeamSearch.select_paths(self, logprobs, prior_scores, current_length) +class DelayedNucleusBeamSearch(TreeSearch): + def __init__(self, p, delay, *args, **kwargs): + super().__init__(*args, **kwargs) + self.p = p + self.delay = delay + + def select_paths(self, logprobs, prior_scores, current_length) -> _PathSelection: + if current_length < self.delay: + return NucleusSampling.select_paths(self, logprobs, prior_scores, current_length) + else: + return BeamSearch.select_paths(self, logprobs, prior_scores, current_length) + + class TopKSampling(TreeSearch): """ Top-K sampling (Fan et al., 2018).