diff --git a/CHANGELOG.md b/CHANGELOG.md index c9596e72..1374db1a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,7 +23,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- Fixed `StrEnum.from_str` with source as key ([#99](https://github.com/Lightning-AI/utilities/pull/99)) +- Fixed `StrEnum.from_str` with source as key ( + [#99](https://github.com/Lightning-AI/utilities/pull/99), + [#102](https://github.com/Lightning-AI/utilities/pull/102) +) ## [0.6.0] - 2023-01-23 diff --git a/src/lightning_utilities/core/enums.py b/src/lightning_utilities/core/enums.py index 6c0953fa..e0c9b873 100644 --- a/src/lightning_utilities/core/enums.py +++ b/src/lightning_utilities/core/enums.py @@ -4,7 +4,7 @@ # import warnings from enum import Enum -from typing import Optional +from typing import List, Optional from typing_extensions import Literal @@ -19,46 +19,30 @@ class StrEnum(str, Enum): True >>> MySE.from_str("t-2", source="value") == MySE.t2 True + >>> MySE.from_str("t-2", source="value") + + >>> MySE.from_str("t-3", source="any") + Traceback (most recent call last): + ... + ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3. """ @classmethod - def from_str( - cls, value: str, source: Literal["key", "value", "any"] = "key", strict: bool = False - ) -> Optional["StrEnum"]: - """Create StrEnum from a sting matching the key or value. + def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> "StrEnum": + """Create ``StrEnum`` from a string matching the key or value. Args: value: matching string source: compare with: - - ``"key"``: validates only with Enum keys, typical alphanumeric with "_" - - ``"value"``: validates only with Enum values, could be any string - - ``"key"``: validates with any key or value, but key has priority - - strict: allow not matching string and returns None; if false raises exceptions + - ``"key"``: validates only from the enum keys, typical alphanumeric with "_" + - ``"value"``: validates only from the values, could be any string + - ``"any"``: validates with any key or value, but key has priority Raises: ValueError: - if requested string does not match any option based on selected source and use ``"strict=True"`` - UserWarning: - if requested string does not match any option based on selected source and use ``"strict=False"`` - - Example: - >>> class MySE(StrEnum): - ... t1 = "T-1" - ... t2 = "T-2" - >>> MySE.from_str("t-1", source="key") - >>> MySE.from_str("t-2", source="value") - - >>> MySE.from_str("t-3", source="any", strict=True) - Traceback (most recent call last): - ... - ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3. + if requested string does not match any option based on selected source. """ - allowed = cls._allowed_matches(source) - if strict and not any(enum_.lower() == value.lower() for enum_ in allowed): - raise ValueError(f"Invalid match: expected one of {allowed}, but got {value}.") - if source in ("key", "any"): for enum_key in cls.__members__.keys(): if enum_key.lower() == value.lower(): @@ -67,12 +51,20 @@ def from_str( for enum_key, enum_val in cls.__members__.items(): if enum_val == value: return cls[enum_key] + raise ValueError(f"Invalid match: expected one of {cls._allowed_matches(source)}, but got {value}.") - warnings.warn(UserWarning(f"Invalid string: expected one of {allowed}, but got {value}.")) + @classmethod + def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]: + try: + return cls.from_str(value, source) + except ValueError: + warnings.warn( + UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.") + ) return None @classmethod - def _allowed_matches(cls, source: str) -> list: + def _allowed_matches(cls, source: str) -> List[str]: keys, vals = [], [] for enum_key, enum_val in cls.__members__.items(): keys.append(enum_key) diff --git a/tests/unittests/core/__init__.py b/tests/unittests/core/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/unittests/core/test_apply_func.py b/tests/unittests/core/test_apply_func.py index 1dfc71ed..f5e0a5e2 100644 --- a/tests/unittests/core/test_apply_func.py +++ b/tests/unittests/core/test_apply_func.py @@ -5,7 +5,7 @@ from typing import Any, ClassVar, List, Optional import pytest -from unittests.core.mocks import torch +from unittests.mocks import torch from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections diff --git a/tests/unittests/core/test_enums.py b/tests/unittests/core/test_enums.py index 6ddba5ca..3bd3638c 100644 --- a/tests/unittests/core/test_enums.py +++ b/tests/unittests/core/test_enums.py @@ -47,9 +47,9 @@ class MyEnum(StrEnum): T2 = "t:2" assert MyEnum.from_str("T1", source="key") - assert MyEnum.from_str("T1", source="value") is None + assert MyEnum.try_from_str("T1", source="value") is None assert MyEnum.from_str("T1", source="any") - assert MyEnum.from_str("T:2", source="key") is None + assert MyEnum.try_from_str("T:2", source="key") is None assert MyEnum.from_str("T:2", source="value") assert MyEnum.from_str("T:2", source="any") diff --git a/tests/unittests/core/mocks.py b/tests/unittests/mocks.py similarity index 100% rename from tests/unittests/core/mocks.py rename to tests/unittests/mocks.py