diff --git a/parlai/agents/hred/hred.py b/parlai/agents/hred/hred.py index 888343445b8..376e40a4d2c 100644 --- a/parlai/agents/hred/hred.py +++ b/parlai/agents/hred/hred.py @@ -128,6 +128,10 @@ def batchify(self, obs_batch, sort=True): Store history vec as context_vec. """ + # NOTE: `sort` is set to True here (Default is False in TorchGeneratorAgent) + # TODO: Sorting the batch will result in various local metrics being broadcasted + # back to individual examples in the wrong order, such as the lengths of + # the context and labels. Aggregate metric reports will still be accurate. batch = super().batchify(obs_batch, sort) # sum here is list concat, not addition context_vec, hist_lens_ = self._pad_tensor( diff --git a/parlai/agents/seq2seq/seq2seq.py b/parlai/agents/seq2seq/seq2seq.py index df972d11c96..8ce0e0215bd 100644 --- a/parlai/agents/seq2seq/seq2seq.py +++ b/parlai/agents/seq2seq/seq2seq.py @@ -209,6 +209,9 @@ def batchify(self, *args, **kwargs): Override batchify options for seq2seq. """ kwargs['sort'] = True # need sorted for pack_padded + # TODO: Sorting the batch will result in various local metrics being broadcasted + # back to individual examples in the wrong order, such as the lengths of + # the context and labels. Aggregate metric reports will still be accurate. return super().batchify(*args, **kwargs) def state_dict(self): diff --git a/parlai/core/torch_generator_agent.py b/parlai/core/torch_generator_agent.py index 558b836fbbb..0c16ba4779c 100644 --- a/parlai/core/torch_generator_agent.py +++ b/parlai/core/torch_generator_agent.py @@ -639,7 +639,7 @@ def vectorize(self, *args, **kwargs): kwargs['add_end'] = True # we do want this return super().vectorize(*args, **kwargs) - def batchify(self, obs_batch, sort=True): + def batchify(self, obs_batch, sort=False): batch = super().batchify(obs_batch, sort=sort) if ( self.beam_block_full_context