From 36273530c0cb332192df2807481970fd27fc5a67 Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Fri, 19 Feb 2021 09:40:49 -0800 Subject: [PATCH 1/9] add tr stats --- parlai/core/torch_agent.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 87af63906a9..f8cad93f2b8 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -43,6 +43,7 @@ Adafactor, ) from parlai.core.metrics import ( + AverageMetric, Metrics, Metric, GlobalAverageMetric, @@ -1371,9 +1372,11 @@ def _set_text_vec(self, obs, history, truncate): # check truncation if obs.get('text_vec') is not None: truncate_left = not self.history_reversed + text_length = len(obs['text_vec']) truncated_vec = self._check_truncate( obs['text_vec'], truncate, truncate_left ) + obs['if_text_truncate'] = text_length != len(truncated_vec) obs.force_set('text_vec', torch.LongTensor(truncated_vec)) return obs @@ -1994,6 +1997,14 @@ def batch_act(self, observations): # check if there are any labels available, if so we will train on them self.is_training = any('labels' in obs for obs in observations) + # check if we should add truncate stats + if any('if_text_truncate' in obs for obs in observations): + self.record_local_metric( + 'tr', + AverageMetric.many( + [float(obs['if_text_truncate']) for obs in observations] + ), + ) # create a batch from the vectors batch = self.batchify(observations) self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize)) From c21d75f8225a36c6c4d97d8bd0c483f81e23cc8d Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Fri, 19 Feb 2021 09:52:07 -0800 Subject: [PATCH 2/9] change trigger condition from any to all --- parlai/core/torch_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index f8cad93f2b8..c2b0613731e 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1998,7 +1998,7 @@ def batch_act(self, observations): self.is_training = any('labels' in obs for obs in observations) # check if we should add truncate stats - if any('if_text_truncate' in obs for obs in observations): + if all('if_text_truncate' in obs for obs in observations): self.record_local_metric( 'tr', AverageMetric.many( From ce3066e4eeac31c0165c70b1c8b3da24e66d0752 Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Fri, 19 Feb 2021 11:23:26 -0800 Subject: [PATCH 3/9] bug fix --- parlai/core/torch_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index c2b0613731e..8d12d55368c 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1376,7 +1376,7 @@ def _set_text_vec(self, obs, history, truncate): truncated_vec = self._check_truncate( obs['text_vec'], truncate, truncate_left ) - obs['if_text_truncate'] = text_length != len(truncated_vec) + obs.force_set('if_text_truncate', text_length != len(truncated_vec)) obs.force_set('text_vec', torch.LongTensor(truncated_vec)) return obs From eb1b6b9d09e43c4d6e8290bdd706be382831cffe Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Mon, 1 Mar 2021 12:28:07 -0800 Subject: [PATCH 4/9] apply comments --- parlai/core/torch_agent.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 8d12d55368c..48671c8ddc5 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1376,6 +1376,7 @@ def _set_text_vec(self, obs, history, truncate): truncated_vec = self._check_truncate( obs['text_vec'], truncate, truncate_left ) + obs.force_set('ori_text_length', text_length) obs.force_set('if_text_truncate', text_length != len(truncated_vec)) obs.force_set('text_vec', torch.LongTensor(truncated_vec)) @@ -1400,7 +1401,10 @@ def _set_label_vec(self, obs, add_start, add_end, truncate): elif label_type + '_vec' in obs: # check truncation of pre-computed vector + label_length = len(obs[label_type + '_vec']) truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate) + obs.force_set('ori_label_length', label_length) + obs.force_set('if_label_truncate', label_length > len(truncated_vec)) obs.force_set(label_type + '_vec', torch.LongTensor(truncated_vec)) else: # pick one label if there are multiple @@ -2000,11 +2004,29 @@ def batch_act(self, observations): # check if we should add truncate stats if all('if_text_truncate' in obs for obs in observations): self.record_local_metric( - 'tr', + 'truncate', AverageMetric.many( [float(obs['if_text_truncate']) for obs in observations] ), ) + if all('if_label_truncate' in obs for obs in observations): + self.record_local_metric( + 'label_truncate', + AverageMetric.many( + [float(obs['if_label_truncate']) for obs in observations] + ), + ) + if all('ori_text_length' in obs for obs in observations): + self.record_local_metric( + 'text_length', + AverageMetric.many([obs['ori_text_length'] for obs in observations]), + ) + if all('ori_label_length' in obs for obs in observations): + self.record_local_metric( + 'label_length', + AverageMetric.many([obs['ori_label_length'] for obs in observations]), + ) + # create a batch from the vectors batch = self.batchify(observations) self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize)) From e8a1b9ca10176662217bee719a49cbd536487b2b Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Fri, 5 Mar 2021 13:33:44 -0800 Subject: [PATCH 5/9] add label text stats --- parlai/core/torch_agent.py | 40 ++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 15 deletions(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 48671c8ddc5..eb1426b55a8 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1329,9 +1329,11 @@ def _vectorize_text( """ vec = self.dict.txt2vec(text) vec = self._add_start_end_tokens(vec, add_start, add_end) + original_length = len(vec) vec = self._check_truncate(vec, truncate, truncate_left) + if_truncated = original_length > len(vec) tensor = torch.LongTensor(vec) - return tensor + return tensor, original_length, if_truncated def _check_truncate(self, vec, truncate, truncate_left=False): """ @@ -1376,8 +1378,8 @@ def _set_text_vec(self, obs, history, truncate): truncated_vec = self._check_truncate( obs['text_vec'], truncate, truncate_left ) - obs.force_set('ori_text_length', text_length) - obs.force_set('if_text_truncate', text_length != len(truncated_vec)) + obs.force_set('original_context_length', text_length) + obs.force_set('if_context_truncate', text_length != len(truncated_vec)) obs.force_set('text_vec', torch.LongTensor(truncated_vec)) return obs @@ -1401,16 +1403,20 @@ def _set_label_vec(self, obs, add_start, add_end, truncate): elif label_type + '_vec' in obs: # check truncation of pre-computed vector - label_length = len(obs[label_type + '_vec']) + vec_label_length = len(obs[label_type + '_vec']) truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate) - obs.force_set('ori_label_length', label_length) - obs.force_set('if_label_truncate', label_length > len(truncated_vec)) + obs.force_set('original_label_length', vec_label_length) + obs.force_set('if_label_truncate', vec_label_length > len(truncated_vec)) obs.force_set(label_type + '_vec', torch.LongTensor(truncated_vec)) else: # pick one label if there are multiple lbls = obs[label_type] label = lbls[0] if len(lbls) == 1 else self.random.choice(lbls) - vec_label = self._vectorize_text(label, add_start, add_end, truncate, False) + vec_label, vec_label_length, vec_label_truncated = self._vectorize_text( + label, add_start, add_end, truncate, False + ) + obs.force_set('original_label_length', vec_label_length) + obs.force_set('if_label_truncate', vec_label_truncated) obs[label_type + '_vec'] = vec_label obs[label_type + '_choice'] = label @@ -2002,11 +2008,11 @@ def batch_act(self, observations): self.is_training = any('labels' in obs for obs in observations) # check if we should add truncate stats - if all('if_text_truncate' in obs for obs in observations): + if all('if_context_truncate' in obs for obs in observations): self.record_local_metric( - 'truncate', + 'context_truncate', AverageMetric.many( - [float(obs['if_text_truncate']) for obs in observations] + [float(obs['if_context_truncate']) for obs in observations] ), ) if all('if_label_truncate' in obs for obs in observations): @@ -2016,15 +2022,19 @@ def batch_act(self, observations): [float(obs['if_label_truncate']) for obs in observations] ), ) - if all('ori_text_length' in obs for obs in observations): + if all('original_context_length' in obs for obs in observations): self.record_local_metric( - 'text_length', - AverageMetric.many([obs['ori_text_length'] for obs in observations]), + 'context_length', + AverageMetric.many( + [obs['original_context_length'] for obs in observations] + ), ) - if all('ori_label_length' in obs for obs in observations): + if all('original_label_length' in obs for obs in observations): self.record_local_metric( 'label_length', - AverageMetric.many([obs['ori_label_length'] for obs in observations]), + AverageMetric.many( + [obs['original_label_length'] for obs in observations] + ), ) # create a batch from the vectors From d8c1397dc27f3bab07f4b5cc3f5fb76743f82da4 Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Fri, 5 Mar 2021 19:04:02 -0800 Subject: [PATCH 6/9] add wrapper --- parlai/core/torch_agent.py | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index eb1426b55a8..8c824d4144d 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1305,7 +1305,7 @@ def _v2t(self, vec): new_vec.append(i) return self.dict.vec2txt(new_vec) - def _vectorize_text( + def _vectorize_text_with_truncate_stats( self, text, add_start=False, add_end=False, truncate=None, truncate_left=True ): """ @@ -1335,6 +1335,30 @@ def _vectorize_text( tensor = torch.LongTensor(vec) return tensor, original_length, if_truncated + def _vectorize_text( + self, text, add_start=False, add_end=False, truncate=None, truncate_left=True + ): + """ + Return vector from text. + + :param text: + String to vectorize. + + :param add_start: + Add the start token to the front of the tensor. + + :param add_end: + Add the end token to the end of the tensor. + + :param truncate: + Truncate to this many tokens >= 0, or None. + + :param truncate_left: + Truncate from the left side (keep the rightmost tokens). You + probably want this True for inputs, False for targets. + """ + return self._vectorize_text_with_truncate_stats(**locals())[0] + def _check_truncate(self, vec, truncate, truncate_left=False): """ Check that vector is truncated correctly. @@ -1412,7 +1436,7 @@ def _set_label_vec(self, obs, add_start, add_end, truncate): # pick one label if there are multiple lbls = obs[label_type] label = lbls[0] if len(lbls) == 1 else self.random.choice(lbls) - vec_label, vec_label_length, vec_label_truncated = self._vectorize_text( + vec_label, vec_label_length, vec_label_truncated = self._vectorize_text_with_truncate_stats( label, add_start, add_end, truncate, False ) obs.force_set('original_label_length', vec_label_length) From 767ad3d9fa0d19e8d36c22f8df3eba5e18e4f032 Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Fri, 5 Mar 2021 21:48:14 -0800 Subject: [PATCH 7/9] fix unit test --- parlai/core/torch_agent.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 8c824d4144d..d32606557cb 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1357,7 +1357,9 @@ def _vectorize_text( Truncate from the left side (keep the rightmost tokens). You probably want this True for inputs, False for targets. """ - return self._vectorize_text_with_truncate_stats(**locals())[0] + return self._vectorize_text_with_truncate_stats( + text, add_start, add_end, truncate, truncate_left + )[0] def _check_truncate(self, vec, truncate, truncate_left=False): """ From 0c926d2107aba07babdd6dbbac871d9811247db5 Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Mon, 8 Mar 2021 11:21:26 -0800 Subject: [PATCH 8/9] add test --- tests/test_torch_agent.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/test_torch_agent.py b/tests/test_torch_agent.py index 513e80b0d16..144ae2c70b2 100644 --- a/tests/test_torch_agent.py +++ b/tests/test_torch_agent.py @@ -1051,3 +1051,17 @@ def get_opt(init_mf, mf): init_model_file, is_finetune = agent._get_init_model(popt, None) self.assertEqual(init_model_file, '{}.checkpoint'.format(mf)) self.assertFalse(is_finetune) + + def test_truncate_metrics(self): + agent = get_agent(model='test_agents/unigram', truncate=5) + obs = { + 'text': "I'll be back. I'll be back. I'll be back.", + 'labels': ["I'll be back. I'll be back. I'll be back."], + 'episode_done': True, + } + obs = agent.observe(obs) + agent.act() + self.assertEqual(agent._local_metrics['context_truncate'][0].value(), 1.0) + self.assertEqual(agent._local_metrics['label_truncate'][0].value(), 1.0) + self.assertEqual(agent._local_metrics['context_length'][0].value(), 9) + self.assertEqual(agent._local_metrics['label_length'][0].value(), 11) From 802cec94a4ac027c403b243c8b61d2b967494909 Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Mon, 8 Mar 2021 11:26:19 -0800 Subject: [PATCH 9/9] update doc --- docs/source/tutorial_metrics.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/tutorial_metrics.md b/docs/source/tutorial_metrics.md index f94b8cbabc6..3a2d9bb665e 100644 --- a/docs/source/tutorial_metrics.md +++ b/docs/source/tutorial_metrics.md @@ -422,6 +422,8 @@ please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issu | `accuracy` | Exact match text accuracy | | `bleu-4` | BLEU-4 of the generation, under a standardized (model-independent) tokenizer | | `clip` | Fraction of batches with clipped gradients | +|`context_truncate` | Ratio of samples' contexts being truncated per batch | +|`context_length` | Average length of context tokens per batch | | `ctpb` | Context tokens per batch | | `ctps` | Context tokens per second | | `exps` | Examples per second | @@ -433,6 +435,8 @@ please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issu | `interdistinct-1`, `interdictinct-2` | Fraction of n-grams unique across _all_ generations | | `intradistinct-1`, `intradictinct-2` | Fraction of n-grams unique _within_ each utterance | | `jga` | Joint Goal Accuracy | +|`label_length` | Average length of label tokens per batch | +|`label_truncate` | Ratio of samples' labels being truncated per batch | | `loss` | Loss | | `lr` | The most recent learning rate applied | | `ltpb` | Label tokens per batch |