Skip to content

Commit

Permalink
fix validation_step_outputs initialization for multi-dataloader (NVID…
Browse files Browse the repository at this point in the history
…IA#7546) (NVIDIA#7572)

* added correct validation_step_outputs initialization for mutli-dataloader



* changed kernel for display



* Update logic for validation and test step outputs



* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* revert multidataloader changes in multilang ASR notebook



---------

Signed-off-by: KunalDhawan <kunaldhawan97@gmail.com>
Signed-off-by: smajumdar <titu1994@gmail.com>
Co-authored-by: Kunal Dhawan <kunaldhawan97@gmail.com>
Co-authored-by: Somshubra Majumdar <titu1994@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Sasha Meister <ameister@nvidia.com>
  • Loading branch information
4 people authored and sashameister committed Oct 2, 2023
1 parent c69e7e8 commit 0b75d82
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 16 deletions.
99 changes: 86 additions & 13 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,11 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):

# Create list of lists for val and test outputs to support multiple dataloaders
# Initialize an empty list as sometimes self._validation_dl can be None at this stage
self.validation_step_outputs = []
# Check len(self._validation_dl) > 1 as sometimes single dataloader can be in a list: [<Dataloader obj>] when ds_item in
# config has 1 item passed in a list
if self._validation_dl and type(self._validation_dl) == list and len(self._validation_dl) > 1:
for _ in range(len(self._validation_dl)):
self.validation_step_outputs.append([])
self._validation_step_outputs = None

# Initialize an empty list as sometimes self._test_dl can be None at this stage
self.test_step_outputs = []
if self._test_dl and type(self._test_dl) == list and len(self._test_dl) > 1:
for _ in range(len(self._test_dl)):
self.test_step_outputs.append([])
self._test_step_outputs = None

# ModelPT wrappers over subclass implementations
self.training_step = model_utils.wrap_training_step(self.training_step)

Expand Down Expand Up @@ -1573,6 +1566,61 @@ def cfg(self, cfg):
if hasattr(self, '_hparams_initial') and 'cfg' in self._hparams_initial:
self._hparams_initial['cfg'] = OmegaConf.to_object(self._cfg)

@property
def validation_step_outputs(self):
"""
Cached outputs of validation_step. It can be a list of items (for single data loader) or a list of lists
(for multiple data loaders).
Returns:
List of outputs of validation_step.
"""
if self._validation_step_outputs is not None:
return self._validation_step_outputs

# Initialize new output list
self._validation_step_outputs = []
# Check len(self._validation_dl) > 1 as sometimes single dataloader can be in a list: [<Dataloader obj>] when ds_item in
# config has 1 item passed in a list
if (
self._validation_dl is not None
and isinstance(self._validation_dl, (list, tuple))
and len(self._validation_dl) > 1
):
for _ in range(len(self._validation_dl)):
self._validation_step_outputs.append([])

return self._validation_step_outputs

@validation_step_outputs.setter
def validation_step_outputs(self, value):
self._validation_step_outputs = value

@property
def test_step_outputs(self):
"""
Cached outputs of test_step. It can be a list of items (for single data loader) or a list of lists (for multiple data loaders).
Returns:
List of outputs of test_step.
"""
if self._test_step_outputs is not None:
return self._test_step_outputs

# Initialize new output list
self._test_step_outputs = []
# Check len(self._test_dl) > 1 as sometimes single dataloader can be in a list: [<Dataloader obj>] when ds_item in
# config has 1 item passed in a list
if self._test_dl is not None and isinstance(self._test_dl, (list, tuple)) and len(self._test_dl) > 1:
for _ in range(len(self._test_dl)):
self._test_step_outputs.append([])

return self._test_step_outputs

@test_step_outputs.setter
def test_step_outputs(self, value):
self._test_step_outputs = value

@staticmethod
def _is_model_being_restored() -> bool:
app_state = AppState()
Expand Down Expand Up @@ -1714,15 +1762,40 @@ def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, unused: int =
logging.info("====== End nsys profiling ======")
torch.cuda.cudart().cudaProfilerStop()

def _cleanup_on_execution_end(self):
"""
Utility function to clean up the module state at the end of execution.
"""

# dynamic freezing cleanup
if hasattr(self, '_freeze_cfg'):
delattr(self, '_freeze_cfg')

# Clear up the val and test output caches
self._validation_step_outputs = None
self._test_step_outputs = None

def on_train_end(self):
""" PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-train-end
We use it here to cleanup the dynamic freezing config.
"""

# dynamic freezing cleanup
if hasattr(self, '_freeze_cfg'):
delattr(self, '_freeze_cfg')
self._cleanup_on_execution_end()

def on_test_end(self):
""" PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-test-end
"""

self._cleanup_on_execution_end()

def on_predict_end(self):
""" PyTorch Lightning hook:
https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html#on-test-end
"""

self._cleanup_on_execution_end()

# TODO: Remove in PTL 1.7.2
def cuda(self, device=None):
Expand Down
11 changes: 8 additions & 3 deletions tutorials/asr/Multilang_ASR.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1713,7 +1713,7 @@
},
"outputs": [],
"source": [
"asr_model.setup_multiple_validation_data(val_data_config=validation_ds) "
"asr_model.setup_multiple_validation_data(val_data_config=validation_ds)"
]
},
{
Expand Down Expand Up @@ -2273,7 +2273,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "base",
"language": "python",
"name": "python3"
},
Expand All @@ -2287,11 +2287,16 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
"version": "3.8.12"
},
"nteract": {
"version": "0.28.0"
},
"vscode": {
"interpreter": {
"hash": "1aaa02ce0ce2638a6e16a203f0ce39bc7495f7236d7115882d2d3541e1318e7a"
}
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"013abc9bfddf456abf15dc2b0567d969": {
Expand Down

0 comments on commit 0b75d82

Please sign in to comment.