-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PTL2.0 related ASR bugs in r1.21.0: Val metrics logging, None dataloader issue #7505
Conversation
Signed-off-by: KunalDhawan <kunaldhawan97@gmail.com>
Signed-off-by: KunalDhawan <kunaldhawan97@gmail.com>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.. did you check for other models as well?
@@ -601,6 +601,9 @@ def validation_pass(self, batch, batch_idx, dataloader_idx): | |||
if AccessMixin.is_access_enabled(): | |||
AccessMixin.reset_registry(self) | |||
|
|||
# adding this as return values are no longer logger automatically in PTL2.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove the comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in latest commit. Also added necessary logging changes for label_models, slu_models, ssl_models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Kunal. @athitten how is this different from adding to test_step_outputs
or validation_step_outputs
?
Signed-off-by: KunalDhawan <kunaldhawan97@gmail.com>
Signed-off-by: KunalDhawan <kunaldhawan97@gmail.com>
|
||
self.log_dict(val_log_dict) | ||
|
||
return val_log_dict | ||
|
||
def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0): | ||
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is loss_val_mean
this one the loss value that should be logged? I mean, should we add logging in the multi_validation_epoch_end
function instead? the one in validation_step
is per-batch, but we need to monitor the mean loss over all validation data. similar for test_step
and multi_test_epoch_end
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logging has to be done at validation step itself akin to the change introduced in this PR for PTL upgrade - https://github.com/NVIDIA/NeMo/pull/6433/files#diff-b2780d88910b132d177fb0081453ad276c5e4aefe47a87f219e96f38af0625be
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we do the logging at multi_validation_epoch_end and multi_test_epoch_end, we still get the current error - ModelCheckpoint(monitor='val_wer')
could not find the monitored key in the returned metrics: ['train_loss', 'learning_rate', 'global_step', 'train_backward_timing in s', 'train_step_timing in s', 'training_batch_wer', 'epoch', 'step']. HINT: Did you call log('val_wer', value)
in the LightningModule
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm refactoring the PR currently to make the logging similar to how we do it for ctc_models. I'll make the change for RNNT and Hybrid models for now, maybe we can open another PR next to address these issues for the SLU, SSL and label models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New PR - #7531
f'{tag}_loss': loss_value, | ||
f'{tag}_correct_counts': correct_counts, | ||
f'{tag}_total_counts': total_counts, | ||
f'{tag}_acc_micro_top_k': acc_top_k, | ||
f'{tag}_acc_macro_stats': stats, | ||
} | ||
|
||
self.log_dict(eval_dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not all variables in eval_dict
need logging, please move self.log()
into multi_evaluation_epoch_end
where it calculates the averaged metrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all these variables were logged previously too (example run - https://wandb.ai/nvidia/titanet-chime7-training?workspace=user-kdhawan), please let me know if you want me to remove some of these from the log
Opened a new PR for this issue - #7531 |
What does this PR do ?
This PR added fixes for PTL2.0 related ASR bugs in r1.21.0: Val metrics logging, None dataloader issue
Collection:
ASR, Core
Changelog
Before your PR is "Ready for review"
Pre checks:
PR Type:
Additional Information