This repository has been archived by the owner on Nov 3, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Add truncate stats to training logs #3458
Merged
Merged
Changes from 6 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
3627353
add tr stats
dexterju27 c21d75f
change trigger condition from any to all
dexterju27 ce3066e
bug fix
dexterju27 eb1b6b9
apply comments
dexterju27 e8a1b9c
add label text stats
dexterju27 d8c1397
add wrapper
dexterju27 767ad3d
fix unit test
dexterju27 3c211a1
Merge remote-tracking branch 'origin/master' into add-text-tr-stats
dexterju27 0c926d2
add test
dexterju27 802cec9
update doc
dexterju27 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,6 +43,7 @@ | |
Adafactor, | ||
) | ||
from parlai.core.metrics import ( | ||
AverageMetric, | ||
Metrics, | ||
Metric, | ||
GlobalAverageMetric, | ||
|
@@ -1304,7 +1305,7 @@ def _v2t(self, vec): | |
new_vec.append(i) | ||
return self.dict.vec2txt(new_vec) | ||
|
||
def _vectorize_text( | ||
def _vectorize_text_with_truncate_stats( | ||
self, text, add_start=False, add_end=False, truncate=None, truncate_left=True | ||
): | ||
""" | ||
|
@@ -1328,9 +1329,35 @@ def _vectorize_text( | |
""" | ||
vec = self.dict.txt2vec(text) | ||
vec = self._add_start_end_tokens(vec, add_start, add_end) | ||
original_length = len(vec) | ||
vec = self._check_truncate(vec, truncate, truncate_left) | ||
if_truncated = original_length > len(vec) | ||
tensor = torch.LongTensor(vec) | ||
return tensor | ||
return tensor, original_length, if_truncated | ||
|
||
def _vectorize_text( | ||
self, text, add_start=False, add_end=False, truncate=None, truncate_left=True | ||
): | ||
""" | ||
Return vector from text. | ||
|
||
:param text: | ||
String to vectorize. | ||
|
||
:param add_start: | ||
Add the start token to the front of the tensor. | ||
|
||
:param add_end: | ||
Add the end token to the end of the tensor. | ||
|
||
:param truncate: | ||
Truncate to this many tokens >= 0, or None. | ||
|
||
:param truncate_left: | ||
Truncate from the left side (keep the rightmost tokens). You | ||
probably want this True for inputs, False for targets. | ||
""" | ||
return self._vectorize_text_with_truncate_stats(**locals())[0] | ||
|
||
def _check_truncate(self, vec, truncate, truncate_left=False): | ||
""" | ||
|
@@ -1371,9 +1398,12 @@ def _set_text_vec(self, obs, history, truncate): | |
# check truncation | ||
if obs.get('text_vec') is not None: | ||
truncate_left = not self.history_reversed | ||
text_length = len(obs['text_vec']) | ||
truncated_vec = self._check_truncate( | ||
obs['text_vec'], truncate, truncate_left | ||
) | ||
obs.force_set('original_context_length', text_length) | ||
obs.force_set('if_context_truncate', text_length != len(truncated_vec)) | ||
obs.force_set('text_vec', torch.LongTensor(truncated_vec)) | ||
|
||
return obs | ||
|
@@ -1397,13 +1427,20 @@ def _set_label_vec(self, obs, add_start, add_end, truncate): | |
|
||
elif label_type + '_vec' in obs: | ||
# check truncation of pre-computed vector | ||
vec_label_length = len(obs[label_type + '_vec']) | ||
truncated_vec = self._check_truncate(obs[label_type + '_vec'], truncate) | ||
obs.force_set('original_label_length', vec_label_length) | ||
obs.force_set('if_label_truncate', vec_label_length > len(truncated_vec)) | ||
obs.force_set(label_type + '_vec', torch.LongTensor(truncated_vec)) | ||
else: | ||
# pick one label if there are multiple | ||
lbls = obs[label_type] | ||
label = lbls[0] if len(lbls) == 1 else self.random.choice(lbls) | ||
vec_label = self._vectorize_text(label, add_start, add_end, truncate, False) | ||
vec_label, vec_label_length, vec_label_truncated = self._vectorize_text_with_truncate_stats( | ||
label, add_start, add_end, truncate, False | ||
) | ||
obs.force_set('original_label_length', vec_label_length) | ||
obs.force_set('if_label_truncate', vec_label_truncated) | ||
obs[label_type + '_vec'] = vec_label | ||
obs[label_type + '_choice'] = label | ||
|
||
|
@@ -1994,6 +2031,36 @@ def batch_act(self, observations): | |
# check if there are any labels available, if so we will train on them | ||
self.is_training = any('labels' in obs for obs in observations) | ||
|
||
# check if we should add truncate stats | ||
if all('if_context_truncate' in obs for obs in observations): | ||
self.record_local_metric( | ||
'context_truncate', | ||
AverageMetric.many( | ||
[float(obs['if_context_truncate']) for obs in observations] | ||
), | ||
) | ||
if all('if_label_truncate' in obs for obs in observations): | ||
self.record_local_metric( | ||
'label_truncate', | ||
AverageMetric.many( | ||
[float(obs['if_label_truncate']) for obs in observations] | ||
), | ||
) | ||
if all('original_context_length' in obs for obs in observations): | ||
self.record_local_metric( | ||
'context_length', | ||
AverageMetric.many( | ||
[obs['original_context_length'] for obs in observations] | ||
), | ||
) | ||
if all('original_label_length' in obs for obs in observations): | ||
self.record_local_metric( | ||
'label_length', | ||
AverageMetric.many( | ||
[obs['original_label_length'] for obs in observations] | ||
), | ||
) | ||
|
||
Comment on lines
+2036
to
+2065
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just a question about these - won't this check fail on e.g. the last batch of training/evaluation if there are padding examples? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IDTS. It's going to have to change in another patch anyway. |
||
# create a batch from the vectors | ||
batch = self.batchify(observations) | ||
self.global_metrics.add('exps', GlobalTimerMetric(batch.batchsize)) | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uh i'm not sure locals is the right choice here