Skip to content

Commit

Permalink
separate InlinedExpressionGenMapper recursion from DAG recursion
Browse files Browse the repository at this point in the history
  • Loading branch information
majosm authored and inducer committed Jul 27, 2023
1 parent 8a91363 commit 4743d05
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions pytato/target/loopy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,17 @@ class LocalExpressionContext:
.. automethod:: lookup
"""
num_indices: int
local_namespace: Mapping[str, Array]
local_namespace: Mapping[str, ImplementedResult]
reduction_bounds: ReductionBounds
var_to_reduction_descr: Mapping[str, ReductionDescriptor]

def lookup(self, name: str) -> Array:
def lookup(self, name: str) -> ImplementedResult:
return self.local_namespace[name]

def copy(self, *,
reduction_bounds: Optional[ReductionBounds] = None,
num_indices: Optional[int] = None,
local_namespace: Optional[Mapping[str, Array]] = None,
local_namespace: Optional[Mapping[str, ImplementedResult]] = None,
var_to_reduction_descr: Optional[
Mapping[str, ReductionDescriptor]] = None,
) -> LocalExpressionContext:
Expand Down Expand Up @@ -347,7 +347,7 @@ class CodeGenMapper(Mapper):
def __init__(self,
array_tag_t_to_not_propagate: FrozenSet[Type[Tag]],
axis_tag_t_to_not_propagate: FrozenSet[Type[Tag]]) -> None:
self.exprgen_mapper = InlinedExpressionGenMapper(self)
self.exprgen_mapper = InlinedExpressionGenMapper(axis_tag_t_to_not_propagate)
self.array_tag_t_to_not_propagate = array_tag_t_to_not_propagate
self.axis_tag_t_to_not_propagate = axis_tag_t_to_not_propagate
self.has_loopy_call = False
Expand Down Expand Up @@ -401,7 +401,9 @@ def map_index_lambda(self, expr: IndexLambda,

prstnt_ctx = PersistentExpressionContext(state)
local_ctx = LocalExpressionContext(
local_namespace=expr.bindings,
local_namespace={
name: self.rec(subexpr, state)
for name, subexpr in expr.bindings.items()},
num_indices=expr.ndim,
reduction_bounds={},
var_to_reduction_descr=expr.var_to_reduction_descr)
Expand Down Expand Up @@ -622,10 +624,10 @@ class InlinedExpressionGenMapper(scalar_expr.IdentityMapper):
The outputs of this mapper are scalar expressions suitable for wrapping in
:class:`InlinedResult`.
"""
codegen_mapper: CodeGenMapper
axis_tag_t_to_not_propagate: FrozenSet[Type[Tag]]

def __init__(self, codegen_mapper: CodeGenMapper):
self.codegen_mapper = codegen_mapper
def __init__(self, axis_tag_t_to_not_propagate: FrozenSet[Type[Tag]]):
self.axis_tag_t_to_not_propagate = axis_tag_t_to_not_propagate

if TYPE_CHECKING:
def __call__(self, expr: ScalarExpression,
Expand All @@ -639,11 +641,8 @@ def map_subscript(self, expr: prim.Subscript,
local_ctx: LocalExpressionContext,
) -> ScalarExpression:
assert isinstance(expr.aggregate, prim.Variable)
result: ImplementedResult = self.codegen_mapper(
local_ctx.lookup(expr.aggregate.name), prstnt_ctx.state)
return result.to_loopy_expression(self.rec(expr.index, prstnt_ctx,
local_ctx),
prstnt_ctx)
return local_ctx.lookup(expr.aggregate.name).to_loopy_expression(
self.rec(expr.index, prstnt_ctx, local_ctx), prstnt_ctx)

def map_variable(self, expr: prim.Variable,
prstnt_ctx: PersistentExpressionContext,
Expand All @@ -660,10 +659,7 @@ def map_variable(self, expr: prim.Variable,
elif expr.name in local_ctx.reduction_bounds:
return expr
else:
array = local_ctx.lookup(expr.name)
impl_result: ImplementedResult = self.codegen_mapper(array,
prstnt_ctx.state)
return impl_result.to_loopy_expression((), prstnt_ctx)
return local_ctx.lookup(expr.name).to_loopy_expression((), prstnt_ctx)

def map_call(self, expr: prim.Call,
prstnt_ctx: PersistentExpressionContext,
Expand Down Expand Up @@ -718,7 +714,7 @@ def map_reduce(self, expr: scalar_expr.Reduce,
for name_in_expr, name_in_kernel in sorted(unique_names_mapping.items()):
for tag in local_ctx.var_to_reduction_descr[name_in_expr].tags:
if all(not isinstance(tag, tag_t)
for tag_t in self.codegen_mapper.axis_tag_t_to_not_propagate):
for tag_t in self.axis_tag_t_to_not_propagate):
state.update_kernel(lp.tag_inames(state.kernel,
{name_in_kernel: tag}))

Expand Down

0 comments on commit 4743d05

Please sign in to comment.