Skip to content

Commit 3a82518

Browse files
author
Yusuke Oda
authoredNov 6, 2022
Supports if clause in comprehension with sum/prod. (#79)
* add some utils to analyze range. * Support some forms in sum/prod * support multi-clause comprehension in sum and prod * support if in sum and prod. * fix
1 parent ee02185 commit 3a82518

File tree

2 files changed

+33
-13
lines changed

2 files changed

+33
-13
lines changed
 

‎src/latexify/codegen/function_codegen.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -384,23 +384,28 @@ def _get_sum_prod_info(
384384
scripts: list[tuple[str, str]] = []
385385

386386
for comp in node.generators:
387-
# TODO(odashi): This could be supported.
388-
if comp.ifs:
389-
raise exceptions.LatexifyNotSupportedError(
390-
"If-clause in comprehension is not supported."
391-
)
392-
393387
target = self.visit(comp.target)
394388
range_args = self._get_sum_prod_range(comp)
395389

396-
if range_args is not None:
390+
if range_args is not None and not comp.ifs:
397391
lower_rhs, upper = range_args
398392
lower = f"{target} = {lower_rhs}"
399393
else:
400394
lower_rhs = self.visit(comp.iter)
401-
lower = rf"{target} \in {lower_rhs}"
395+
lower_in = rf"{target} \in {lower_rhs}"
402396
upper = ""
403397

398+
if comp.ifs:
399+
conds = [lower_in] + [self.visit(cond) for cond in comp.ifs]
400+
conds_wrapped = [r"\left(" + cond + r"\right)" for cond in conds]
401+
lower = r" \land ".join(conds_wrapped)
402+
# TODO(odashi):
403+
# Following form may be prettier, but requires amsmath.
404+
# It would be good if we have an option to switch the behavior.
405+
# lower = r"\substack{" + r" \\ ".join(lowers) + "}"
406+
else:
407+
lower = lower_in
408+
404409
scripts.append((lower, upper))
405410

406411
return elt, scripts

‎src/latexify/codegen/function_codegen_test.py

+20-5
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,26 @@ def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> No
108108
assert FunctionCodegen().visit(node) == latex
109109

110110

111-
def test_visit_call_sum_prod_with_if() -> None:
112-
for fn_name in ["sum", "math.prod"]:
113-
node = ast.parse(f"{fn_name}(i for y in x if y == 0)").body[0].value
114-
with pytest.raises(exceptions.LatexifyNotSupportedError, match="^If-clause"):
115-
FunctionCodegen().visit(node)
111+
@pytest.mark.parametrize(
112+
"src_suffix,dest_suffix",
113+
[
114+
(
115+
"(i for i in x if i < y)",
116+
r"_{\left(i \in x\right) \land \left({i < y}\right)}^{} \left({i}\right)",
117+
),
118+
(
119+
"(i for i in x if i < y if f(i))",
120+
r"_{\left(i \in x\right) \land \left({i < y}\right)"
121+
r" \land \left(\mathrm{f}\left(i\right)\right)}^{}"
122+
r" \left({i}\right)",
123+
),
124+
],
125+
)
126+
def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
127+
for src_fn, dest_fn in [("sum", r"\sum"), ("math.prod", r"\prod")]:
128+
node = ast.parse(src_fn + src_suffix).body[0].value
129+
assert isinstance(node, ast.Call)
130+
assert FunctionCodegen().visit(node) == dest_fn + dest_suffix
116131

117132

118133
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)
Please sign in to comment.