diff --git a/torchdistill/core/interfaces/registry.py b/torchdistill/core/interfaces/registry.py index 4995c9e6..06eb3551 100644 --- a/torchdistill/core/interfaces/registry.py +++ b/torchdistill/core/interfaces/registry.py @@ -247,9 +247,7 @@ def get_forward_proc_func(key): :return: registered forward process function. :rtype: typing.Callable """ - if key is None: - return FORWARD_PROC_FUNC_DICT['forward_batch_only'] - elif key in FORWARD_PROC_FUNC_DICT: + if key in FORWARD_PROC_FUNC_DICT: return FORWARD_PROC_FUNC_DICT[key] raise ValueError('No forward process function `{}` registered'.format(key))