|
2 | 2 |
|
3 | 3 | __all__ = ["EnumProperty"]
|
4 | 4 |
|
5 |
| -from typing import Any, ClassVar, List, Union, cast |
| 5 | +from typing import Any, ClassVar, List, Sequence, Union, cast |
6 | 6 |
|
7 | 7 | from attr import evolve
|
8 | 8 | from attrs import define
|
@@ -121,7 +121,7 @@ def build( # noqa: PLR0911
|
121 | 121 | if parent_name:
|
122 | 122 | class_name = f"{utils.pascal_case(parent_name)}{utils.pascal_case(class_name)}"
|
123 | 123 | class_info = Class.from_string(string=class_name, config=config)
|
124 |
| - values = EnumProperty.values_from_list(value_list) |
| 124 | + values = EnumProperty.values_from_list(value_list, case_sensitive_enums=config.case_sensitive_enums) |
125 | 125 |
|
126 | 126 | if class_info.name in schemas.classes_by_name:
|
127 | 127 | existing = schemas.classes_by_name[class_info.name]
|
@@ -183,24 +183,30 @@ def get_imports(self, *, prefix: str) -> set[str]:
|
183 | 183 | return imports
|
184 | 184 |
|
185 | 185 | @staticmethod
|
186 |
| - def values_from_list(values: list[str] | list[int]) -> dict[str, ValueType]: |
| 186 | + def values_from_list( |
| 187 | + values: Sequence[str] | Sequence[int], case_sensitive_enums: bool = False |
| 188 | + ) -> dict[str, ValueType]: |
187 | 189 | """Convert a list of values into dict of {name: value}, where value can sometimes be None"""
|
188 | 190 | output: dict[str, ValueType] = {}
|
189 | 191 |
|
190 |
| - for i, value in enumerate(values): |
191 |
| - value = cast(Union[str, int], value) |
| 192 | + for value in values: |
192 | 193 | if isinstance(value, int):
|
193 | 194 | if value < 0:
|
194 | 195 | output[f"VALUE_NEGATIVE_{-value}"] = value
|
195 | 196 | else:
|
196 | 197 | output[f"VALUE_{value}"] = value
|
197 | 198 | continue
|
198 |
| - if value and value[0].isalpha(): |
199 |
| - key = value.upper() |
| 199 | + |
| 200 | + if case_sensitive_enums: |
| 201 | + sanitized_key = utils.case_insensitive_snake_case(value) |
200 | 202 | else:
|
201 |
| - key = f"VALUE_{i}" |
202 |
| - if key in output: |
203 |
| - raise ValueError(f"Duplicate key {key} in Enum") |
204 |
| - sanitized_key = utils.snake_case(key).upper() |
| 203 | + sanitized_key = utils.snake_case(value.lower()).upper() |
| 204 | + if not value or not value[0].isalpha(): |
| 205 | + sanitized_key = f"LITERAL_{sanitized_key}" |
| 206 | + |
| 207 | + if sanitized_key in output: |
| 208 | + raise ValueError(f"Duplicate key {sanitized_key} in Enum") |
| 209 | + |
205 | 210 | output[sanitized_key] = utils.remove_string_escapes(value)
|
| 211 | + |
206 | 212 | return output
|
0 commit comments