diff --git a/doc/conf.py b/doc/conf.py index b5f851e11..081642f1d 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -46,9 +46,9 @@ nitpick_ignore_regex = [ ["py:class", r"numpy.(u?)int[\d]+"], ["py:class", r"typing_extensions(.+)"], - # As of 2022-10-20, it doesn't look like there's sphinx documentation + # As of 2023-10-05, it doesn't look like there's sphinx documentation # available. - ["py:class", r"immutables\.(.+)"], + ["py:class", r"immutabledict(.*)"], # https://github.com/python-attrs/attrs/issues/1073 ["py:mod", "attrs"], ] diff --git a/pytato/array.py b/pytato/array.py index 1271631fe..73c4c427e 100644 --- a/pytato/array.py +++ b/pytato/array.py @@ -178,7 +178,7 @@ ScalarExpression, IntegralT, INT_CLASSES, get_reduction_induction_variables) import re -from immutables import Map +from immutabledict import immutabledict # {{{ get a type variable that represents the type of '...' @@ -556,7 +556,7 @@ def _unary_op(self, op: Any) -> Array: indices = tuple(var(f"_{i}") for i in range(self.ndim)) expr = op(var("_in0")[indices]) - bindings = Map({"_in0": self}) + bindings: immutabledict[str, Array] = immutabledict({"_in0": self}) return IndexLambda( expr=expr, shape=self.shape, @@ -564,7 +564,7 @@ def _unary_op(self, op: Any) -> Array: bindings=bindings, tags=_get_default_tags(), axes=_get_default_axes(self.ndim), - var_to_reduction_descr=Map()) + var_to_reduction_descr=immutabledict()) __mul__ = partialmethod(_binary_op, operator.mul) __rmul__ = partialmethod(_binary_op, operator.mul, reverse=True) @@ -895,17 +895,18 @@ def with_tagged_reduction(self, f" '{self.var_to_reduction_descr.keys()}'," f" got '{reduction_variable}'.") - assert isinstance(self.var_to_reduction_descr, Map) - new_var_to_redn_descr = self.var_to_reduction_descr.set( - reduction_variable, - self.var_to_reduction_descr[reduction_variable].tagged(tag)) + assert isinstance(self.var_to_reduction_descr, immutabledict) + new_var_to_redn_descr = dict(self.var_to_reduction_descr) + new_var_to_redn_descr[reduction_variable] = \ + self.var_to_reduction_descr[reduction_variable].tagged(tag) return type(self)(expr=self.expr, shape=self.shape, dtype=self.dtype, bindings=self.bindings, axes=self.axes, - var_to_reduction_descr=new_var_to_redn_descr, + var_to_reduction_descr=immutabledict + (new_var_to_redn_descr), tags=self.tags) # }}} @@ -1006,7 +1007,7 @@ def _access_descr_to_axis_len(self else: descr_to_axis_len[descr] = arg_axis_len - return Map(descr_to_axis_len) + return immutabledict(descr_to_axis_len) @cached_property def shape(self) -> ShapeType: @@ -1063,14 +1064,16 @@ def with_tagged_reduction(self, # }}} - assert isinstance(self.redn_axis_to_redn_descr, Map) - new_redn_axis_to_redn_descr = self.redn_axis_to_redn_descr.set( - redn_axis, self.redn_axis_to_redn_descr[redn_axis].tagged(tag)) + assert isinstance(self.redn_axis_to_redn_descr, immutabledict) + new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr) + new_redn_axis_to_redn_descr[redn_axis] = \ + self.redn_axis_to_redn_descr[redn_axis].tagged(tag) return type(self)(access_descriptors=self.access_descriptors, args=self.args, axes=self.axes, - redn_axis_to_redn_descr=new_redn_axis_to_redn_descr, + redn_axis_to_redn_descr=immutabledict + (new_redn_axis_to_redn_descr), tags=self.tags, index_to_access_descr=self.index_to_access_descr, ) @@ -1079,7 +1082,7 @@ def with_tagged_reduction(self, EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P[a-zA-Z])|(?P\.\.\.))\s*") -def _normalize_einsum_out_subscript(subscript: str) -> Map[str, +def _normalize_einsum_out_subscript(subscript: str) -> immutabledict[str, EinsumAxisDescriptor]: """ Normalizes the output subscript of an einsum (provided in the explicit @@ -1119,19 +1122,20 @@ def _normalize_einsum_out_subscript(subscript: str) -> Map[str, raise ValueError("Used an input more than once to refer to the" f" output axis in '{subscript}") - return Map({idx: EinsumElementwiseAxis(i) + return immutabledict({idx: EinsumElementwiseAxis(i) for i, idx in enumerate(normalized_indices)}) def _normalize_einsum_in_subscript(subscript: str, in_operand: Array, - index_to_descr: Map[str, + index_to_descr: immutabledict[str, EinsumAxisDescriptor], - index_to_axis_length: Map[str, + index_to_axis_length: immutabledict[str, ShapeComponent], ) -> Tuple[Tuple[EinsumAxisDescriptor, ...], - Map[str, EinsumAxisDescriptor], - Map[str, ShapeComponent]]: + immutabledict + [str, EinsumAxisDescriptor], + immutabledict[str, ShapeComponent]]: """ Normalizes the subscript for an input operand in an einsum. Returns ``(access_descrs, updated_index_to_descr, updated_to_index_to_axis_length)``, @@ -1174,12 +1178,14 @@ def _normalize_einsum_in_subscript(subscript: str, f"of corresponding operand ({in_operand.ndim}).") in_operand_axis_descrs = [] + index_to_axis_length_dict = dict(index_to_axis_length) + index_to_descr_dict = dict(index_to_descr) for iaxis, index_char in enumerate(normalized_indices): in_axis_len = in_operand.shape[iaxis] - if index_char in index_to_descr: - if index_char in index_to_axis_length: - seen_axis_len = index_to_axis_length[index_char] + if index_char in index_to_descr_dict: + if index_char in index_to_axis_length_dict: + seen_axis_len = index_to_axis_length_dict[index_char] if not are_shape_components_equal(in_axis_len, seen_axis_len): if are_shape_components_equal(in_axis_len, 1): @@ -1187,24 +1193,24 @@ def _normalize_einsum_in_subscript(subscript: str, pass elif are_shape_components_equal(seen_axis_len, 1): # Broadcast to the length of the current axis - index_to_axis_length = (index_to_axis_length - .set(index_char, in_axis_len)) + index_to_axis_length_dict[index_char] = in_axis_len else: raise ValueError("Got conflicting lengths for" f" '{index_char}' -- {in_axis_len}," f" {seen_axis_len}.") else: - index_to_axis_length = index_to_axis_length.set(index_char, - in_axis_len) + index_to_axis_length_dict[index_char] = in_axis_len else: - redn_sr_no = len([descr for descr in index_to_descr.values() + redn_sr_no = len([descr for descr in index_to_descr_dict.values() if isinstance(descr, EinsumReductionAxis)]) redn_axis_descr = EinsumReductionAxis(redn_sr_no) - index_to_descr = index_to_descr.set(index_char, redn_axis_descr) - index_to_axis_length = index_to_axis_length.set(index_char, - in_axis_len) + index_to_descr_dict[index_char] = redn_axis_descr + index_to_axis_length_dict[index_char] = in_axis_len - in_operand_axis_descrs.append(index_to_descr[index_char]) + in_operand_axis_descrs.append(index_to_descr_dict[index_char]) + + index_to_axis_length = immutabledict(index_to_axis_length_dict) + index_to_descr = immutabledict(index_to_descr_dict) return (tuple(in_operand_axis_descrs), index_to_descr, index_to_axis_length) @@ -1239,7 +1245,7 @@ def einsum(subscripts: str, *operands: Array, ) index_to_descr = _normalize_einsum_out_subscript(out_spec) - index_to_axis_length: Map[str, ShapeComponent] = Map() + index_to_axis_length: immutabledict[str, ShapeComponent] = immutabledict() access_descriptors = [] for in_spec, in_operand in zip(in_specs, operands): @@ -1274,7 +1280,7 @@ def einsum(subscripts: str, *operands: Array, if isinstance(descr, EinsumElementwiseAxis)}) ), - redn_axis_to_redn_descr=Map(redn_axis_to_redn_descr), + redn_axis_to_redn_descr=immutabledict(redn_axis_to_redn_descr), index_to_access_descr=index_to_descr, ) @@ -2088,10 +2094,10 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType, fill_value = dtype.type(fill_value) return IndexLambda(expr=fill_value, shape=shape, dtype=dtype, - bindings=Map(), + bindings=immutabledict(), tags=_get_default_tags(), axes=_get_default_axes(len(shape)), - var_to_reduction_descr=Map()) + var_to_reduction_descr=immutabledict()) def zeros(shape: ConvertibleToShape, dtype: Any = float, @@ -2134,10 +2140,10 @@ def eye(N: int, M: Optional[int] = None, k: int = 0, # noqa: N803 raise ValueError(f"k must be int, got {type(k)}.") return IndexLambda(expr=parse(f"1 if ((_1 - _0) == {k}) else 0"), - shape=(N, M), dtype=dtype, bindings=Map({}), + shape=(N, M), dtype=dtype, bindings=immutabledict({}), tags=_get_default_tags(), axes=_get_default_axes(2), - var_to_reduction_descr=Map()) + var_to_reduction_descr=immutabledict()) # }}} @@ -2229,10 +2235,10 @@ def arange(*args: Any, **kwargs: Any) -> Array: from pymbolic.primitives import Variable return IndexLambda(expr=start + Variable("_0") * step, - shape=(size,), dtype=dtype, bindings=Map(), + shape=(size,), dtype=dtype, bindings=immutabledict(), tags=_get_default_tags(), axes=_get_default_axes(1), - var_to_reduction_descr=Map()) + var_to_reduction_descr=immutabledict()) # }}} @@ -2343,7 +2349,7 @@ def logical_not(x: ArrayOrScalar) -> Union[Array, bool]: bindings={"_in0": x}, tags=_get_default_tags(), axes=_get_default_axes(len(x.shape)), - var_to_reduction_descr=Map()) + var_to_reduction_descr=immutabledict()) # }}} @@ -2396,10 +2402,10 @@ def where(condition: ArrayOrScalar, expr=prim.If(expr1, expr2, expr3), shape=result_shape, dtype=dtype, - bindings=Map(bindings), + bindings=immutabledict(bindings), tags=_get_default_tags(), axes=_get_default_axes(len(result_shape)), - var_to_reduction_descr=Map()) + var_to_reduction_descr=immutabledict()) # }}} @@ -2492,12 +2498,13 @@ def make_index_lambda( # }}} return IndexLambda(expr=expression, - bindings=Map(bindings), + bindings=immutabledict(bindings), shape=shape, dtype=dtype, tags=_get_default_tags(), axes=_get_default_axes(len(shape)), - var_to_reduction_descr=Map(processed_var_to_reduction_descr)) + var_to_reduction_descr=immutabledict + (processed_var_to_reduction_descr)) # }}} @@ -2578,10 +2585,10 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array: shape)), shape=shape, dtype=array.dtype, - bindings=Map({"in": array}), + bindings=immutabledict({"in": array}), tags=_get_default_tags(), axes=_get_default_axes(len(shape)), - var_to_reduction_descr=Map()) + var_to_reduction_descr=immutabledict()) # }}} diff --git a/pytato/cmath.py b/pytato/cmath.py index 88fca8fa6..38c520c7e 100644 --- a/pytato/cmath.py +++ b/pytato/cmath.py @@ -62,7 +62,7 @@ _get_default_axes, _get_default_tags) from pytato.scalar_expr import SCALAR_CLASSES from pymbolic import var -from immutables import Map +from immutabledict import immutabledict def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], @@ -113,10 +113,10 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...], return IndexLambda( expr=prim.Call(var(f"pytato.c99.{func_name}"), tuple(sym_args)), - shape=shape, dtype=ret_dtype, bindings=Map(bindings), + shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings), tags=_get_default_tags(), axes=_get_default_axes(len(shape)), - var_to_reduction_descr=Map(), + var_to_reduction_descr=immutabledict(), ) diff --git a/pytato/codegen.py b/pytato/codegen.py index 0bc85d649..5a76cd3a5 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -24,6 +24,7 @@ import dataclasses from typing import Union, Dict, Tuple, List, Any, Optional +from immutabledict import immutabledict from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder, DataInterface, SizeParam, InputArgumentBase, @@ -173,9 +174,10 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: # }}} - bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array) + bindings: immutabledict[str, Any] = immutabledict( + {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) - for name, subexpr in sorted(expr.bindings.items())} + for name, subexpr in sorted(expr.bindings.items())}) return LoopyCall(translation_unit=translation_unit, bindings=bindings, @@ -282,8 +284,8 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult: for out in outputs.values())) # only look for dependencies between the outputs - deps = {name: get_deps(output.expr) - for name, output in outputs.items()} + deps: immutabledict[str, Any] = immutabledict({name: get_deps(output.expr) + for name, output in outputs.items()}) # represent deps in terms of output names output_expr_to_name = {output.expr: name for name, output in outputs.items()} diff --git a/pytato/distributed/partition.py b/pytato/distributed/partition.py index c8ab048f3..426a8cff5 100644 --- a/pytato/distributed/partition.py +++ b/pytato/distributed/partition.py @@ -68,7 +68,7 @@ List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional) import attrs -from immutables import Map +from immutabledict import immutabledict from pytools.graph import CycleError from pytools import memoize_method @@ -405,11 +405,11 @@ def _make_distributed_partition( partition_input_names=frozenset( comm_replacer.partition_input_name_to_placeholder.keys()), output_names=frozenset(name_to_ouput.keys()), - name_to_recv_node=Map({ + name_to_recv_node=immutabledict({ recvd_ary_to_name[local_recv_id_to_recv_node[recv_id]]: local_recv_id_to_recv_node[recv_id] for recv_id in comm_ids.recv_ids}), - name_to_send_nodes=Map(name_to_send_nodes)) + name_to_send_nodes=immutabledict(name_to_send_nodes)) result = DistributedGraphPartition( parts=parts, diff --git a/pytato/function.py b/pytato/function.py index b053831a0..34151945f 100644 --- a/pytato/function.py +++ b/pytato/function.py @@ -49,7 +49,7 @@ from typing import (Callable, Dict, FrozenSet, Tuple, Union, TypeVar, Optional, Hashable, Sequence, ClassVar, Iterator, Iterable, Mapping) -from immutables import Map +from immutabledict import immutabledict from functools import cached_property from pytato.array import (Array, AbstractResultWithNamedArrays, Placeholder, NamedArray, ShapeType, _dtype_any, @@ -126,7 +126,7 @@ class FunctionDefinition(Taggable): """ parameters: FrozenSet[str] return_type: ReturnType - returns: Map[str, Array] + returns: immutabledict[str, Array] tags: FrozenSet[Tag] = attrs.field(kw_only=True) @cached_property @@ -142,7 +142,7 @@ def _placeholders(self) -> Mapping[str, Placeholder]: frozenset() ) - return Map({input_arg.name: input_arg + return immutabledict({input_arg.name: input_arg for input_arg in all_input_args if isinstance(input_arg, Placeholder)}) @@ -188,7 +188,8 @@ def __call__(self, **kwargs: Array # }}} - call_site = Call(self, bindings=Map(kwargs), tags=_get_default_tags()) + call_site = Call(self, bindings=immutabledict(kwargs), + tags=_get_default_tags()) if self.return_type == ReturnType.ARRAY: return call_site["_"] @@ -253,7 +254,7 @@ def dtype(self) -> _dtype_any: return self._container.function.returns[self.name].dtype -# eq=False to avoid equality comparison without EqualityMaper +# eq=False to avoid equality comparison without EqualityMapper @attrs.define(frozen=True, eq=False, hash=True, cache_hash=True, repr=False) class Call(AbstractResultWithNamedArrays): """ @@ -270,7 +271,7 @@ class Call(AbstractResultWithNamedArrays): """ function: FunctionDefinition - bindings: Map[str, Array] + bindings: immutabledict[str, Array] _mapper_method: ClassVar[str] = "map_call" @@ -371,7 +372,7 @@ def trace_call(f: Callable[..., ReturnT], function = FunctionDefinition( frozenset(pl_arg.name for pl_arg in pl_args) | frozenset(pl_kwargs), return_type, - Map(returns), + immutabledict(returns), tags=_get_default_tags() | (frozenset([FunctionIdentifier(identifier)]) if identifier else frozenset()) diff --git a/pytato/loopy.py b/pytato/loopy.py index 3d1ee1572..2158721c7 100644 --- a/pytato/loopy.py +++ b/pytato/loopy.py @@ -37,7 +37,7 @@ from pytato.scalar_expr import (SubstitutionMapper, ScalarExpression, EvaluationMapper, IntegralT) from pytools import memoize_method -from immutables import Map +from immutabledict import immutabledict import islpy as isl __doc__ = r""" @@ -78,7 +78,7 @@ class LoopyCall(AbstractResultWithNamedArrays): :mod:`loopy` translation unit. """ translation_unit: "lp.TranslationUnit" - bindings: Dict[str, ArrayOrScalar] + bindings: immutabledict[str, ArrayOrScalar] entrypoint: str _mapper_method: ClassVar[str] = "map_loopy_call" @@ -212,18 +212,19 @@ def call_loopy(translation_unit: "lp.TranslationUnit", # {{{ perform shape inference here - bindings = extend_bindings_with_shape_inference(translation_unit[entrypoint], - Map(bindings)) + bindings_new = extend_bindings_with_shape_inference(translation_unit[entrypoint], + immutabledict(bindings)) + del bindings # }}} for arg in translation_unit[entrypoint].args: if arg.is_input: - if arg.name not in bindings: + if arg.name not in bindings_new: raise ValueError(f"Kernel '{entrypoint}' expects an input" f" '{arg.name}'") - arg_binding = bindings[arg.name] + arg_binding = bindings_new[arg.name] if isinstance(arg, (lp.ArrayArg, lp.ConstantArg)): if not isinstance(arg_binding, Array): @@ -242,7 +243,7 @@ def call_loopy(translation_unit: "lp.TranslationUnit", # {{{ infer types of the translation_unit - for name, ary in bindings.items(): + for name, ary in bindings_new.items(): if translation_unit[entrypoint].arg_dict[name].dtype not in [lp.auto, None]: continue @@ -265,7 +266,7 @@ def call_loopy(translation_unit: "lp.TranslationUnit", translation_unit = translation_unit.with_entrypoints(frozenset()) - return LoopyCall(translation_unit, bindings, entrypoint, + return LoopyCall(translation_unit, bindings_new, entrypoint, tags=_get_default_tags()) @@ -379,8 +380,8 @@ def _get_pt_dim_expr(dim: Union[IntegralT, Array]) -> ScalarExpression: def extend_bindings_with_shape_inference(knl: lp.LoopKernel, - bindings: Map[str, ArrayOrScalar] - ) -> Dict[str, ArrayOrScalar]: + bindings: immutabledict[str, ArrayOrScalar] + ) -> immutabledict[str, ArrayOrScalar]: from functools import reduce from loopy.symbolic import get_dependencies as lpy_get_deps from loopy.kernel.array import ArrayBase @@ -478,6 +479,8 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel, as_pt_size_param = EvaluationMapper({_pt_var_to_global_namespace(arg.name): arg for arg in pt_size_params}) + bindings_dict = dict(bindings) + for var, val in solutions.items(): # map the pymbolic expression back into an expression in terms of # pt.SizeParams @@ -494,9 +497,9 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel, # }}} - bindings = bindings.set(var, val) + bindings_dict[var] = val - return dict(bindings) + return immutabledict(bindings_dict) # }}} diff --git a/pytato/raising.py b/pytato/raising.py index 3cdf77e54..188ccdaff 100644 --- a/pytato/raising.py +++ b/pytato/raising.py @@ -11,7 +11,7 @@ from pytato.scalar_expr import ScalarType, ScalarExpression, Reduce, SCALAR_CLASSES from pytato.reductions import ReductionOperation from dataclasses import dataclass -from immutables import Map +from immutabledict import immutabledict __doc__ = """ @@ -101,7 +101,7 @@ class ReduceOp(HighLevelOp): """ op: ReductionOperation x: Array - axes: Map[int, str] + axes: immutabledict[int, str] # }}} @@ -319,7 +319,7 @@ def index_lambda_to_high_level_op(expr: IndexLambda) -> HighLevelOp: .expr .inner_expr .aggregate.name], - axes=Map({i: idx.name + axes=immutabledict({i: idx.name for i, idx in enumerate(expr .expr .inner_expr diff --git a/pytato/reductions.py b/pytato/reductions.py index cd4f509f6..6497e1db9 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -34,7 +34,7 @@ from pytato.array import ShapeType, Array, make_index_lambda, ReductionDescriptor from pytato.scalar_expr import ScalarExpression, Reduce, INT_CLASSES -from immutables import Map +from immutabledict import immutabledict import pymbolic.primitives as prim # {{{ docs @@ -213,7 +213,7 @@ def _get_reduction_indices_bounds(shape: ShapeType, indices.append(prim.Variable(f"_{n_out_dims}")) n_out_dims += 1 - return indices, Map(redn_bounds) + return indices, immutabledict(redn_bounds) def _get_var_to_redn_descr(shape: ShapeType, @@ -258,7 +258,7 @@ def _get_var_to_redn_descr(shape: ShapeType, var_to_redn_descr[idx] = redn_descr n_redn_dims += 1 - return Map(var_to_redn_descr) + return immutabledict(var_to_redn_descr) def _make_reduction_lambda( diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index 98724ad87..0e21c2856 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -42,7 +42,7 @@ StringifyMapperBase) from pymbolic.mapper import CombineMapper as CombineMapperBase from pymbolic.mapper.collector import TermCollector as TermCollectorBase -from immutables import Map +from immutabledict import immutabledict import pymbolic.primitives as prim import numpy as np import re @@ -113,7 +113,7 @@ class SubstitutionMapper(SubstitutionMapperBase): def map_reduce(self, expr: Reduce) -> ScalarExpression: return Reduce(self.rec(expr.inner_expr), op=expr.op, - bounds=Map( + bounds=immutabledict( {name: self.rec(bound) for name, bound in expr.bounds.items()})) diff --git a/pytato/stringifier.py b/pytato/stringifier.py index e37027922..8aac8d340 100644 --- a/pytato/stringifier.py +++ b/pytato/stringifier.py @@ -31,7 +31,7 @@ from pytato.array import (Array, DataWrapper, DictOfNamedArrays, Axis, IndexLambda, ReductionDescriptor) from pytato.loopy import LoopyCall -from immutables import Map +from immutabledict import immutabledict import attrs @@ -77,7 +77,7 @@ def __call__(self, expr: Any, depth: int = 0) -> str: # type: ignore[override] def map_foreign(self, expr: Any, depth: int) -> str: # type: ignore[override] if isinstance(expr, tuple): return "(" + ", ".join(self.rec(el, depth) for el in expr) + ")" - elif isinstance(expr, (dict, Map)): + elif isinstance(expr, (dict, immutabledict)): return ("{" + ", ".join(f"{key!r}: {self.rec(val, depth)}" for key, val diff --git a/pytato/target/loopy/__init__.py b/pytato/target/loopy/__init__.py index 68d051bb6..080eaa2c5 100644 --- a/pytato/target/loopy/__init__.py +++ b/pytato/target/loopy/__init__.py @@ -54,7 +54,7 @@ from functools import cached_property from typing import Any, Mapping, Optional, Callable, Dict, TYPE_CHECKING -from immutables import Map +from immutabledict import immutabledict from pytato.target import Target, BoundProgram from pytato.tags import ImplementationStrategy @@ -137,7 +137,8 @@ class BoundPyOpenCLProgram(BoundProgram): """ program: loopy.TranslationUnit _processed_bound_args_cache: Dict[pyopencl.Context, - Map[str, Any]] = field(default_factory=dict) + immutabledict[str, Any]] = \ + field(default_factory=dict) def copy(self, *, program: Optional[loopy.TranslationUnit] = None, @@ -169,7 +170,7 @@ def _get_processed_bound_arguments(self, queue: pyopencl.CommandQueue, allocator: Optional[Callable[ [int], pyopencl.MemoryObject]], - ) -> Map[str, Any]: + ) -> immutabledict[str, Any]: import pyopencl.array as cla cache_key = queue.context @@ -193,7 +194,7 @@ def _get_processed_bound_arguments(self, " numpy array, pyopencl array or scalar." f" Got {type(bnd_arg).__name__} for '{name}'.") - result = Map(proc_bnd_args) + result: immutabledict[str, Any] = immutabledict(proc_bnd_args) assert set(result.keys()) == set(self.bound_arguments.keys()) self._processed_bound_args_cache[cache_key] = result return result diff --git a/pytato/target/python/numpy_like.py b/pytato/target/python/numpy_like.py index 10c5fe4f9..998208b96 100644 --- a/pytato/target/python/numpy_like.py +++ b/pytato/target/python/numpy_like.py @@ -38,7 +38,7 @@ Reshape, Array, DictOfNamedArrays, IndexBase, DataInterface, NormalizedSlice, ShapeComponent, IndexExpr, ArrayOrScalar, NamedArray) -from immutables import Map +from immutabledict import immutabledict from pytato.scalar_expr import SCALAR_CLASSES from pytato.utils import are_shape_components_equal from pytato.raising import BinaryOpType, C99CallOp @@ -601,4 +601,4 @@ def generate_numpy_like(expr: Union[Array, Mapping[str, Array], DictOfNamedArray program, function_name, expected_arguments=frozenset(cgen_mapper.arg_names), - bound_arguments=Map(cgen_mapper.bound_arguments)) + bound_arguments=immutabledict(cgen_mapper.bound_arguments)) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index b7ac02b2f..10fc3b086 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -30,7 +30,7 @@ import logging import numpy as np -from immutables import Map +from immutabledict import immutabledict from typing import (Any, Callable, Dict, FrozenSet, Union, TypeVar, Set, Generic, List, Mapping, Iterable, Tuple, Optional, TYPE_CHECKING, Hashable) @@ -261,7 +261,7 @@ def rec_idx_or_size_tuple(self, situp: Tuple[IndexOrShapeExpr, ...] for s in situp) def map_index_lambda(self, expr: IndexLambda) -> Array: - bindings: Mapping[str, Array] = Map({ + bindings: Mapping[str, Array] = immutabledict({ name: self.rec(subexpr) for name, subexpr in sorted(expr.bindings.items())}) return IndexLambda(expr=expr.expr, @@ -354,9 +354,10 @@ def map_dict_of_named_arrays(self, ) def map_loopy_call(self, expr: LoopyCall) -> LoopyCall: - bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array) + bindings: immutabledict[Any, Any] = immutabledict( + {name: (self.rec(subexpr) if isinstance(subexpr, Array) else subexpr) - for name, subexpr in sorted(expr.bindings.items())} + for name, subexpr in sorted(expr.bindings.items())}) return LoopyCall(translation_unit=expr.translation_unit, bindings=bindings, @@ -406,13 +407,13 @@ def map_function_definition(self, for name, ret in expr.returns.items()} return FunctionDefinition(expr.parameters, expr.return_type, - Map(new_returns), + immutabledict(new_returns), tags=expr.tags ) def map_call(self, expr: Call) -> AbstractResultWithNamedArrays: return Call(self.map_function_definition(expr.function), - Map({name: self.rec(bnd) + immutabledict({name: self.rec(bnd) for name, bnd in expr.bindings.items()}), tags=expr.tags, ) @@ -578,10 +579,11 @@ def map_dict_of_named_arrays(self, def map_loopy_call(self, expr: LoopyCall, *args: Any, **kwargs: Any) -> LoopyCall: - bindings = {name: (self.rec(subexpr, *args, **kwargs) + bindings: immutabledict[Any, Any] = immutabledict( + {name: (self.rec(subexpr, *args, **kwargs) if isinstance(subexpr, Array) else subexpr) - for name, subexpr in sorted(expr.bindings.items())} + for name, subexpr in sorted(expr.bindings.items())}) return LoopyCall(translation_unit=expr.translation_unit, bindings=bindings, @@ -634,7 +636,7 @@ def map_function_definition(self, expr: FunctionDefinition, def map_call(self, expr: Call, *args: Any, **kwargs: Any) -> AbstractResultWithNamedArrays: return Call(self.map_function_definition(expr.function, *args, **kwargs), - Map({name: self.rec(bnd, *args, **kwargs) + immutabledict({name: self.rec(bnd, *args, **kwargs) for name, bnd in expr.bindings.items()}), tags=expr.tags, ) @@ -1312,7 +1314,7 @@ def map_index_lambda(self, expr: IndexLambda) -> MPMSMaterializerAccumulator: new_expr = IndexLambda(expr=expr.expr, shape=expr.shape, dtype=expr.dtype, - bindings=Map({bnd_name: bnd.expr + bindings=immutabledict({bnd_name: bnd.expr for bnd_name, bnd in sorted(children_rec.items())}), axes=expr.axes, var_to_reduction_descr=expr.var_to_reduction_descr, @@ -1441,13 +1443,13 @@ def map_function_definition(self, expr: FunctionDefinition new_returns = {name: new_mapper(ret) for name, ret in expr.returns.items()} return FunctionDefinition(expr.parameters, expr.return_type, - Map(new_returns), + immutabledict(new_returns), tags=expr.tags) @memoize_method def map_call(self, expr: Call) -> Call: return Call(self.map_function_definition(expr.function), - Map({name: self.rec(bnd).expr + immutabledict({name: self.rec(bnd).expr for name, bnd in expr.bindings.items()}), tags=expr.tags) diff --git a/pytato/transform/calls.py b/pytato/transform/calls.py index a25dd8311..8ba0ccdc7 100644 --- a/pytato/transform/calls.py +++ b/pytato/transform/calls.py @@ -26,7 +26,7 @@ THE SOFTWARE. """ -from immutables import Map +from immutabledict import immutabledict from pytato.transform import (ArrayOrNames, CopyMapper) from pytato.array import (AbstractResultWithNamedArrays, Array, DictOfNamedArrays, Placeholder) @@ -44,7 +44,8 @@ class PlaceholderSubstitutor(CopyMapper): A mapping from the placeholder name to the array that it is to be substituted with. """ - def __init__(self, substitutions: Map[str, Array]) -> None: + + def __init__(self, substitutions: immutabledict[str, Array]) -> None: super().__init__() self.substitutions = substitutions diff --git a/pytato/transform/einsum_distributive_law.py b/pytato/transform/einsum_distributive_law.py index 106c5a4fe..e81550e1e 100644 --- a/pytato/transform/einsum_distributive_law.py +++ b/pytato/transform/einsum_distributive_law.py @@ -41,7 +41,7 @@ Stack, Concatenate, Roll, AxisPermutation, IndexBase, Reshape, InputArgumentBase) from pytato.raising import HighLevelOp -from immutables import Map +from immutabledict import immutabledict from pytools.tag import Tag from pytato.utils import are_shapes_equal import numpy as np @@ -74,10 +74,10 @@ class DoDistribute(EinsumDistributiveLawDescriptor): @attrs.frozen class _EinsumDistributiveLawMapperContext: access_descriptors: Tuple[Tuple[EinsumAxisDescriptor, ...], ...] - surrounding_args: Map[int, Array] - redn_axis_to_redn_descr: Map[EinsumReductionAxis, + surrounding_args: immutabledict[int, Array] + redn_axis_to_redn_descr: immutabledict[EinsumReductionAxis, ReductionDescriptor] - index_to_access_descr: Map[str, EinsumAxisDescriptor] + index_to_access_descr: immutabledict[str, EinsumAxisDescriptor] axes: AxesT = attrs.field(kw_only=True) tags: FrozenSet[Tag] = attrs.field(kw_only=True) @@ -223,7 +223,7 @@ def map_index_lambda(self, expr=expr.expr, shape=expr.shape, dtype=expr.dtype, - bindings=Map({name: self.rec(bnd, None) + bindings=immutabledict({name: self.rec(bnd, None) for name, bnd in sorted(expr.bindings.items())}), var_to_reduction_descr=expr.var_to_reduction_descr, tags=expr.tags, @@ -243,11 +243,11 @@ def map_einsum(self, else: ctx = _EinsumDistributiveLawMapperContext( expr.access_descriptors, - Map({iarg: arg + immutabledict({iarg: arg for iarg, arg in enumerate(expr.args) if iarg != distributive_law_descr.ioperand}), - Map(expr.redn_axis_to_redn_descr), - Map(expr.index_to_access_descr), + immutabledict(expr.redn_axis_to_redn_descr), + immutabledict(expr.index_to_access_descr), tags=expr.tags, axes=expr.axes, ) diff --git a/pytato/transform/lower_to_index_lambda.py b/pytato/transform/lower_to_index_lambda.py index 4aa1d4cca..ad16facd2 100644 --- a/pytato/transform/lower_to_index_lambda.py +++ b/pytato/transform/lower_to_index_lambda.py @@ -29,7 +29,7 @@ import pymbolic.primitives as prim from typing import List, Any, Dict, Tuple, TypeVar, TYPE_CHECKING -from immutables import Map +from immutabledict import immutabledict from pytools import UniqueNameGenerator from pytato.array import (Array, IndexLambda, Stack, Concatenate, Einsum, Reshape, Roll, AxisPermutation, @@ -96,7 +96,7 @@ def map_index_lambda(self, expr: IndexLambda) -> IndexLambda: return IndexLambda(expr=expr.expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=Map({name: self.rec(bnd) + bindings=immutabledict({name: self.rec(bnd) for name, bnd in sorted(expr.bindings.items())}), axes=expr.axes, @@ -133,8 +133,8 @@ def map_stack(self, expr: Stack) -> IndexLambda: shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, - bindings=Map(bindings), - var_to_reduction_descr=Map(), + bindings=immutabledict(bindings), + var_to_reduction_descr=immutabledict(), tags=expr.tags) def map_concatenate(self, expr: Concatenate) -> IndexLambda: @@ -180,9 +180,9 @@ def get_subscript(array_index: int, offset: ScalarExpression) -> Subscript: return IndexLambda(expr=concat_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=Map(bindings), + bindings=immutabledict(bindings), axes=expr.axes, - var_to_reduction_descr=Map(), + var_to_reduction_descr=immutabledict(), tags=expr.tags) def map_einsum(self, expr: Einsum) -> IndexLambda: @@ -249,9 +249,9 @@ def map_einsum(self, expr: Einsum) -> IndexLambda: return IndexLambda(expr=inner_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=Map(bindings), + bindings=immutabledict(bindings), axes=expr.axes, - var_to_reduction_descr=Map(var_to_redn_descr), + var_to_reduction_descr=immutabledict(var_to_redn_descr), tags=expr.tags) def map_roll(self, expr: Roll) -> IndexLambda: @@ -275,10 +275,10 @@ def map_roll(self, expr: Roll) -> IndexLambda: return IndexLambda(expr=index_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=Map({name: self.rec(bnd) + bindings=immutabledict({name: self.rec(bnd) for name, bnd in bindings.items()}), axes=expr.axes, - var_to_reduction_descr=Map(), + var_to_reduction_descr=immutabledict(), tags=expr.tags) def map_contiguous_advanced_index(self, @@ -338,11 +338,11 @@ def map_contiguous_advanced_index(self, return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), - bindings=Map(bindings), + bindings=immutabledict(bindings), shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, - var_to_reduction_descr=Map(), + var_to_reduction_descr=immutabledict(), tags=expr.tags, ) @@ -400,11 +400,11 @@ def map_non_contiguous_advanced_index( return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), - bindings=Map(bindings), + bindings=immutabledict(bindings), shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, - var_to_reduction_descr=Map(), + var_to_reduction_descr=immutabledict(), tags=expr.tags, ) @@ -433,11 +433,11 @@ def map_basic_index(self, expr: BasicIndex) -> IndexLambda: return IndexLambda(expr=prim.Subscript(prim.Variable(in_ary), tuple(indices)), - bindings=Map(bindings), + bindings=immutabledict(bindings), shape=self._rec_shape(expr.shape), dtype=expr.dtype, axes=expr.axes, - var_to_reduction_descr=Map(), + var_to_reduction_descr=immutabledict(), tags=expr.tags, ) @@ -447,9 +447,9 @@ def map_reshape(self, expr: Reshape) -> IndexLambda: return IndexLambda(expr=index_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=Map({"_in0": self.rec(expr.array)}), + bindings=immutabledict({"_in0": self.rec(expr.array)}), axes=expr.axes, - var_to_reduction_descr=Map(), + var_to_reduction_descr=immutabledict(), tags=expr.tags) def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: @@ -462,9 +462,9 @@ def map_axis_permutation(self, expr: AxisPermutation) -> IndexLambda: return IndexLambda(expr=index_expr, shape=self._rec_shape(expr.shape), dtype=expr.dtype, - bindings=Map({"_in0": self.rec(expr.array)}), + bindings=immutabledict({"_in0": self.rec(expr.array)}), axes=expr.axes, - var_to_reduction_descr=Map(), + var_to_reduction_descr=immutabledict(), tags=expr.tags) diff --git a/pytato/transform/metadata.py b/pytato/transform/metadata.py index 591e7055c..ff3cd2d65 100644 --- a/pytato/transform/metadata.py +++ b/pytato/transform/metadata.py @@ -553,14 +553,14 @@ def map_distributed_recv(self, def _get_propagation_graph_from_constraints( equations: List[Tuple[str, str]]) -> Mapping[str, FrozenSet[str]]: - import immutables + from immutabledict import immutabledict propagation_graph: Dict[str, Set[str]] = {} for lhs, rhs in equations: assert lhs != rhs propagation_graph.setdefault(lhs, set()).add(rhs) propagation_graph.setdefault(rhs, set()).add(lhs) - return immutables.Map({k: frozenset(v) + return immutabledict({k: frozenset(v) for k, v in propagation_graph.items()}) diff --git a/pytato/utils.py b/pytato/utils.py index 58bcb5803..4c96e3730 100644 --- a/pytato/utils.py +++ b/pytato/utils.py @@ -38,7 +38,7 @@ SCALAR_CLASSES, INT_CLASSES, BoolT, ScalarType) from pytools import UniqueNameGenerator from pytato.transform import Mapper -from immutables import Map +from immutabledict import immutabledict __doc__ = """ @@ -205,9 +205,9 @@ def broadcast_binary_op(a1: ArrayOrScalar, a2: ArrayOrScalar, return IndexLambda(expr=op(expr1, expr2), shape=result_shape, dtype=result_dtype, - bindings=Map(bindings), + bindings=immutabledict(bindings), tags=_get_default_tags(), - var_to_reduction_descr=Map(), + var_to_reduction_descr=immutabledict(), axes=_get_default_axes(len(result_shape))) diff --git a/setup.py b/setup.py index e00ab5254..098aa4a8c 100644 --- a/setup.py +++ b/setup.py @@ -36,7 +36,7 @@ install_requires=[ "loopy>=2020.2", "pytools>=2022.1.13", - "immutables", + "immutabledict", "attrs", "bidict", ], diff --git a/test/test_pytato.py b/test/test_pytato.py index 98393b95b..19eade09c 100644 --- a/test/test_pytato.py +++ b/test/test_pytato.py @@ -691,7 +691,7 @@ def test_basic_index_equality_traverses_underlying_arrays(): def test_idx_lambda_to_hlo(): from pytato.raising import index_lambda_to_high_level_op - from immutables import Map + from immutabledict import immutabledict from pytato.raising import (BinaryOp, BinaryOpType, FullOp, ReduceOp, C99CallOp, BroadcastOp) @@ -734,11 +734,11 @@ def test_idx_lambda_to_hlo(): assert (index_lambda_to_high_level_op(pt.sum(b, axis=1)) == ReduceOp(SumReductionOperation(), b, - Map({1: "_r0"}))) + immutabledict({1: "_r0"}))) assert (index_lambda_to_high_level_op(pt.prod(a)) == ReduceOp(ProductReductionOperation(), a, - Map({0: "_r0", + immutabledict({0: "_r0", 1: "_r1"}))) assert index_lambda_to_high_level_op(pt.sinh(a)) == C99CallOp("sinh", (a,)) assert index_lambda_to_high_level_op(pt.arctan2(b, a)) == C99CallOp("atan2",