Skip to content

Commit

Permalink
use a class for CachedMapper caches instead of using a dict directly
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm committed Dec 20, 2024
1 parent 4451da8 commit df7db39
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 50 deletions.
5 changes: 3 additions & 2 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@


if TYPE_CHECKING:
from collections.abc import Hashable, Mapping
from collections.abc import Mapping

from pytato.function import FunctionDefinition, NamedCallResult
from pytato.target import Target
Expand Down Expand Up @@ -135,12 +135,13 @@ class CodeGenPreprocessor(ToIndexLambdaMixin, CopyMapper): # type: ignore[misc]
:class:`~pytato.array.Stack` :class:`~pytato.array.IndexLambda`
====================================== =====================================
"""
_FunctionCacheT: TypeAlias = CopyMapper._FunctionCacheT

def __init__(
self,
target: Target,
kernels_seen: dict[str, lp.LoopKernel] | None = None,
_function_cache: dict[Hashable, FunctionDefinition] | None = None
_function_cache: _FunctionCacheT | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
self.bound_arguments: dict[str, DataInterface] = {}
Expand Down
9 changes: 6 additions & 3 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@


if TYPE_CHECKING:
from typing import TypeAlias

import mpi4py.MPI

from pytato.function import FunctionDefinition, NamedCallResult
Expand Down Expand Up @@ -283,12 +285,13 @@ class _DistributedInputReplacer(CopyMapper):
instances for their assigned names. Also gathers names for
user-supplied inputs needed by the part
"""
_FunctionCacheT: TypeAlias = CopyMapper._FunctionCacheT

def __init__(self,
recvd_ary_to_name: Mapping[Array, str],
sptpo_ary_to_name: Mapping[Array, str],
name_to_output: Mapping[str, Array],
_function_cache: dict[Hashable, FunctionDefinition] | None = None,
_function_cache: _FunctionCacheT | None = None,
) -> None:
super().__init__(_function_cache=_function_cache)

Expand Down Expand Up @@ -337,9 +340,9 @@ def map_distributed_send(self, expr: DistributedSend) -> DistributedSend:
return new_send

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
key = self.get_cache_key(expr)
key = self._cache.get_key(expr)
try:
return self._cache[key]
return self._cache.retrieve(expr, key=key)
except KeyError:
pass

Expand Down
187 changes: 150 additions & 37 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"""
import dataclasses
import logging
from collections.abc import Hashable
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -79,17 +80,21 @@


if TYPE_CHECKING:
from collections.abc import Callable, Hashable, Iterable, Mapping
from collections.abc import Callable, Iterable, Mapping


ArrayOrNames: TypeAlias = Array | AbstractResultWithNamedArrays
MappedT = TypeVar("MappedT",
Array, AbstractResultWithNamedArrays, ArrayOrNames)
CacheExprT = TypeVar("CacheExprT") # used in CachedMapperCache
CacheKeyT = TypeVar("CacheKeyT") # used in CachedMapperCache
CacheResultT = TypeVar("CacheResultT") # used in CachedMapperCache
IndexOrShapeExpr = TypeVar("IndexOrShapeExpr")
R = frozenset[Array]

__doc__ = """
.. autoclass:: Mapper
.. autoclass:: CachedMapperCache
.. autoclass:: CachedMapper
.. autoclass:: TransformMapper
.. autoclass:: TransformMapperWithExtraArgs
Expand Down Expand Up @@ -246,61 +251,147 @@ def __call__(self,

# {{{ CachedMapper

class CachedMapperCache(Generic[CacheExprT, CacheKeyT, CacheResultT, P]):
"""
Cache for :class:`CachedMapper`.
.. automethod:: __init__
.. automethod:: get_key
.. automethod:: add
.. automethod:: retrieve
"""
def __init__(
self,
# FIXME: Figure out the right way to type annotate this
key_func: Callable[..., CacheKeyT]) -> None:
"""
Initialize the cache.
:arg key_func: Function to compute a hashable cache key from an input
expression and any extra arguments.
"""
self._key_func = key_func
self._expr_key_to_result: dict[CacheKeyT, CacheResultT] = {}

# FIXME: Can this be inlined?
def get_key(
self, expr: CacheExprT, *args: P.args, **kwargs: P.kwargs) -> CacheKeyT:
"""Compute the key for an input expression."""
return self._key_func(expr, *args, **kwargs)

def add(
self,
key_inputs:
CacheExprT
# FIXME: Figure out the right way to type annotate these
| tuple[CacheExprT, tuple[Any, ...], dict[str, Any]],
result: CacheResultT,
key: CacheKeyT | None = None) -> CacheResultT:
"""Cache a mapping result."""
if key is None:
if isinstance(key_inputs, tuple):
expr, key_args, key_kwargs = key_inputs
key = self._key_func(expr, *key_args, **key_kwargs)
else:
key = self._key_func(key_inputs)

self._expr_key_to_result[key] = result

return result

def retrieve(
self,
key_inputs:
CacheExprT
# FIXME: Figure out the right way to type annotate these
| tuple[CacheExprT, tuple[Any, ...], dict[str, Any]],
key: CacheKeyT | None = None) -> CacheResultT:
"""Retrieve the cached mapping result."""
if key is None:
if isinstance(key_inputs, tuple):
expr, key_args, key_kwargs = key_inputs
key = self._key_func(expr, *key_args, **key_kwargs)
else:
key = self._key_func(key_inputs)

return self._expr_key_to_result[key]


class CachedMapper(Mapper[ResultT, FunctionResultT, P]):
"""Mapper class that maps each node in the DAG exactly once. This loses some
information compared to :class:`Mapper` as a node is visited only from
one of its predecessors.
.. automethod:: get_cache_key
.. automethod:: get_function_definition_cache_key
.. automethod:: clone_for_callee
"""
# Not sure if there's a way to simplify this stuff?
_OtherP = ParamSpec("_OtherP")

_CacheType: type[Any] = CachedMapperCache[
ArrayOrNames,
Hashable,
ResultT, P]
_OtherResultT = TypeVar("_OtherResultT")
_CacheT: TypeAlias = CachedMapperCache[
ArrayOrNames,
Hashable,
_OtherResultT, _OtherP]

_FunctionCacheType: type[Any] = CachedMapperCache[
FunctionDefinition,
Hashable,
FunctionResultT, P]
_OtherFunctionResultT = TypeVar("_OtherFunctionResultT")
_FunctionCacheT: TypeAlias = CachedMapperCache[
FunctionDefinition,
Hashable,
_OtherFunctionResultT, _OtherP]

def __init__(
self,
# Arrays are cached separately for each call stack frame, but
# functions are cached globally
_function_cache: dict[Hashable, FunctionResultT] | None = None
_function_cache: _FunctionCacheT[FunctionResultT, P] | None = None
) -> None:
super().__init__()
self._cache: dict[Hashable, ResultT] = {}

def key_func(
expr: ArrayOrNames | FunctionDefinition,
*args: Any, **kwargs: Any) -> Hashable:
return (expr, args, tuple(sorted(kwargs.items())))

self._cache: CachedMapper._CacheT[ResultT, P] = \
CachedMapper._CacheType(key_func)

if _function_cache is not None:
function_cache = _function_cache
else:
function_cache = {}
function_cache = CachedMapper._FunctionCacheType(key_func)

self._function_cache: dict[Hashable, FunctionResultT] = function_cache

def get_cache_key(
self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs
) -> Hashable:
return (expr, *args, tuple(sorted(kwargs.items())))

def get_function_definition_cache_key(
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
) -> Hashable:
return (expr, *args, tuple(sorted(kwargs.items())))
self._function_cache: CachedMapper._FunctionCacheT[FunctionResultT, P] = \
function_cache

def rec(self, expr: ArrayOrNames, *args: P.args, **kwargs: P.kwargs) -> ResultT:
key = self.get_cache_key(expr, *args, **kwargs)
key = self._cache.get_key(expr, *args, **kwargs)
try:
return self._cache[key]
return self._cache.retrieve((expr, args, kwargs), key=key)
except KeyError:
result = super().rec(expr, *args, **kwargs)
self._cache[key] = result
return result
return self._cache.add(
(expr, args, kwargs),
super().rec(expr, *args, **kwargs),
key=key)

def rec_function_definition(
self, expr: FunctionDefinition, *args: P.args, **kwargs: P.kwargs
) -> FunctionResultT:
key = self.get_function_definition_cache_key(expr, *args, **kwargs)
key = self._function_cache.get_key(expr, *args, **kwargs)
try:
return self._function_cache[key]
return self._function_cache.retrieve((expr, args, kwargs), key=key)
except KeyError:
result = super().rec_function_definition(expr, *args, **kwargs)
self._function_cache[key] = result
return result
return self._function_cache.add(
(expr, args, kwargs),
super().rec_function_definition(expr, *args, **kwargs),
key=key)

def clone_for_callee(
self, function: FunctionDefinition) -> Self:
Expand All @@ -320,10 +411,19 @@ class TransformMapper(CachedMapper[ArrayOrNames, FunctionDefinition, []]):
other :class:`pytato.array.Array`\\ s.
Enables certain operations that can only be done if the mapping results are also
arrays (e.g., calling :meth:`~CachedMapper.get_cache_key` on them). Does not
implement default mapper methods; for that, see :class:`CopyMapper`.
arrays (e.g., computing a cache key from them). Does not implement default
mapper methods; for that, see :class:`CopyMapper`.
"""
_CacheType: type[Any] = CachedMapperCache[
ArrayOrNames, Hashable, ArrayOrNames, []]
_CacheT: TypeAlias = CachedMapperCache[
ArrayOrNames, Hashable, ArrayOrNames, []]

_FunctionCacheType: type[Any] = CachedMapperCache[
FunctionDefinition, Hashable, FunctionDefinition, []]
_FunctionCacheT: TypeAlias = CachedMapperCache[
FunctionDefinition, Hashable, FunctionDefinition, []]

def rec_ary(self, expr: Array) -> Array:
res = self.rec(expr)
assert isinstance(res, Array)
Expand All @@ -345,6 +445,18 @@ class TransformMapperWithExtraArgs(
The logic in :class:`TransformMapper` purposely does not take the extra
arguments to keep the cost of its each call frame low.
"""
_OtherP = ParamSpec("_OtherP")

_CacheType: type[Any] = CachedMapperCache[
ArrayOrNames, Hashable, ArrayOrNames, P]
_CacheT: TypeAlias = CachedMapperCache[
ArrayOrNames, Hashable, ArrayOrNames, _OtherP]

_FunctionCacheType: type[Any] = CachedMapperCache[
FunctionDefinition, Hashable, FunctionDefinition, P]
_FunctionCacheT: TypeAlias = CachedMapperCache[
FunctionDefinition, Hashable, FunctionDefinition, _OtherP]

def rec_ary(self, expr: Array, *args: P.args, **kwargs: P.kwargs) -> Array:
res = self.rec(expr, *args, **kwargs)
assert isinstance(res, Array)
Expand Down Expand Up @@ -1381,11 +1493,12 @@ class CachedMapAndCopyMapper(CopyMapper):
Mapper that applies *map_fn* to each node and copies it. Results of
traversals are memoized i.e. each node is mapped via *map_fn* exactly once.
"""
_FunctionCacheT: TypeAlias = CopyMapper._FunctionCacheT

def __init__(
self,
map_fn: Callable[[ArrayOrNames], ArrayOrNames],
_function_cache: dict[Hashable, FunctionDefinition] | None = None
_function_cache: _FunctionCacheT | None = None
) -> None:
super().__init__(_function_cache=_function_cache)
self.map_fn: Callable[[ArrayOrNames], ArrayOrNames] = map_fn
Expand All @@ -1395,12 +1508,12 @@ def clone_for_callee(
return type(self)(self.map_fn, _function_cache=self._function_cache)

def rec(self, expr: ArrayOrNames) -> ArrayOrNames:
if expr in self._cache:
return self._cache[expr]

result = super().rec(self.map_fn(expr))
self._cache[expr] = result
return result
key = self._cache.get_key(expr)
try:
return self._cache.retrieve(expr, key=key)
except KeyError:
return self._cache.add(
expr, super().rec(self.map_fn(expr)), key=key)

# }}}

Expand Down
16 changes: 9 additions & 7 deletions pytato/transform/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@


if TYPE_CHECKING:
from collections.abc import Collection, Hashable, Mapping
from collections.abc import Collection, Mapping
from typing import TypeAlias

from pytato.function import FunctionDefinition, NamedCallResult
from pytato.function import NamedCallResult
from pytato.loopy import LoopyCall


Expand Down Expand Up @@ -593,10 +594,12 @@ class AxisTagAttacher(CopyMapper):
"""
A mapper that tags the axes in a DAG as prescribed by *axis_to_tags*.
"""
_FunctionCacheT: TypeAlias = CopyMapper._FunctionCacheT

def __init__(self,
axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]],
tag_corresponding_redn_descr: bool,
_function_cache: dict[Hashable, FunctionDefinition] | None = None):
_function_cache: _FunctionCacheT | None = None):
super().__init__(_function_cache=_function_cache)
self.axis_to_tags: Mapping[tuple[Array, int], Collection[Tag]] = axis_to_tags
self.tag_corresponding_redn_descr: bool = tag_corresponding_redn_descr
Expand Down Expand Up @@ -644,18 +647,17 @@ def _attach_tags(self, expr: Array, rec_expr: Array) -> Array:
return result

def rec(self, expr: ArrayOrNames) -> Any:
key = self.get_cache_key(expr)
key = self._cache.get_key(expr)
try:
return self._cache[key]
return self._cache.retrieve(expr, key=key)
except KeyError:
result = Mapper.rec(self, expr)
if not isinstance(
expr, AbstractResultWithNamedArrays | DistributedSendRefHolder):
assert isinstance(expr, Array)
# type-ignore reason: passed "ArrayOrNames"; expected "Array"
result = self._attach_tags(expr, result) # type: ignore[arg-type]
self._cache[key] = result
return result
return self._cache.add(expr, result, key=key)

def map_named_call_result(self, expr: NamedCallResult) -> Array:
raise NotImplementedError(
Expand Down
2 changes: 1 addition & 1 deletion test/test_apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(self, fft_vec_gatherer):
arrays = fft_vec_gatherer.level_to_arrays[lev]
rec_arrays = [self.rec(ary) for ary in arrays]
# reset cache so that the partial subs are not stored
self._cache = {}
self._cache = type(self._cache)(lambda expr: expr)
lev_array = pt.concatenate(rec_arrays, axis=0)
assert lev_array.shape == (fft_vec_gatherer.n,)

Expand Down

0 comments on commit df7db39

Please sign in to comment.