Skip to content
This repository was archived by the owner on Nov 3, 2023. It is now read-only.

Add truncate stats to training logs #3458

Merged
merged 10 commits into from
Mar 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/tutorial_metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -422,6 +422,8 @@ please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issu
| `accuracy` | Exact match text accuracy |
| `bleu-4` | BLEU-4 of the generation, under a standardized (model-independent) tokenizer |
| `clip` | Fraction of batches with clipped gradients |
|`context_truncate` | Ratio of samples' contexts being truncated per batch |
|`context_length` | Average length of context tokens per batch |
| `ctpb` | Context tokens per batch |
| `ctps` | Context tokens per second |
| `exps` | Examples per second |
Expand All @@ -433,6 +435,8 @@ please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issu
| `interdistinct-1`, `interdictinct-2` | Fraction of n-grams unique across _all_ generations |
| `intradistinct-1`, `intradictinct-2` | Fraction of n-grams unique _within_ each utterance |
| `jga` | Joint Goal Accuracy |
|`label_length` | Average length of label tokens per batch |
|`label_truncate` | Ratio of samples' labels being truncated per batch |
| `loss` | Loss |
| `lr` | The most recent learning rate applied |
| `ltpb` | Label tokens per batch |
Expand Down
75 changes: 72 additions & 3 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
Adafactor,
)
from parlai.core.metrics import (
AverageMetric,
Metrics,
Metric,
GlobalAverageMetric,
Expand Down Expand Up @@ -1304,7 +1305,7 @@ def _v2t(self, vec):
new_vec.append(i)
return self.dict.vec2txt(new_vec)

def _vectorize_text(
def _vectorize_text_with_truncate_stats(
self, text, add_start=False, add_end=False, truncate=None, truncate_left=True
):
"""
Expand All @@ -1328,9 +1329,37 @@ def _vectorize_text(
"""
vec = self.dict.txt2vec(text)
vec = self._add_start_end_tokens(vec, add_start, add_end)
original_length = len(vec)
vec = self._check_truncate(vec, truncate, truncate_left)
if_truncated = original_length > len(vec)
tensor = torch.LongTensor(vec)
return tensor
return tensor, original_length, if_truncated

def _vectorize_text(
self, text, add_start=False, add_end=False, truncate=None, truncate_left=True
):
"""
Return vector from text.

:param text:
String to vectorize.

:param add_start:
Add the start token to the front of the tensor.

:param add_end:
Add the end token to the end of the tensor.

:param truncate:
Truncate to this many tokens >= 0, or None.

:param truncate_left:
Truncate from the left side (keep the rightmost tokens). You
probably want this True for inputs, False for targets.
"""
return self._vectorize_text_with_truncate_stats(
text, add_start, add_end, truncate, truncate_left
)[0]

def _check_truncate(self, vec, truncate, truncate_left=False):
"""
Expand Down Expand Up @@ -1371,9 +1400,12 @@ def _set_text_vec(self, obs, history, truncate):
# check truncation
if obs.get('text_vec') is not None:
truncate_left = not self.history_reversed
text_length = len(obs['text_vec'])
truncated_vec = self._check_truncate(
obs['text_vec'], truncate, truncate_left
)
obs.force_set('original_context_length', text_length)
obs.force_set('if_context_truncate', text_length != len(truncated_vec))
obs.force_set('text_vec', torch.LongTensor(truncated_vec))

return obs
Expand All @@ -1397,13 +1429,20 @@ def _set_label_vec(self, obs, add_start, add_end, truncate):

elif label_type + '_vec' in obs:
# check truncation of pre-computed vector
vec_label_length = len(obs[label_type + '_vec'])
truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate)
obs.force_set('original_label_length', vec_label_length)
obs.force_set('if_label_truncate', vec_label_length > len(truncated_vec))
obs.force_set(label_type + '_vec', torch.LongTensor(truncated_vec))
else:
# pick one label if there are multiple
lbls = obs[label_type]
label = lbls[0] if len(lbls) == 1 else self.random.choice(lbls)
vec_label = self._vectorize_text(label, add_start, add_end, truncate, False)
vec_label, vec_label_length, vec_label_truncated = self._vectorize_text_with_truncate_stats(
label, add_start, add_end, truncate, False
)
obs.force_set('original_label_length', vec_label_length)
obs.force_set('if_label_truncate', vec_label_truncated)
obs[label_type + '_vec'] = vec_label
obs[label_type + '_choice'] = label

Expand Down Expand Up @@ -1994,6 +2033,36 @@ def batch_act(self, observations):
# check if there are any labels available, if so we will train on them
self.is_training = any('labels' in obs for obs in observations)

# check if we should add truncate stats
if all('if_context_truncate' in obs for obs in observations):
self.record_local_metric(
'context_truncate',
AverageMetric.many(
[float(obs['if_context_truncate']) for obs in observations]
),
)
if all('if_label_truncate' in obs for obs in observations):
self.record_local_metric(
'label_truncate',
AverageMetric.many(
[float(obs['if_label_truncate']) for obs in observations]
),
)
if all('original_context_length' in obs for obs in observations):
self.record_local_metric(
'context_length',
AverageMetric.many(
[obs['original_context_length'] for obs in observations]
),
)
if all('original_label_length' in obs for obs in observations):
self.record_local_metric(
'label_length',
AverageMetric.many(
[obs['original_label_length'] for obs in observations]
),
)

Comment on lines +2036 to +2065
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a question about these - won't this check fail on e.g. the last batch of training/evaluation if there are padding examples?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDTS.

It's going to have to change in another patch anyway.

# create a batch from the vectors
batch = self.batchify(observations)
self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize))
Expand Down
14 changes: 14 additions & 0 deletions tests/test_torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,3 +1051,17 @@ def get_opt(init_mf, mf):
init_model_file, is_finetune = agent._get_init_model(popt, None)
self.assertEqual(init_model_file, '{}.checkpoint'.format(mf))
self.assertFalse(is_finetune)

def test_truncate_metrics(self):
agent = get_agent(model='test_agents/unigram', truncate=5)
obs = {
'text': "I'll be back. I'll be back. I'll be back.",
'labels': ["I'll be back. I'll be back. I'll be back."],
'episode_done': True,
}
obs = agent.observe(obs)
agent.act()
self.assertEqual(agent._local_metrics['context_truncate'][0].value(), 1.0)
self.assertEqual(agent._local_metrics['label_truncate'][0].value(), 1.0)
self.assertEqual(agent._local_metrics['context_length'][0].value(), 9)
self.assertEqual(agent._local_metrics['label_length'][0].value(), 11)