-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Conversation
parlai/core/torch_agent.py
Outdated
@@ -1983,6 +1986,14 @@ def batch_act(self, observations): | |||
# check if there are any labels available, if so we will train on them | |||
self.is_training = any('labels' in obs for obs in observations) | |||
|
|||
# check if we should add truncate stats | |||
if all('if_text_truncate' in obs for obs in observations): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought you'd also set one for average length pre-truncated
Also what about label truncate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, to me that is more useful --> if every example is 1-2 tokens beyond truncation, it'd be quite misleading to see that the truncate percentage is very high, when in reality it's not nearly as destructive as thought
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
label truncate is tricky to add, as I don't want to touch the return value of _vectorize_text. _set_label_vec
calls '''_vectorize_text''' instead doing the truncation itself.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's only one re-definition of _vectorize_text
in our entire code-base, externally and internally. I think you can upgrade its return value.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I might misunderstood you.
This function is called everywhere?
parlai/core/torch_agent.py
Outdated
truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate) | ||
obs.force_set('ori_label_length', label_length) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please use long, verbose names for these. there's no reason not to
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also a bit wondering about whether this will hit the fan in the case of a cached text_vec. I don't think we do use caching anymore but this kinda looks like it.
parlai/core/torch_agent.py
Outdated
@@ -1983,6 +1986,14 @@ def batch_act(self, observations): | |||
# check if there are any labels available, if so we will train on them | |||
self.is_training = any('labels' in obs for obs in observations) | |||
|
|||
# check if we should add truncate stats | |||
if all('if_text_truncate' in obs for obs in observations): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's only one re-definition of _vectorize_text
in our entire code-base, externally and internally. I think you can upgrade its return value.
ce753c7
to
e8a1b9c
Compare
parlai/core/torch_agent.py
Outdated
Truncate from the left side (keep the rightmost tokens). You | ||
probably want this True for inputs, False for targets. | ||
""" | ||
return self._vectorize_text_with_truncate_stats(**locals())[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uh i'm not sure locals is the right choice here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lol sry fix unit tests first
I fixed all the tests except: https://app.circleci.com/pipelines/github/facebookresearch/ParlAI/8731/workflows/f41aca0b-9922-4e5e-9f58-64949fbc574e/jobs/71205 |
# check if we should add truncate stats | ||
if all('if_context_truncate' in obs for obs in observations): | ||
self.record_local_metric( | ||
'context_truncate', | ||
AverageMetric.many( | ||
[float(obs['if_context_truncate']) for obs in observations] | ||
), | ||
) | ||
if all('if_label_truncate' in obs for obs in observations): | ||
self.record_local_metric( | ||
'label_truncate', | ||
AverageMetric.many( | ||
[float(obs['if_label_truncate']) for obs in observations] | ||
), | ||
) | ||
if all('original_context_length' in obs for obs in observations): | ||
self.record_local_metric( | ||
'context_length', | ||
AverageMetric.many( | ||
[obs['original_context_length'] for obs in observations] | ||
), | ||
) | ||
if all('original_label_length' in obs for obs in observations): | ||
self.record_local_metric( | ||
'label_length', | ||
AverageMetric.many( | ||
[obs['original_label_length'] for obs in observations] | ||
), | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a question about these - won't this check fail on e.g. the last batch of training/evaluation if there are padding examples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IDTS.
It's going to have to change in another patch anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a test which instantiates a test_agent/unigram
, does a fixed observe/act, and checks the resulting truncate metric please
#3498 merge master into this branch and add your new metrics to this list too please |
aPatch description
this adds a truncate stat to Torch Agent
Testing steps
Testing with existing training or evaluation script.
For example
Other information