From 21f48505b0804a789ab8df65bc57137d2d07ebff Mon Sep 17 00:00:00 2001 From: Mauko Quiroga Date: Wed, 20 Nov 2024 00:37:59 +0100 Subject: [PATCH] refactor: consolidate tracers module types --- openfisca_core/tracers/__init__.py | 2 - openfisca_core/tracers/computation_log.py | 3 +- openfisca_core/tracers/flat_trace.py | 3 +- openfisca_core/tracers/full_tracer.py | 3 +- openfisca_core/tracers/simple_tracer.py | 2 +- openfisca_core/tracers/trace_node.py | 2 +- openfisca_core/tracers/types.py | 108 ---------------------- openfisca_core/types.py | 93 ++++++++++++++++++- 8 files changed, 98 insertions(+), 118 deletions(-) delete mode 100644 openfisca_core/tracers/types.py diff --git a/openfisca_core/tracers/__init__.py b/openfisca_core/tracers/__init__.py index 76e36b55c..7220dc8c6 100644 --- a/openfisca_core/tracers/__init__.py +++ b/openfisca_core/tracers/__init__.py @@ -21,7 +21,6 @@ # # See: https://www.python.org/dev/peps/pep-0008/#imports -from . import types from .computation_log import ComputationLog from .flat_trace import FlatTrace from .full_tracer import FullTracer @@ -38,5 +37,4 @@ "SimpleTracer", "TraceNode", "TracingParameterNodeAtInstant", - "types", ] diff --git a/openfisca_core/tracers/computation_log.py b/openfisca_core/tracers/computation_log.py index bd68b4b21..5d3849c9a 100644 --- a/openfisca_core/tracers/computation_log.py +++ b/openfisca_core/tracers/computation_log.py @@ -4,10 +4,9 @@ import numpy +from openfisca_core import types as t from openfisca_core.indexed_enums import EnumArray -from . import types as t - class ComputationLog: _full_tracer: t.FullTracer diff --git a/openfisca_core/tracers/flat_trace.py b/openfisca_core/tracers/flat_trace.py index 412ac8b02..e63301a27 100644 --- a/openfisca_core/tracers/flat_trace.py +++ b/openfisca_core/tracers/flat_trace.py @@ -2,10 +2,9 @@ import numpy +from openfisca_core import types as t from openfisca_core.indexed_enums import EnumArray -from . import types as t - class FlatTrace: _full_tracer: t.FullTracer diff --git a/openfisca_core/tracers/full_tracer.py b/openfisca_core/tracers/full_tracer.py index f6f793e19..6187b1134 100644 --- a/openfisca_core/tracers/full_tracer.py +++ b/openfisca_core/tracers/full_tracer.py @@ -5,7 +5,8 @@ import sys import time -from . import types as t +from openfisca_core import types as t + from .computation_log import ComputationLog from .flat_trace import FlatTrace from .performance_log import PerformanceLog diff --git a/openfisca_core/tracers/simple_tracer.py b/openfisca_core/tracers/simple_tracer.py index 174dd3119..d096b0311 100644 --- a/openfisca_core/tracers/simple_tracer.py +++ b/openfisca_core/tracers/simple_tracer.py @@ -1,6 +1,6 @@ from __future__ import annotations -from . import types as t +from openfisca_core import types as t class SimpleTracer: diff --git a/openfisca_core/tracers/trace_node.py b/openfisca_core/tracers/trace_node.py index de81825e8..8e8ce1fdc 100644 --- a/openfisca_core/tracers/trace_node.py +++ b/openfisca_core/tracers/trace_node.py @@ -2,7 +2,7 @@ import dataclasses -from . import types as t +from openfisca_core import types as t @dataclasses.dataclass diff --git a/openfisca_core/tracers/types.py b/openfisca_core/tracers/types.py deleted file mode 100644 index f26c85424..000000000 --- a/openfisca_core/tracers/types.py +++ /dev/null @@ -1,108 +0,0 @@ -from __future__ import annotations - -from collections.abc import Iterator -from typing import NewType, Protocol -from typing_extensions import TypeAlias, TypedDict - -from openfisca_core.types import ( - Array, - ArrayLike, - ParameterNode, - ParameterNodeChild, - Period, - PeriodInt, - VariableName, -) - -from numpy import generic as VarDType - -#: A type of a generic array. -VarArray: TypeAlias = Array[VarDType] - -#: A type representing a unit time. -Time: TypeAlias = float - -#: A type representing a mapping of flat traces. -FlatNodeMap: TypeAlias = dict["NodeKey", "FlatTraceMap"] - -#: A type representing a mapping of serialized traces. -SerializedNodeMap: TypeAlias = dict["NodeKey", "SerializedTraceMap"] - -#: A stack of simple traces. -SimpleStack: TypeAlias = list["SimpleTraceMap"] - -#: Key of a trace. -NodeKey = NewType("NodeKey", str) - - -class FlatTraceMap(TypedDict, total=True): - dependencies: list[NodeKey] - parameters: dict[NodeKey, None | ArrayLike[object]] - value: None | VarArray - calculation_time: Time - formula_time: Time - - -class SerializedTraceMap(TypedDict, total=True): - dependencies: list[NodeKey] - parameters: dict[NodeKey, None | ArrayLike[object]] - value: None | ArrayLike[object] - calculation_time: Time - formula_time: Time - - -class SimpleTraceMap(TypedDict, total=True): - name: VariableName - period: int | Period - - -class ComputationLog(Protocol): - def print_log(self, aggregate: bool = ..., max_depth: int = ..., /) -> None: ... - - -class FlatTrace(Protocol): - def get_trace(self, /) -> FlatNodeMap: ... - def get_serialized_trace(self, /) -> SerializedNodeMap: ... - - -class FullTracer(Protocol): - @property - def trees(self, /) -> list[TraceNode]: ... - def browse_trace(self, /) -> Iterator[TraceNode]: ... - - -class PerformanceLog(Protocol): - def generate_graph(self, dir_path: str, /) -> None: ... - def generate_performance_tables(self, dir_path: str, /) -> None: ... - - -class SimpleTracer(Protocol): - @property - def stack(self, /) -> SimpleStack: ... - def record_calculation_start( - self, variable: VariableName, period: PeriodInt | Period, / - ) -> None: ... - def record_calculation_end(self, /) -> None: ... - - -class TraceNode(Protocol): - children: list[TraceNode] - end: Time - name: str - parameters: list[TraceNode] - parent: None | TraceNode - period: PeriodInt | Period - start: Time - value: None | VarArray - - def calculation_time(self, *, round_: bool = ...) -> Time: ... - def formula_time(self, /) -> Time: ... - def append_child(self, node: TraceNode, /) -> None: ... - - -__all__ = [ - "ArrayLike", - "ParameterNode", - "ParameterNodeChild", - "PeriodInt", -] diff --git a/openfisca_core/types.py b/openfisca_core/types.py index 60573ccd3..f02df2801 100644 --- a/openfisca_core/types.py +++ b/openfisca_core/types.py @@ -1,6 +1,6 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence, Sized +from collections.abc import Iterable, Iterator, Sequence, Sized from numpy.typing import DTypeLike, NDArray from typing import NewType, TypeVar, Union from typing_extensions import Protocol, Required, Self, TypeAlias, TypedDict @@ -309,6 +309,97 @@ def get_variable( ) -> None | Variable: ... +# Tracers + +#: A type representing a unit time. +Time: TypeAlias = float + +#: A type representing a mapping of flat traces. +FlatNodeMap: TypeAlias = dict["NodeKey", "FlatTraceMap"] + +#: A type representing a mapping of serialized traces. +SerializedNodeMap: TypeAlias = dict["NodeKey", "SerializedTraceMap"] + +#: Key of a trace. +NodeKey = NewType("NodeKey", str) + + +class FlatTraceMap(TypedDict, total=True): + dependencies: list[NodeKey] + parameters: dict[NodeKey, None | ArrayLike[object]] + value: None | VarArray + calculation_time: Time + formula_time: Time + + +class SerializedTraceMap(TypedDict, total=True): + dependencies: list[NodeKey] + parameters: dict[NodeKey, None | ArrayLike[object]] + value: None | ArrayLike[object] + calculation_time: Time + formula_time: Time + + +class SimpleTraceMap(TypedDict, total=True): + name: VariableName + period: int | Period + + +class ComputationLog(Protocol): + def print_log(self, __aggregate: bool = ..., __max_depth: int = ..., /) -> None: ... + + +class FlatTrace(Protocol): + def get_trace(self, /) -> FlatNodeMap: ... + def get_serialized_trace(self, /) -> SerializedNodeMap: ... + + +class FullTracer(Protocol): + @property + def trees(self, /) -> list[TraceNode]: ... + def browse_trace(self, /) -> Iterator[TraceNode]: ... + def get_nb_requests(self, __name: VariableName, /) -> int: ... + + +class PerformanceLog(Protocol): + def generate_graph(self, __dir_path: str, /) -> None: ... + def generate_performance_tables(self, __dir_path: str, /) -> None: ... + + +class SimpleTracer(Protocol): + @property + def stack(self, /) -> SimpleStack: ... + def record_calculation_start( + self, __name: VariableName, __period: PeriodInt | Period, / + ) -> None: ... + def record_calculation_end(self, /) -> None: ... + + +class TraceNode(Protocol): + @property + def children(self, /) -> list[TraceNode]: ... + @property + def end(self, /) -> Time: ... + @property + def name(self, /) -> str: ... + @property + def parameters(self, /) -> list[TraceNode]: ... + @property + def parent(self, /) -> None | TraceNode: ... + @property + def period(self, /) -> PeriodInt | Period: ... + @property + def start(self, /) -> Time: ... + @property + def value(self, /) -> None | VarArray: ... + def calculation_time(self, *, __round: bool = ...) -> Time: ... + def formula_time(self, /) -> Time: ... + def append_child(self, __node: TraceNode, /) -> None: ... + + +#: A stack of simple traces. +SimpleStack: TypeAlias = list[SimpleTraceMap] + # Variables #: For example "salary".