Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add truncate stats to training logs #3458

Merged
merged 10 commits into from
Mar 8, 2021
Merged

Add truncate stats to training logs #3458

merged 10 commits into from
Mar 8, 2021

Conversation

dexterju27
Copy link
Contributor

@dexterju27 dexterju27 commented Feb 19, 2021

aPatch description
this adds a truncate stat to Torch Agent

Testing steps
Testing with existing training or evaluation script.
For example

 parlai em  -t wizard_of_wikipedia:Generator --prepend-gold-knowledge true  -mf ./model
13:32:30 | 0.1% complete (3 / 3,939), 0:00:10 elapsed, 3:46:03 eta
    accuracy    bleu-4  context_length  context_truncate  ctpb  ctps  exps  exs     f1  gpu_mem  label_length  label_truncate  loss       lr  ltpb  ltps   ppl  token_acc  total_train_updates   tpb  tps
           0 4.531e-11           184.3             .6667   103  29.9 .2903    3 .08522    .3847         31.33               0 2.693 1.51e-06 31.33 9.096 14.77      .4894               202171 134.3   39
13:32:40 | 0.2% complete (8 / 3,939), 0:00:20 elapsed, 2:50:51 eta
    accuracy    bleu-4  context_length  context_truncate  ctpb  ctps  exps  exs     f1  gpu_mem  label_length  label_truncate  loss       lr  ltpb  ltps   ppl  token_acc  total_train_updates   tpb   tps
           0 2.994e-09           227.2             .7500 111.8 42.86 .3835    8 .06685    .3337          26.5               0 2.724 1.51e-06  26.5 10.16 15.25      .4481               202171 138.2 53.02
13:32:55 | 0.4% complete (17 / 3,939), 0:00:35 elapsed, 2:17:40 eta
    accuracy  bleu-4  context_length  context_truncate  ctpb  ctps  exps  exs    f1  gpu_mem  label_length  label_truncate  loss       lr  ltpb  ltps  ppl  token_acc  total_train_updates   tpb   tps
           0  .01764           210.3             .6471 108.5 51.54 .4749   17 .1281    .3333         25.94               0  2.46 1.51e-06 25.94 12.32 11.7      .5125               202171 134.5 63.86
13:33:08 | 0.5% complete (20 / 3,939), 0:00:48 elapsed, 2:39:52 eta
    accuracy  bleu-4  context_length  context_truncate  ctpb  ctps  exps  exs    f1  gpu_mem  label_length  label_truncate  loss       lr  ltpb  ltps   ppl  token_acc  total_train_updates   tpb   tps
           0  .01499           216.1             .7000 111.5 45.54 .4086   20 .1383    .3339          27.7               0 2.523 1.51e-06  27.7 11.32 12.47      .4856               202171 139.2 56.86
13:33:20 | 0.6% complete (23 / 3,939), 0:01:01 elapsed, 2:54:19 eta
    accuracy  bleu-4  context_length  context_truncate  ctpb  ctps  exps  exs    f1  gpu_mem  label_length  label_truncate  loss       lr  ltpb  ltps   ppl  token_acc  total_train_updates   tpb   tps
           0  .01304           208.9             .6957 111.3 41.68 .3744   23 .1272    .3339         27.48               0 2.401 1.51e-06 27.48 10.29 11.03      .5127               202171 138.8 51.97

Other information

@@ -1983,6 +1986,14 @@ def batch_act(self, observations):
# check if there are any labels available, if so we will train on them
self.is_training = any('labels' in obs for obs in observations)

# check if we should add truncate stats
if all('if_text_truncate' in obs for obs in observations):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought you'd also set one for average length pre-truncated

Also what about label truncate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, to me that is more useful --> if every example is 1-2 tokens beyond truncation, it'd be quite misleading to see that the truncate percentage is very high, when in reality it's not nearly as destructive as thought

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label truncate is tricky to add, as I don't want to touch the return value of _vectorize_text. _set_label_vec calls '''_vectorize_text''' instead doing the truncation itself.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only one re-definition of _vectorize_text in our entire code-base, externally and internally. I think you can upgrade its return value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might misunderstood you.
This function is called everywhere?

parlai/core/torch_agent.py Outdated Show resolved Hide resolved
parlai/core/torch_agent.py Outdated Show resolved Hide resolved
truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate)
obs.force_set('ori_label_length', label_length)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use long, verbose names for these. there's no reason not to

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also a bit wondering about whether this will hit the fan in the case of a cached text_vec. I don't think we do use caching anymore but this kinda looks like it.

@@ -1983,6 +1986,14 @@ def batch_act(self, observations):
# check if there are any labels available, if so we will train on them
self.is_training = any('labels' in obs for obs in observations)

# check if we should add truncate stats
if all('if_text_truncate' in obs for obs in observations):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's only one re-definition of _vectorize_text in our entire code-base, externally and internally. I think you can upgrade its return value.

Truncate from the left side (keep the rightmost tokens). You
probably want this True for inputs, False for targets.
"""
return self._vectorize_text_with_truncate_stats(**locals())[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uh i'm not sure locals is the right choice here

Copy link
Contributor

@stephenroller stephenroller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lol sry fix unit tests first

@dexterju27
Copy link
Contributor Author

Comment on lines +2036 to +2065
# check if we should add truncate stats
if all('if_context_truncate' in obs for obs in observations):
self.record_local_metric(
'context_truncate',
AverageMetric.many(
[float(obs['if_context_truncate']) for obs in observations]
),
)
if all('if_label_truncate' in obs for obs in observations):
self.record_local_metric(
'label_truncate',
AverageMetric.many(
[float(obs['if_label_truncate']) for obs in observations]
),
)
if all('original_context_length' in obs for obs in observations):
self.record_local_metric(
'context_length',
AverageMetric.many(
[obs['original_context_length'] for obs in observations]
),
)
if all('original_label_length' in obs for obs in observations):
self.record_local_metric(
'label_length',
AverageMetric.many(
[obs['original_label_length'] for obs in observations]
),
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a question about these - won't this check fail on e.g. the last batch of training/evaluation if there are padding examples?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDTS.

It's going to have to change in another patch anyway.

Copy link
Contributor

@stephenroller stephenroller left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test which instantiates a test_agent/unigram, does a fixed observe/act, and checks the resulting truncate metric please

@stephenroller
Copy link
Contributor

#3498 merge master into this branch and add your new metrics to this list too please

@dexterju27 dexterju27 merged commit 34eb8fb into master Mar 8, 2021
@dexterju27 dexterju27 deleted the add-text-tr-stats branch March 8, 2021 21:09
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants