Skip to content

Commit

Permalink
enabling heterogeneous val / test datasets (NVIDIA#6306)
Browse files Browse the repository at this point in the history
Signed-off-by: hsiehjackson <c2hsieh@ucsd.edu>
  • Loading branch information
bmwshop authored and hsiehjackson committed Jun 2, 2023
1 parent 6a88747 commit 15f2d25
Showing 1 changed file with 22 additions and 4 deletions.
26 changes: 22 additions & 4 deletions nemo/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,12 +270,21 @@ def resolve_validation_dataloaders(model: 'ModelPT'):
if isinstance(ds_values, (list, tuple, ListConfig)):

for ds_value in ds_values:
cfg.validation_ds[ds_key] = ds_value
if isinstance(ds_value, (dict, DictConfig)):
# this is a nested dataset
cfg.validation_ds = ds_value
else:
cfg.validation_ds[ds_key] = ds_value

model.setup_validation_data(cfg.validation_ds)
dataloaders.append(model._validation_dl)

model._validation_dl = dataloaders
model._validation_names = [parse_dataset_as_name(ds) for ds in ds_values]
if isinstance(ds_values[0], (dict, DictConfig)):
# using the name of each of the nested dataset
model._validation_names = [ds.name for ds in ds_values]
else:
model._validation_names = [parse_dataset_as_name(ds) for ds in ds_values]
unique_names_check(name_list=model._validation_names)
return

Expand Down Expand Up @@ -340,12 +349,21 @@ def resolve_test_dataloaders(model: 'ModelPT'):
if isinstance(ds_values, (list, tuple, ListConfig)):

for ds_value in ds_values:
cfg.test_ds[ds_key] = ds_value
if isinstance(ds_value, (dict, DictConfig)):
# this is a nested dataset
cfg.test_ds = ds_value
else:
cfg.test_ds[ds_key] = ds_value

model.setup_test_data(cfg.test_ds)
dataloaders.append(model._test_dl)

model._test_dl = dataloaders
model._test_names = [parse_dataset_as_name(ds) for ds in ds_values]
if isinstance(ds_values[0], (dict, DictConfig)):
# using the name of each of the nested dataset
model._test_names = [ds.name for ds in ds_values]
else:
model._test_names = [parse_dataset_as_name(ds) for ds in ds_values]

unique_names_check(name_list=model._test_names)
return
Expand Down

0 comments on commit 15f2d25

Please sign in to comment.