diff --git a/nemo/utils/model_utils.py b/nemo/utils/model_utils.py index 297132ed9c65..86d6aaded806 100644 --- a/nemo/utils/model_utils.py +++ b/nemo/utils/model_utils.py @@ -270,12 +270,21 @@ def resolve_validation_dataloaders(model: 'ModelPT'): if isinstance(ds_values, (list, tuple, ListConfig)): for ds_value in ds_values: - cfg.validation_ds[ds_key] = ds_value + if isinstance(ds_value, (dict, DictConfig)): + # this is a nested dataset + cfg.validation_ds = ds_value + else: + cfg.validation_ds[ds_key] = ds_value + model.setup_validation_data(cfg.validation_ds) dataloaders.append(model._validation_dl) model._validation_dl = dataloaders - model._validation_names = [parse_dataset_as_name(ds) for ds in ds_values] + if isinstance(ds_values[0], (dict, DictConfig)): + # using the name of each of the nested dataset + model._validation_names = [ds.name for ds in ds_values] + else: + model._validation_names = [parse_dataset_as_name(ds) for ds in ds_values] unique_names_check(name_list=model._validation_names) return @@ -340,12 +349,21 @@ def resolve_test_dataloaders(model: 'ModelPT'): if isinstance(ds_values, (list, tuple, ListConfig)): for ds_value in ds_values: - cfg.test_ds[ds_key] = ds_value + if isinstance(ds_value, (dict, DictConfig)): + # this is a nested dataset + cfg.test_ds = ds_value + else: + cfg.test_ds[ds_key] = ds_value + model.setup_test_data(cfg.test_ds) dataloaders.append(model._test_dl) model._test_dl = dataloaders - model._test_names = [parse_dataset_as_name(ds) for ds in ds_values] + if isinstance(ds_values[0], (dict, DictConfig)): + # using the name of each of the nested dataset + model._test_names = [ds.name for ds in ds_values] + else: + model._test_names = [parse_dataset_as_name(ds) for ds in ds_values] unique_names_check(name_list=model._test_names) return