diff --git a/docs/source/tutorial_metrics.md b/docs/source/tutorial_metrics.md index 75be7b42266..4ece1bdce18 100644 --- a/docs/source/tutorial_metrics.md +++ b/docs/source/tutorial_metrics.md @@ -426,6 +426,7 @@ please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issu | `ctpb` | Context tokens per batch | | `ctps` | Context tokens per second | | `ctrunc` | Fraction of samples with some context truncation | +| `context_average_tokens_truncated` | Average length of context tokens truncated | | `exps` | Examples per second | | `exs` | Number of examples processed since last print | | `f1` | Unigram F1 overlap, under a standardized (model-independent) tokenizer | @@ -441,6 +442,7 @@ please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issu | `ltpb` | Label tokens per batch | | `ltps` | Label tokens per second | | `ltrunc` | Fraction of samples with some label truncation | +| `label_average_tokens_truncated` | Average length of label tokens truncated | | `rouge-1`, `rouge-1`, `rouge-L` | ROUGE metrics | | `token_acc` | Token-wise accuracy (generative only) | | `token_em` | Utterance-level token accuracy. Roughly corresponds to perfection under greedy search (generative only) | diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 8a6aba311cd..2a43c225dca 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -114,8 +114,10 @@ class Batch(AttrDict): image: Optional[List[Any]] _context_original_length: Optional[torch.LongTensor] _context_truncate_rate: Optional[torch.LongTensor] + _context_truncated_length: Optional[torch.LongTensor] _label_original_length: Optional[torch.LongTensor] _label_truncate_rate: Optional[torch.LongTensor] + _label_truncated_length: Optional[torch.LongTensor] def __init__( self, @@ -131,8 +133,10 @@ def __init__( image=None, _context_original_length: Optional[torch.LongTensor] = None, _context_truncate_rate: Optional[torch.LongTensor] = None, + _context_truncated_length: Optional[torch.LongTensor] = None, _label_original_length: Optional[torch.LongTensor] = None, _label_truncate_rate: Optional[torch.LongTensor] = None, + _label_truncated_length: Optional[torch.LongTensor] = None, **kwargs, ): super().__init__( @@ -147,8 +151,10 @@ def __init__( image=image, _context_original_length=_context_original_length, _context_truncate_rate=_context_truncate_rate, + _context_truncated_length=_context_truncated_length, _label_original_length=_label_original_length, _label_truncate_rate=_label_truncate_rate, + _label_truncated_length=_label_truncated_length, **kwargs, ) @@ -1445,6 +1451,9 @@ def _set_text_vec(self, obs, history, truncate): ) obs.force_set('context_original_length', text_length) obs.force_set('context_truncate_rate', text_length != len(truncated_vec)) + obs.force_set( + 'context_truncated_length', max(text_length - len(truncated_vec), 0) + ) obs.force_set('text_vec', torch.LongTensor(truncated_vec)) return obs @@ -1472,6 +1481,9 @@ def _set_label_vec(self, obs, add_start, add_end, truncate): truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate) obs.force_set('label_original_length', vec_label_length) obs.force_set('label_truncate_rate', vec_label_length > len(truncated_vec)) + obs.force_set( + 'label_truncated_length', max(vec_label_length - len(truncated_vec), 0) + ) obs.force_set(label_type + '_vec', torch.LongTensor(truncated_vec)) else: # pick one label if there are multiple @@ -1482,6 +1494,9 @@ def _set_label_vec(self, obs, add_start, add_end, truncate): ) obs.force_set('label_original_length', vec_label_length) obs.force_set('label_truncate_rate', vec_label_truncated) + obs.force_set( + 'label_truncated_length', max(vec_label_length - len(vec_label), 0) + ) obs[label_type + '_vec'] = vec_label obs[label_type + '_choice'] = label @@ -1625,7 +1640,8 @@ def batchify(self, obs_batch, sort=False): valid_inds, exs = zip(*valid_obs) # TEXT - xs = x_lens = context_original_lengths = context_truncate_rate = None + xs = x_lens = context_original_lengths = None + context_truncate_rate = context_truncated_lengths = None if any(ex.get('text_vec') is not None for ex in exs): if any('context_original_length' in ex for ex in exs): context_truncate_rate = torch.LongTensor( @@ -1634,6 +1650,10 @@ def batchify(self, obs_batch, sort=False): context_original_lengths = torch.LongTensor( [ex.get('context_original_length', 0) for ex in exs] ) + if any('context_truncated_length' in ex for ex in exs): + context_truncated_lengths = torch.LongTensor( + [ex.get('context_truncated_length', 0) for ex in exs] + ) _xs = [ex.get('text_vec', self.EMPTY) for ex in exs] xs, x_lens = self._pad_tensor(_xs) if sort: @@ -1646,7 +1666,8 @@ def batchify(self, obs_batch, sort=False): labels_avail = any('labels_vec' in ex for ex in exs) some_labels_avail = labels_avail or any('eval_labels_vec' in ex for ex in exs) - ys = y_lens = labels = label_original_lengths = label_truncate_rate = None + ys = y_lens = labels = label_original_lengths = None + label_truncate_rate = label_truncated_lengths = None if some_labels_avail: if any('label_original_length' in ex for ex in exs): label_truncate_rate = torch.LongTensor( @@ -1655,6 +1676,10 @@ def batchify(self, obs_batch, sort=False): label_original_lengths = torch.LongTensor( [ex.get('label_original_length', 0) for ex in exs] ) + if any('label_truncated_length' in ex for ex in exs): + label_truncated_lengths = torch.LongTensor( + [ex.get('label_truncated_length') for ex in exs] + ) field = 'labels' if labels_avail else 'eval_labels' label_vecs = [ex.get(field + '_vec', self.EMPTY) for ex in exs] @@ -1700,8 +1725,10 @@ def batchify(self, obs_batch, sort=False): observations=exs if self.is_debug else None, _context_original_length=context_original_lengths, _context_truncate_rate=context_truncate_rate, + _context_truncated_length=context_truncated_lengths, _label_original_length=label_original_lengths, _label_truncate_rate=label_truncate_rate, + _label_truncated_length=label_truncated_lengths, ) def match_batch(self, batch_reply, valid_inds, output=None): @@ -2102,6 +2129,11 @@ def batch_act(self, observations): self.record_local_metric( 'ctrunc', AverageMetric.many(batch._context_truncate_rate) ) + if batch._context_truncated_length is not None: + self.record_local_metric( + 'context_average_tokens_truncated', + AverageMetric.many(batch._context_truncated_length), + ) if batch._label_original_length is not None: self.record_local_metric( 'llen', AverageMetric.many(batch._label_original_length) @@ -2109,6 +2141,11 @@ def batch_act(self, observations): self.record_local_metric( 'ltrunc', AverageMetric.many(batch._label_truncate_rate) ) + if batch._label_truncated_length is not None: + self.record_local_metric( + 'label_average_tokens_truncated', + AverageMetric.many(batch._label_truncated_length), + ) self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize)) diff --git a/tests/test_torch_agent.py b/tests/test_torch_agent.py index 1aa17362086..acff06ab0ab 100644 --- a/tests/test_torch_agent.py +++ b/tests/test_torch_agent.py @@ -1048,3 +1048,9 @@ def test_truncate_metrics(self): self.assertEqual(agent._local_metrics['ltrunc'][0].value(), 1.0) self.assertEqual(agent._local_metrics['clen'][0].value(), 9) self.assertEqual(agent._local_metrics['llen'][0].value(), 11) + self.assertEqual( + agent._local_metrics['context_average_tokens_truncated'][0].value(), 4 + ) + self.assertEqual( + agent._local_metrics['label_average_tokens_truncated'][0].value(), 6 + )