Skip to content

Commit

Permalink
fix(enums): do actual indexing (#1267)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Oct 9, 2024
1 parent 0195451 commit b5b7968
Show file tree
Hide file tree
Showing 9 changed files with 279 additions and 68 deletions.
2 changes: 2 additions & 0 deletions openfisca_core/indexed_enums/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Enumerations for variables with a limited set of possible values."""

from . import types
from ._enum_type import EnumType
from .config import ENUM_ARRAY_DTYPE
from .enum import Enum
from .enum_array import EnumArray
Expand All @@ -9,5 +10,6 @@
"ENUM_ARRAY_DTYPE",
"Enum",
"EnumArray",
"EnumType",
"types",
]
113 changes: 113 additions & 0 deletions openfisca_core/indexed_enums/_enum_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import annotations

from typing import final

import numpy

from . import types as t


def _item_list(enum_class: type[t.Enum]) -> t.ItemList:
"""Return the non-vectorised list of enum items."""
return [
(index, name, value)
for index, (name, value) in enumerate(enum_class.__members__.items())
]


def _item_dtype(enum_class: type[t.Enum]) -> t.RecDType:
"""Return the dtype of the indexed enum's items."""
size = max(map(len, enum_class.__members__.keys()))
return numpy.dtype(
(
numpy.generic,
{
"index": (t.EnumDType, 0),
"name": (f"U{size}", 2),
"enum": (enum_class, 2 + size * 4),
},
)
)


def _item_array(enum_class: type[t.Enum]) -> t.RecArray:
"""Return the indexed enum's items."""
items = _item_list(enum_class)
dtype = _item_dtype(enum_class)
array = numpy.array(items, dtype=dtype)
return array.view(numpy.recarray)


@final
class EnumType(t.EnumType):
"""Meta class for creating an indexed :class:`.Enum`.
Examples:
>>> from openfisca_core import indexed_enums as enum
>>> class Enum(enum.Enum, metaclass=enum.EnumType):
... pass
>>> Enum.items
Traceback (most recent call last):
AttributeError: type object 'Enum' has no attribute 'items'
>>> class Housing(Enum):
... OWNER = "Owner"
... TENANT = "Tenant"
>>> Housing.items
rec.array([(0, 'OWNER', <Housing.OWNER: 'Owner'>), ...])
>>> Housing.indices
array([0, 1], dtype=int16)
>>> Housing.names
array(['OWNER', 'TENANT'], dtype='<U6')
>>> Housing.enums
array([<Housing.OWNER: 'Owner'>, <Housing.TENANT: 'Tenant'>], dtype...)
"""

#: The items of the indexed enum class.
items: t.RecArray

@property
def indices(cls) -> t.IndexArray:
"""Return the indices of the indexed enum class."""
return cls.items.index

@property
def names(cls) -> t.StrArray:
"""Return the names of the indexed enum class."""
return cls.items.name

@property
def enums(cls) -> t.ObjArray:
"""Return the members of the indexed enum class."""
return cls.items.enum

def __new__(
metacls,
cls: str,
bases: tuple[type, ...],
classdict: t.EnumDict,
**kwds: object,
) -> t.EnumType:
"""Create a new indexed enum class."""
# Create the enum class.
enum_class = super().__new__(metacls, cls, bases, classdict, **kwds)

# If the enum class has no members, return it as is.
if not enum_class.__members__:
return enum_class

# Add the items attribute to the enum class.
enum_class.items = _item_array(enum_class)

# Return the modified enum class.
return enum_class

def __dir__(cls) -> list[str]:
return sorted({"items", "indices", "names", "enums", *super().__dir__()})
68 changes: 68 additions & 0 deletions openfisca_core/indexed_enums/_type_guards.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

from typing_extensions import TypeIs

import numpy

from . import types as t


def _is_int_array(array: t.AnyArray) -> TypeIs[t.IndexArray | t.IntArray]:
"""Narrow the type of a given array to an array of :obj:`numpy.integer`.
Args:
array: Array to check.
Returns:
bool: True if ``array`` is an array of :obj:`numpy.integer`, False otherwise.
Examples:
>>> import numpy
>>> array = numpy.array([1], dtype=numpy.int16)
>>> _is_int_array(array)
True
>>> array = numpy.array([1], dtype=numpy.int32)
>>> _is_int_array(array)
True
>>> array = numpy.array([1.0])
>>> _is_int_array(array)
False
"""
return numpy.issubdtype(array.dtype, numpy.integer)


def _is_str_array(array: t.AnyArray) -> TypeIs[t.StrArray]:
"""Narrow the type of a given array to an array of :obj:`numpy.str_`.
Args:
array: Array to check.
Returns:
bool: True if ``array`` is an array of :obj:`numpy.str_`, False otherwise.
Examples:
>>> import numpy
>>> from openfisca_core import indexed_enums as enum
>>> class Housing(enum.Enum):
... OWNER = "owner"
... TENANT = "tenant"
>>> array = numpy.array([Housing.OWNER])
>>> _is_str_array(array)
False
>>> array = numpy.array(["owner"])
>>> _is_str_array(array)
True
"""
return numpy.issubdtype(array.dtype, str)


__all__ = ["_is_int_array", "_is_str_array"]
91 changes: 39 additions & 52 deletions openfisca_core/indexed_enums/enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import numpy

from . import types as t
from .config import ENUM_ARRAY_DTYPE
from ._enum_type import EnumType
from ._type_guards import _is_int_array, _is_str_array
from .enum_array import EnumArray


class Enum(t.Enum):
class Enum(t.Enum, metaclass=EnumType):
"""Enum based on `enum34 <https://pypi.python.org/pypi/enum34/>`_.
Its items have an :class:`int` index, useful and performant when running
Expand Down Expand Up @@ -115,20 +116,19 @@ def __ne__(self, other: object) -> bool:
return NotImplemented
return self.index != other.index

#: :meth:`.__hash__` must also be defined so as to stay hashable.
__hash__ = object.__hash__
def __hash__(self) -> int:
return hash(self.index)

@classmethod
def encode(
cls,
array: (
EnumArray
| t.Array[t.DTypeStr]
| t.Array[t.DTypeInt]
| t.Array[t.DTypeEnum]
| t.Array[t.DTypeObject]
| t.ArrayLike[str]
| t.IntArray
| t.StrArray
| t.ObjArray
| t.ArrayLike[int]
| t.ArrayLike[str]
| t.ArrayLike[t.Enum]
),
) -> EnumArray:
Expand All @@ -143,7 +143,6 @@ def encode(
Raises:
TypeError: If ``array`` is a scalar :class:`~numpy.ndarray`.
TypeError: If ``array`` is of a diffent :class:`.Enum` type.
NotImplementedError: If ``array`` is of an unsupported type.
Examples:
>>> import numpy
Expand Down Expand Up @@ -187,7 +186,7 @@ def encode(
>>> array = numpy.array([b"TENANT"])
>>> enum_array = Housing.encode(array)
Traceback (most recent call last):
NotImplementedError: Unsupported encoding: bytes48.
TypeError: Failed to encode "[b'TENANT']" of type 'bytes_', as i...
.. seealso::
:meth:`.EnumArray.decode` for decoding.
Expand All @@ -200,7 +199,7 @@ def encode(
return cls.encode(numpy.array(array))

if array.size == 0:
return EnumArray(array, cls)
return EnumArray(numpy.array([]), cls)

if array.ndim == 0:
msg = (
Expand All @@ -209,49 +208,37 @@ def encode(
)
raise TypeError(msg)

# Enum data type array
if numpy.issubdtype(array.dtype, t.DTypeEnum):
indexes = numpy.array([item.index for item in cls], t.DTypeEnum)
return EnumArray(indexes[array[array < indexes.size]], cls)

# Integer array
if numpy.issubdtype(array.dtype, int):
array = numpy.array(array, dtype=t.DTypeEnum)
return cls.encode(array)
if _is_int_array(array):
indices = numpy.array(array[array < len(cls.items)], dtype=t.EnumDType)
return EnumArray(indices, cls)

# String array
if numpy.issubdtype(array.dtype, t.DTypeStr):
enums = [cls.__members__[key] for key in array if key in cls.__members__]
return cls.encode(enums)

# Enum items arrays
if numpy.issubdtype(array.dtype, t.DTypeObject):
# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from
# variable.possible_values, while the array values may come from
# directly importing a module containing an Enum class. However,
# variables (and hence their possible_values) are loaded by a call
# to load_module, which gives them a different identity from the
# ones imported in the usual way.
#
# So, instead of relying on the "cls" passed in, we use only its
# name to check that the values in the array, if non-empty, are of
# the right type.
if cls.__name__ is array[0].__class__.__name__:
array = numpy.select(
[array == item for item in array[0].__class__],
[item.index for item in array[0].__class__],
).astype(ENUM_ARRAY_DTYPE)
return EnumArray(array, cls)

msg = (
f"Diverging enum types are not supported: expected {cls.__name__}, "
f"but got {array[0].__class__.__name__} instead."
)
raise TypeError(msg)

msg = f"Unsupported encoding: {array.dtype.name}."
raise NotImplementedError(msg)
if _is_str_array(array):
indices = cls.items[numpy.isin(cls.names, array)].index
return EnumArray(indices, cls)

# Ensure we are comparing the comparable. The problem this fixes:
# On entering this method "cls" will generally come from
# variable.possible_values, while the array values may come from
# directly importing a module containing an Enum class. However,
# variables (and hence their possible_values) are loaded by a call
# to load_module, which gives them a different identity from the
# ones imported in the usual way.
#
# So, instead of relying on the "cls" passed in, we use only its
# name to check that the values in the array, if non-empty, are of
# the right type.
if cls.__name__ is array[0].__class__.__name__:
indices = cls.items[numpy.isin(cls.enums, array)].index
return EnumArray(indices, cls)

msg = (
f"Failed to encode \"{array}\" of type '{array[0].__class__.__name__}', "
"as it is not supported. Please, try again with an array of "
f"'{int.__name__}', '{str.__name__}', or '{cls.__name__}'."
)
raise TypeError(msg)


__all__ = ["Enum"]
2 changes: 1 addition & 1 deletion openfisca_core/indexed_enums/enum_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ class EnumArray(t.EnumArray):

def __new__(
cls,
input_array: t.Array[t.DTypeEnum],
input_array: t.IndexArray,
possible_values: None | type[t.Enum] = None,
) -> Self:
"""See comment above."""
Expand Down
8 changes: 4 additions & 4 deletions openfisca_core/indexed_enums/tests/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def test_enum_encode_with_array_of_enum():

def test_enum_encode_with_enum_sequence():
"""Does encode when called with an enum sequence."""
sequence = list(Animal)
sequence = list(Animal) + list(Colour)
enum_array = Animal.encode(sequence)
assert Animal.DOG in enum_array

Expand Down Expand Up @@ -89,7 +89,7 @@ def test_enum_encode_with_array_of_string():

def test_enum_encode_with_str_sequence():
"""Does encode when called with a str sequence."""
sequence = ("DOG",)
sequence = ("DOG", "JAIBA")
enum_array = Animal.encode(sequence)
assert Animal.DOG in enum_array

Expand Down Expand Up @@ -130,5 +130,5 @@ def test_enum_encode_with_any_scalar_array():
def test_enum_encode_with_any_sequence():
"""Does not encode when called with unsupported types."""
sequence = memoryview(b"DOG")
with pytest.raises(NotImplementedError):
Animal.encode(sequence)
enum_array = Animal.encode(sequence)
assert len(enum_array) == 0
Loading

0 comments on commit b5b7968

Please sign in to comment.