Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename registry dicts and arguments for registry key #269

Merged
merged 1 commit into from
Dec 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 44 additions & 44 deletions tests/registry_test.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from unittest import TestCase

from torchdistill.core.registry import register_forward_proc_func, PROC_FUNC_DICT
from torchdistill.core.registry import register_forward_proc_func, FORWARD_PROC_FUNC_DICT
from torchdistill.datasets.registry import register_dataset, register_collate_func, register_sample_loader_class, \
register_sample_loader_func, register_batch_sampler_class, register_transform_class, register_dataset_wrapper, \
DATASET_DICT, COLLATE_FUNC_DICT, SAMPLE_LOADER_CLASS_DICT, SAMPLE_LOADER_FUNC_DICT, BATCH_SAMPLER_CLASS_DICT, \
TRANSFORM_CLASS_DICT, WRAPPER_CLASS_DICT
from torchdistill.losses.registry import register_custom_loss, CUSTOM_LOSS_CLASS_DICT, register_loss_wrapper, \
register_sample_loader_func, register_batch_sampler, register_transform, register_dataset_wrapper, \
DATASET_DICT, COLLATE_FUNC_DICT, SAMPLE_LOADER_CLASS_DICT, SAMPLE_LOADER_FUNC_DICT, BATCH_SAMPLER_DICT, \
TRANSFORM_DICT, DATASET_WRAPPER_DICT
from torchdistill.losses.registry import register_custom_loss, CUSTOM_LOSS_DICT, register_loss_wrapper, \
register_single_loss, register_org_loss, \
LOSS_WRAPPER_CLASS_DICT, SINGLE_LOSS_CLASS_DICT, ORG_LOSS_LIST, register_func2extract_org_output, \
LOSS_WRAPPER_DICT, SINGLE_LOSS_DICT, ORG_LOSS_LIST, register_func2extract_org_output, \
FUNC2EXTRACT_ORG_OUTPUT_DICT
from torchdistill.models.registry import get_model, register_adaptation_module, ADAPTATION_CLASS_DICT, \
from torchdistill.models.registry import get_model, register_adaptation_module, ADAPTATION_MODULE_DICT, \
register_model_class, register_model_func, MODEL_CLASS_DICT, MODEL_FUNC_DICT, register_special_module, \
SPECIAL_CLASS_DICT
SPECIAL_MODULE_DICT
from torchdistill.optim.registry import register_optimizer, register_scheduler, OPTIM_DICT, SCHEDULER_DICT


Expand Down Expand Up @@ -51,20 +51,20 @@ def test_register_forward_proc_func(self):
def test_forward_proc0(model, batch):
return model(batch)

assert PROC_FUNC_DICT['test_forward_proc0'] == test_forward_proc0
assert FORWARD_PROC_FUNC_DICT['test_forward_proc0'] == test_forward_proc0

@register_forward_proc_func()
def test_forward_proc1(model, batch):
return model(batch)

assert PROC_FUNC_DICT['test_forward_proc1'] == test_forward_proc1
assert FORWARD_PROC_FUNC_DICT['test_forward_proc1'] == test_forward_proc1
random_name = 'custom_forward_proc_name2'

@register_forward_proc_func(key=random_name)
def test_forward_proc2(model, batch, label):
return model(batch, label)

assert PROC_FUNC_DICT[random_name] == test_forward_proc2
assert FORWARD_PROC_FUNC_DICT[random_name] == test_forward_proc2

def test_register_collate_func(self):
@register_collate_func
Expand Down Expand Up @@ -129,158 +129,158 @@ def test_sample_loader2(batch, label):
assert SAMPLE_LOADER_FUNC_DICT[random_name] == test_sample_loader2

def test_register_sampler(self):
@register_batch_sampler_class
@register_batch_sampler
class TestBatchSampler0(object):
def __init__(self):
self.name = 'test0'

assert BATCH_SAMPLER_CLASS_DICT['TestBatchSampler0'] == TestBatchSampler0
assert BATCH_SAMPLER_DICT['TestBatchSampler0'] == TestBatchSampler0

@register_batch_sampler_class()
@register_batch_sampler()
class TestBatchSampler1(object):
def __init__(self):
self.name = 'test1'

assert BATCH_SAMPLER_CLASS_DICT['TestBatchSampler1'] == TestBatchSampler1
assert BATCH_SAMPLER_DICT['TestBatchSampler1'] == TestBatchSampler1
random_name = 'custom_batch_sampler_class_name2'

@register_batch_sampler_class(key=random_name)
@register_batch_sampler(key=random_name)
class TestBatchSampler2(object):
def __init__(self):
self.name = 'test2'

assert BATCH_SAMPLER_CLASS_DICT[random_name] == TestBatchSampler2
assert BATCH_SAMPLER_DICT[random_name] == TestBatchSampler2

def test_register_transform(self):
@register_transform_class()
@register_transform()
class TestTransform0(object):
def __init__(self):
self.name = 'test0'

assert TRANSFORM_CLASS_DICT['TestTransform0'] == TestTransform0
assert TRANSFORM_DICT['TestTransform0'] == TestTransform0

@register_transform_class()
@register_transform()
class TestTransform1(object):
def __init__(self):
self.name = 'test1'

assert TRANSFORM_CLASS_DICT['TestTransform1'] == TestTransform1
assert TRANSFORM_DICT['TestTransform1'] == TestTransform1
random_name = 'custom_transform_class_name2'

@register_transform_class(key=random_name)
@register_transform(key=random_name)
class TestTransform2(object):
def __init__(self):
self.name = 'test2'

assert TRANSFORM_CLASS_DICT[random_name] == TestTransform2
assert TRANSFORM_DICT[random_name] == TestTransform2

def test_register_dataset_wrapper(self):
@register_dataset_wrapper
class TestDatasetWrapper0(object):
def __init__(self):
self.name = 'test0'

assert WRAPPER_CLASS_DICT['TestDatasetWrapper0'] == TestDatasetWrapper0
assert DATASET_WRAPPER_DICT['TestDatasetWrapper0'] == TestDatasetWrapper0

@register_dataset_wrapper()
class TestDatasetWrapper1(object):
def __init__(self):
self.name = 'test1'

assert WRAPPER_CLASS_DICT['TestDatasetWrapper1'] == TestDatasetWrapper1
assert DATASET_WRAPPER_DICT['TestDatasetWrapper1'] == TestDatasetWrapper1
random_name = 'custom_dataset_wrapper_class_name2'

@register_dataset_wrapper(key=random_name)
class TestDatasetWrapper2(object):
def __init__(self):
self.name = 'test2'

assert WRAPPER_CLASS_DICT[random_name] == TestDatasetWrapper2
assert DATASET_WRAPPER_DICT[random_name] == TestDatasetWrapper2

def test_register_custom_loss_class(self):
@register_custom_loss
class TestCustomLoss0(object):
def __init__(self):
self.name = 'test0'

assert CUSTOM_LOSS_CLASS_DICT['TestCustomLoss0'] == TestCustomLoss0
assert CUSTOM_LOSS_DICT['TestCustomLoss0'] == TestCustomLoss0

@register_custom_loss()
class TestCustomLoss1(object):
def __init__(self):
self.name = 'test1'

assert CUSTOM_LOSS_CLASS_DICT['TestCustomLoss1'] == TestCustomLoss1
assert CUSTOM_LOSS_DICT['TestCustomLoss1'] == TestCustomLoss1
random_name = 'custom_loss_class_name2'

@register_custom_loss(key=random_name)
class TestCustomLoss2(object):
def __init__(self):
self.name = 'test2'

assert CUSTOM_LOSS_CLASS_DICT[random_name] == TestCustomLoss2
assert CUSTOM_LOSS_DICT[random_name] == TestCustomLoss2

def test_register_loss_wrapper_class(self):
@register_loss_wrapper
class TestLossWrapper0(object):
def __init__(self):
self.name = 'test0'

assert LOSS_WRAPPER_CLASS_DICT['TestLossWrapper0'] == TestLossWrapper0
assert LOSS_WRAPPER_DICT['TestLossWrapper0'] == TestLossWrapper0

@register_loss_wrapper()
class TestLossWrapper1(object):
def __init__(self):
self.name = 'test1'

assert LOSS_WRAPPER_CLASS_DICT['TestLossWrapper1'] == TestLossWrapper1
assert LOSS_WRAPPER_DICT['TestLossWrapper1'] == TestLossWrapper1
random_name = 'custom_loss_wrapper_class_name2'

@register_loss_wrapper(key=random_name)
class TestLossWrapper2(object):
def __init__(self):
self.name = 'test2'

assert LOSS_WRAPPER_CLASS_DICT[random_name] == TestLossWrapper2
assert LOSS_WRAPPER_DICT[random_name] == TestLossWrapper2

def test_register_single_loss(self):
@register_single_loss
class TestSingleLoss0(object):
def __init__(self):
self.name = 'test0'

assert SINGLE_LOSS_CLASS_DICT['TestSingleLoss0'] == TestSingleLoss0
assert SINGLE_LOSS_DICT['TestSingleLoss0'] == TestSingleLoss0

@register_single_loss()
class TestSingleLoss1(object):
def __init__(self):
self.name = 'test1'

assert SINGLE_LOSS_CLASS_DICT['TestSingleLoss1'] == TestSingleLoss1
assert SINGLE_LOSS_DICT['TestSingleLoss1'] == TestSingleLoss1
random_name = 'custom_single_loss_class_name2'

@register_single_loss(key=random_name)
class TestSingleLoss2(object):
def __init__(self):
self.name = 'test2'

assert SINGLE_LOSS_CLASS_DICT[random_name] == TestSingleLoss2
assert SINGLE_LOSS_DICT[random_name] == TestSingleLoss2

def test_register_org_loss(self):
@register_org_loss
class TestOrgLoss0(object):
def __init__(self):
self.name = 'test0'

assert SINGLE_LOSS_CLASS_DICT['TestOrgLoss0'] == TestOrgLoss0
assert SINGLE_LOSS_DICT['TestOrgLoss0'] == TestOrgLoss0
assert TestOrgLoss0 in ORG_LOSS_LIST

@register_org_loss()
class TestOrgLoss1(object):
def __init__(self):
self.name = 'test1'

assert SINGLE_LOSS_CLASS_DICT['TestOrgLoss1'] == TestOrgLoss1
assert SINGLE_LOSS_DICT['TestOrgLoss1'] == TestOrgLoss1
assert TestOrgLoss1 in ORG_LOSS_LIST
random_name = 'custom_org_loss_class_name2'

Expand All @@ -289,7 +289,7 @@ class TestOrgLoss2(object):
def __init__(self):
self.name = 'test2'

assert SINGLE_LOSS_CLASS_DICT[random_name] == TestOrgLoss2
assert SINGLE_LOSS_DICT[random_name] == TestOrgLoss2
assert TestOrgLoss2 in ORG_LOSS_LIST

def test_func2extract_org_output(self):
Expand Down Expand Up @@ -364,22 +364,22 @@ class TestAdaptationModule0(object):
def __init__(self):
self.name = 'test0'

assert ADAPTATION_CLASS_DICT['TestAdaptationModule0'] == TestAdaptationModule0
assert ADAPTATION_MODULE_DICT['TestAdaptationModule0'] == TestAdaptationModule0

@register_adaptation_module()
class TestAdaptationModule1(object):
def __init__(self):
self.name = 'test1'

assert ADAPTATION_CLASS_DICT['TestAdaptationModule1'] == TestAdaptationModule1
assert ADAPTATION_MODULE_DICT['TestAdaptationModule1'] == TestAdaptationModule1
random_name = 'custom_adaptation_module_class_name2'

@register_adaptation_module(key=random_name)
class TestAdaptationModule2(object):
def __init__(self):
self.name = 'test2'

assert ADAPTATION_CLASS_DICT[random_name] == TestAdaptationModule2
assert ADAPTATION_MODULE_DICT[random_name] == TestAdaptationModule2

def test_register_model_class(self):
@register_model_class
Expand Down Expand Up @@ -430,19 +430,19 @@ class TestSpecialModule0(object):
def __init__(self):
self.name = 'test0'

assert SPECIAL_CLASS_DICT['TestSpecialModule0'] == TestSpecialModule0
assert SPECIAL_MODULE_DICT['TestSpecialModule0'] == TestSpecialModule0

@register_special_module()
class TestSpecialModule1(object):
def __init__(self):
self.name = 'test1'

assert SPECIAL_CLASS_DICT['TestSpecialModule1'] == TestSpecialModule1
assert SPECIAL_MODULE_DICT['TestSpecialModule1'] == TestSpecialModule1
random_name = 'custom_special_module_class_name2'

@register_special_module(key=random_name)
class TestSpecialModule2(object):
def __init__(self):
self.name = 'test2'

assert SPECIAL_CLASS_DICT[random_name] == TestSpecialModule2
assert SPECIAL_MODULE_DICT[random_name] == TestSpecialModule2
16 changes: 8 additions & 8 deletions torchdistill/core/registry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
PROC_FUNC_DICT = dict()
FORWARD_PROC_FUNC_DICT = dict()


def register_forward_proc_func(arg=None, **kwargs):
Expand All @@ -7,17 +7,17 @@ def _register_forward_proc_func(func):
if key is None:
key = func.__name__

PROC_FUNC_DICT[key] = func
FORWARD_PROC_FUNC_DICT[key] = func
return func

if callable(arg):
return _register_forward_proc_func(arg)
return _register_forward_proc_func


def get_forward_proc_func(func_name):
if func_name is None:
return PROC_FUNC_DICT['forward_batch_only']
elif func_name in PROC_FUNC_DICT:
return PROC_FUNC_DICT[func_name]
raise ValueError('No forward process function `{}` registered'.format(func_name))
def get_forward_proc_func(key):
if key is None:
return FORWARD_PROC_FUNC_DICT['forward_batch_only']
elif key in FORWARD_PROC_FUNC_DICT:
return FORWARD_PROC_FUNC_DICT[key]
raise ValueError('No forward process function `{}` registered'.format(key))
Loading