diff --git a/openfisca_core/populations/_core_population.py b/openfisca_core/populations/_core_population.py index 8dff04137..7e160d314 100644 --- a/openfisca_core/populations/_core_population.py +++ b/openfisca_core/populations/_core_population.py @@ -10,8 +10,7 @@ from openfisca_core import holders, periods from . import types as t -from ._errors import InvalidArraySizeError -from ._errors import PeriodValidityError +from ._errors import InvalidArraySizeError, PeriodValidityError #: Type variable for a covariant data type. _DT_co = TypeVar("_DT_co", covariant=True, bound=t.VarDType) @@ -160,7 +159,35 @@ def get_index(self, id: str) -> int: # Calculations - def check_array_compatible_with_entity(self, array: t.FloatArray) -> None: + def check_array_compatible_with_entity(self, array: t.VarArray) -> None: + """Check if an array is compatible with the population. + + Args: + array: The array to check. + + Raises: + InvalidArraySizeError: If the array is not compatible. + + Examples: + >>> import numpy + + >>> from openfisca_core import entities, populations + + >>> class Person(entities.SingleEntity): ... + + >>> person = Person("person", "people", "", "") + >>> population = populations.CorePopulation(person) + >>> population.count = 3 + + >>> array = numpy.array([1, 2, 3]) + >>> population.check_array_compatible_with_entity(array) + + >>> array = numpy.array([1, 2, 3, 4]) + >>> population.check_array_compatible_with_entity(array) + Traceback (most recent call last): + InvalidArraySizeError: Input [1 2 3 4] is not a valid value for t... + + """ if self.count == array.size: return raise InvalidArraySizeError(array, self.entity.key, self.count) diff --git a/openfisca_core/populations/_errors.py b/openfisca_core/populations/_errors.py index 77e6c424b..b48569681 100644 --- a/openfisca_core/populations/_errors.py +++ b/openfisca_core/populations/_errors.py @@ -4,7 +4,7 @@ class InvalidArraySizeError(ValueError): """Raised when an array has an invalid size.""" - def __init__(self, array: t.FloatArray, entity: t.EntityKey, count: int) -> None: + def __init__(self, array: t.VarArray, entity: t.EntityKey, count: int) -> None: msg = ( f"Input {array} is not a valid value for the entity {entity} " f"(size = {array.size} != {count} = count)." diff --git a/openfisca_core/populations/types.py b/openfisca_core/populations/types.py index 0cfccef36..9b90c43ef 100644 --- a/openfisca_core/populations/types.py +++ b/openfisca_core/populations/types.py @@ -47,6 +47,9 @@ #: Type alias for an array of floats. FloatArray: TypeAlias = Array[FloatDType] +#: Type alias for an array of generic objects. +VarArray: TypeAlias = Array[VarDType] + # Periods #: New type for a period integer.