Skip to content

Commit

Permalink
Fix PTL2.0 related ASR bugs in r1.21.0: Val metrics logging, None dat…
Browse files Browse the repository at this point in the history
…aloader issue (NVIDIA#7531) (NVIDIA#7533)

* fix none dataloader issue ptl2



* ptl2.0 logging fixes for rnnt_models



---------

Signed-off-by: KunalDhawan <kunaldhawan97@gmail.com>
Co-authored-by: Kunal Dhawan <kunaldhawan97@gmail.com>
Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
  • Loading branch information
3 people authored Sep 27, 2023
1 parent 9c4fbe1 commit 231d08b
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def predict_step(self, batch, batch_idx, dataloader_idx=0):
sample_id = sample_id.cpu().detach().numpy()
return list(zip(sample_id, best_hyp_text))

def validation_step(self, batch, batch_idx, dataloader_idx=0):
def validation_pass(self, batch, batch_idx, dataloader_idx=0):
signal, signal_len, transcript, transcript_len = batch

# forward() only performs encoder forward
Expand Down Expand Up @@ -835,15 +835,21 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):

return tensorboard_logs

def validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics = self.validation_pass(batch, batch_idx, dataloader_idx)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
return metrics

def test_step(self, batch, batch_idx, dataloader_idx=0):
logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {
'test_wer_num': logs['val_wer_num'],
'test_wer_denom': logs['val_wer_denom'],
# 'test_wer': logs['val_wer'],
}
if 'val_loss' in logs:
test_logs['test_loss'] = logs['val_loss']
logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()}
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(test_logs)
else:
self.test_step_outputs.append(test_logs)
return test_logs

def multi_validation_epoch_end(self, outputs, dataloader_idx: int = 0):
Expand Down

0 comments on commit 231d08b

Please sign in to comment.