Skip to content

Commit

Permalink
Fix mypy error in pytato.loopy regarding get_dependencies arg
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Dec 20, 2024
1 parent 3f90cbd commit e9c70d9
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions pytato/loopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,7 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel,
from loopy.kernel.array import ArrayBase
from loopy.symbolic import get_dependencies as lpy_get_deps
from pymbolic.mapper.substitutor import make_subst_func
from pymbolic.primitives import is_expression

from pytato.transform import SizeParamGatherer

Expand All @@ -426,8 +427,9 @@ def extend_bindings_with_shape_inference(knl: lp.LoopKernel,
lp_size_params: frozenset[str] = reduce(frozenset.union,
(lpy_get_deps(not_none(arg.shape))
for arg in knl.args
if isinstance(arg, ArrayBase)),
frozenset())
if isinstance(arg, ArrayBase)
and is_expression(arg.shape)
), frozenset())

pt_size_params: frozenset[SizeParam] = reduce(frozenset.union,
(get_size_param_deps(bnd)
Expand Down

0 comments on commit e9c70d9

Please sign in to comment.