-
Notifications
You must be signed in to change notification settings - Fork 75
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix(enums): do actual indexing (#1267)
- Loading branch information
1 parent
0195451
commit b5b7968
Showing
9 changed files
with
279 additions
and
68 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
from __future__ import annotations | ||
|
||
from typing import final | ||
|
||
import numpy | ||
|
||
from . import types as t | ||
|
||
|
||
def _item_list(enum_class: type[t.Enum]) -> t.ItemList: | ||
"""Return the non-vectorised list of enum items.""" | ||
return [ | ||
(index, name, value) | ||
for index, (name, value) in enumerate(enum_class.__members__.items()) | ||
] | ||
|
||
|
||
def _item_dtype(enum_class: type[t.Enum]) -> t.RecDType: | ||
"""Return the dtype of the indexed enum's items.""" | ||
size = max(map(len, enum_class.__members__.keys())) | ||
return numpy.dtype( | ||
( | ||
numpy.generic, | ||
{ | ||
"index": (t.EnumDType, 0), | ||
"name": (f"U{size}", 2), | ||
"enum": (enum_class, 2 + size * 4), | ||
}, | ||
) | ||
) | ||
|
||
|
||
def _item_array(enum_class: type[t.Enum]) -> t.RecArray: | ||
"""Return the indexed enum's items.""" | ||
items = _item_list(enum_class) | ||
dtype = _item_dtype(enum_class) | ||
array = numpy.array(items, dtype=dtype) | ||
return array.view(numpy.recarray) | ||
|
||
|
||
@final | ||
class EnumType(t.EnumType): | ||
"""Meta class for creating an indexed :class:`.Enum`. | ||
Examples: | ||
>>> from openfisca_core import indexed_enums as enum | ||
>>> class Enum(enum.Enum, metaclass=enum.EnumType): | ||
... pass | ||
>>> Enum.items | ||
Traceback (most recent call last): | ||
AttributeError: type object 'Enum' has no attribute 'items' | ||
>>> class Housing(Enum): | ||
... OWNER = "Owner" | ||
... TENANT = "Tenant" | ||
>>> Housing.items | ||
rec.array([(0, 'OWNER', <Housing.OWNER: 'Owner'>), ...]) | ||
>>> Housing.indices | ||
array([0, 1], dtype=int16) | ||
>>> Housing.names | ||
array(['OWNER', 'TENANT'], dtype='<U6') | ||
>>> Housing.enums | ||
array([<Housing.OWNER: 'Owner'>, <Housing.TENANT: 'Tenant'>], dtype...) | ||
""" | ||
|
||
#: The items of the indexed enum class. | ||
items: t.RecArray | ||
|
||
@property | ||
def indices(cls) -> t.IndexArray: | ||
"""Return the indices of the indexed enum class.""" | ||
return cls.items.index | ||
|
||
@property | ||
def names(cls) -> t.StrArray: | ||
"""Return the names of the indexed enum class.""" | ||
return cls.items.name | ||
|
||
@property | ||
def enums(cls) -> t.ObjArray: | ||
"""Return the members of the indexed enum class.""" | ||
return cls.items.enum | ||
|
||
def __new__( | ||
metacls, | ||
cls: str, | ||
bases: tuple[type, ...], | ||
classdict: t.EnumDict, | ||
**kwds: object, | ||
) -> t.EnumType: | ||
"""Create a new indexed enum class.""" | ||
# Create the enum class. | ||
enum_class = super().__new__(metacls, cls, bases, classdict, **kwds) | ||
|
||
# If the enum class has no members, return it as is. | ||
if not enum_class.__members__: | ||
return enum_class | ||
|
||
# Add the items attribute to the enum class. | ||
enum_class.items = _item_array(enum_class) | ||
|
||
# Return the modified enum class. | ||
return enum_class | ||
|
||
def __dir__(cls) -> list[str]: | ||
return sorted({"items", "indices", "names", "enums", *super().__dir__()}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from __future__ import annotations | ||
|
||
from typing_extensions import TypeIs | ||
|
||
import numpy | ||
|
||
from . import types as t | ||
|
||
|
||
def _is_int_array(array: t.AnyArray) -> TypeIs[t.IndexArray | t.IntArray]: | ||
"""Narrow the type of a given array to an array of :obj:`numpy.integer`. | ||
Args: | ||
array: Array to check. | ||
Returns: | ||
bool: True if ``array`` is an array of :obj:`numpy.integer`, False otherwise. | ||
Examples: | ||
>>> import numpy | ||
>>> array = numpy.array([1], dtype=numpy.int16) | ||
>>> _is_int_array(array) | ||
True | ||
>>> array = numpy.array([1], dtype=numpy.int32) | ||
>>> _is_int_array(array) | ||
True | ||
>>> array = numpy.array([1.0]) | ||
>>> _is_int_array(array) | ||
False | ||
""" | ||
return numpy.issubdtype(array.dtype, numpy.integer) | ||
|
||
|
||
def _is_str_array(array: t.AnyArray) -> TypeIs[t.StrArray]: | ||
"""Narrow the type of a given array to an array of :obj:`numpy.str_`. | ||
Args: | ||
array: Array to check. | ||
Returns: | ||
bool: True if ``array`` is an array of :obj:`numpy.str_`, False otherwise. | ||
Examples: | ||
>>> import numpy | ||
>>> from openfisca_core import indexed_enums as enum | ||
>>> class Housing(enum.Enum): | ||
... OWNER = "owner" | ||
... TENANT = "tenant" | ||
>>> array = numpy.array([Housing.OWNER]) | ||
>>> _is_str_array(array) | ||
False | ||
>>> array = numpy.array(["owner"]) | ||
>>> _is_str_array(array) | ||
True | ||
""" | ||
return numpy.issubdtype(array.dtype, str) | ||
|
||
|
||
__all__ = ["_is_int_array", "_is_str_array"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.