Skip to content

Commit

Permalink
Added test_enum_factory unit test and fixed implementation of TypeEnu…
Browse files Browse the repository at this point in the history
…mFactory (#1765)

* Added test_enum_factory

* Added test_enum_factory
  • Loading branch information
BloodAxe authored Jan 16, 2024
1 parent ce7a357 commit f1f4d61
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ class UnknownTypeException(Exception):
"""

def __init__(self, unknown_type: str, choices: List, message: str = None):
choices = [str(choice) for choice in choices] # Ensure all choices are strings
message = message or f"Unknown object type: {unknown_type} in configuration. valid types are: {choices}"
err_msg_tip = ""
if isinstance(unknown_type, str):
Expand Down
2 changes: 1 addition & 1 deletion src/super_gradients/common/factories/type_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, type_dict: Dict[str, type]):

@classmethod
def from_enum_cls(cls, enum_cls: Type[Enum]):
return cls({entity.name: entity.value for entity in enum_cls})
return cls({entity.value: entity for entity in enum_cls})

def get(self, conf: Union[str, type]):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/super_gradients/training/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ def fuzzy_keys(params: Mapping) -> List[str]:
:param params: Mapping, the mapping containing the keys to be returned.
:return: List[str], list of keys as discussed above.
"""
return [fuzzy_str(s) for s in params.keys()]
return [fuzzy_str(str(s)) for s in params.keys()]


def fuzzy_str(s: str):
Expand All @@ -276,7 +276,7 @@ def get_fuzzy_mapping_param(name: str, params: Mapping):
:param params: Mapping, the mapping containing param.
:return:
"""
fuzzy_params = {fuzzy_str(key): params[key] for key in params.keys()}
fuzzy_params = {fuzzy_str(str(key)): params[key] for key in params.keys()}
return fuzzy_params[fuzzy_str(name)]


Expand Down
25 changes: 25 additions & 0 deletions tests/unit_tests/factories_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import torch

from super_gradients import Trainer
from super_gradients.common import StrictLoad
from super_gradients.common.decorators.factory_decorator import resolve_param
from super_gradients.common.exceptions import UnknownTypeException
from super_gradients.common.factories.activations_type_factory import ActivationsTypeFactory
from super_gradients.common.factories.type_factory import TypeFactory
from super_gradients.common.object_names import Models
from super_gradients.training import models
from super_gradients.training.dataloaders.dataloaders import classification_test_dataloader
Expand Down Expand Up @@ -86,6 +89,28 @@ def __init__(self, activation_in_head):
model = DummyModel(activation_in_head=nn.LeakyReLU)
self.assertIsInstance(model.activation_in_head, nn.LeakyReLU)

def test_enum_factory(self):
@resolve_param("v", TypeFactory.from_enum_cls(StrictLoad))
def get_enum_value_from_string(v):
return v

self.assertEqual(StrictLoad.ON, get_enum_value_from_string(StrictLoad.ON))
self.assertEqual(StrictLoad.ON, get_enum_value_from_string(True))
self.assertEqual(StrictLoad.ON, get_enum_value_from_string("True"))

self.assertEqual(StrictLoad.OFF, get_enum_value_from_string(StrictLoad.OFF))
self.assertEqual(StrictLoad.OFF, get_enum_value_from_string(False))
self.assertEqual(StrictLoad.OFF, get_enum_value_from_string("False"))

self.assertEqual(StrictLoad.KEY_MATCHING, get_enum_value_from_string(StrictLoad.KEY_MATCHING))
self.assertEqual(StrictLoad.NO_KEY_MATCHING, get_enum_value_from_string(StrictLoad.NO_KEY_MATCHING))

self.assertEqual(StrictLoad.KEY_MATCHING, get_enum_value_from_string("KEY_MATCHING"))
self.assertEqual(StrictLoad.KEY_MATCHING, get_enum_value_from_string("key_matching"))

with self.assertRaises(UnknownTypeException):
print(get_enum_value_from_string("ABCABABABA"))


if __name__ == "__main__":
unittest.main()

0 comments on commit f1f4d61

Please sign in to comment.