From f1f4d61acba7e8d532798697079b102055fce44a Mon Sep 17 00:00:00 2001 From: Eugene Khvedchenya Date: Tue, 16 Jan 2024 16:35:03 +0200 Subject: [PATCH] Added test_enum_factory unit test and fixed implementation of TypeEnumFactory (#1765) * Added test_enum_factory * Added test_enum_factory --- .../common/exceptions/factory_exceptions.py | 1 + .../common/factories/type_factory.py | 2 +- src/super_gradients/training/utils/utils.py | 4 +-- tests/unit_tests/factories_test.py | 25 +++++++++++++++++++ 4 files changed, 29 insertions(+), 3 deletions(-) diff --git a/src/super_gradients/common/exceptions/factory_exceptions.py b/src/super_gradients/common/exceptions/factory_exceptions.py index ede5baa1da..87a7ab342f 100644 --- a/src/super_gradients/common/exceptions/factory_exceptions.py +++ b/src/super_gradients/common/exceptions/factory_exceptions.py @@ -12,6 +12,7 @@ class UnknownTypeException(Exception): """ def __init__(self, unknown_type: str, choices: List, message: str = None): + choices = [str(choice) for choice in choices] # Ensure all choices are strings message = message or f"Unknown object type: {unknown_type} in configuration. valid types are: {choices}" err_msg_tip = "" if isinstance(unknown_type, str): diff --git a/src/super_gradients/common/factories/type_factory.py b/src/super_gradients/common/factories/type_factory.py index 0409da1bc0..b56b92a33c 100644 --- a/src/super_gradients/common/factories/type_factory.py +++ b/src/super_gradients/common/factories/type_factory.py @@ -20,7 +20,7 @@ def __init__(self, type_dict: Dict[str, type]): @classmethod def from_enum_cls(cls, enum_cls: Type[Enum]): - return cls({entity.name: entity.value for entity in enum_cls}) + return cls({entity.value: entity for entity in enum_cls}) def get(self, conf: Union[str, type]): """ diff --git a/src/super_gradients/training/utils/utils.py b/src/super_gradients/training/utils/utils.py index dd404b88c3..939c4b8171 100755 --- a/src/super_gradients/training/utils/utils.py +++ b/src/super_gradients/training/utils/utils.py @@ -249,7 +249,7 @@ def fuzzy_keys(params: Mapping) -> List[str]: :param params: Mapping, the mapping containing the keys to be returned. :return: List[str], list of keys as discussed above. """ - return [fuzzy_str(s) for s in params.keys()] + return [fuzzy_str(str(s)) for s in params.keys()] def fuzzy_str(s: str): @@ -276,7 +276,7 @@ def get_fuzzy_mapping_param(name: str, params: Mapping): :param params: Mapping, the mapping containing param. :return: """ - fuzzy_params = {fuzzy_str(key): params[key] for key in params.keys()} + fuzzy_params = {fuzzy_str(str(key)): params[key] for key in params.keys()} return fuzzy_params[fuzzy_str(name)] diff --git a/tests/unit_tests/factories_test.py b/tests/unit_tests/factories_test.py index e3b7babba0..f5364d6e52 100644 --- a/tests/unit_tests/factories_test.py +++ b/tests/unit_tests/factories_test.py @@ -3,8 +3,11 @@ import torch from super_gradients import Trainer +from super_gradients.common import StrictLoad from super_gradients.common.decorators.factory_decorator import resolve_param +from super_gradients.common.exceptions import UnknownTypeException from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory +from super_gradients.common.factories.type_factory import TypeFactory from super_gradients.common.object_names import Models from super_gradients.training import models from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader @@ -86,6 +89,28 @@ def __init__(self, activation_in_head): model = DummyModel(activation_in_head=nn.LeakyReLU) self.assertIsInstance(model.activation_in_head, nn.LeakyReLU) + def test_enum_factory(self): + @resolve_param("v", TypeFactory.from_enum_cls(StrictLoad)) + def get_enum_value_from_string(v): + return v + + self.assertEqual(StrictLoad.ON, get_enum_value_from_string(StrictLoad.ON)) + self.assertEqual(StrictLoad.ON, get_enum_value_from_string(True)) + self.assertEqual(StrictLoad.ON, get_enum_value_from_string("True")) + + self.assertEqual(StrictLoad.OFF, get_enum_value_from_string(StrictLoad.OFF)) + self.assertEqual(StrictLoad.OFF, get_enum_value_from_string(False)) + self.assertEqual(StrictLoad.OFF, get_enum_value_from_string("False")) + + self.assertEqual(StrictLoad.KEY_MATCHING, get_enum_value_from_string(StrictLoad.KEY_MATCHING)) + self.assertEqual(StrictLoad.NO_KEY_MATCHING, get_enum_value_from_string(StrictLoad.NO_KEY_MATCHING)) + + self.assertEqual(StrictLoad.KEY_MATCHING, get_enum_value_from_string("KEY_MATCHING")) + self.assertEqual(StrictLoad.KEY_MATCHING, get_enum_value_from_string("key_matching")) + + with self.assertRaises(UnknownTypeException): + print(get_enum_value_from_string("ABCABABABA")) + if __name__ == "__main__": unittest.main()