diff --git a/torchdistill/common/func_util.py b/torchdistill/common/func_util.py index fe7ddc0d..651a4fdb 100644 --- a/torchdistill/common/func_util.py +++ b/torchdistill/common/func_util.py @@ -1,4 +1,4 @@ -import torch.nn as nn +from torch import nn from torchdistill.common import misc_util @@ -7,6 +7,16 @@ SCHEDULER_DICT = misc_util.get_classes_as_dict('torch.optim.lr_scheduler') +def register_optimizer(cls_or_func): + OPTIM_DICT[cls_or_func.__name__] = cls_or_func + return cls_or_func + + +def register_scheduler(cls_or_func): + SCHEDULER_DICT[cls_or_func.__name__] = cls_or_func + return cls_or_func + + def get_loss(loss_type, param_dict=dict(), **kwargs): lower_loss_type = loss_type.lower() if lower_loss_type in LOSS_DICT: diff --git a/torchdistill/datasets/registry.py b/torchdistill/datasets/registry.py new file mode 100644 index 00000000..b72baa81 --- /dev/null +++ b/torchdistill/datasets/registry.py @@ -0,0 +1,9 @@ +import torchvision + +DATASET_DICT = dict() +DATASET_DICT.update(torchvision.datasets.__dict__) + + +def register_dataset(cls_or_func): + DATASET_DICT[cls_or_func.__name__] = cls_or_func + return cls_or_func diff --git a/torchdistill/datasets/util.py b/torchdistill/datasets/util.py index 32bcf445..572ce7e3 100644 --- a/torchdistill/datasets/util.py +++ b/torchdistill/datasets/util.py @@ -9,6 +9,7 @@ from torchdistill.common.constant import def_logger from torchdistill.datasets.coco import ImageToTensor, Compose, CocoRandomHorizontalFlip, get_coco from torchdistill.datasets.collator import get_collate_func +from torchdistill.datasets.registry import DATASET_DICT from torchdistill.datasets.sample_loader import get_sample_loader from torchdistill.datasets.sampler import get_batch_sampler from torchdistill.datasets.transform import TRANSFORM_CLASS_DICT, CustomCompose @@ -16,7 +17,6 @@ logger = def_logger.getChild(__name__) -DATASET_DICT = torchvision.datasets.__dict__ TRANSFORM_CLASS_DICT.update(torchvision.transforms.__dict__) @@ -56,10 +56,10 @@ def build_transform(transform_params_config, compose_cls=None): return torchvision.transforms.Compose(component_list) if compose_cls is None else compose_cls(component_list) -def get_official_dataset(dataset_cls, dataset_params_config): +def get_torchvision_dataset(dataset_cls, dataset_params_config): params_config = dataset_params_config.copy() transform = build_transform(params_config.pop('transform_params', None)) - target_transform = build_transform(params_config.pop('transform_params', None)) + target_transform = build_transform(params_config.pop('target_transform_params', None)) if 'loader' in params_config: loader_config = params_config.pop('loader') loader_type = loader_config['type'] @@ -118,13 +118,15 @@ def get_dataset_dict(dataset_config): split_config['annotated_only'], split_config.get('random_horizontal_flip', None), is_segment, transforms, split_config.get('jpeg_quality', None)) elif dataset_type in DATASET_DICT: - dataset_cls = DATASET_DICT[dataset_type] + dataset_cls_or_func = DATASET_DICT[dataset_type] + is_torchvision = dataset_type in torchvision.datasets.__dict__ dataset_splits_config = dataset_config['splits'] for split_name in dataset_splits_config.keys(): st = time.time() logger.info('Loading {} data'.format(split_name)) split_config = dataset_splits_config[split_name] - org_dataset = get_official_dataset(dataset_cls, split_config['params']) + org_dataset = get_torchvision_dataset(dataset_cls_or_func, split_config['params']) if is_torchvision \ + else dataset_cls_or_func(**split_config['params']) dataset_id = split_config['dataset_id'] random_split_config = split_config.get('random_split', None) if random_split_config is None: