Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add truncate stats to training logs #3458

Merged
merged 10 commits into from
Mar 8, 2021
Merged
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 70 additions & 3 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Adafactor,
)
from parlai.core.metrics import (
AverageMetric,
Metrics,
Metric,
GlobalAverageMetric,
Expand Down Expand Up @@ -1304,7 +1305,7 @@ def _v2t(self, vec):
new_vec.append(i)
return self.dict.vec2txt(new_vec)

def _vectorize_text(
def _vectorize_text_with_truncate_stats(
self, text, add_start=False, add_end=False, truncate=None, truncate_left=True
):
"""
Expand All @@ -1328,9 +1329,35 @@ def _vectorize_text(
"""
vec = self.dict.txt2vec(text)
vec = self._add_start_end_tokens(vec, add_start, add_end)
original_length = len(vec)
vec = self._check_truncate(vec, truncate, truncate_left)
if_truncated = original_length > len(vec)
tensor = torch.LongTensor(vec)
return tensor
return tensor, original_length, if_truncated

def _vectorize_text(
self, text, add_start=False, add_end=False, truncate=None, truncate_left=True
):
"""
Return vector from text.

:param text:
String to vectorize.

:param add_start:
Add the start token to the front of the tensor.

:param add_end:
Add the end token to the end of the tensor.

:param truncate:
Truncate to this many tokens >= 0, or None.

:param truncate_left:
Truncate from the left side (keep the rightmost tokens). You
probably want this True for inputs, False for targets.
"""
return self._vectorize_text_with_truncate_stats(**locals())[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh i'm not sure locals is the right choice here


def _check_truncate(self, vec, truncate, truncate_left=False):
"""
Expand Down Expand Up @@ -1371,9 +1398,12 @@ def _set_text_vec(self, obs, history, truncate):
# check truncation
if obs.get('text_vec') is not None:
truncate_left = not self.history_reversed
text_length = len(obs['text_vec'])
truncated_vec = self._check_truncate(
obs['text_vec'], truncate, truncate_left
)
obs.force_set('original_context_length', text_length)
obs.force_set('if_context_truncate', text_length != len(truncated_vec))
obs.force_set('text_vec', torch.LongTensor(truncated_vec))

return obs
Expand All @@ -1397,13 +1427,20 @@ def _set_label_vec(self, obs, add_start, add_end, truncate):

elif label_type + '_vec' in obs:
# check truncation of pre-computed vector
vec_label_length = len(obs[label_type + '_vec'])
truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate)
obs.force_set('original_label_length', vec_label_length)
obs.force_set('if_label_truncate', vec_label_length > len(truncated_vec))
obs.force_set(label_type + '_vec', torch.LongTensor(truncated_vec))
else:
# pick one label if there are multiple
lbls = obs[label_type]
label = lbls[0] if len(lbls) == 1 else self.random.choice(lbls)
vec_label = self._vectorize_text(label, add_start, add_end, truncate, False)
vec_label, vec_label_length, vec_label_truncated = self._vectorize_text_with_truncate_stats(
label, add_start, add_end, truncate, False
)
obs.force_set('original_label_length', vec_label_length)
obs.force_set('if_label_truncate', vec_label_truncated)
obs[label_type + '_vec'] = vec_label
obs[label_type + '_choice'] = label

Expand Down Expand Up @@ -1994,6 +2031,36 @@ def batch_act(self, observations):
# check if there are any labels available, if so we will train on them
self.is_training = any('labels' in obs for obs in observations)

# check if we should add truncate stats
if all('if_context_truncate' in obs for obs in observations):
self.record_local_metric(
'context_truncate',
AverageMetric.many(
[float(obs['if_context_truncate']) for obs in observations]
),
)
if all('if_label_truncate' in obs for obs in observations):
self.record_local_metric(
'label_truncate',
AverageMetric.many(
[float(obs['if_label_truncate']) for obs in observations]
),
)
if all('original_context_length' in obs for obs in observations):
self.record_local_metric(
'context_length',
AverageMetric.many(
[obs['original_context_length'] for obs in observations]
),
)
if all('original_label_length' in obs for obs in observations):
self.record_local_metric(
'label_length',
AverageMetric.many(
[obs['original_label_length'] for obs in observations]
),
)

Comment on lines +2036 to +2065
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a question about these - won't this check fail on e.g. the last batch of training/evaluation if there are padding examples?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDTS.

It's going to have to change in another patch anyway.

# create a batch from the vectors
batch = self.batchify(observations)
self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize))
Expand Down