Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PTL2.0 related ASR bugs in r1.21.0: Val metrics logging, None dataloader issue #7505

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,8 @@ def validation_pass(self, batch, batch_idx, dataloader_idx):
if AccessMixin.is_access_enabled():
AccessMixin.reset_registry(self)

self.log_dict(tensorboard_logs)

return tensorboard_logs

def validation_step(self, batch, batch_idx, dataloader_idx=0):
Expand Down
5 changes: 4 additions & 1 deletion nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,17 @@ def evaluation_step(self, batch, batch_idx, dataloader_idx: int = 0, tag: str =
self._macro_accuracy.update(preds=logits, target=labels)
stats = self._macro_accuracy._final_state()

return {
eval_dict = {
f'{tag}_loss': loss_value,
f'{tag}_correct_counts': correct_counts,
f'{tag}_total_counts': total_counts,
f'{tag}_acc_micro_top_k': acc_top_k,
f'{tag}_acc_macro_stats': stats,
}

self.log_dict(eval_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not all variables in eval_dict need logging, please move self.log() into multi_evaluation_epoch_end where it calculates the averaged metrics

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all these variables were logged previously too (example run - https://wandb.ai/nvidia/titanet-chime7-training?workspace=user-kdhawan), please let me know if you want me to remove some of these from the log

return eval_dict

def multi_evaluation_epoch_end(self, outputs, dataloader_idx: int = 0, tag: str = 'val'):
loss_mean = torch.stack([x[f'{tag}_loss'] for x in outputs]).mean()
correct_counts = torch.stack([x[f'{tag}_correct_counts'] for x in outputs]).sum(axis=0)
Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,6 +833,8 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):

self.log('global_step', torch.tensor(self.trainer.global_step, dtype=torch.float32))

self.log_dict(tensorboard_logs)

return tensorboard_logs

def test_step(self, batch, batch_idx, dataloader_idx=0):
Expand Down
9 changes: 8 additions & 1 deletion nemo/collections/asr/models/slu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,13 +320,17 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
wer, wer_num, wer_denom = self._wer.compute()
self._wer.reset()

return {
val_logs_dict = {
'val_loss': loss_value,
'val_wer_num': wer_num,
'val_wer_denom': wer_denom,
'val_wer': wer,
}

self.log_dict(val_logs_dict)

return val_logs_dict

def test_step(self, batch, batch_idx, dataloader_idx=0):
logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {
Expand All @@ -335,6 +339,9 @@ def test_step(self, batch, batch_idx, dataloader_idx=0):
'test_wer_denom': logs['val_wer_denom'],
'test_wer': logs['val_wer'],
}

self.log_dict(test_logs)

return test_logs

def test_dataloader(self):
Expand Down
8 changes: 5 additions & 3 deletions nemo/collections/asr/models/ssl_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,11 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
self.reset_registry()
del self._in_validation_step

return {
'val_loss': loss_value,
}
val_log_dict = {'val_loss': loss_value}

self.log_dict(val_log_dict)

return val_log_dict

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
val_loss_mean = torch.stack([x['val_loss'] for x in outputs]).mean()
Copy link
Collaborator

@stevehuang52 stevehuang52 Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is loss_val_mean this one the loss value that should be logged? I mean, should we add logging in the multi_validation_epoch_end function instead? the one in validation_step is per-batch, but we need to monitor the mean loss over all validation data. similar for test_step and multi_test_epoch_end

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logging has to be done at validation step itself akin to the change introduced in this PR for PTL upgrade - https://github.com/NVIDIA/NeMo/pull/6433/files#diff-b2780d88910b132d177fb0081453ad276c5e4aefe47a87f219e96f38af0625be

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do the logging at multi_validation_epoch_end and multi_test_epoch_end, we still get the current error - ModelCheckpoint(monitor='val_wer') could not find the monitored key in the returned metrics: ['train_loss', 'learning_rate', 'global_step', 'train_backward_timing in s', 'train_step_timing in s', 'training_batch_wer', 'epoch', 'step']. HINT: Did you call log('val_wer', value) in the LightningModule?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm refactoring the PR currently to make the logging similar to how we do it for ctc_models. I'll make the change for RNNT and Hybrid models for now, maybe we can open another PR next to address these issues for the SLU, SSL and label models

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New PR - #7531

Expand Down
14 changes: 10 additions & 4 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -856,12 +856,18 @@ def train_dataloader(self):
return self._train_dl

def val_dataloader(self):
if self._validation_dl is not None:
return self._validation_dl
if self._validation_dl is None:
# None dataloader no longer supported in PTL2.0
self._validation_dl = []

return self._validation_dl

def test_dataloader(self):
if self._test_dl is not None:
return self._test_dl
if self._test_dl is None:
# None dataloader no longer supported in PTL2.0
self._test_dl = []

return self._test_dl

def on_validation_epoch_end(self) -> Optional[Dict[str, Dict[str, torch.Tensor]]]:
"""
Expand Down