Skip to content

Commit

Permalink
Document & test eval_expression (#1225)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko authored Sep 25, 2024
2 parents eb154a0 + ecf904f commit 7d451a1
Show file tree
Hide file tree
Showing 39 changed files with 402 additions and 313 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

### 41.5.7 [#1225](https://github.com/openfisca/openfisca-core/pull/1225)

#### Technical changes

- Refactor & test `eval_expression`

### 41.5.6 [#1185](https://github.com/openfisca/openfisca-core/pull/1185)

#### Technical changes
Expand Down
28 changes: 16 additions & 12 deletions openfisca_core/commons/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* :func:`.average_rate`
* :func:`.concat`
* :func:`.empty_clone`
* :func:`.eval_expression`
* :func:`.marginal_rate`
* :func:`.stringify_array`
* :func:`.switch`
Expand Down Expand Up @@ -50,18 +51,21 @@
"""

# Official Public API

from . import types
from .dummy import Dummy
from .formulas import apply_thresholds, concat, switch
from .misc import empty_clone, stringify_array
from .misc import empty_clone, eval_expression, stringify_array
from .rates import average_rate, marginal_rate

__all__ = ["apply_thresholds", "concat", "switch"]
__all__ = ["empty_clone", "stringify_array", *__all__]
__all__ = ["average_rate", "marginal_rate", *__all__]

# Deprecated

from .dummy import Dummy

__all__ = ["Dummy", *__all__]
__all__ = [
"Dummy",
"apply_thresholds",
"average_rate",
"concat",
"empty_clone",
"eval_expression",
"marginal_rate",
"stringify_array",
"switch",
"types",
]
3 changes: 3 additions & 0 deletions openfisca_core/commons/dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@ def __init__(self) -> None:
"and will be removed in the future.",
]
warnings.warn(" ".join(message), DeprecationWarning, stacklevel=2)


__all__ = ["Dummy"]
35 changes: 21 additions & 14 deletions openfisca_core/commons/formulas.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from __future__ import annotations

from collections.abc import Mapping
from typing import Union

import numpy

from openfisca_core import types as t
from . import types as t


def apply_thresholds(
input: t.Array[numpy.float64],
input: t.Array[numpy.float32],
thresholds: t.ArrayLike[float],
choices: t.ArrayLike[float],
) -> t.Array[numpy.float64]:
) -> t.Array[numpy.float32]:
"""Makes a choice based on an input and thresholds.
From a list of ``choices``, this function selects one of these values
Expand Down Expand Up @@ -38,26 +39,29 @@ def apply_thresholds(
array([10, 10, 15, 15, 20])
"""
condlist: list[Union[t.Array[numpy.bool_], bool]]

condlist: list[t.Array[numpy.bool_] | bool]
condlist = [input <= threshold for threshold in thresholds]

if len(condlist) == len(choices) - 1:
# If a choice is provided for input > highest threshold, last condition
# must be true to return it.
condlist += [True]

assert len(condlist) == len(
choices
), "'apply_thresholds' must be called with the same number of thresholds than choices, or one more choice."
msg = (
"'apply_thresholds' must be called with the same number of thresholds "
"than choices, or one more choice."
)
assert len(condlist) == len(choices), msg

return numpy.select(condlist, choices)


def concat(
this: Union[t.Array[numpy.str_], t.ArrayLike[str]],
that: Union[t.Array[numpy.str_], t.ArrayLike[str]],
this: t.Array[numpy.str_] | t.ArrayLike[str],
that: t.Array[numpy.str_] | t.ArrayLike[str],
) -> t.Array[numpy.str_]:
"""Concatenates the values of two arrays.
"""Concatenate the values of two arrays.
Args:
this: An array to concatenate.
Expand All @@ -84,10 +88,10 @@ def concat(


def switch(
conditions: t.Array[numpy.float64],
conditions: t.Array[numpy.float32],
value_by_condition: Mapping[float, float],
) -> t.Array[numpy.float64]:
"""Mimicks a switch statement.
) -> t.Array[numpy.float32]:
"""Mimick a switch statement.
Given an array of conditions, returns an array of the same size,
replacing each condition item with the matching given value.
Expand Down Expand Up @@ -117,3 +121,6 @@ def switch(
condlist = [conditions == condition for condition in value_by_condition]

return numpy.select(condlist, tuple(value_by_condition.values()))


__all__ = ["apply_thresholds", "concat", "switch"]
47 changes: 37 additions & 10 deletions openfisca_core/commons/misc.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from typing import Optional, TypeVar
from __future__ import annotations

import numexpr
import numpy

from openfisca_core import types as t

T = TypeVar("T")


def empty_clone(original: T) -> T:
"""Creates an empty instance of the same class of the original object.
def empty_clone(original: object) -> object:
"""Create an empty instance of the same class of the original object.
Args:
original: An object to clone.
Expand All @@ -30,22 +29,20 @@ def empty_clone(original: T) -> T:
True
"""
Dummy: object
new: T

Dummy = type(
"Dummy",
(original.__class__,),
{"__init__": lambda self: None},
{"__init__": lambda _: None},
)

new = Dummy()
new.__class__ = original.__class__
return new


def stringify_array(array: Optional[t.Array[numpy.generic]]) -> str:
"""Generates a clean string representation of a numpy array.
def stringify_array(array: None | t.Array[numpy.generic]) -> str:
"""Generate a clean string representation of a numpy array.
Args:
array: An array.
Expand Down Expand Up @@ -76,3 +73,33 @@ def stringify_array(array: Optional[t.Array[numpy.generic]]) -> str:
return "None"

return f"[{', '.join(str(cell) for cell in array)}]"


def eval_expression(
expression: str,
) -> str | t.Array[numpy.bool_] | t.Array[numpy.int32] | t.Array[numpy.float32]:
"""Evaluate a string expression to a numpy array.
Args:
expression(str): An expression to evaluate.
Returns:
:obj:`object`: The result of the evaluation.
Examples:
>>> eval_expression("1 + 2")
array(3, dtype=int32)
>>> eval_expression("salary")
'salary'
"""

try:
return numexpr.evaluate(expression)

except (KeyError, TypeError):
return expression


__all__ = ["empty_clone", "eval_expression", "stringify_array"]
33 changes: 18 additions & 15 deletions openfisca_core/commons/rates.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from typing import Optional

from openfisca_core.types import Array, ArrayLike
from __future__ import annotations

import numpy

from . import types as t


def average_rate(
target: Array[numpy.float64],
varying: ArrayLike[float],
trim: Optional[ArrayLike[float]] = None,
) -> Array[numpy.float64]:
"""Computes the average rate of a target net income.
target: t.Array[numpy.float32],
varying: t.ArrayLike[float],
trim: None | t.ArrayLike[float] = None,
) -> t.Array[numpy.float32]:
"""Compute the average rate of a target net income.
Given a ``target`` net income, and according to the ``varying`` gross
income. Optionally, a ``trim`` can be applied consisting of the lower and
Expand Down Expand Up @@ -40,8 +40,8 @@ def average_rate(
array([ nan, 0. , -0.5])
"""
average_rate: Array[numpy.float64]

average_rate: t.Array[numpy.float32]
average_rate = 1 - target / varying

if trim is not None:
Expand All @@ -61,11 +61,11 @@ def average_rate(


def marginal_rate(
target: Array[numpy.float64],
varying: Array[numpy.float64],
trim: Optional[ArrayLike[float]] = None,
) -> Array[numpy.float64]:
"""Computes the marginal rate of a target net income.
target: t.Array[numpy.float32],
varying: t.Array[numpy.float32],
trim: None | t.ArrayLike[float] = None,
) -> t.Array[numpy.float32]:
"""Compute the marginal rate of a target net income.
Given a ``target`` net income, and according to the ``varying`` gross
income. Optionally, a ``trim`` can be applied consisting of the lower and
Expand Down Expand Up @@ -95,8 +95,8 @@ def marginal_rate(
array([nan, 0.5])
"""
marginal_rate: Array[numpy.float64]

marginal_rate: t.Array[numpy.float32]
marginal_rate = +1 - (target[:-1] - target[1:]) / (varying[:-1] - varying[1:])

if trim is not None:
Expand All @@ -113,3 +113,6 @@ def marginal_rate(
)

return marginal_rate


__all__ = ["average_rate", "marginal_rate"]
3 changes: 3 additions & 0 deletions openfisca_core/commons/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from openfisca_core.types import Array, ArrayLike

__all__ = ["Array", "ArrayLike"]
49 changes: 30 additions & 19 deletions openfisca_core/entities/_core_entity.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from __future__ import annotations

from typing import ClassVar

import abc
import os
from abc import abstractmethod

from . import types as t
from .role import Role
Expand All @@ -12,29 +14,30 @@ class _CoreEntity:

#: A key to identify the entity.
key: t.EntityKey

#: The ``key``, pluralised.
plural: t.EntityPlural | None
plural: t.EntityPlural

#: A summary description.
label: str | None
label: str

#: A full description.
doc: str | None
doc: str

#: Whether the entity is a person or not.
is_person: bool
is_person: ClassVar[bool]

#: A TaxBenefitSystem instance.
_tax_benefit_system: t.TaxBenefitSystem | None = None
_tax_benefit_system: None | t.TaxBenefitSystem = None

@abstractmethod
@abc.abstractmethod
def __init__(
self,
key: str,
plural: str,
label: str,
doc: str,
*args: object,
__key: str,
__plural: str,
__label: str,
__doc: str,
*__args: object,
) -> None: ...

def __repr__(self) -> str:
Expand All @@ -46,7 +49,7 @@ def set_tax_benefit_system(self, tax_benefit_system: t.TaxBenefitSystem) -> None

def get_variable(
self,
variable_name: str,
variable_name: t.VariableName,
check_existence: bool = False,
) -> t.Variable | None:
"""Get a ``variable_name`` from ``variables``."""
Expand All @@ -57,16 +60,20 @@ def get_variable(
)
return self._tax_benefit_system.get_variable(variable_name, check_existence)

def check_variable_defined_for_entity(self, variable_name: str) -> None:
def check_variable_defined_for_entity(self, variable_name: t.VariableName) -> None:
"""Check if ``variable_name`` is defined for ``self``."""
variable: t.Variable | None
entity: t.CoreEntity

variable = self.get_variable(variable_name, check_existence=True)
entity: None | t.CoreEntity = None
variable: None | t.Variable = self.get_variable(
variable_name,
check_existence=True,
)

if variable is not None:
entity = variable.entity

if entity is None:
return

if entity.key != self.key:
message = (
f"You tried to compute the variable '{variable_name}' for",
Expand All @@ -77,8 +84,12 @@ def check_variable_defined_for_entity(self, variable_name: str) -> None:
)
raise ValueError(os.linesep.join(message))

def check_role_validity(self, role: object) -> None:
@staticmethod
def check_role_validity(role: object) -> None:
"""Check if a ``role`` is an instance of Role."""
if role is not None and not isinstance(role, Role):
msg = f"{role} is not a valid role"
raise ValueError(msg)


__all__ = ["_CoreEntity"]
Loading

0 comments on commit 7d451a1

Please sign in to comment.