Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable registering various module types and fix a bug in target_transform #57

Merged
merged 4 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion torchdistill/common/func_util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import torch.nn as nn
from torch import nn

from torchdistill.common import misc_util

Expand All @@ -7,6 +7,16 @@
SCHEDULER_DICT = misc_util.get_classes_as_dict('torch.optim.lr_scheduler')


def register_optimizer(cls_or_func):
OPTIM_DICT[cls_or_func.__name__] = cls_or_func
return cls_or_func


def register_scheduler(cls_or_func):
SCHEDULER_DICT[cls_or_func.__name__] = cls_or_func
return cls_or_func


def get_loss(loss_type, param_dict=dict(), **kwargs):
lower_loss_type = loss_type.lower()
if lower_loss_type in LOSS_DICT:
Expand Down
9 changes: 9 additions & 0 deletions torchdistill/datasets/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torchvision

DATASET_DICT = dict()
DATASET_DICT.update(torchvision.datasets.__dict__)


def register_dataset(cls_or_func):
DATASET_DICT[cls_or_func.__name__] = cls_or_func
return cls_or_func
12 changes: 7 additions & 5 deletions torchdistill/datasets/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
from torchdistill.common.constant import def_logger
from torchdistill.datasets.coco import ImageToTensor, Compose, CocoRandomHorizontalFlip, get_coco
from torchdistill.datasets.collator import get_collate_func
from torchdistill.datasets.registry import DATASET_DICT
from torchdistill.datasets.sample_loader import get_sample_loader
from torchdistill.datasets.sampler import get_batch_sampler
from torchdistill.datasets.transform import TRANSFORM_CLASS_DICT, CustomCompose
from torchdistill.datasets.wrapper import default_idx2subpath, BaseDatasetWrapper, CacheableDataset, get_dataset_wrapper

logger = def_logger.getChild(__name__)

DATASET_DICT = torchvision.datasets.__dict__
TRANSFORM_CLASS_DICT.update(torchvision.transforms.__dict__)


Expand Down Expand Up @@ -56,10 +56,10 @@ def build_transform(transform_params_config, compose_cls=None):
return torchvision.transforms.Compose(component_list) if compose_cls is None else compose_cls(component_list)


def get_official_dataset(dataset_cls, dataset_params_config):
def get_torchvision_dataset(dataset_cls, dataset_params_config):
params_config = dataset_params_config.copy()
transform = build_transform(params_config.pop('transform_params', None))
target_transform = build_transform(params_config.pop('transform_params', None))
target_transform = build_transform(params_config.pop('target_transform_params', None))
if 'loader' in params_config:
loader_config = params_config.pop('loader')
loader_type = loader_config['type']
Expand Down Expand Up @@ -118,13 +118,15 @@ def get_dataset_dict(dataset_config):
split_config['annotated_only'], split_config.get('random_horizontal_flip', None),
is_segment, transforms, split_config.get('jpeg_quality', None))
elif dataset_type in DATASET_DICT:
dataset_cls = DATASET_DICT[dataset_type]
dataset_cls_or_func = DATASET_DICT[dataset_type]
is_torchvision = dataset_type in torchvision.datasets.__dict__
dataset_splits_config = dataset_config['splits']
for split_name in dataset_splits_config.keys():
st = time.time()
logger.info('Loading {} data'.format(split_name))
split_config = dataset_splits_config[split_name]
org_dataset = get_official_dataset(dataset_cls, split_config['params'])
org_dataset = get_torchvision_dataset(dataset_cls_or_func, split_config['params']) if is_torchvision \
else dataset_cls_or_func(**split_config['params'])
dataset_id = split_config['dataset_id']
random_split_config = split_config.get('random_split', None)
if random_split_config is None:
Expand Down