Skip to content

Commit

Permalink
Fix population types
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko committed Dec 1, 2022
1 parent d23f09d commit 63863d3
Show file tree
Hide file tree
Showing 12 changed files with 119 additions and 70 deletions.
20 changes: 13 additions & 7 deletions openfisca_core/data_storage/in_memory_storage.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,24 @@
from __future__ import annotations

from typing import Any, Dict, KeysView, Optional

import numpy

from openfisca_core import periods
from openfisca_core import periods, types


class InMemoryStorage:
"""
Low-level class responsible for storing and retrieving calculated vectors in memory
"""

def __init__(self, is_eternal = False):
_arrays: Dict[types.Period, numpy.ndarray]

def __init__(self, is_eternal: bool = False) -> None:
self._arrays = {}
self.is_eternal = is_eternal

def get(self, period):
def get(self, period: types.Period) -> Any:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)
Expand All @@ -23,14 +29,14 @@ def get(self, period):

return values

def put(self, value, period):
def put(self, value: Any, period: types.Period) -> None:
if self.is_eternal:
period = periods.period(periods.ETERNITY)
period = periods.period(period)

self._arrays[period] = value

def delete(self, period = None):
def delete(self, period: Optional[types.Period] = None) -> None:
if period is None:
self._arrays = {}
return
Expand All @@ -45,10 +51,10 @@ def delete(self, period = None):
if not period.contains(period_item)
}

def get_known_periods(self):
def get_known_periods(self) -> KeysView[types.Period]:
return self._arrays.keys()

def get_memory_usage(self):
def get_memory_usage(self) -> types.MemoryUsage:
if not self._arrays:
return dict(
nb_arrays = 0,
Expand Down
5 changes: 4 additions & 1 deletion openfisca_core/data_storage/on_disk_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Dict, KeysView, Optional, Type
from typing import Any, Dict, KeysView, NoReturn, Optional, Type

import os
import shutil
Expand Down Expand Up @@ -82,6 +82,9 @@ def delete(self, period: Optional[types.Period] = None) -> None:
def get_known_periods(self) -> KeysView[types.Period]:
return self._files.keys()

def get_memory_usage(self) -> NoReturn:
raise NotImplementedError

def restore(self) -> None:
self._files = {}
# Restore self._files from content of storage_dir.
Expand Down
1 change: 0 additions & 1 deletion openfisca_core/holders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,3 @@

from .helpers import set_input_dispatch_by_period, set_input_divide_by_period # noqa: F401
from .holder import Holder # noqa: F401
from .memory_usage import MemoryUsage # noqa: F401
8 changes: 3 additions & 5 deletions openfisca_core/holders/holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@
types,
)

from .memory_usage import MemoryUsage


class Holder:
"""
A holder keeps tracks of a variable values after they have been calculated, or set as an input.
"""

_disk_storage: types.Storage
_disk_storage: Optional[types.Storage]
_do_not_store: bool
_memory_storage: types.Storage
_on_disk_storable: bool
Expand Down Expand Up @@ -112,7 +110,7 @@ def get_array(self, period):
if self._disk_storage:
return self._disk_storage.get(period)

def get_memory_usage(self) -> MemoryUsage:
def get_memory_usage(self) -> types.MemoryUsage:
"""Get data about the virtual memory usage of the Holder.
Returns:
Expand Down Expand Up @@ -154,7 +152,7 @@ def get_memory_usage(self) -> MemoryUsage:
"""

usage = MemoryUsage(
usage = types.MemoryUsage(
nb_cells_by_array = self.population.count,
dtype = self.variable.dtype,
)
Expand Down
26 changes: 0 additions & 26 deletions openfisca_core/holders/memory_usage.py

This file was deleted.

8 changes: 6 additions & 2 deletions openfisca_core/periods/period_.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,11 +319,15 @@ def offset(self, offset, unit = None):
"""
return self.__class__((self[0], self[1].offset(offset, self[0] if unit is None else unit), self[2]))

def contains(self, other: Period) -> bool:
def contains(self, other: object) -> bool:
"""
Returns ``True`` if the period contains ``other``. For instance, ``period(2015)`` contains ``period(2015-01)``
"""
return self.start <= other.start and self.stop >= other.stop

if isinstance(other, Period):
return self.start <= other.start and self.stop >= other.stop

return NotImplemented

@property
def size(self):
Expand Down
73 changes: 47 additions & 26 deletions openfisca_core/populations/population.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,45 +7,44 @@

import numpy

from openfisca_core import periods, projectors
from openfisca_core.holders import Holder, MemoryUsage
from openfisca_core import errors, periods, projectors, types
from openfisca_core.holders import Holder
from openfisca_core.projectors import Projector
from openfisca_core.types import Array, Entity, Period, Role, Simulation

from . import config


class Population:

simulation: Optional[Simulation]
entity: Entity
simulation: Optional[types.Simulation]
entity: types.Entity
_holders: Dict[str, Holder]
count: int
ids: Array[str]
ids: types.Array[str]

def __init__(self, entity: Entity) -> None:
def __init__(self, entity: types.Entity) -> None:
self.simulation = None
self.entity = entity
self._holders = {}
self.count = 0
self.ids = []

def clone(self, simulation: Simulation) -> Population:
def clone(self, simulation: types.Simulation) -> Population:
result = Population(self.entity)
result.simulation = simulation
result._holders = {variable: holder.clone(result) for (variable, holder) in self._holders.items()}
result.count = self.count
result.ids = self.ids
return result

def empty_array(self) -> Array[float]:
def empty_array(self) -> types.Array[float]:
return numpy.zeros(self.count)

def filled_array(
self,
value: Union[float, bool],
dtype: Optional[numpy.dtype] = None,
) -> Union[Array[float], Array[bool]]:
) -> Union[types.Array[float], types.Array[bool]]:
return numpy.full(self.count, value, dtype)

def __getattr__(self, attribute: str) -> Projector:
Expand All @@ -64,7 +63,7 @@ def get_index(self, id: str) -> int:

def check_array_compatible_with_entity(
self,
array: Array[float],
array: types.Array[float],
) -> None:
if self.count == array.size:
return None
Expand All @@ -75,9 +74,9 @@ def check_array_compatible_with_entity(
def check_period_validity(
self,
variable_name: str,
period: Optional[Union[int, str, Period]],
period: Optional[Union[int, str, types.Period]],
) -> None:
if isinstance(period, (int, str, Period)):
if isinstance(period, (int, str, types.Period)):
return None

stack = traceback.extract_stack()
Expand All @@ -93,9 +92,9 @@ def check_period_validity(
def __call__(
self,
variable_name: str,
period: Optional[Union[int, str, Period]] = None,
period: Optional[Union[int, str, types.Period]] = None,
options: Optional[Sequence[str]] = None,
) -> Optional[Array[float]]:
) -> Optional[types.Array[float]]:
"""
Calculate the variable ``variable_name`` for the entity and the period ``period``, using the variable formula if it exists.
Expand Down Expand Up @@ -141,13 +140,35 @@ def __call__(
# Helpers

def get_holder(self, variable_name: str) -> Holder:
holder: Optional[types.Holder]
variable: Optional[types.Variable]
simulation: Optional[types.Simulation]
tax_benefit_system: Optional[types.TaxBenefitSystem]

self.entity.check_variable_defined_for_entity(variable_name)
holder = self._holders.get(variable_name)
if holder:

if holder is not None:
return holder

variable = self.entity.get_variable(variable_name)
self._holders[variable_name] = holder = Holder(variable, self)
return holder

if variable is not None:
holder = Holder(variable, self)
self._holders[variable_name] = holder
return holder

simulation = self.simulation

if simulation is None:
raise TypeError("Simulation can't be None.")

tax_benefit_system = simulation.tax_benefit_system

if tax_benefit_system is None:
raise TypeError("TaxBenefitSystem can't be None.")

raise errors.VariableNotFoundError(variable_name, tax_benefit_system)

def get_memory_usage(
self,
Expand All @@ -169,7 +190,7 @@ def get_memory_usage(
})

@projectors.projectable
def has_role(self, role: Role) -> Optional[Array[bool]]:
def has_role(self, role: types.Role) -> Optional[types.Array[bool]]:
"""
Check if a person has a given role within its `GroupEntity`
Expand All @@ -195,10 +216,10 @@ def has_role(self, role: Role) -> Optional[Array[bool]]:
@projectors.projectable
def value_from_partner(
self,
array: Array[float],
array: types.Array[float],
entity: Projector,
role: Role,
) -> Optional[Array[float]]:
role: types.Role,
) -> Optional[types.Array[float]]:
self.check_array_compatible_with_entity(array)
self.entity.check_role_validity(role)

Expand All @@ -218,9 +239,9 @@ def value_from_partner(
def get_rank(
self,
entity: Population,
criteria: Array[float],
criteria: types.Array[float],
condition: bool = True,
) -> Array[int]:
) -> types.Array[int]:
"""
Get the rank of a person within an entity according to a criteria.
The person with rank 0 has the minimum value of criteria.
Expand Down Expand Up @@ -265,10 +286,10 @@ def get_rank(

class Calculate(NamedTuple):
variable: str
period: Period
period: types.Period
option: Optional[Sequence[str]]


class MemoryUsageByVariable(TypedDict, total = False):
by_variable: Dict[str, MemoryUsage]
by_variable: Dict[str, types.MemoryUsage]
total_nb_bytes: int
3 changes: 3 additions & 0 deletions openfisca_core/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
* :attr:`.Formula`
* :attr:`.Holder`
* :attr:`.Instant`
* :attr:`.MemoryUsage`
* :attr:`.ParameterNodeAtInstant`
* :attr:`.Params`
* :attr:`.Period`
Expand Down Expand Up @@ -54,6 +55,7 @@
Array,
ArrayLike,
Instant,
MemoryUsage,
Period,
)

Expand Down Expand Up @@ -81,6 +83,7 @@
"Formula",
"Holder",
"Instant",
"MemoryUsage",
"ParameterNodeAtInstant",
"Params",
"Period",
Expand Down
25 changes: 24 additions & 1 deletion openfisca_core/types/_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import typing_extensions
from typing import Any, Sequence, TypeVar, Union
from typing_extensions import Protocol
from typing_extensions import Protocol, TypedDict

import abc

Expand Down Expand Up @@ -76,6 +76,29 @@ class Instant(Protocol):
"""Instant protocol."""


class MemoryUsage(TypedDict, total = False):
"""Virtual memory usage of a Holder.
Attributes:
cell_size: The amount of bytes assigned to each value.
dtype: The :mod:`numpy.dtype` of any, each, and every value.
nb_arrays: The number of periods for which the Holder contains values.
nb_cells_by_array: The number of entities in the current Simulation.
nb_requests: The number of times the Variable has been computed.
nb_requests_by_array: Average times a stored array has been read.
total_nb_bytes: The total number of bytes used by the Holder.
"""

cell_size: float
dtype: numpy.dtype[Any]
nb_arrays: int
nb_cells_by_array: int
nb_requests: int
nb_requests_by_array: int
total_nb_bytes: int


@typing_extensions.runtime_checkable
class Period(Protocol):
"""Period protocol."""
Expand Down
Loading

0 comments on commit 63863d3

Please sign in to comment.