Skip to content

Commit

Permalink
Merge pull request #347 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Add default args and kwargs
  • Loading branch information
yoshitomo-matsubara authored May 6, 2023
2 parents 8469898 + e218869 commit 51f364e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion torchdistill/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,14 @@ def build_data_loader(dataset, data_loader_config, distributed, accelerator=None
cache_dir_path = data_loader_config.get('cache_output', None)
dataset_wrapper_config = data_loader_config.get('dataset_wrapper', None)
if isinstance(dataset_wrapper_config, dict) and len(dataset_wrapper_config) > 0:
dataset = get_dataset_wrapper(dataset_wrapper_config['key'], dataset, **dataset_wrapper_config['kwargs'])
dataset_wrapper_args = dataset_wrapper_config.get('args', None)
dataset_wrapper_kwargs = dataset_wrapper_config.get('kwargs', None)
if dataset_wrapper_args is None:
dataset_wrapper_args = list()
if dataset_wrapper_kwargs is None:
dataset_wrapper_kwargs = dict()
dataset = get_dataset_wrapper(dataset_wrapper_config['key'], dataset, *dataset_wrapper_args,
**dataset_wrapper_kwargs)
elif cache_dir_path is not None:
dataset = CacheableDataset(dataset, cache_dir_path, idx2subpath_func=default_idx2subpath)
elif data_loader_config.get('requires_supp', False):
Expand Down

0 comments on commit 51f364e

Please sign in to comment.