From 2a6f454fa2dc7312b35c9ec276f0078c9eb642dc Mon Sep 17 00:00:00 2001 From: Yoshitomo Matsubara Date: Sat, 24 Dec 2022 23:40:09 -0800 Subject: [PATCH] rename registry dicts and arguments for registry key --- tests/registry_test.py | 88 +++++++-------- torchdistill/core/registry.py | 16 +-- torchdistill/datasets/registry.py | 100 +++++++++--------- torchdistill/datasets/sampler.py | 4 +- torchdistill/datasets/transform.py | 20 ++-- torchdistill/datasets/util.py | 10 +- torchdistill/losses/registry.py | 72 ++++++------- torchdistill/models/__init__.py | 6 +- .../models/custom/bottleneck/registry.py | 23 ++-- torchdistill/models/registry.py | 58 +++++----- torchdistill/optim/registry.py | 12 +-- 11 files changed, 209 insertions(+), 200 deletions(-) diff --git a/tests/registry_test.py b/tests/registry_test.py index 296b977c..0db68708 100644 --- a/tests/registry_test.py +++ b/tests/registry_test.py @@ -1,17 +1,17 @@ from unittest import TestCase -from torchdistill.core.registry import register_forward_proc_func, PROC_FUNC_DICT +from torchdistill.core.registry import register_forward_proc_func, FORWARD_PROC_FUNC_DICT from torchdistill.datasets.registry import register_dataset, register_collate_func, register_sample_loader_class, \ - register_sample_loader_func, register_batch_sampler_class, register_transform_class, register_dataset_wrapper, \ - DATASET_DICT, COLLATE_FUNC_DICT, SAMPLE_LOADER_CLASS_DICT, SAMPLE_LOADER_FUNC_DICT, BATCH_SAMPLER_CLASS_DICT, \ - TRANSFORM_CLASS_DICT, WRAPPER_CLASS_DICT -from torchdistill.losses.registry import register_custom_loss, CUSTOM_LOSS_CLASS_DICT, register_loss_wrapper, \ + register_sample_loader_func, register_batch_sampler, register_transform, register_dataset_wrapper, \ + DATASET_DICT, COLLATE_FUNC_DICT, SAMPLE_LOADER_CLASS_DICT, SAMPLE_LOADER_FUNC_DICT, BATCH_SAMPLER_DICT, \ + TRANSFORM_DICT, DATASET_WRAPPER_DICT +from torchdistill.losses.registry import register_custom_loss, CUSTOM_LOSS_DICT, register_loss_wrapper, \ register_single_loss, register_org_loss, \ - LOSS_WRAPPER_CLASS_DICT, SINGLE_LOSS_CLASS_DICT, ORG_LOSS_LIST, register_func2extract_org_output, \ + LOSS_WRAPPER_DICT, SINGLE_LOSS_DICT, ORG_LOSS_LIST, register_func2extract_org_output, \ FUNC2EXTRACT_ORG_OUTPUT_DICT -from torchdistill.models.registry import get_model, register_adaptation_module, ADAPTATION_CLASS_DICT, \ +from torchdistill.models.registry import get_model, register_adaptation_module, ADAPTATION_MODULE_DICT, \ register_model_class, register_model_func, MODEL_CLASS_DICT, MODEL_FUNC_DICT, register_special_module, \ - SPECIAL_CLASS_DICT + SPECIAL_MODULE_DICT from torchdistill.optim.registry import register_optimizer, register_scheduler, OPTIM_DICT, SCHEDULER_DICT @@ -51,20 +51,20 @@ def test_register_forward_proc_func(self): def test_forward_proc0(model, batch): return model(batch) - assert PROC_FUNC_DICT['test_forward_proc0'] == test_forward_proc0 + assert FORWARD_PROC_FUNC_DICT['test_forward_proc0'] == test_forward_proc0 @register_forward_proc_func() def test_forward_proc1(model, batch): return model(batch) - assert PROC_FUNC_DICT['test_forward_proc1'] == test_forward_proc1 + assert FORWARD_PROC_FUNC_DICT['test_forward_proc1'] == test_forward_proc1 random_name = 'custom_forward_proc_name2' @register_forward_proc_func(key=random_name) def test_forward_proc2(model, batch, label): return model(batch, label) - assert PROC_FUNC_DICT[random_name] == test_forward_proc2 + assert FORWARD_PROC_FUNC_DICT[random_name] == test_forward_proc2 def test_register_collate_func(self): @register_collate_func @@ -129,50 +129,50 @@ def test_sample_loader2(batch, label): assert SAMPLE_LOADER_FUNC_DICT[random_name] == test_sample_loader2 def test_register_sampler(self): - @register_batch_sampler_class + @register_batch_sampler class TestBatchSampler0(object): def __init__(self): self.name = 'test0' - assert BATCH_SAMPLER_CLASS_DICT['TestBatchSampler0'] == TestBatchSampler0 + assert BATCH_SAMPLER_DICT['TestBatchSampler0'] == TestBatchSampler0 - @register_batch_sampler_class() + @register_batch_sampler() class TestBatchSampler1(object): def __init__(self): self.name = 'test1' - assert BATCH_SAMPLER_CLASS_DICT['TestBatchSampler1'] == TestBatchSampler1 + assert BATCH_SAMPLER_DICT['TestBatchSampler1'] == TestBatchSampler1 random_name = 'custom_batch_sampler_class_name2' - @register_batch_sampler_class(key=random_name) + @register_batch_sampler(key=random_name) class TestBatchSampler2(object): def __init__(self): self.name = 'test2' - assert BATCH_SAMPLER_CLASS_DICT[random_name] == TestBatchSampler2 + assert BATCH_SAMPLER_DICT[random_name] == TestBatchSampler2 def test_register_transform(self): - @register_transform_class() + @register_transform() class TestTransform0(object): def __init__(self): self.name = 'test0' - assert TRANSFORM_CLASS_DICT['TestTransform0'] == TestTransform0 + assert TRANSFORM_DICT['TestTransform0'] == TestTransform0 - @register_transform_class() + @register_transform() class TestTransform1(object): def __init__(self): self.name = 'test1' - assert TRANSFORM_CLASS_DICT['TestTransform1'] == TestTransform1 + assert TRANSFORM_DICT['TestTransform1'] == TestTransform1 random_name = 'custom_transform_class_name2' - @register_transform_class(key=random_name) + @register_transform(key=random_name) class TestTransform2(object): def __init__(self): self.name = 'test2' - assert TRANSFORM_CLASS_DICT[random_name] == TestTransform2 + assert TRANSFORM_DICT[random_name] == TestTransform2 def test_register_dataset_wrapper(self): @register_dataset_wrapper @@ -180,14 +180,14 @@ class TestDatasetWrapper0(object): def __init__(self): self.name = 'test0' - assert WRAPPER_CLASS_DICT['TestDatasetWrapper0'] == TestDatasetWrapper0 + assert DATASET_WRAPPER_DICT['TestDatasetWrapper0'] == TestDatasetWrapper0 @register_dataset_wrapper() class TestDatasetWrapper1(object): def __init__(self): self.name = 'test1' - assert WRAPPER_CLASS_DICT['TestDatasetWrapper1'] == TestDatasetWrapper1 + assert DATASET_WRAPPER_DICT['TestDatasetWrapper1'] == TestDatasetWrapper1 random_name = 'custom_dataset_wrapper_class_name2' @register_dataset_wrapper(key=random_name) @@ -195,7 +195,7 @@ class TestDatasetWrapper2(object): def __init__(self): self.name = 'test2' - assert WRAPPER_CLASS_DICT[random_name] == TestDatasetWrapper2 + assert DATASET_WRAPPER_DICT[random_name] == TestDatasetWrapper2 def test_register_custom_loss_class(self): @register_custom_loss @@ -203,14 +203,14 @@ class TestCustomLoss0(object): def __init__(self): self.name = 'test0' - assert CUSTOM_LOSS_CLASS_DICT['TestCustomLoss0'] == TestCustomLoss0 + assert CUSTOM_LOSS_DICT['TestCustomLoss0'] == TestCustomLoss0 @register_custom_loss() class TestCustomLoss1(object): def __init__(self): self.name = 'test1' - assert CUSTOM_LOSS_CLASS_DICT['TestCustomLoss1'] == TestCustomLoss1 + assert CUSTOM_LOSS_DICT['TestCustomLoss1'] == TestCustomLoss1 random_name = 'custom_loss_class_name2' @register_custom_loss(key=random_name) @@ -218,7 +218,7 @@ class TestCustomLoss2(object): def __init__(self): self.name = 'test2' - assert CUSTOM_LOSS_CLASS_DICT[random_name] == TestCustomLoss2 + assert CUSTOM_LOSS_DICT[random_name] == TestCustomLoss2 def test_register_loss_wrapper_class(self): @register_loss_wrapper @@ -226,14 +226,14 @@ class TestLossWrapper0(object): def __init__(self): self.name = 'test0' - assert LOSS_WRAPPER_CLASS_DICT['TestLossWrapper0'] == TestLossWrapper0 + assert LOSS_WRAPPER_DICT['TestLossWrapper0'] == TestLossWrapper0 @register_loss_wrapper() class TestLossWrapper1(object): def __init__(self): self.name = 'test1' - assert LOSS_WRAPPER_CLASS_DICT['TestLossWrapper1'] == TestLossWrapper1 + assert LOSS_WRAPPER_DICT['TestLossWrapper1'] == TestLossWrapper1 random_name = 'custom_loss_wrapper_class_name2' @register_loss_wrapper(key=random_name) @@ -241,7 +241,7 @@ class TestLossWrapper2(object): def __init__(self): self.name = 'test2' - assert LOSS_WRAPPER_CLASS_DICT[random_name] == TestLossWrapper2 + assert LOSS_WRAPPER_DICT[random_name] == TestLossWrapper2 def test_register_single_loss(self): @register_single_loss @@ -249,14 +249,14 @@ class TestSingleLoss0(object): def __init__(self): self.name = 'test0' - assert SINGLE_LOSS_CLASS_DICT['TestSingleLoss0'] == TestSingleLoss0 + assert SINGLE_LOSS_DICT['TestSingleLoss0'] == TestSingleLoss0 @register_single_loss() class TestSingleLoss1(object): def __init__(self): self.name = 'test1' - assert SINGLE_LOSS_CLASS_DICT['TestSingleLoss1'] == TestSingleLoss1 + assert SINGLE_LOSS_DICT['TestSingleLoss1'] == TestSingleLoss1 random_name = 'custom_single_loss_class_name2' @register_single_loss(key=random_name) @@ -264,7 +264,7 @@ class TestSingleLoss2(object): def __init__(self): self.name = 'test2' - assert SINGLE_LOSS_CLASS_DICT[random_name] == TestSingleLoss2 + assert SINGLE_LOSS_DICT[random_name] == TestSingleLoss2 def test_register_org_loss(self): @register_org_loss @@ -272,7 +272,7 @@ class TestOrgLoss0(object): def __init__(self): self.name = 'test0' - assert SINGLE_LOSS_CLASS_DICT['TestOrgLoss0'] == TestOrgLoss0 + assert SINGLE_LOSS_DICT['TestOrgLoss0'] == TestOrgLoss0 assert TestOrgLoss0 in ORG_LOSS_LIST @register_org_loss() @@ -280,7 +280,7 @@ class TestOrgLoss1(object): def __init__(self): self.name = 'test1' - assert SINGLE_LOSS_CLASS_DICT['TestOrgLoss1'] == TestOrgLoss1 + assert SINGLE_LOSS_DICT['TestOrgLoss1'] == TestOrgLoss1 assert TestOrgLoss1 in ORG_LOSS_LIST random_name = 'custom_org_loss_class_name2' @@ -289,7 +289,7 @@ class TestOrgLoss2(object): def __init__(self): self.name = 'test2' - assert SINGLE_LOSS_CLASS_DICT[random_name] == TestOrgLoss2 + assert SINGLE_LOSS_DICT[random_name] == TestOrgLoss2 assert TestOrgLoss2 in ORG_LOSS_LIST def test_func2extract_org_output(self): @@ -364,14 +364,14 @@ class TestAdaptationModule0(object): def __init__(self): self.name = 'test0' - assert ADAPTATION_CLASS_DICT['TestAdaptationModule0'] == TestAdaptationModule0 + assert ADAPTATION_MODULE_DICT['TestAdaptationModule0'] == TestAdaptationModule0 @register_adaptation_module() class TestAdaptationModule1(object): def __init__(self): self.name = 'test1' - assert ADAPTATION_CLASS_DICT['TestAdaptationModule1'] == TestAdaptationModule1 + assert ADAPTATION_MODULE_DICT['TestAdaptationModule1'] == TestAdaptationModule1 random_name = 'custom_adaptation_module_class_name2' @register_adaptation_module(key=random_name) @@ -379,7 +379,7 @@ class TestAdaptationModule2(object): def __init__(self): self.name = 'test2' - assert ADAPTATION_CLASS_DICT[random_name] == TestAdaptationModule2 + assert ADAPTATION_MODULE_DICT[random_name] == TestAdaptationModule2 def test_register_model_class(self): @register_model_class @@ -430,14 +430,14 @@ class TestSpecialModule0(object): def __init__(self): self.name = 'test0' - assert SPECIAL_CLASS_DICT['TestSpecialModule0'] == TestSpecialModule0 + assert SPECIAL_MODULE_DICT['TestSpecialModule0'] == TestSpecialModule0 @register_special_module() class TestSpecialModule1(object): def __init__(self): self.name = 'test1' - assert SPECIAL_CLASS_DICT['TestSpecialModule1'] == TestSpecialModule1 + assert SPECIAL_MODULE_DICT['TestSpecialModule1'] == TestSpecialModule1 random_name = 'custom_special_module_class_name2' @register_special_module(key=random_name) @@ -445,4 +445,4 @@ class TestSpecialModule2(object): def __init__(self): self.name = 'test2' - assert SPECIAL_CLASS_DICT[random_name] == TestSpecialModule2 + assert SPECIAL_MODULE_DICT[random_name] == TestSpecialModule2 diff --git a/torchdistill/core/registry.py b/torchdistill/core/registry.py index 6710cc81..56a2e89b 100644 --- a/torchdistill/core/registry.py +++ b/torchdistill/core/registry.py @@ -1,4 +1,4 @@ -PROC_FUNC_DICT = dict() +FORWARD_PROC_FUNC_DICT = dict() def register_forward_proc_func(arg=None, **kwargs): @@ -7,7 +7,7 @@ def _register_forward_proc_func(func): if key is None: key = func.__name__ - PROC_FUNC_DICT[key] = func + FORWARD_PROC_FUNC_DICT[key] = func return func if callable(arg): @@ -15,9 +15,9 @@ def _register_forward_proc_func(func): return _register_forward_proc_func -def get_forward_proc_func(func_name): - if func_name is None: - return PROC_FUNC_DICT['forward_batch_only'] - elif func_name in PROC_FUNC_DICT: - return PROC_FUNC_DICT[func_name] - raise ValueError('No forward process function `{}` registered'.format(func_name)) +def get_forward_proc_func(key): + if key is None: + return FORWARD_PROC_FUNC_DICT['forward_batch_only'] + elif key in FORWARD_PROC_FUNC_DICT: + return FORWARD_PROC_FUNC_DICT[key] + raise ValueError('No forward process function `{}` registered'.format(key)) diff --git a/torchdistill/datasets/registry.py b/torchdistill/datasets/registry.py index 85806ee5..c2398266 100644 --- a/torchdistill/datasets/registry.py +++ b/torchdistill/datasets/registry.py @@ -1,17 +1,17 @@ from types import BuiltinFunctionType, BuiltinMethodType, FunctionType -import torch -import torchvision +from ..common import misc_util DATASET_DICT = dict() COLLATE_FUNC_DICT = dict() SAMPLE_LOADER_CLASS_DICT = dict() SAMPLE_LOADER_FUNC_DICT = dict() -BATCH_SAMPLER_CLASS_DICT = dict() -TRANSFORM_CLASS_DICT = dict() -WRAPPER_CLASS_DICT = dict() -DATASET_DICT.update(torchvision.datasets.__dict__) -BATCH_SAMPLER_CLASS_DICT.update(torch.utils.data.sampler.__dict__) +BATCH_SAMPLER_DICT = dict() +TRANSFORM_DICT = dict() +DATASET_WRAPPER_DICT = dict() + +DATASET_DICT.update(misc_util.get_classes_as_dict('torchvision.datasets')) +BATCH_SAMPLER_DICT.update(misc_util.get_classes_as_dict('torch.utils.data.sampler')) def register_dataset(arg=None, **kwargs): @@ -71,84 +71,84 @@ def _register_sample_loader_func(func): return _register_sample_loader_func -def register_batch_sampler_class(arg=None, **kwargs): - def _register_batch_sampler_class(cls): +def register_batch_sampler(arg=None, **kwargs): + def _register_batch_sampler(cls_or_func): key = kwargs.get('key') if key is None: - key = cls.__name__ + key = cls_or_func.__name__ - BATCH_SAMPLER_CLASS_DICT[key] = cls - return cls + BATCH_SAMPLER_DICT[key] = cls_or_func + return cls_or_func if callable(arg): - return _register_batch_sampler_class(arg) - return _register_batch_sampler_class + return _register_batch_sampler(arg) + return _register_batch_sampler -def register_transform_class(arg=None, **kwargs): - def _register_transform_class(cls): +def register_transform(arg=None, **kwargs): + def _register_transform(cls_or_func): key = kwargs.get('key') if key is None: - key = cls.__name__ + key = cls_or_func.__name__ - TRANSFORM_CLASS_DICT[key] = cls - return cls + TRANSFORM_DICT[key] = cls_or_func + return cls_or_func if callable(arg): - return _register_transform_class(arg) - return _register_transform_class + return _register_transform(arg) + return _register_transform def register_dataset_wrapper(arg=None, **kwargs): - def _register_dataset_wrapper(cls): + def _register_dataset_wrapper(cls_or_func): key = kwargs.get('key') if key is None: - key = cls.__name__ + key = cls_or_func.__name__ - WRAPPER_CLASS_DICT[key] = cls - return cls + DATASET_WRAPPER_DICT[key] = cls_or_func + return cls_or_func if callable(arg): return _register_dataset_wrapper(arg) return _register_dataset_wrapper -def get_collate_func(func_name): - if func_name is None: +def get_collate_func(key): + if key is None: return None - elif func_name in COLLATE_FUNC_DICT: - return COLLATE_FUNC_DICT[func_name] - raise ValueError('No collate function `{}` registered'.format(func_name)) + elif key in COLLATE_FUNC_DICT: + return COLLATE_FUNC_DICT[key] + raise ValueError('No collate function `{}` registered'.format(key)) -def get_sample_loader(obj_name, *args, **kwargs): - if obj_name is None: +def get_sample_loader(key, *args, **kwargs): + if key is None: return None - elif obj_name in SAMPLE_LOADER_CLASS_DICT: - return SAMPLE_LOADER_CLASS_DICT[obj_name](*args, **kwargs) - elif obj_name in SAMPLE_LOADER_FUNC_DICT: - return SAMPLE_LOADER_FUNC_DICT[obj_name] - raise ValueError('No sample loader `{}` registered.'.format(obj_name)) + elif key in SAMPLE_LOADER_CLASS_DICT: + return SAMPLE_LOADER_CLASS_DICT[key](*args, **kwargs) + elif key in SAMPLE_LOADER_FUNC_DICT: + return SAMPLE_LOADER_FUNC_DICT[key] + raise ValueError('No sample loader `{}` registered.'.format(key)) -def get_batch_sampler(class_name, *args, **kwargs): - if class_name is None: +def get_batch_sampler(key, *args, **kwargs): + if key is None: return None - if class_name not in BATCH_SAMPLER_CLASS_DICT and class_name != 'BatchSampler': - raise ValueError('No batch sampler `{}` registered.'.format(class_name)) + if key not in BATCH_SAMPLER_DICT and key != 'BatchSampler': + raise ValueError('No batch sampler `{}` registered.'.format(key)) - batch_sampler_cls = BATCH_SAMPLER_CLASS_DICT[class_name] + batch_sampler_cls = BATCH_SAMPLER_DICT[key] return batch_sampler_cls(*args, **kwargs) -def get_transform(obj_name, *args, **kwargs): - if obj_name in TRANSFORM_CLASS_DICT: - return TRANSFORM_CLASS_DICT[obj_name](*args, **kwargs) - raise ValueError('No transform `{}` registered.'.format(obj_name)) +def get_transform(key, *args, **kwargs): + if key in TRANSFORM_DICT: + return TRANSFORM_DICT[key](*args, **kwargs) + raise ValueError('No transform `{}` registered.'.format(key)) -def get_dataset_wrapper(class_name, *args, **kwargs): - if class_name not in WRAPPER_CLASS_DICT: - return WRAPPER_CLASS_DICT[class_name](*args, **kwargs) - raise ValueError('No dataset wrapper `{}` registered.'.format(class_name)) +def get_dataset_wrapper(key, *args, **kwargs): + if key not in DATASET_WRAPPER_DICT: + return DATASET_WRAPPER_DICT[key](*args, **kwargs) + raise ValueError('No dataset wrapper `{}` registered.'.format(key)) diff --git a/torchdistill/datasets/sampler.py b/torchdistill/datasets/sampler.py index c2cc8002..2a614e71 100644 --- a/torchdistill/datasets/sampler.py +++ b/torchdistill/datasets/sampler.py @@ -10,14 +10,14 @@ from torch.utils.data.sampler import BatchSampler, Sampler from torch.utils.model_zoo import tqdm -from .registry import register_batch_sampler_class +from .registry import register_batch_sampler from ..common.constant import def_logger from ..datasets.wrapper import BaseDatasetWrapper logger = def_logger.getChild(__name__) -@register_batch_sampler_class +@register_batch_sampler class GroupedBatchSampler(BatchSampler): """ Wraps another sampler to yield a mini-batch of indices. diff --git a/torchdistill/datasets/transform.py b/torchdistill/datasets/transform.py index 8696a8e8..1f66d04d 100644 --- a/torchdistill/datasets/transform.py +++ b/torchdistill/datasets/transform.py @@ -9,7 +9,7 @@ from torchvision.transforms import functional as F from torchvision.transforms.functional import InterpolationMode -from .registry import register_transform_class +from .registry import register_transform from ..common.constant import def_logger logger = def_logger.getChild(__name__) @@ -34,7 +34,7 @@ def pad_if_smaller(img, size, fill=0): return img -@register_transform_class +@register_transform class CustomCompose(object): def __init__(self, transforms): self.transforms = transforms @@ -45,7 +45,7 @@ def __call__(self, image, target): return image, target -@register_transform_class +@register_transform class CustomRandomResize(object): def __init__(self, min_size, max_size=None, square=False, jpeg_quality=None): self.min_size = min_size @@ -71,7 +71,7 @@ def __call__(self, image, target): return image, target -@register_transform_class +@register_transform class CustomRandomHorizontalFlip(object): def __init__(self, p): self.p = p @@ -83,7 +83,7 @@ def __call__(self, image, target): return image, target -@register_transform_class +@register_transform class CustomRandomCrop(object): def __init__(self, size): self.size = size @@ -97,7 +97,7 @@ def __call__(self, image, target): return image, target -@register_transform_class +@register_transform class CustomCenterCrop(object): def __init__(self, size): self.size = size @@ -108,7 +108,7 @@ def __call__(self, image, target): return image, target -@register_transform_class +@register_transform class CustomToTensor(object): def __call__(self, image, target): image = F.to_tensor(image) @@ -116,7 +116,7 @@ def __call__(self, image, target): return image, target -@register_transform_class +@register_transform class CustomNormalize(object): def __init__(self, mean, std): self.mean = mean @@ -127,7 +127,7 @@ def __call__(self, image, target): return image, target -@register_transform_class +@register_transform class WrappedRandomResizedCrop(RandomResizedCrop): def __init__(self, interpolation=None, **kwargs): if isinstance(interpolation, str): @@ -135,7 +135,7 @@ def __init__(self, interpolation=None, **kwargs): super().__init__(**kwargs, interpolation=interpolation) -@register_transform_class +@register_transform class WrappedResize(Resize): def __init__(self, interpolation=None, **kwargs): if isinstance(interpolation, str): diff --git a/torchdistill/datasets/util.py b/torchdistill/datasets/util.py index 1e5c673c..867d94c7 100644 --- a/torchdistill/datasets/util.py +++ b/torchdistill/datasets/util.py @@ -10,14 +10,14 @@ from ..common.constant import def_logger from ..datasets.coco import ImageToTensor, Compose, CocoRandomHorizontalFlip, get_coco -from ..datasets.registry import DATASET_DICT, TRANSFORM_CLASS_DICT, \ +from ..datasets.registry import DATASET_DICT, TRANSFORM_DICT, \ get_collate_func, get_sample_loader, get_batch_sampler, get_dataset_wrapper from ..datasets.transform import CustomCompose from ..datasets.wrapper import default_idx2subpath, BaseDatasetWrapper, CacheableDataset logger = def_logger.getChild(__name__) -TRANSFORM_CLASS_DICT.update(torchvision.transforms.__dict__) +TRANSFORM_DICT.update(torchvision.transforms.__dict__) def load_coco_dataset(img_dir_path, ann_file_path, annotated_only, random_horizontal_flip=None, is_segment=False, @@ -36,7 +36,7 @@ def build_transform(transform_params_config, compose_cls=None): return None if isinstance(compose_cls, str): - compose_cls = TRANSFORM_CLASS_DICT[compose_cls] + compose_cls = TRANSFORM_DICT[compose_cls] component_list = list() if isinstance(transform_params_config, dict): @@ -46,7 +46,7 @@ def build_transform(transform_params_config, compose_cls=None): if params_config is None: params_config = dict() - component = TRANSFORM_CLASS_DICT[component_config['type']](**params_config) + component = TRANSFORM_DICT[component_config['type']](**params_config) component_list.append(component) else: for component_config in transform_params_config: @@ -54,7 +54,7 @@ def build_transform(transform_params_config, compose_cls=None): if params_config is None: params_config = dict() - component = TRANSFORM_CLASS_DICT[component_config['type']](**params_config) + component = TRANSFORM_DICT[component_config['type']](**params_config) component_list.append(component) return torchvision.transforms.Compose(component_list) if compose_cls is None else compose_cls(component_list) diff --git a/torchdistill/losses/registry.py b/torchdistill/losses/registry.py index e2c06d12..8a2de484 100644 --- a/torchdistill/losses/registry.py +++ b/torchdistill/losses/registry.py @@ -1,21 +1,21 @@ from ..common import misc_util LOSS_DICT = misc_util.get_classes_as_dict('torch.nn.modules.loss') -CUSTOM_LOSS_CLASS_DICT = dict() -LOSS_WRAPPER_CLASS_DICT = dict() -SINGLE_LOSS_CLASS_DICT = dict() +CUSTOM_LOSS_DICT = dict() +LOSS_WRAPPER_DICT = dict() +SINGLE_LOSS_DICT = dict() ORG_LOSS_LIST = list() FUNC2EXTRACT_ORG_OUTPUT_DICT = dict() def register_custom_loss(arg=None, **kwargs): - def _register_custom_loss(cls): + def _register_custom_loss(cls_or_func): key = kwargs.get('key') if key is None: - key = cls.__name__ + key = cls_or_func.__name__ - CUSTOM_LOSS_CLASS_DICT[key] = cls - return cls + CUSTOM_LOSS_DICT[key] = cls_or_func + return cls_or_func if callable(arg): return _register_custom_loss(arg) @@ -23,13 +23,13 @@ def _register_custom_loss(cls): def register_loss_wrapper(arg=None, **kwargs): - def _register_loss_wrapper(cls): + def _register_loss_wrapper(cls_or_func): key = kwargs.get('key') if key is None: - key = cls.__name__ + key = cls_or_func.__name__ - LOSS_WRAPPER_CLASS_DICT[key] = cls - return cls + LOSS_WRAPPER_DICT[key] = cls_or_func + return cls_or_func if callable(arg): return _register_loss_wrapper(arg) @@ -37,13 +37,13 @@ def _register_loss_wrapper(cls): def register_single_loss(arg=None, **kwargs): - def _register_single_loss(cls): + def _register_single_loss(cls_or_func): key = kwargs.get('key') if key is None: - key = cls.__name__ + key = cls_or_func.__name__ - SINGLE_LOSS_CLASS_DICT[key] = cls - return cls + SINGLE_LOSS_DICT[key] = cls_or_func + return cls_or_func if callable(arg): return _register_single_loss(arg) @@ -51,14 +51,14 @@ def _register_single_loss(cls): def register_org_loss(arg=None, **kwargs): - def _register_org_loss(cls): + def _register_org_loss(cls_or_func): key = kwargs.get('key') if key is None: - key = cls.__name__ + key = cls_or_func.__name__ - SINGLE_LOSS_CLASS_DICT[key] = cls - ORG_LOSS_LIST.append(cls) - return cls + SINGLE_LOSS_DICT[key] = cls_or_func + ORG_LOSS_LIST.append(cls_or_func) + return cls_or_func if callable(arg): return _register_org_loss(arg) @@ -79,43 +79,43 @@ def _register_func2extract_org_output(func): return _register_func2extract_org_output -def get_loss(loss_type, param_dict=None, **kwargs): +def get_loss(key, param_dict=None, **kwargs): if param_dict is None: param_dict = dict() - lower_loss_type = loss_type.lower() + lower_loss_type = key.lower() if lower_loss_type in LOSS_DICT: return LOSS_DICT[lower_loss_type](**param_dict, **kwargs) - raise ValueError('No loss `{}` registered'.format(loss_type)) + raise ValueError('No loss `{}` registered'.format(key)) def get_custom_loss(criterion_config): criterion_type = criterion_config['type'] - if criterion_type in CUSTOM_LOSS_CLASS_DICT: - return CUSTOM_LOSS_CLASS_DICT[criterion_type](criterion_config) + if criterion_type in CUSTOM_LOSS_DICT: + return CUSTOM_LOSS_DICT[criterion_type](criterion_config) raise ValueError('No custom loss `{}` registered'.format(criterion_type)) def get_loss_wrapper(single_loss, params_config, wrapper_config): wrapper_type = wrapper_config.get('type', None) if wrapper_type is None: - return LOSS_WRAPPER_CLASS_DICT['SimpleLossWrapper'](single_loss, params_config) - elif wrapper_type in LOSS_WRAPPER_CLASS_DICT: - return LOSS_WRAPPER_CLASS_DICT[wrapper_type](single_loss, params_config, **wrapper_config.get('params', dict())) + return LOSS_WRAPPER_DICT['SimpleLossWrapper'](single_loss, params_config) + elif wrapper_type in LOSS_WRAPPER_DICT: + return LOSS_WRAPPER_DICT[wrapper_type](single_loss, params_config, **wrapper_config.get('params', dict())) raise ValueError('No loss wrapper `{}` registered'.format(wrapper_type)) def get_single_loss(single_criterion_config, params_config=None): loss_type = single_criterion_config['type'] - single_loss = SINGLE_LOSS_CLASS_DICT[loss_type](**single_criterion_config['params']) \ - if loss_type in SINGLE_LOSS_CLASS_DICT else get_loss(loss_type, single_criterion_config['params']) + single_loss = SINGLE_LOSS_DICT[loss_type](**single_criterion_config['params']) \ + if loss_type in SINGLE_LOSS_DICT else get_loss(loss_type, single_criterion_config['params']) if params_config is None: return single_loss return get_loss_wrapper(single_loss, params_config, params_config.get('wrapper', dict())) -def get_func2extract_org_output(func_name): - if func_name is None: - func_name = 'extract_simple_org_loss' - if func_name in FUNC2EXTRACT_ORG_OUTPUT_DICT: - return FUNC2EXTRACT_ORG_OUTPUT_DICT[func_name] - raise ValueError('No function to extract original output `{}` registered'.format(func_name)) +def get_func2extract_org_output(key): + if key is None: + key = 'extract_simple_org_loss' + if key in FUNC2EXTRACT_ORG_OUTPUT_DICT: + return FUNC2EXTRACT_ORG_OUTPUT_DICT[key] + raise ValueError('No function to extract original output `{}` registered'.format(key)) diff --git a/torchdistill/models/__init__.py b/torchdistill/models/__init__.py index 66deb076..17abb72e 100644 --- a/torchdistill/models/__init__.py +++ b/torchdistill/models/__init__.py @@ -1,11 +1,11 @@ -from .registry import ADAPTATION_CLASS_DICT, SPECIAL_CLASS_DICT +from .registry import ADAPTATION_MODULE_DICT, SPECIAL_MODULE_DICT from .classification import CLASSIFICATION_MODEL_FUNC_DICT from .custom import CUSTOM_MODEL_CLASS_DICT, CUSTOM_MODEL_FUNC_DICT MODEL_DICT = dict() -MODEL_DICT.update(ADAPTATION_CLASS_DICT) -MODEL_DICT.update(SPECIAL_CLASS_DICT) +MODEL_DICT.update(ADAPTATION_MODULE_DICT) +MODEL_DICT.update(SPECIAL_MODULE_DICT) MODEL_DICT.update(CUSTOM_MODEL_CLASS_DICT) MODEL_DICT.update(CUSTOM_MODEL_FUNC_DICT) MODEL_DICT.update(CLASSIFICATION_MODEL_FUNC_DICT) diff --git a/torchdistill/models/custom/bottleneck/registry.py b/torchdistill/models/custom/bottleneck/registry.py index 7b1ee772..74a77c29 100644 --- a/torchdistill/models/custom/bottleneck/registry.py +++ b/torchdistill/models/custom/bottleneck/registry.py @@ -5,15 +5,24 @@ logger = def_logger.getChild(__name__) -def register_bottleneck_processor(cls): - BOTTLENECK_PROCESSOR_DICT[cls.__name__] = cls - return cls +def register_bottleneck_processor(arg=None, **kwargs): + def _register_bottleneck_processor(cls_or_func): + key = kwargs.get('key') + if key is None: + key = cls_or_func.__name__ + BOTTLENECK_PROCESSOR_DICT[key] = cls_or_func + return cls_or_func -def get_bottleneck_processor(class_name, *args, **kwargs): - if class_name not in BOTTLENECK_PROCESSOR_DICT: - logger.info('No bottleneck processor called `{}` is registered.'.format(class_name)) + if callable(arg): + return _register_bottleneck_processor(arg) + return _register_bottleneck_processor + + +def get_bottleneck_processor(key, *args, **kwargs): + if key not in BOTTLENECK_PROCESSOR_DICT: + logger.info('No bottleneck processor called `{}` is registered.'.format(key)) return None - instance = BOTTLENECK_PROCESSOR_DICT[class_name](*args, **kwargs) + instance = BOTTLENECK_PROCESSOR_DICT[key](*args, **kwargs) return instance diff --git a/torchdistill/models/registry.py b/torchdistill/models/registry.py index 8ddba145..18727f4c 100644 --- a/torchdistill/models/registry.py +++ b/torchdistill/models/registry.py @@ -1,12 +1,12 @@ import torch -from torch import nn +from ..common import misc_util MODEL_CLASS_DICT = dict() MODEL_FUNC_DICT = dict() -ADAPTATION_CLASS_DICT = dict() -SPECIAL_CLASS_DICT = dict() -MODULE_CLASS_DICT = nn.__dict__ +ADAPTATION_MODULE_DICT = dict() +SPECIAL_MODULE_DICT = dict() +MODULE_DICT = misc_util.get_classes_as_dict('torch.nn') def register_model_class(arg=None, **kwargs): @@ -38,13 +38,13 @@ def _register_model_func(func): def register_adaptation_module(arg=None, **kwargs): - def _register_adaptation_module(cls): + def _register_adaptation_module(cls_or_func): key = kwargs.get('key') if key is None: - key = cls.__name__ + key = cls_or_func.__name__ - ADAPTATION_CLASS_DICT[key] = cls - return cls + ADAPTATION_MODULE_DICT[key] = cls_or_func + return cls_or_func if callable(arg): return _register_adaptation_module(arg) @@ -52,38 +52,38 @@ def _register_adaptation_module(cls): def register_special_module(arg=None, **kwargs): - def _register_special_module(cls): + def _register_special_module(cls_or_func): key = kwargs.get('key') if key is None: - key = cls.__name__ + key = cls_or_func.__name__ - SPECIAL_CLASS_DICT[key] = cls - return cls + SPECIAL_MODULE_DICT[key] = cls_or_func + return cls_or_func if callable(arg): return _register_special_module(arg) return _register_special_module -def get_model(model_name, repo_or_dir=None, **kwargs): - if model_name in MODEL_CLASS_DICT: - return MODEL_CLASS_DICT[model_name](**kwargs) - elif model_name in MODEL_FUNC_DICT: - return MODEL_FUNC_DICT[model_name](**kwargs) +def get_model(key, repo_or_dir=None, **kwargs): + if key in MODEL_CLASS_DICT: + return MODEL_CLASS_DICT[key](**kwargs) + elif key in MODEL_FUNC_DICT: + return MODEL_FUNC_DICT[key](**kwargs) elif repo_or_dir is not None: - return torch.hub.load(repo_or_dir, model_name, **kwargs) - raise ValueError('model_name `{}` is not expected'.format(model_name)) + return torch.hub.load(repo_or_dir, key, **kwargs) + raise ValueError('model_name `{}` is not expected'.format(key)) -def get_adaptation_module(class_name, *args, **kwargs): - if class_name in ADAPTATION_CLASS_DICT: - return ADAPTATION_CLASS_DICT[class_name](*args, **kwargs) - elif class_name in MODULE_CLASS_DICT: - return MODULE_CLASS_DICT[class_name](*args, **kwargs) - raise ValueError('No adaptation module `{}` registered'.format(class_name)) +def get_adaptation_module(key, *args, **kwargs): + if key in ADAPTATION_MODULE_DICT: + return ADAPTATION_MODULE_DICT[key](*args, **kwargs) + elif key in MODULE_DICT: + return MODULE_DICT[key](*args, **kwargs) + raise ValueError('No adaptation module `{}` registered'.format(key)) -def get_special_module(class_name, *args, **kwargs): - if class_name in SPECIAL_CLASS_DICT: - return SPECIAL_CLASS_DICT[class_name](*args, **kwargs) - raise ValueError('No special module `{}` registered'.format(class_name)) +def get_special_module(key, *args, **kwargs): + if key in SPECIAL_MODULE_DICT: + return SPECIAL_MODULE_DICT[key](*args, **kwargs) + raise ValueError('No special module `{}` registered'.format(key)) diff --git a/torchdistill/optim/registry.py b/torchdistill/optim/registry.py index 9314c314..c2362706 100644 --- a/torchdistill/optim/registry.py +++ b/torchdistill/optim/registry.py @@ -34,12 +34,12 @@ def _register_scheduler(cls_or_func): return _register_scheduler -def get_optimizer(module, optim_type, param_dict=None, filters_params=True, **kwargs): +def get_optimizer(module, key, param_dict=None, filters_params=True, **kwargs): if param_dict is None: param_dict = dict() is_module = isinstance(module, nn.Module) - lower_optim_type = optim_type.lower() + lower_optim_type = key.lower() if lower_optim_type in OPTIM_DICT: optim_cls_or_func = OPTIM_DICT[lower_optim_type] if is_module and filters_params: @@ -47,14 +47,14 @@ def get_optimizer(module, optim_type, param_dict=None, filters_params=True, **kw updatable_params = [p for p in params if p.requires_grad] return optim_cls_or_func(updatable_params, **param_dict, **kwargs) return optim_cls_or_func(module, **param_dict, **kwargs) - raise ValueError('No optimizer `{}` registered'.format(optim_type)) + raise ValueError('No optimizer `{}` registered'.format(key)) -def get_scheduler(optimizer, scheduler_type, param_dict=None, **kwargs): +def get_scheduler(optimizer, key, param_dict=None, **kwargs): if param_dict is None: param_dict = dict() - lower_scheduler_type = scheduler_type.lower() + lower_scheduler_type = key.lower() if lower_scheduler_type in SCHEDULER_DICT: return SCHEDULER_DICT[lower_scheduler_type](optimizer, **param_dict, **kwargs) - raise ValueError('No scheduler `{}` registered'.format(scheduler_type)) + raise ValueError('No scheduler `{}` registered'.format(key))