From ce753c717958d91d2d5598331eaa1b6476a06c5e Mon Sep 17 00:00:00 2001 From: Dexter Ju Date: Mon, 1 Mar 2021 12:28:07 -0800 Subject: [PATCH] apply comments --- parlai/core/torch_agent.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/parlai/core/torch_agent.py b/parlai/core/torch_agent.py index 697278b1ce7..0448af26db1 100644 --- a/parlai/core/torch_agent.py +++ b/parlai/core/torch_agent.py @@ -1373,6 +1373,7 @@ def _set_text_vec(self, obs, history, truncate): truncated_vec = self._check_truncate( obs['text_vec'], truncate, truncate_left ) + obs.force_set('ori_text_length', text_length) obs.force_set('if_text_truncate', text_length != len(truncated_vec)) obs.force_set('text_vec', torch.LongTensor(truncated_vec)) @@ -1397,7 +1398,10 @@ def _set_label_vec(self, obs, add_start, add_end, truncate): elif label_type + '_vec' in obs: # check truncation of pre-computed vector + label_length = len(obs[label_type + '_vec']) truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate) + obs.force_set('ori_label_length', label_length) + obs.force_set('if_label_truncate', label_length > len(truncated_vec)) obs.force_set(label_type + '_vec', torch.LongTensor(truncated_vec)) else: # pick one label if there are multiple @@ -1989,11 +1993,29 @@ def batch_act(self, observations): # check if we should add truncate stats if all('if_text_truncate' in obs for obs in observations): self.record_local_metric( - 'tr', + 'truncate', AverageMetric.many( [float(obs['if_text_truncate']) for obs in observations] ), ) + if all('if_label_truncate' in obs for obs in observations): + self.record_local_metric( + 'label_truncate', + AverageMetric.many( + [float(obs['if_label_truncate']) for obs in observations] + ), + ) + if all('ori_text_length' in obs for obs in observations): + self.record_local_metric( + 'text_length', + AverageMetric.many([obs['ori_text_length'] for obs in observations]), + ) + if all('ori_label_length' in obs for obs in observations): + self.record_local_metric( + 'label_length', + AverageMetric.many([obs['ori_label_length'] for obs in observations]), + ) + # create a batch from the vectors batch = self.batchify(observations) self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize))