diff --git a/src/integration_tests/regression_test.py b/src/integration_tests/regression_test.py index bf4e012..cbdb9d6 100644 --- a/src/integration_tests/regression_test.py +++ b/src/integration_tests/regression_test.py @@ -106,6 +106,28 @@ def sum_with_limit(a, n): _check_function(sum_with_limit, latex) +def test_sum_with_reducible_limit() -> None: + def sum_with_limit(n): + return sum(i for i in range(n + 1)) + + latex = ( + r"\mathrm{sum_with_limit}(n) = \sum_{i = {0}}^{{n}} " + r"\mathopen{}\left({i}\mathclose{}\right)" + ) + _check_function(sum_with_limit, latex) + + +def test_sum_with_irreducible_limit() -> None: + def sum_with_limit(n): + return sum(i for i in range(n * 3)) + + latex = ( + r"\mathrm{sum_with_limit}(n) = \sum_{i = {0}}^{{n {3} - 1}} " + r"\mathopen{}\left({i}\mathclose{}\right)" + ) + _check_function(sum_with_limit, latex) + + def test_prod_with_limit_1arg() -> None: def prod_with_limit(n): return math.prod(i**2 for i in range(n)) @@ -128,6 +150,28 @@ def prod_with_limit(a, n): _check_function(prod_with_limit, latex) +def test_prod_with_reducible_limits() -> None: + def prod_with_limit(n): + return math.prod(i for i in range(n - 1)) + + latex = ( + r"\mathrm{prod_with_limit}(n) = " + r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)" + ) + _check_function(prod_with_limit, latex) + + +def test_prod_with_irreducible_limit() -> None: + def prod_with_limit(n): + return math.prod(i for i in range(n * 3)) + + latex = ( + r"\mathrm{prod_with_limit}(n) = " + r"\prod_{i = {0}}^{{n {3} - 1}} \mathopen{}\left({i}\mathclose{}\right)" + ) + _check_function(prod_with_limit, latex) + + def test_nested_function() -> None: def nested(x): return 3 * x diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index 3cd5f47..c96ee31 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -3,6 +3,7 @@ from __future__ import annotations import ast +import sys import dataclasses from typing import Any @@ -536,6 +537,48 @@ def visit_IfExp(self, node: ast.IfExp) -> str: latex += self.visit(node) return latex + r", & \mathrm{otherwise} \end{array} \right." + def _reduce_stop_parameter(self, node: ast.BinOp) -> str: + # ast.Constant class is added in Python 3.8 + # ast.Num is the relevant node type in previous versions + if sys.version_info.minor < 8: + if isinstance(node.right, ast.Num): + if isinstance(node.op, ast.Add): + if node.right.n == 1: + upper = "{" + self.visit(node.left) + "}" + else: + reduced_constant = ast.Num(node.right.n - 1) + new_node = ast.BinOp(node.left, node.op, reduced_constant) + upper = "{" + self.visit(new_node) + "}" + else: + if node.right.n == -1: + upper = "{" + self.visit(node.left) + "}" + else: + reduced_constant = ast.Num(node.right.n + 1) + new_node = ast.BinOp(node.left, node.op, reduced_constant) + upper = "{" + self.visit(new_node) + "}" + else: + upper = "{" + self.visit(node) + "}" + else: + if isinstance(node.right, ast.Constant): + if isinstance(node.op, ast.Add): + if node.right.value == 1: + upper = "{" + self.visit(node.left) + "}" + else: + reduced_constant = ast.Constant(node.right.value - 1) + new_node = ast.BinOp(node.left, node.op, reduced_constant) + upper = "{" + self.visit(new_node) + "}" + else: + if node.right.value == -1: + upper = "{" + self.visit(node.left) + "}" + else: + reduced_constant = ast.Constant(node.right.value + 1) + new_node = ast.BinOp(node.left, node.op, reduced_constant) + upper = "{" + self.visit(new_node) + "}" + else: + upper = "{" + self.visit(node) + "}" + + return upper + def _get_sum_prod_range(self, node: ast.comprehension) -> tuple[str, str] | None: """Helper to process range(...) for sum and prod functions. @@ -577,7 +620,13 @@ def _get_sum_prod_range(self, node: ast.comprehension) -> tuple[str, str] | None lower_rhs = f"{{{range_info.start_int}}}" if range_info.stop_int is None: - upper = "{" + self.visit(range_info.stop) + " - 1}" + # use special processing if range_info.stop involves addition or subtraction + if isinstance(range_info.stop, ast.BinOp) and isinstance( + range_info.stop.op, (ast.Add, ast.Sub) + ): + upper = self._reduce_stop_parameter(range_info.stop) + else: + upper = "{" + self.visit(range_info.stop) + " - 1}" else: upper = f"{{{range_info.stop_int - 1}}}" diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index 613755c..9dc9945 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -219,6 +219,24 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} " r"\mathopen{}\left({i}\mathclose{}\right)", ), + # reduce stop parameter + ( + "sum(i for i in range(n+1))", + r"\sum_{i = {0}}^{{n}} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "math.prod(i for i in range(n-1))", + r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)", + ), + # reduce stop parameter + ( + "sum(i for i in range(n+1))", + r"\sum_{i = {0}}^{{n}} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "math.prod(i for i in range(n-1))", + r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)", + ), ], ) def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None: