Skip to content

Commit

Permalink
replace immutables.Map with immutabledict (#461)
Browse files Browse the repository at this point in the history
* replace immutables.Map with immutabledict

* Use Mapping rather than any concrete type in public type annotations

---------

Co-authored-by: Andreas Kloeckner <inform@tiker.net>
  • Loading branch information
matthiasdiener and inducer authored Oct 6, 2023
1 parent 5274e91 commit 6adf9ae
Show file tree
Hide file tree
Showing 21 changed files with 163 additions and 146 deletions.
4 changes: 2 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
nitpick_ignore_regex = [
["py:class", r"numpy.(u?)int[\d]+"],
["py:class", r"typing_extensions(.+)"],
# As of 2022-10-20, it doesn't look like there's sphinx documentation
# As of 2023-10-05, it doesn't look like there's sphinx documentation
# available.
["py:class", r"immutables\.(.+)"],
["py:class", r"immutabledict(.*)"],
# https://github.com/python-attrs/attrs/issues/1073
["py:mod", "attrs"],
]
99 changes: 53 additions & 46 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@
ScalarExpression, IntegralT,
INT_CLASSES, get_reduction_induction_variables)
import re
from immutables import Map
from immutabledict import immutabledict

# {{{ get a type variable that represents the type of '...'

Expand Down Expand Up @@ -556,15 +556,15 @@ def _unary_op(self, op: Any) -> Array:
indices = tuple(var(f"_{i}") for i in range(self.ndim))
expr = op(var("_in0")[indices])

bindings = Map({"_in0": self})
bindings: Mapping[str, Array] = immutabledict({"_in0": self})
return IndexLambda(
expr=expr,
shape=self.shape,
dtype=self.dtype,
bindings=bindings,
tags=_get_default_tags(),
axes=_get_default_axes(self.ndim),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

__mul__ = partialmethod(_binary_op, operator.mul)
__rmul__ = partialmethod(_binary_op, operator.mul, reverse=True)
Expand Down Expand Up @@ -895,17 +895,18 @@ def with_tagged_reduction(self,
f" '{self.var_to_reduction_descr.keys()}',"
f" got '{reduction_variable}'.")

assert isinstance(self.var_to_reduction_descr, Map)
new_var_to_redn_descr = self.var_to_reduction_descr.set(
reduction_variable,
self.var_to_reduction_descr[reduction_variable].tagged(tag))
assert isinstance(self.var_to_reduction_descr, immutabledict)
new_var_to_redn_descr = dict(self.var_to_reduction_descr)
new_var_to_redn_descr[reduction_variable] = \
self.var_to_reduction_descr[reduction_variable].tagged(tag)

return type(self)(expr=self.expr,
shape=self.shape,
dtype=self.dtype,
bindings=self.bindings,
axes=self.axes,
var_to_reduction_descr=new_var_to_redn_descr,
var_to_reduction_descr=immutabledict
(new_var_to_redn_descr),
tags=self.tags)

# }}}
Expand Down Expand Up @@ -1006,7 +1007,7 @@ def _access_descr_to_axis_len(self
else:
descr_to_axis_len[descr] = arg_axis_len

return Map(descr_to_axis_len)
return immutabledict(descr_to_axis_len)

@cached_property
def shape(self) -> ShapeType:
Expand Down Expand Up @@ -1063,14 +1064,16 @@ def with_tagged_reduction(self,

# }}}

assert isinstance(self.redn_axis_to_redn_descr, Map)
new_redn_axis_to_redn_descr = self.redn_axis_to_redn_descr.set(
redn_axis, self.redn_axis_to_redn_descr[redn_axis].tagged(tag))
assert isinstance(self.redn_axis_to_redn_descr, immutabledict)
new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr)
new_redn_axis_to_redn_descr[redn_axis] = \
self.redn_axis_to_redn_descr[redn_axis].tagged(tag)

return type(self)(access_descriptors=self.access_descriptors,
args=self.args,
axes=self.axes,
redn_axis_to_redn_descr=new_redn_axis_to_redn_descr,
redn_axis_to_redn_descr=immutabledict
(new_redn_axis_to_redn_descr),
tags=self.tags,
index_to_access_descr=self.index_to_access_descr,
)
Expand All @@ -1079,7 +1082,7 @@ def with_tagged_reduction(self,
EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P<alpha>[a-zA-Z])|(?P<ellipsis>\.\.\.))\s*")


def _normalize_einsum_out_subscript(subscript: str) -> Map[str,
def _normalize_einsum_out_subscript(subscript: str) -> immutabledict[str,
EinsumAxisDescriptor]:
"""
Normalizes the output subscript of an einsum (provided in the explicit
Expand Down Expand Up @@ -1119,19 +1122,20 @@ def _normalize_einsum_out_subscript(subscript: str) -> Map[str,
raise ValueError("Used an input more than once to refer to the"
f" output axis in '{subscript}")

return Map({idx: EinsumElementwiseAxis(i)
return immutabledict({idx: EinsumElementwiseAxis(i)
for i, idx in enumerate(normalized_indices)})


def _normalize_einsum_in_subscript(subscript: str,
in_operand: Array,
index_to_descr: Map[str,
index_to_descr: Mapping[str,
EinsumAxisDescriptor],
index_to_axis_length: Map[str,
index_to_axis_length: Mapping[str,
ShapeComponent],
) -> Tuple[Tuple[EinsumAxisDescriptor, ...],
Map[str, EinsumAxisDescriptor],
Map[str, ShapeComponent]]:
immutabledict
[str, EinsumAxisDescriptor],
immutabledict[str, ShapeComponent]]:
"""
Normalizes the subscript for an input operand in an einsum. Returns
``(access_descrs, updated_index_to_descr, updated_to_index_to_axis_length)``,
Expand Down Expand Up @@ -1174,37 +1178,39 @@ def _normalize_einsum_in_subscript(subscript: str,
f"of corresponding operand ({in_operand.ndim}).")

in_operand_axis_descrs = []
index_to_axis_length_dict = dict(index_to_axis_length)
index_to_descr_dict = dict(index_to_descr)

for iaxis, index_char in enumerate(normalized_indices):
in_axis_len = in_operand.shape[iaxis]
if index_char in index_to_descr:
if index_char in index_to_axis_length:
seen_axis_len = index_to_axis_length[index_char]
if index_char in index_to_descr_dict:
if index_char in index_to_axis_length_dict:
seen_axis_len = index_to_axis_length_dict[index_char]
if not are_shape_components_equal(in_axis_len,
seen_axis_len):
if are_shape_components_equal(in_axis_len, 1):
# Broadcast the current axis
pass
elif are_shape_components_equal(seen_axis_len, 1):
# Broadcast to the length of the current axis
index_to_axis_length = (index_to_axis_length
.set(index_char, in_axis_len))
index_to_axis_length_dict[index_char] = in_axis_len
else:
raise ValueError("Got conflicting lengths for"
f" '{index_char}' -- {in_axis_len},"
f" {seen_axis_len}.")
else:
index_to_axis_length = index_to_axis_length.set(index_char,
in_axis_len)
index_to_axis_length_dict[index_char] = in_axis_len
else:
redn_sr_no = len([descr for descr in index_to_descr.values()
redn_sr_no = len([descr for descr in index_to_descr_dict.values()
if isinstance(descr, EinsumReductionAxis)])
redn_axis_descr = EinsumReductionAxis(redn_sr_no)
index_to_descr = index_to_descr.set(index_char, redn_axis_descr)
index_to_axis_length = index_to_axis_length.set(index_char,
in_axis_len)
index_to_descr_dict[index_char] = redn_axis_descr
index_to_axis_length_dict[index_char] = in_axis_len

in_operand_axis_descrs.append(index_to_descr[index_char])
in_operand_axis_descrs.append(index_to_descr_dict[index_char])

index_to_axis_length = immutabledict(index_to_axis_length_dict)
index_to_descr = immutabledict(index_to_descr_dict)

return (tuple(in_operand_axis_descrs), index_to_descr, index_to_axis_length)

Expand Down Expand Up @@ -1239,7 +1245,7 @@ def einsum(subscripts: str, *operands: Array,
)

index_to_descr = _normalize_einsum_out_subscript(out_spec)
index_to_axis_length: Map[str, ShapeComponent] = Map()
index_to_axis_length: Mapping[str, ShapeComponent] = immutabledict()
access_descriptors = []

for in_spec, in_operand in zip(in_specs, operands):
Expand Down Expand Up @@ -1274,7 +1280,7 @@ def einsum(subscripts: str, *operands: Array,
if isinstance(descr,
EinsumElementwiseAxis)})
),
redn_axis_to_redn_descr=Map(redn_axis_to_redn_descr),
redn_axis_to_redn_descr=immutabledict(redn_axis_to_redn_descr),
index_to_access_descr=index_to_descr,
)

Expand Down Expand Up @@ -2088,10 +2094,10 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType,
fill_value = dtype.type(fill_value)

return IndexLambda(expr=fill_value, shape=shape, dtype=dtype,
bindings=Map(),
bindings=immutabledict(),
tags=_get_default_tags(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())


def zeros(shape: ConvertibleToShape, dtype: Any = float,
Expand Down Expand Up @@ -2134,10 +2140,10 @@ def eye(N: int, M: Optional[int] = None, k: int = 0, # noqa: N803
raise ValueError(f"k must be int, got {type(k)}.")

return IndexLambda(expr=parse(f"1 if ((_1 - _0) == {k}) else 0"),
shape=(N, M), dtype=dtype, bindings=Map({}),
shape=(N, M), dtype=dtype, bindings=immutabledict({}),
tags=_get_default_tags(),
axes=_get_default_axes(2),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

# }}}

Expand Down Expand Up @@ -2229,10 +2235,10 @@ def arange(*args: Any, **kwargs: Any) -> Array:

from pymbolic.primitives import Variable
return IndexLambda(expr=start + Variable("_0") * step,
shape=(size,), dtype=dtype, bindings=Map(),
shape=(size,), dtype=dtype, bindings=immutabledict(),
tags=_get_default_tags(),
axes=_get_default_axes(1),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

# }}}

Expand Down Expand Up @@ -2343,7 +2349,7 @@ def logical_not(x: ArrayOrScalar) -> Union[Array, bool]:
bindings={"_in0": x},
tags=_get_default_tags(),
axes=_get_default_axes(len(x.shape)),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

# }}}

Expand Down Expand Up @@ -2396,10 +2402,10 @@ def where(condition: ArrayOrScalar,
expr=prim.If(expr1, expr2, expr3),
shape=result_shape,
dtype=dtype,
bindings=Map(bindings),
bindings=immutabledict(bindings),
tags=_get_default_tags(),
axes=_get_default_axes(len(result_shape)),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

# }}}

Expand Down Expand Up @@ -2492,12 +2498,13 @@ def make_index_lambda(
# }}}

return IndexLambda(expr=expression,
bindings=Map(bindings),
bindings=immutabledict(bindings),
shape=shape,
dtype=dtype,
tags=_get_default_tags(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=Map(processed_var_to_reduction_descr))
var_to_reduction_descr=immutabledict
(processed_var_to_reduction_descr))

# }}}

Expand Down Expand Up @@ -2578,10 +2585,10 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array:
shape)),
shape=shape,
dtype=array.dtype,
bindings=Map({"in": array}),
bindings=immutabledict({"in": array}),
tags=_get_default_tags(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

# }}}

Expand Down
6 changes: 3 additions & 3 deletions pytato/cmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
_get_default_axes, _get_default_tags)
from pytato.scalar_expr import SCALAR_CLASSES
from pymbolic import var
from immutables import Map
from immutabledict import immutabledict


def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...],
Expand Down Expand Up @@ -113,10 +113,10 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...],
return IndexLambda(
expr=prim.Call(var(f"pytato.c99.{func_name}"),
tuple(sym_args)),
shape=shape, dtype=ret_dtype, bindings=Map(bindings),
shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings),
tags=_get_default_tags(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=Map(),
var_to_reduction_descr=immutabledict(),
)


Expand Down
12 changes: 7 additions & 5 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
"""

import dataclasses
from typing import Union, Dict, Tuple, List, Any, Optional
from typing import Union, Dict, Tuple, List, Any, Optional, Mapping
from immutabledict import immutabledict

from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder,
DataInterface, SizeParam, InputArgumentBase,
Expand Down Expand Up @@ -173,9 +174,10 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:

# }}}

bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array)
bindings: Mapping[str, Any] = immutabledict(
{name: (self.rec(subexpr) if isinstance(subexpr, Array)
else subexpr)
for name, subexpr in sorted(expr.bindings.items())}
for name, subexpr in sorted(expr.bindings.items())})

return LoopyCall(translation_unit=translation_unit,
bindings=bindings,
Expand Down Expand Up @@ -282,8 +284,8 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:
for out in outputs.values()))

# only look for dependencies between the outputs
deps = {name: get_deps(output.expr)
for name, output in outputs.items()}
deps: Mapping[str, Any] = immutabledict({name: get_deps(output.expr)
for name, output in outputs.items()})

# represent deps in terms of output names
output_expr_to_name = {output.expr: name for name, output in outputs.items()}
Expand Down
6 changes: 3 additions & 3 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional)

import attrs
from immutables import Map
from immutabledict import immutabledict

from pytools.graph import CycleError
from pytools import memoize_method
Expand Down Expand Up @@ -405,11 +405,11 @@ def _make_distributed_partition(
partition_input_names=frozenset(
comm_replacer.partition_input_name_to_placeholder.keys()),
output_names=frozenset(name_to_ouput.keys()),
name_to_recv_node=Map({
name_to_recv_node=immutabledict({
recvd_ary_to_name[local_recv_id_to_recv_node[recv_id]]:
local_recv_id_to_recv_node[recv_id]
for recv_id in comm_ids.recv_ids}),
name_to_send_nodes=Map(name_to_send_nodes))
name_to_send_nodes=immutabledict(name_to_send_nodes))

result = DistributedGraphPartition(
parts=parts,
Expand Down
Loading

0 comments on commit 6adf9ae

Please sign in to comment.