Skip to content

Commit

Permalink
fix ptl_bugs in slu_models.py (#7689) (#7712)
Browse files Browse the repository at this point in the history
* fix ptl_bugs in slu_models.py



* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* change strategy to ddp_find_unused_parameters_true in  slu example yaml



---------

Signed-off-by: Seonghun Noh <jzi040941@naver.com>
Signed-off-by: Seonghun <jzi040941@naver.com>
Co-authored-by: Seonghun Noh <jzi040941@naver.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: He Huang (Steve) <105218074+stevehuang52@users.noreply.github.com>
  • Loading branch information
4 people authored Oct 13, 2023
1 parent 5a8ad7f commit efb0dda
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ trainer:
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
strategy: ddp_find_unused_parameters_true
accumulate_grad_batches: 1
gradient_clip_val: 0.0
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
Expand Down
30 changes: 20 additions & 10 deletions nemo/collections/asr/models/slu_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def predict(
predictions = self.sequence_generator.decode_semantics_from_tokens(pred_tokens)
return predictions

def validation_step(self, batch, batch_idx, dataloader_idx=0):
def validation_pass(self, batch, batch_idx, dataloader_idx=0):
if len(batch) == 4:
signal, signal_len, semantics, semantics_len = batch
else:
Expand Down Expand Up @@ -327,19 +327,29 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
'val_wer': wer,
}

def validation_step(self, batch, batch_idx, dataloader_idx=0):
metrics = self.validation_pass(batch, batch_idx, dataloader_idx)
if type(self.trainer.val_dataloaders) == list and len(self.trainer.val_dataloaders) > 1:
self.validation_step_outputs[dataloader_idx].append(metrics)
else:
self.validation_step_outputs.append(metrics)
return metrics

def test_step(self, batch, batch_idx, dataloader_idx=0):
logs = self.validation_step(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {
'test_loss': logs['val_loss'],
'test_wer_num': logs['val_wer_num'],
'test_wer_denom': logs['val_wer_denom'],
'test_wer': logs['val_wer'],
}
logs = self.validation_pass(batch, batch_idx, dataloader_idx=dataloader_idx)
test_logs = {name.replace("val_", "test_"): value for name, value in logs.items()}
if type(self.trainer.test_dataloaders) == list and len(self.trainer.test_dataloaders) > 1:
self.test_step_outputs[dataloader_idx].append(test_logs)
else:
self.test_step_outputs.append(test_logs)
return test_logs

def test_dataloader(self):
if self._test_dl is not None:
return self._test_dl
if self._test_dl is None:
# None dataloader no longer supported in PTL2.0
self._test_dl = []

return self._test_dl

def _setup_dataloader_from_config(self, config: Optional[Dict]):
if 'augmentor' in config:
Expand Down

0 comments on commit efb0dda

Please sign in to comment.