diff --git a/docs/source/tutorial_metrics.md b/docs/source/tutorial_metrics.md index f94b8cbabc6..3a2d9bb665e 100644 --- a/docs/source/tutorial_metrics.md +++ b/docs/source/tutorial_metrics.md @@ -422,6 +422,8 @@ please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issu | `accuracy` | Exact match text accuracy | | `bleu-4` | BLEU-4 of the generation, under a standardized (model-independent) tokenizer | | `clip` | Fraction of batches with clipped gradients | +|`context_truncate` | Ratio of samples' contexts being truncated per batch | +|`context_length` | Average length of context tokens per batch | | `ctpb` | Context tokens per batch | | `ctps` | Context tokens per second | | `exps` | Examples per second | @@ -433,6 +435,8 @@ please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issu | `interdistinct-1`, `interdictinct-2` | Fraction of n-grams unique across _all_ generations | | `intradistinct-1`, `intradictinct-2` | Fraction of n-grams unique _within_ each utterance | | `jga` | Joint Goal Accuracy | +|`label_length` | Average length of label tokens per batch | +|`label_truncate` | Ratio of samples' labels being truncated per batch | | `loss` | Loss | | `lr` | The most recent learning rate applied | | `ltpb` | Label tokens per batch | diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 87af63906a9..d32606557cb 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -43,6 +43,7 @@ Adafactor, ) from parlai.core.metrics import ( + AverageMetric, Metrics, Metric, GlobalAverageMetric, @@ -1304,7 +1305,7 @@ def _v2t(self, vec): new_vec.append(i) return self.dict.vec2txt(new_vec) - def _vectorize_text( + def _vectorize_text_with_truncate_stats( self, text, add_start=False, add_end=False, truncate=None, truncate_left=True ): """ @@ -1328,9 +1329,37 @@ def _vectorize_text( """ vec = self.dict.txt2vec(text) vec = self._add_start_end_tokens(vec, add_start, add_end) + original_length = len(vec) vec = self._check_truncate(vec, truncate, truncate_left) + if_truncated = original_length > len(vec) tensor = torch.LongTensor(vec) - return tensor + return tensor, original_length, if_truncated + + def _vectorize_text( + self, text, add_start=False, add_end=False, truncate=None, truncate_left=True + ): + """ + Return vector from text. + + :param text: + String to vectorize. + + :param add_start: + Add the start token to the front of the tensor. + + :param add_end: + Add the end token to the end of the tensor. + + :param truncate: + Truncate to this many tokens >= 0, or None. + + :param truncate_left: + Truncate from the left side (keep the rightmost tokens). You + probably want this True for inputs, False for targets. + """ + return self._vectorize_text_with_truncate_stats( + text, add_start, add_end, truncate, truncate_left + )[0] def _check_truncate(self, vec, truncate, truncate_left=False): """ @@ -1371,9 +1400,12 @@ def _set_text_vec(self, obs, history, truncate): # check truncation if obs.get('text_vec') is not None: truncate_left = not self.history_reversed + text_length = len(obs['text_vec']) truncated_vec = self._check_truncate( obs['text_vec'], truncate, truncate_left ) + obs.force_set('original_context_length', text_length) + obs.force_set('if_context_truncate', text_length != len(truncated_vec)) obs.force_set('text_vec', torch.LongTensor(truncated_vec)) return obs @@ -1397,13 +1429,20 @@ def _set_label_vec(self, obs, add_start, add_end, truncate): elif label_type + '_vec' in obs: # check truncation of pre-computed vector + vec_label_length = len(obs[label_type + '_vec']) truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate) + obs.force_set('original_label_length', vec_label_length) + obs.force_set('if_label_truncate', vec_label_length > len(truncated_vec)) obs.force_set(label_type + '_vec', torch.LongTensor(truncated_vec)) else: # pick one label if there are multiple lbls = obs[label_type] label = lbls[0] if len(lbls) == 1 else self.random.choice(lbls) - vec_label = self._vectorize_text(label, add_start, add_end, truncate, False) + vec_label, vec_label_length, vec_label_truncated = self._vectorize_text_with_truncate_stats( + label, add_start, add_end, truncate, False + ) + obs.force_set('original_label_length', vec_label_length) + obs.force_set('if_label_truncate', vec_label_truncated) obs[label_type + '_vec'] = vec_label obs[label_type + '_choice'] = label @@ -1994,6 +2033,36 @@ def batch_act(self, observations): # check if there are any labels available, if so we will train on them self.is_training = any('labels' in obs for obs in observations) + # check if we should add truncate stats + if all('if_context_truncate' in obs for obs in observations): + self.record_local_metric( + 'context_truncate', + AverageMetric.many( + [float(obs['if_context_truncate']) for obs in observations] + ), + ) + if all('if_label_truncate' in obs for obs in observations): + self.record_local_metric( + 'label_truncate', + AverageMetric.many( + [float(obs['if_label_truncate']) for obs in observations] + ), + ) + if all('original_context_length' in obs for obs in observations): + self.record_local_metric( + 'context_length', + AverageMetric.many( + [obs['original_context_length'] for obs in observations] + ), + ) + if all('original_label_length' in obs for obs in observations): + self.record_local_metric( + 'label_length', + AverageMetric.many( + [obs['original_label_length'] for obs in observations] + ), + ) + # create a batch from the vectors batch = self.batchify(observations) self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize)) diff --git a/tests/test_torch_agent.py b/tests/test_torch_agent.py index 513e80b0d16..144ae2c70b2 100644 --- a/tests/test_torch_agent.py +++ b/tests/test_torch_agent.py @@ -1051,3 +1051,17 @@ def get_opt(init_mf, mf): init_model_file, is_finetune = agent._get_init_model(popt, None) self.assertEqual(init_model_file, '{}.checkpoint'.format(mf)) self.assertFalse(is_finetune) + + def test_truncate_metrics(self): + agent = get_agent(model='test_agents/unigram', truncate=5) + obs = { + 'text': "I'll be back. I'll be back. I'll be back.", + 'labels': ["I'll be back. I'll be back. I'll be back."], + 'episode_done': True, + } + obs = agent.observe(obs) + agent.act() + self.assertEqual(agent._local_metrics['context_truncate'][0].value(), 1.0) + self.assertEqual(agent._local_metrics['label_truncate'][0].value(), 1.0) + self.assertEqual(agent._local_metrics['context_length'][0].value(), 9) + self.assertEqual(agent._local_metrics['label_length'][0].value(), 11)