Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
apply comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dexterju27 committed Mar 1, 2021
1 parent ff6796e commit ce753c7
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,7 @@ def _set_text_vec(self, obs, history, truncate):
truncated_vec = self._check_truncate(
obs['text_vec'], truncate, truncate_left
)
obs.force_set('ori_text_length', text_length)
obs.force_set('if_text_truncate', text_length != len(truncated_vec))
obs.force_set('text_vec', torch.LongTensor(truncated_vec))

Expand All @@ -1397,7 +1398,10 @@ def _set_label_vec(self, obs, add_start, add_end, truncate):

elif label_type + '_vec' in obs:
# check truncation of pre-computed vector
label_length = len(obs[label_type + '_vec'])
truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate)
obs.force_set('ori_label_length', label_length)
obs.force_set('if_label_truncate', label_length > len(truncated_vec))
obs.force_set(label_type + '_vec', torch.LongTensor(truncated_vec))
else:
# pick one label if there are multiple
Expand Down Expand Up @@ -1989,11 +1993,29 @@ def batch_act(self, observations):
# check if we should add truncate stats
if all('if_text_truncate' in obs for obs in observations):
self.record_local_metric(
'tr',
'truncate',
AverageMetric.many(
[float(obs['if_text_truncate']) for obs in observations]
),
)
if all('if_label_truncate' in obs for obs in observations):
self.record_local_metric(
'label_truncate',
AverageMetric.many(
[float(obs['if_label_truncate']) for obs in observations]
),
)
if all('ori_text_length' in obs for obs in observations):
self.record_local_metric(
'text_length',
AverageMetric.many([obs['ori_text_length'] for obs in observations]),
)
if all('ori_label_length' in obs for obs in observations):
self.record_local_metric(
'label_length',
AverageMetric.many([obs['ori_label_length'] for obs in observations]),
)

# create a batch from the vectors
batch = self.batchify(observations)
self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize))
Expand Down

0 comments on commit ce753c7

Please sign in to comment.