Skip to content

Commit

Permalink
Refactor StrEnum.from_str (#102)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
  • Loading branch information
4 people authored Feb 14, 2023
1 parent 5fcd7c2 commit 78eb098
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 35 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed `StrEnum.from_str` with source as key ([#99](https://github.com/Lightning-AI/utilities/pull/99))
- Fixed `StrEnum.from_str` with source as key (
[#99](https://github.com/Lightning-AI/utilities/pull/99),
[#102](https://github.com/Lightning-AI/utilities/pull/102)
)


## [0.6.0] - 2023-01-23
Expand Down
54 changes: 23 additions & 31 deletions src/lightning_utilities/core/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#
import warnings
from enum import Enum
from typing import Optional
from typing import List, Optional

from typing_extensions import Literal

Expand All @@ -19,46 +19,30 @@ class StrEnum(str, Enum):
True
>>> MySE.from_str("t-2", source="value") == MySE.t2
True
>>> MySE.from_str("t-2", source="value")
<MySE.t2: 'T-2'>
>>> MySE.from_str("t-3", source="any")
Traceback (most recent call last):
...
ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
"""

@classmethod
def from_str(
cls, value: str, source: Literal["key", "value", "any"] = "key", strict: bool = False
) -> Optional["StrEnum"]:
"""Create StrEnum from a sting matching the key or value.
def from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> "StrEnum":
"""Create ``StrEnum`` from a string matching the key or value.
Args:
value: matching string
source: compare with:
- ``"key"``: validates only with Enum keys, typical alphanumeric with "_"
- ``"value"``: validates only with Enum values, could be any string
- ``"key"``: validates with any key or value, but key has priority
strict: allow not matching string and returns None; if false raises exceptions
- ``"key"``: validates only from the enum keys, typical alphanumeric with "_"
- ``"value"``: validates only from the values, could be any string
- ``"any"``: validates with any key or value, but key has priority
Raises:
ValueError:
if requested string does not match any option based on selected source and use ``"strict=True"``
UserWarning:
if requested string does not match any option based on selected source and use ``"strict=False"``
Example:
>>> class MySE(StrEnum):
... t1 = "T-1"
... t2 = "T-2"
>>> MySE.from_str("t-1", source="key")
>>> MySE.from_str("t-2", source="value")
<MySE.t2: 'T-2'>
>>> MySE.from_str("t-3", source="any", strict=True)
Traceback (most recent call last):
...
ValueError: Invalid match: expected one of ['t1', 't2', 'T-1', 'T-2'], but got t-3.
if requested string does not match any option based on selected source.
"""
allowed = cls._allowed_matches(source)
if strict and not any(enum_.lower() == value.lower() for enum_ in allowed):
raise ValueError(f"Invalid match: expected one of {allowed}, but got {value}.")

if source in ("key", "any"):
for enum_key in cls.__members__.keys():
if enum_key.lower() == value.lower():
Expand All @@ -67,12 +51,20 @@ def from_str(
for enum_key, enum_val in cls.__members__.items():
if enum_val == value:
return cls[enum_key]
raise ValueError(f"Invalid match: expected one of {cls._allowed_matches(source)}, but got {value}.")

warnings.warn(UserWarning(f"Invalid string: expected one of {allowed}, but got {value}."))
@classmethod
def try_from_str(cls, value: str, source: Literal["key", "value", "any"] = "key") -> Optional["StrEnum"]:
try:
return cls.from_str(value, source)
except ValueError:
warnings.warn(
UserWarning(f"Invalid string: expected one of {cls._allowed_matches(source)}, but got {value}.")
)
return None

@classmethod
def _allowed_matches(cls, source: str) -> list:
def _allowed_matches(cls, source: str) -> List[str]:
keys, vals = [], []
for enum_key, enum_val in cls.__members__.items():
keys.append(enum_key)
Expand Down
Empty file removed tests/unittests/core/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion tests/unittests/core/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Any, ClassVar, List, Optional

import pytest
from unittests.core.mocks import torch
from unittests.mocks import torch

from lightning_utilities.core.apply_func import apply_to_collection, apply_to_collections

Expand Down
4 changes: 2 additions & 2 deletions tests/unittests/core/test_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class MyEnum(StrEnum):
T2 = "t:2"

assert MyEnum.from_str("T1", source="key")
assert MyEnum.from_str("T1", source="value") is None
assert MyEnum.try_from_str("T1", source="value") is None
assert MyEnum.from_str("T1", source="any")

assert MyEnum.from_str("T:2", source="key") is None
assert MyEnum.try_from_str("T:2", source="key") is None
assert MyEnum.from_str("T:2", source="value")
assert MyEnum.from_str("T:2", source="any")
File renamed without changes.

0 comments on commit 78eb098

Please sign in to comment.