Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Add length of truncation to logs (#3508)
Browse files Browse the repository at this point in the history
* log average context tokens truncated

* label truncate length, sync naming

* update test

* update tutorial_metrics.md

* fix some weird artwork generated with black
  • Loading branch information
spencerp authored Mar 16, 2021
1 parent fe543b5 commit 333ab04
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/source/tutorial_metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issu
| `ctpb` | Context tokens per batch |
| `ctps` | Context tokens per second |
| `ctrunc` | Fraction of samples with some context truncation |
| `context_average_tokens_truncated` | Average length of context tokens truncated |

This comment has been minimized.

Copy link
@stephenroller

stephenroller Mar 17, 2021

Contributor

lol Uh this name is a bit brutal man..

This comment has been minimized.

Copy link
@stephenroller

stephenroller Mar 17, 2021

Contributor

how about cnumtrunc?

| `exps` | Examples per second |
| `exs` | Number of examples processed since last print |
| `f1` | Unigram F1 overlap, under a standardized (model-independent) tokenizer |
Expand All @@ -441,6 +442,7 @@ please [file an issue on GitHub](https://github.com/facebookresearch/ParlAI/issu
| `ltpb` | Label tokens per batch |
| `ltps` | Label tokens per second |
| `ltrunc` | Fraction of samples with some label truncation |
| `label_average_tokens_truncated` | Average length of label tokens truncated |
| `rouge-1`, `rouge-1`, `rouge-L` | ROUGE metrics |
| `token_acc` | Token-wise accuracy (generative only) |
| `token_em` | Utterance-level token accuracy. Roughly corresponds to perfection under greedy search (generative only) |
Expand Down
41 changes: 39 additions & 2 deletions parlai/core/torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,10 @@ class Batch(AttrDict):
image: Optional[List[Any]]
_context_original_length: Optional[torch.LongTensor]
_context_truncate_rate: Optional[torch.LongTensor]
_context_truncated_length: Optional[torch.LongTensor]
_label_original_length: Optional[torch.LongTensor]
_label_truncate_rate: Optional[torch.LongTensor]
_label_truncated_length: Optional[torch.LongTensor]

def __init__(
self,
Expand All @@ -131,8 +133,10 @@ def __init__(
image=None,
_context_original_length: Optional[torch.LongTensor] = None,
_context_truncate_rate: Optional[torch.LongTensor] = None,
_context_truncated_length: Optional[torch.LongTensor] = None,
_label_original_length: Optional[torch.LongTensor] = None,
_label_truncate_rate: Optional[torch.LongTensor] = None,
_label_truncated_length: Optional[torch.LongTensor] = None,
**kwargs,
):
super().__init__(
Expand All @@ -147,8 +151,10 @@ def __init__(
image=image,
_context_original_length=_context_original_length,
_context_truncate_rate=_context_truncate_rate,
_context_truncated_length=_context_truncated_length,
_label_original_length=_label_original_length,
_label_truncate_rate=_label_truncate_rate,
_label_truncated_length=_label_truncated_length,
**kwargs,
)

Expand Down Expand Up @@ -1445,6 +1451,9 @@ def _set_text_vec(self, obs, history, truncate):
)
obs.force_set('context_original_length', text_length)
obs.force_set('context_truncate_rate', text_length != len(truncated_vec))
obs.force_set(
'context_truncated_length', max(text_length - len(truncated_vec), 0)
)
obs.force_set('text_vec', torch.LongTensor(truncated_vec))

return obs
Expand Down Expand Up @@ -1472,6 +1481,9 @@ def _set_label_vec(self, obs, add_start, add_end, truncate):
truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate)
obs.force_set('label_original_length', vec_label_length)
obs.force_set('label_truncate_rate', vec_label_length > len(truncated_vec))
obs.force_set(
'label_truncated_length', max(vec_label_length - len(truncated_vec), 0)
)
obs.force_set(label_type + '_vec', torch.LongTensor(truncated_vec))
else:
# pick one label if there are multiple
Expand All @@ -1482,6 +1494,9 @@ def _set_label_vec(self, obs, add_start, add_end, truncate):
)
obs.force_set('label_original_length', vec_label_length)
obs.force_set('label_truncate_rate', vec_label_truncated)
obs.force_set(
'label_truncated_length', max(vec_label_length - len(vec_label), 0)
)
obs[label_type + '_vec'] = vec_label
obs[label_type + '_choice'] = label

Expand Down Expand Up @@ -1625,7 +1640,8 @@ def batchify(self, obs_batch, sort=False):
valid_inds, exs = zip(*valid_obs)

# TEXT
xs = x_lens = context_original_lengths = context_truncate_rate = None
xs = x_lens = context_original_lengths = None
context_truncate_rate = context_truncated_lengths = None
if any(ex.get('text_vec') is not None for ex in exs):
if any('context_original_length' in ex for ex in exs):
context_truncate_rate = torch.LongTensor(
Expand All @@ -1634,6 +1650,10 @@ def batchify(self, obs_batch, sort=False):
context_original_lengths = torch.LongTensor(
[ex.get('context_original_length', 0) for ex in exs]
)
if any('context_truncated_length' in ex for ex in exs):
context_truncated_lengths = torch.LongTensor(
[ex.get('context_truncated_length', 0) for ex in exs]
)
_xs = [ex.get('text_vec', self.EMPTY) for ex in exs]
xs, x_lens = self._pad_tensor(_xs)
if sort:
Expand All @@ -1646,7 +1666,8 @@ def batchify(self, obs_batch, sort=False):
labels_avail = any('labels_vec' in ex for ex in exs)
some_labels_avail = labels_avail or any('eval_labels_vec' in ex for ex in exs)

ys = y_lens = labels = label_original_lengths = label_truncate_rate = None
ys = y_lens = labels = label_original_lengths = None
label_truncate_rate = label_truncated_lengths = None
if some_labels_avail:
if any('label_original_length' in ex for ex in exs):
label_truncate_rate = torch.LongTensor(
Expand All @@ -1655,6 +1676,10 @@ def batchify(self, obs_batch, sort=False):
label_original_lengths = torch.LongTensor(
[ex.get('label_original_length', 0) for ex in exs]
)
if any('label_truncated_length' in ex for ex in exs):
label_truncated_lengths = torch.LongTensor(
[ex.get('label_truncated_length') for ex in exs]
)
field = 'labels' if labels_avail else 'eval_labels'

label_vecs = [ex.get(field + '_vec', self.EMPTY) for ex in exs]
Expand Down Expand Up @@ -1700,8 +1725,10 @@ def batchify(self, obs_batch, sort=False):
observations=exs if self.is_debug else None,
_context_original_length=context_original_lengths,
_context_truncate_rate=context_truncate_rate,
_context_truncated_length=context_truncated_lengths,
_label_original_length=label_original_lengths,
_label_truncate_rate=label_truncate_rate,
_label_truncated_length=label_truncated_lengths,
)

def match_batch(self, batch_reply, valid_inds, output=None):
Expand Down Expand Up @@ -2102,13 +2129,23 @@ def batch_act(self, observations):
self.record_local_metric(
'ctrunc', AverageMetric.many(batch._context_truncate_rate)
)
if batch._context_truncated_length is not None:
self.record_local_metric(
'context_average_tokens_truncated',
AverageMetric.many(batch._context_truncated_length),
)
if batch._label_original_length is not None:
self.record_local_metric(
'llen', AverageMetric.many(batch._label_original_length)
)
self.record_local_metric(
'ltrunc', AverageMetric.many(batch._label_truncate_rate)
)
if batch._label_truncated_length is not None:
self.record_local_metric(
'label_average_tokens_truncated',
AverageMetric.many(batch._label_truncated_length),
)

self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize))

Expand Down
6 changes: 6 additions & 0 deletions tests/test_torch_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,3 +1048,9 @@ def test_truncate_metrics(self):
self.assertEqual(agent._local_metrics['ltrunc'][0].value(), 1.0)
self.assertEqual(agent._local_metrics['clen'][0].value(), 9)
self.assertEqual(agent._local_metrics['llen'][0].value(), 11)
self.assertEqual(
agent._local_metrics['context_average_tokens_truncated'][0].value(), 4
)
self.assertEqual(
agent._local_metrics['label_average_tokens_truncated'][0].value(), 6
)

0 comments on commit 333ab04

Please sign in to comment.