Skip to content

Commit

Permalink
replace immutables.Map with immutabledict
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Oct 5, 2023
1 parent 5274e91 commit 787f839
Show file tree
Hide file tree
Showing 21 changed files with 161 additions and 144 deletions.
4 changes: 2 additions & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
nitpick_ignore_regex = [
["py:class", r"numpy.(u?)int[\d]+"],
["py:class", r"typing_extensions(.+)"],
# As of 2022-10-20, it doesn't look like there's sphinx documentation
# As of 2023-10-05, it doesn't look like there's sphinx documentation
# available.
["py:class", r"immutables\.(.+)"],
["py:class", r"immutabledict(.*)"],
# https://github.com/python-attrs/attrs/issues/1073
["py:mod", "attrs"],
]
99 changes: 53 additions & 46 deletions pytato/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@
ScalarExpression, IntegralT,
INT_CLASSES, get_reduction_induction_variables)
import re
from immutables import Map
from immutabledict import immutabledict

# {{{ get a type variable that represents the type of '...'

Expand Down Expand Up @@ -556,15 +556,15 @@ def _unary_op(self, op: Any) -> Array:
indices = tuple(var(f"_{i}") for i in range(self.ndim))
expr = op(var("_in0")[indices])

bindings = Map({"_in0": self})
bindings: immutabledict[str, Array] = immutabledict({"_in0": self})
return IndexLambda(
expr=expr,
shape=self.shape,
dtype=self.dtype,
bindings=bindings,
tags=_get_default_tags(),
axes=_get_default_axes(self.ndim),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

__mul__ = partialmethod(_binary_op, operator.mul)
__rmul__ = partialmethod(_binary_op, operator.mul, reverse=True)
Expand Down Expand Up @@ -895,17 +895,18 @@ def with_tagged_reduction(self,
f" '{self.var_to_reduction_descr.keys()}',"
f" got '{reduction_variable}'.")

assert isinstance(self.var_to_reduction_descr, Map)
new_var_to_redn_descr = self.var_to_reduction_descr.set(
reduction_variable,
self.var_to_reduction_descr[reduction_variable].tagged(tag))
assert isinstance(self.var_to_reduction_descr, immutabledict)
new_var_to_redn_descr = dict(self.var_to_reduction_descr)
new_var_to_redn_descr[reduction_variable] = \
self.var_to_reduction_descr[reduction_variable].tagged(tag)

return type(self)(expr=self.expr,
shape=self.shape,
dtype=self.dtype,
bindings=self.bindings,
axes=self.axes,
var_to_reduction_descr=new_var_to_redn_descr,
var_to_reduction_descr=immutabledict
(new_var_to_redn_descr),
tags=self.tags)

# }}}
Expand Down Expand Up @@ -1006,7 +1007,7 @@ def _access_descr_to_axis_len(self
else:
descr_to_axis_len[descr] = arg_axis_len

return Map(descr_to_axis_len)
return immutabledict(descr_to_axis_len)

@cached_property
def shape(self) -> ShapeType:
Expand Down Expand Up @@ -1063,14 +1064,16 @@ def with_tagged_reduction(self,

# }}}

assert isinstance(self.redn_axis_to_redn_descr, Map)
new_redn_axis_to_redn_descr = self.redn_axis_to_redn_descr.set(
redn_axis, self.redn_axis_to_redn_descr[redn_axis].tagged(tag))
assert isinstance(self.redn_axis_to_redn_descr, immutabledict)
new_redn_axis_to_redn_descr = dict(self.redn_axis_to_redn_descr)
new_redn_axis_to_redn_descr[redn_axis] = \
self.redn_axis_to_redn_descr[redn_axis].tagged(tag)

return type(self)(access_descriptors=self.access_descriptors,
args=self.args,
axes=self.axes,
redn_axis_to_redn_descr=new_redn_axis_to_redn_descr,
redn_axis_to_redn_descr=immutabledict
(new_redn_axis_to_redn_descr),
tags=self.tags,
index_to_access_descr=self.index_to_access_descr,
)
Expand All @@ -1079,7 +1082,7 @@ def with_tagged_reduction(self,
EINSUM_FIRST_INDEX = re.compile(r"^\s*((?P<alpha>[a-zA-Z])|(?P<ellipsis>\.\.\.))\s*")


def _normalize_einsum_out_subscript(subscript: str) -> Map[str,
def _normalize_einsum_out_subscript(subscript: str) -> immutabledict[str,
EinsumAxisDescriptor]:
"""
Normalizes the output subscript of an einsum (provided in the explicit
Expand Down Expand Up @@ -1119,19 +1122,20 @@ def _normalize_einsum_out_subscript(subscript: str) -> Map[str,
raise ValueError("Used an input more than once to refer to the"
f" output axis in '{subscript}")

return Map({idx: EinsumElementwiseAxis(i)
return immutabledict({idx: EinsumElementwiseAxis(i)
for i, idx in enumerate(normalized_indices)})


def _normalize_einsum_in_subscript(subscript: str,
in_operand: Array,
index_to_descr: Map[str,
index_to_descr: immutabledict[str,
EinsumAxisDescriptor],
index_to_axis_length: Map[str,
index_to_axis_length: immutabledict[str,
ShapeComponent],
) -> Tuple[Tuple[EinsumAxisDescriptor, ...],
Map[str, EinsumAxisDescriptor],
Map[str, ShapeComponent]]:
immutabledict
[str, EinsumAxisDescriptor],
immutabledict[str, ShapeComponent]]:
"""
Normalizes the subscript for an input operand in an einsum. Returns
``(access_descrs, updated_index_to_descr, updated_to_index_to_axis_length)``,
Expand Down Expand Up @@ -1174,37 +1178,39 @@ def _normalize_einsum_in_subscript(subscript: str,
f"of corresponding operand ({in_operand.ndim}).")

in_operand_axis_descrs = []
index_to_axis_length_dict = dict(index_to_axis_length)
index_to_descr_dict = dict(index_to_descr)

for iaxis, index_char in enumerate(normalized_indices):
in_axis_len = in_operand.shape[iaxis]
if index_char in index_to_descr:
if index_char in index_to_axis_length:
seen_axis_len = index_to_axis_length[index_char]
if index_char in index_to_descr_dict:
if index_char in index_to_axis_length_dict:
seen_axis_len = index_to_axis_length_dict[index_char]
if not are_shape_components_equal(in_axis_len,
seen_axis_len):
if are_shape_components_equal(in_axis_len, 1):
# Broadcast the current axis
pass
elif are_shape_components_equal(seen_axis_len, 1):
# Broadcast to the length of the current axis
index_to_axis_length = (index_to_axis_length
.set(index_char, in_axis_len))
index_to_axis_length_dict[index_char] = in_axis_len
else:
raise ValueError("Got conflicting lengths for"
f" '{index_char}' -- {in_axis_len},"
f" {seen_axis_len}.")
else:
index_to_axis_length = index_to_axis_length.set(index_char,
in_axis_len)
index_to_axis_length_dict[index_char] = in_axis_len
else:
redn_sr_no = len([descr for descr in index_to_descr.values()
redn_sr_no = len([descr for descr in index_to_descr_dict.values()
if isinstance(descr, EinsumReductionAxis)])
redn_axis_descr = EinsumReductionAxis(redn_sr_no)
index_to_descr = index_to_descr.set(index_char, redn_axis_descr)
index_to_axis_length = index_to_axis_length.set(index_char,
in_axis_len)
index_to_descr_dict[index_char] = redn_axis_descr
index_to_axis_length_dict[index_char] = in_axis_len

in_operand_axis_descrs.append(index_to_descr[index_char])
in_operand_axis_descrs.append(index_to_descr_dict[index_char])

index_to_axis_length = immutabledict(index_to_axis_length_dict)
index_to_descr = immutabledict(index_to_descr_dict)

return (tuple(in_operand_axis_descrs), index_to_descr, index_to_axis_length)

Expand Down Expand Up @@ -1239,7 +1245,7 @@ def einsum(subscripts: str, *operands: Array,
)

index_to_descr = _normalize_einsum_out_subscript(out_spec)
index_to_axis_length: Map[str, ShapeComponent] = Map()
index_to_axis_length: immutabledict[str, ShapeComponent] = immutabledict()
access_descriptors = []

for in_spec, in_operand in zip(in_specs, operands):
Expand Down Expand Up @@ -1274,7 +1280,7 @@ def einsum(subscripts: str, *operands: Array,
if isinstance(descr,
EinsumElementwiseAxis)})
),
redn_axis_to_redn_descr=Map(redn_axis_to_redn_descr),
redn_axis_to_redn_descr=immutabledict(redn_axis_to_redn_descr),
index_to_access_descr=index_to_descr,
)

Expand Down Expand Up @@ -2088,10 +2094,10 @@ def full(shape: ConvertibleToShape, fill_value: ScalarType,
fill_value = dtype.type(fill_value)

return IndexLambda(expr=fill_value, shape=shape, dtype=dtype,
bindings=Map(),
bindings=immutabledict(),
tags=_get_default_tags(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())


def zeros(shape: ConvertibleToShape, dtype: Any = float,
Expand Down Expand Up @@ -2134,10 +2140,10 @@ def eye(N: int, M: Optional[int] = None, k: int = 0, # noqa: N803
raise ValueError(f"k must be int, got {type(k)}.")

return IndexLambda(expr=parse(f"1 if ((_1 - _0) == {k}) else 0"),
shape=(N, M), dtype=dtype, bindings=Map({}),
shape=(N, M), dtype=dtype, bindings=immutabledict({}),
tags=_get_default_tags(),
axes=_get_default_axes(2),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

# }}}

Expand Down Expand Up @@ -2229,10 +2235,10 @@ def arange(*args: Any, **kwargs: Any) -> Array:

from pymbolic.primitives import Variable
return IndexLambda(expr=start + Variable("_0") * step,
shape=(size,), dtype=dtype, bindings=Map(),
shape=(size,), dtype=dtype, bindings=immutabledict(),
tags=_get_default_tags(),
axes=_get_default_axes(1),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

# }}}

Expand Down Expand Up @@ -2343,7 +2349,7 @@ def logical_not(x: ArrayOrScalar) -> Union[Array, bool]:
bindings={"_in0": x},
tags=_get_default_tags(),
axes=_get_default_axes(len(x.shape)),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

# }}}

Expand Down Expand Up @@ -2396,10 +2402,10 @@ def where(condition: ArrayOrScalar,
expr=prim.If(expr1, expr2, expr3),
shape=result_shape,
dtype=dtype,
bindings=Map(bindings),
bindings=immutabledict(bindings),
tags=_get_default_tags(),
axes=_get_default_axes(len(result_shape)),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

# }}}

Expand Down Expand Up @@ -2492,12 +2498,13 @@ def make_index_lambda(
# }}}

return IndexLambda(expr=expression,
bindings=Map(bindings),
bindings=immutabledict(bindings),
shape=shape,
dtype=dtype,
tags=_get_default_tags(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=Map(processed_var_to_reduction_descr))
var_to_reduction_descr=immutabledict
(processed_var_to_reduction_descr))

# }}}

Expand Down Expand Up @@ -2578,10 +2585,10 @@ def broadcast_to(array: Array, shape: ShapeType) -> Array:
shape)),
shape=shape,
dtype=array.dtype,
bindings=Map({"in": array}),
bindings=immutabledict({"in": array}),
tags=_get_default_tags(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=Map())
var_to_reduction_descr=immutabledict())

# }}}

Expand Down
6 changes: 3 additions & 3 deletions pytato/cmath.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@
_get_default_axes, _get_default_tags)
from pytato.scalar_expr import SCALAR_CLASSES
from pymbolic import var
from immutables import Map
from immutabledict import immutabledict


def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...],
Expand Down Expand Up @@ -113,10 +113,10 @@ def _apply_elem_wise_func(inputs: Tuple[ArrayOrScalar, ...],
return IndexLambda(
expr=prim.Call(var(f"pytato.c99.{func_name}"),
tuple(sym_args)),
shape=shape, dtype=ret_dtype, bindings=Map(bindings),
shape=shape, dtype=ret_dtype, bindings=immutabledict(bindings),
tags=_get_default_tags(),
axes=_get_default_axes(len(shape)),
var_to_reduction_descr=Map(),
var_to_reduction_descr=immutabledict(),
)


Expand Down
10 changes: 6 additions & 4 deletions pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import dataclasses
from typing import Union, Dict, Tuple, List, Any, Optional
from immutabledict import immutabledict

from pytato.array import (Array, DictOfNamedArrays, DataWrapper, Placeholder,
DataInterface, SizeParam, InputArgumentBase,
Expand Down Expand Up @@ -173,9 +174,10 @@ def map_loopy_call(self, expr: LoopyCall) -> LoopyCall:

# }}}

bindings = {name: (self.rec(subexpr) if isinstance(subexpr, Array)
bindings: immutabledict[str, Any] = immutabledict(
{name: (self.rec(subexpr) if isinstance(subexpr, Array)
else subexpr)
for name, subexpr in sorted(expr.bindings.items())}
for name, subexpr in sorted(expr.bindings.items())})

return LoopyCall(translation_unit=translation_unit,
bindings=bindings,
Expand Down Expand Up @@ -282,8 +284,8 @@ def preprocess(outputs: DictOfNamedArrays, target: Target) -> PreprocessResult:
for out in outputs.values()))

# only look for dependencies between the outputs
deps = {name: get_deps(output.expr)
for name, output in outputs.items()}
deps: immutabledict[str, Any] = immutabledict({name: get_deps(output.expr)
for name, output in outputs.items()})

# represent deps in terms of output names
output_expr_to_name = {output.expr: name for name, output in outputs.items()}
Expand Down
6 changes: 3 additions & 3 deletions pytato/distributed/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
List, AbstractSet, TypeVar, TYPE_CHECKING, Hashable, Optional)

import attrs
from immutables import Map
from immutabledict import immutabledict

from pytools.graph import CycleError
from pytools import memoize_method
Expand Down Expand Up @@ -405,11 +405,11 @@ def _make_distributed_partition(
partition_input_names=frozenset(
comm_replacer.partition_input_name_to_placeholder.keys()),
output_names=frozenset(name_to_ouput.keys()),
name_to_recv_node=Map({
name_to_recv_node=immutabledict({
recvd_ary_to_name[local_recv_id_to_recv_node[recv_id]]:
local_recv_id_to_recv_node[recv_id]
for recv_id in comm_ids.recv_ids}),
name_to_send_nodes=Map(name_to_send_nodes))
name_to_send_nodes=immutabledict(name_to_send_nodes))

result = DistributedGraphPartition(
parts=parts,
Expand Down
Loading

0 comments on commit 787f839

Please sign in to comment.