Skip to content

Commit

Permalink
Merge pull request #57 from yoshitomo-matsubara/dev
Browse files Browse the repository at this point in the history
Enable registering various module types and fix a bug in target_transform
  • Loading branch information
yoshitomo-matsubara authored Jan 18, 2021
2 parents 758ed3e + 1a975ed commit bf8d4f8
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 6 deletions.
12 changes: 11 additions & 1 deletion torchdistill/common/func_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch.nn as nn
from torch import nn

from torchdistill.common import misc_util

Expand All @@ -7,6 +7,16 @@
SCHEDULER_DICT = misc_util.get_classes_as_dict('torch.optim.lr_scheduler')


def register_optimizer(cls_or_func):
OPTIM_DICT[cls_or_func.__name__] = cls_or_func
return cls_or_func


def register_scheduler(cls_or_func):
SCHEDULER_DICT[cls_or_func.__name__] = cls_or_func
return cls_or_func


def get_loss(loss_type, param_dict=dict(), **kwargs):
lower_loss_type = loss_type.lower()
if lower_loss_type in LOSS_DICT:
Expand Down
9 changes: 9 additions & 0 deletions torchdistill/datasets/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torchvision

DATASET_DICT = dict()
DATASET_DICT.update(torchvision.datasets.__dict__)


def register_dataset(cls_or_func):
DATASET_DICT[cls_or_func.__name__] = cls_or_func
return cls_or_func
12 changes: 7 additions & 5 deletions torchdistill/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from torchdistill.common.constant import def_logger
from torchdistill.datasets.coco import ImageToTensor, Compose, CocoRandomHorizontalFlip, get_coco
from torchdistill.datasets.collator import get_collate_func
from torchdistill.datasets.registry import DATASET_DICT
from torchdistill.datasets.sample_loader import get_sample_loader
from torchdistill.datasets.sampler import get_batch_sampler
from torchdistill.datasets.transform import TRANSFORM_CLASS_DICT, CustomCompose
from torchdistill.datasets.wrapper import default_idx2subpath, BaseDatasetWrapper, CacheableDataset, get_dataset_wrapper

logger = def_logger.getChild(__name__)

DATASET_DICT = torchvision.datasets.__dict__
TRANSFORM_CLASS_DICT.update(torchvision.transforms.__dict__)


Expand Down Expand Up @@ -56,10 +56,10 @@ def build_transform(transform_params_config, compose_cls=None):
return torchvision.transforms.Compose(component_list) if compose_cls is None else compose_cls(component_list)


def get_official_dataset(dataset_cls, dataset_params_config):
def get_torchvision_dataset(dataset_cls, dataset_params_config):
params_config = dataset_params_config.copy()
transform = build_transform(params_config.pop('transform_params', None))
target_transform = build_transform(params_config.pop('transform_params', None))
target_transform = build_transform(params_config.pop('target_transform_params', None))
if 'loader' in params_config:
loader_config = params_config.pop('loader')
loader_type = loader_config['type']
Expand Down Expand Up @@ -118,13 +118,15 @@ def get_dataset_dict(dataset_config):
split_config['annotated_only'], split_config.get('random_horizontal_flip', None),
is_segment, transforms, split_config.get('jpeg_quality', None))
elif dataset_type in DATASET_DICT:
dataset_cls = DATASET_DICT[dataset_type]
dataset_cls_or_func = DATASET_DICT[dataset_type]
is_torchvision = dataset_type in torchvision.datasets.__dict__
dataset_splits_config = dataset_config['splits']
for split_name in dataset_splits_config.keys():
st = time.time()
logger.info('Loading {} data'.format(split_name))
split_config = dataset_splits_config[split_name]
org_dataset = get_official_dataset(dataset_cls, split_config['params'])
org_dataset = get_torchvision_dataset(dataset_cls_or_func, split_config['params']) if is_torchvision \
else dataset_cls_or_func(**split_config['params'])
dataset_id = split_config['dataset_id']
random_split_config = split_config.get('random_split', None)
if random_split_config is None:
Expand Down

0 comments on commit bf8d4f8

Please sign in to comment.