Skip to content

Commit

Permalink
docs: add types to tracers (1/3) (#1280)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonjourmauko authored Oct 18, 2024
2 parents 2bd629b + d859a62 commit 809affe
Show file tree
Hide file tree
Showing 12 changed files with 475 additions and 226 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

### 43.2.2 [#1280](https://github.com/openfisca/openfisca-core/pull/1280)

#### Documentation

- Add types to common tracers (`SimpleTracer`, `FlatTracer`, etc.)

### 43.2.1 [#1283](https://github.com/openfisca/openfisca-core/pull/1283)

#### Technical changes
Expand Down
7 changes: 6 additions & 1 deletion openfisca_core/populations/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,9 @@ def __init__(
super().__init__(msg)


__all__ = ["InvalidArraySizeError", "PeriodValidityError"]
__all__ = [
"IncompatibleOptionsError",
"InvalidArraySizeError",
"InvalidOptionError",
"PeriodValidityError",
]
6 changes: 2 additions & 4 deletions openfisca_core/populations/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections.abc import Iterable, MutableMapping, Sequence
from typing import NamedTuple, Union
from typing_extensions import NewType, TypeAlias, TypedDict
from typing_extensions import TypeAlias, TypedDict

from openfisca_core.types import (
Array,
Expand All @@ -14,6 +14,7 @@
Holder,
MemoryUsage,
Period,
PeriodInt,
PeriodStr,
Role,
Simulation,
Expand Down Expand Up @@ -52,9 +53,6 @@

# Periods

#: New type for a period integer.
PeriodInt = NewType("PeriodInt", int)

#: Type alias for a period-like object.
PeriodLike: TypeAlias = Union[Period, PeriodStr, PeriodInt]

Expand Down
28 changes: 19 additions & 9 deletions openfisca_core/tracers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,22 @@
#
# See: https://www.python.org/dev/peps/pep-0008/#imports

from .computation_log import ComputationLog # noqa: F401
from .flat_trace import FlatTrace # noqa: F401
from .full_tracer import FullTracer # noqa: F401
from .performance_log import PerformanceLog # noqa: F401
from .simple_tracer import SimpleTracer # noqa: F401
from .trace_node import TraceNode # noqa: F401
from .tracing_parameter_node_at_instant import ( # noqa: F401
TracingParameterNodeAtInstant,
)
from . import types
from .computation_log import ComputationLog
from .flat_trace import FlatTrace
from .full_tracer import FullTracer
from .performance_log import PerformanceLog
from .simple_tracer import SimpleTracer
from .trace_node import TraceNode
from .tracing_parameter_node_at_instant import TracingParameterNodeAtInstant

__all__ = [
"ComputationLog",
"FlatTrace",
"FullTracer",
"PerformanceLog",
"SimpleTracer",
"TraceNode",
"TracingParameterNodeAtInstant",
"types",
]
66 changes: 27 additions & 39 deletions openfisca_core/tracers/computation_log.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,24 @@
from __future__ import annotations

import typing
from typing import Union
import sys

import numpy

from openfisca_core.indexed_enums import EnumArray

if typing.TYPE_CHECKING:
from numpy.typing import ArrayLike

from openfisca_core import tracers

Array = Union[EnumArray, ArrayLike]
from . import types as t


class ComputationLog:
_full_tracer: tracers.FullTracer
_full_tracer: t.FullTracer

def __init__(self, full_tracer: tracers.FullTracer) -> None:
def __init__(self, full_tracer: t.FullTracer) -> None:
self._full_tracer = full_tracer

def display(
self,
value: Array | None,
) -> str:
if isinstance(value, EnumArray):
value = value.decode_to_str()

return numpy.array2string(value, max_line_width=float("inf"))

def lines(
self,
aggregate: bool = False,
max_depth: int | None = None,
max_depth: int = sys.maxsize,
) -> list[str]:
depth = 1

Expand All @@ -44,7 +29,7 @@ def lines(

return self._flatten(lines_by_tree)

def print_log(self, aggregate=False, max_depth=None) -> None:
def print_log(self, aggregate: bool = False, max_depth: int = sys.maxsize) -> None:
"""Print the computation log of a simulation.
If ``aggregate`` is ``False`` (default), print the value of each
Expand All @@ -60,20 +45,20 @@ def print_log(self, aggregate=False, max_depth=None) -> None:
If ``max_depth`` is set, for example to ``3``, only print computed
vectors up to a depth of ``max_depth``.
"""
for _line in self.lines(aggregate, max_depth):
for _ in self.lines(aggregate, max_depth):
pass

def _get_node_log(
self,
node: tracers.TraceNode,
node: t.TraceNode,
depth: int,
aggregate: bool,
max_depth: int | None,
max_depth: int = sys.maxsize,
) -> list[str]:
if max_depth is not None and depth > max_depth:
if depth > max_depth:
return []

node_log = [self._print_line(depth, node, aggregate, max_depth)]
node_log = [self._print_line(depth, node, aggregate)]

children_logs = [
self._get_node_log(child, depth + 1, aggregate, max_depth)
Expand All @@ -82,13 +67,7 @@ def _get_node_log(

return node_log + self._flatten(children_logs)

def _print_line(
self,
depth: int,
node: tracers.TraceNode,
aggregate: bool,
max_depth: int | None,
) -> str:
def _print_line(self, depth: int, node: t.TraceNode, aggregate: bool) -> str:
indent = " " * depth
value = node.value

Expand All @@ -97,9 +76,11 @@ def _print_line(

elif aggregate:
try:
formatted_value = str(
formatted_value = str( # pyright: ignore[reportCallIssue]
{
"avg": numpy.mean(value),
"avg": numpy.mean(
value
), # pyright: ignore[reportArgumentType,reportCallIssue]
"max": numpy.max(value),
"min": numpy.min(value),
},
Expand All @@ -113,8 +94,15 @@ def _print_line(

return f"{indent}{node.name}<{node.period}> >> {formatted_value}"

def _flatten(
self,
lists: list[list[str]],
) -> list[str]:
@staticmethod
def display(value: t.VarArray, max_depth: int = sys.maxsize) -> str:
if isinstance(value, EnumArray):
value = value.decode_to_str()
return numpy.array2string(value, max_line_width=max_depth)

@staticmethod
def _flatten(lists: list[list[str]]) -> list[str]:
return [item for list_ in lists for item in list_]


__all__ = ["ComputationLog"]
79 changes: 39 additions & 40 deletions openfisca_core/tracers/flat_trace.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,20 @@
from __future__ import annotations

import typing
from typing import Union

import numpy

from openfisca_core.indexed_enums import EnumArray

if typing.TYPE_CHECKING:
from numpy.typing import ArrayLike

from openfisca_core import tracers

Array = Union[EnumArray, ArrayLike]
Trace = dict[str, dict]
from . import types as t


class FlatTrace:
_full_tracer: tracers.FullTracer
_full_tracer: t.FullTracer

def __init__(self, full_tracer: tracers.FullTracer) -> None:
def __init__(self, full_tracer: t.FullTracer) -> None:
self._full_tracer = full_tracer

def key(self, node: tracers.TraceNode) -> str:
name = node.name
period = node.period
return f"{name}<{period}>"

def get_trace(self) -> dict:
trace = {}
def get_trace(self) -> t.FlatNodeMap:
trace: t.FlatNodeMap = {}

for node in self._full_tracer.browse_trace():
# We don't want cache read to overwrite data about the initial
Expand All @@ -45,34 +31,16 @@ def get_trace(self) -> dict:

return trace

def get_serialized_trace(self) -> dict:
def get_serialized_trace(self) -> t.SerializedNodeMap:
return {
key: {**flat_trace, "value": self.serialize(flat_trace["value"])}
for key, flat_trace in self.get_trace().items()
}

def serialize(
self,
value: Array | None,
) -> Array | None | list:
if isinstance(value, EnumArray):
value = value.decode_to_str()

if isinstance(value, numpy.ndarray) and numpy.issubdtype(
value.dtype,
numpy.dtype(bytes),
):
value = value.astype(numpy.dtype(str))

if isinstance(value, numpy.ndarray):
value = value.tolist()

return value

def _get_flat_trace(
self,
node: tracers.TraceNode,
) -> Trace:
node: t.TraceNode,
) -> t.FlatNodeMap:
key = self.key(node)

return {
Expand All @@ -87,3 +55,34 @@ def _get_flat_trace(
"formula_time": node.formula_time(),
},
}

@staticmethod
def key(node: t.TraceNode) -> t.NodeKey:
"""Return the key of a node."""
name = node.name
period = node.period
return t.NodeKey(f"{name}<{period}>")

@staticmethod
def serialize(
value: None | t.VarArray | t.ArrayLike[object],
) -> None | t.ArrayLike[object]:
if value is None:
return None

if isinstance(value, EnumArray):
return value.decode_to_str().tolist()

if isinstance(value, numpy.ndarray) and numpy.issubdtype(
value.dtype,
numpy.dtype(bytes),
):
return value.astype(numpy.dtype(str)).tolist()

if isinstance(value, numpy.ndarray):
return value.tolist()

return value


__all__ = ["FlatTrace"]
Loading

0 comments on commit 809affe

Please sign in to comment.