diff --git a/src/latexify/analyzers.py b/src/latexify/analyzers.py index 76ef75e..f6283b2 100644 --- a/src/latexify/analyzers.py +++ b/src/latexify/analyzers.py @@ -4,6 +4,7 @@ import ast import dataclasses +import sys from latexify import ast_utils, exceptions @@ -62,3 +63,43 @@ def analyze_range(node: ast.Call) -> RangeInfo: stop_int=ast_utils.extract_int_or_none(stop), step_int=ast_utils.extract_int_or_none(step), ) + + +def reduce_stop_parameter(node: ast.expr) -> ast.expr: + """Adjusts the stop expression of the range. + + This function tries to convert the syntax as follows: + * n + 1 --> n + * n + 2 --> n + 1 + * n - 1 --> n - 2 + + Args: + node: The target expression. + + Returns: + Converted expression. + """ + if not (isinstance(node, ast.BinOp) and isinstance(node.op, (ast.Add, ast.Sub))): + return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1)) + + # Treatment for Python 3.7. + rhs = ( + ast.Constant(value=node.right.n) + if sys.version_info.minor < 8 and isinstance(node.right, ast.Num) + else node.right + ) + + if not isinstance(rhs, ast.Constant): + return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1)) + + shift = 1 if isinstance(node.op, ast.Add) else -1 + + return ( + node.left + if rhs.value == shift + else ast.BinOp( + left=node.left, + op=node.op, + right=ast_utils.make_constant(value=rhs.value - shift), + ) + ) diff --git a/src/latexify/analyzers_test.py b/src/latexify/analyzers_test.py index ad13db0..52802d7 100644 --- a/src/latexify/analyzers_test.py +++ b/src/latexify/analyzers_test.py @@ -150,3 +150,20 @@ def test_analyze_range_invalid(code: str) -> None: exceptions.LatexifySyntaxError, match=r"^Unsupported AST for analyze_range\.$" ): analyzers.analyze_range(node) + + +@pytest.mark.parametrize( + "before,after", + [ + ("n + 1", "n"), + ("n + 2", "n + 1"), + ("n - (-1)", "n - (-1) - 1"), + ("n - 1", "n - 2"), + ("1 * 2", "1 * 2 - 1"), + ], +) +def test_reduce_stop_parameter(before: str, after: str) -> None: + test_utils.assert_ast_equal( + analyzers.reduce_stop_parameter(ast_utils.parse_expr(before)), + ast_utils.parse_expr(after), + ) diff --git a/src/latexify/codegen/__init__.py b/src/latexify/codegen/__init__.py index 3d06ec2..cddad8d 100644 --- a/src/latexify/codegen/__init__.py +++ b/src/latexify/codegen/__init__.py @@ -1,5 +1,6 @@ """Package latexify.codegen.""" -from latexify.codegen import function_codegen +from latexify.codegen import expression_codegen, function_codegen +ExpressionCodegen = expression_codegen.ExpressionCodegen FunctionCodegen = function_codegen.FunctionCodegen diff --git a/src/latexify/codegen/codegen_utils.py b/src/latexify/codegen/codegen_utils.py new file mode 100644 index 0000000..94faf25 --- /dev/null +++ b/src/latexify/codegen/codegen_utils.py @@ -0,0 +1,28 @@ +from typing import Any + +from latexify import exceptions + + +def convert_constant(value: Any) -> str: + """Helper to convert constant values to LaTeX. + + Args: + value: A constant value. + + Returns: + The LaTeX representation of `value`. + """ + if value is None or isinstance(value, bool): + return r"\mathrm{" + str(value) + "}" + if isinstance(value, (int, float, complex)): + # TODO(odashi): Support other symbols for the imaginary unit than j. + return str(value) + if isinstance(value, str): + return r'\textrm{"' + value + '"}' + if isinstance(value, bytes): + return r"\textrm{" + str(value) + "}" + if value is ...: + return r"\cdots" + raise exceptions.LatexifyNotSupportedError( + f"Unrecognized constant: {type(value).__name__}" + ) diff --git a/src/latexify/codegen/codegen_utils_test.py b/src/latexify/codegen/codegen_utils_test.py new file mode 100644 index 0000000..5b0f909 --- /dev/null +++ b/src/latexify/codegen/codegen_utils_test.py @@ -0,0 +1,34 @@ +"""Tests for latexify.codegen.codegen_utils.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from latexify import exceptions +from latexify.codegen.codegen_utils import convert_constant + + +@pytest.mark.parametrize( + "constant,latex", + [ + (None, r"\mathrm{None}"), + (True, r"\mathrm{True}"), + (False, r"\mathrm{False}"), + (123, "123"), + (456.789, "456.789"), + (-3 + 4j, "(-3+4j)"), + ("string", r'\textrm{"string"}'), + (..., r"\cdots"), + ], +) +def test_convert_constant(constant: Any, latex: str) -> None: + assert convert_constant(constant) == latex + + +def test_convert_constant_unsupported_constant() -> None: + with pytest.raises( + exceptions.LatexifyNotSupportedError, match="^Unrecognized constant: " + ): + convert_constant({}) diff --git a/src/latexify/codegen/expression_codegen.py b/src/latexify/codegen/expression_codegen.py new file mode 100644 index 0000000..54bee1b --- /dev/null +++ b/src/latexify/codegen/expression_codegen.py @@ -0,0 +1,672 @@ +"""Codegen for single expressions.""" + +from __future__ import annotations + +import ast +import dataclasses + +from latexify import analyzers, ast_utils, constants, exceptions +from latexify.codegen import codegen_utils, identifier_converter + +# Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes. +# Note that this value affects only the appearance of surrounding parentheses for each +# expression, and does not affect the AST itself. +# See also: +# https://docs.python.org/3/reference/expressions.html#operator-precedence +_PRECEDENCES: dict[type[ast.AST], int] = { + ast.Pow: 120, + ast.UAdd: 110, + ast.USub: 110, + ast.Invert: 110, + ast.Mult: 100, + ast.MatMult: 100, + ast.Div: 100, + ast.FloorDiv: 100, + ast.Mod: 100, + ast.Add: 90, + ast.Sub: 90, + ast.LShift: 80, + ast.RShift: 80, + ast.BitAnd: 70, + ast.BitXor: 60, + ast.BitOr: 50, + ast.In: 40, + ast.NotIn: 40, + ast.Is: 40, + ast.IsNot: 40, + ast.Lt: 40, + ast.LtE: 40, + ast.Gt: 40, + ast.GtE: 40, + ast.NotEq: 40, + ast.Eq: 40, + # NOTE(odashi): + # We assume that the `not` operator has the same precedence with other unary + # operators `+`, `-` and `~`, because the LaTeX counterpart $\lnot$ looks to have a + # high precedence. + # ast.Not: 30, + ast.Not: 110, + ast.And: 20, + ast.Or: 10, +} + +# NOTE(odashi): +# Function invocation is treated as a unary operator with a higher precedence. +# This ensures that the argument with a unary operator is wrapped: +# exp(x) --> \exp x +# exp(-x) --> \exp (-x) +# -exp(x) --> - \exp x +_CALL_PRECEDENCE = _PRECEDENCES[ast.UAdd] + 1 + + +def _get_precedence(node: ast.AST) -> int: + """Obtains the precedence of the subtree. + + Args: + node: Subtree to investigate. + + Returns: + If `node` is a subtree with some operator, returns the precedence of the + operator. Otherwise, returns a number larger enough from other precedences. + """ + if isinstance(node, ast.Call): + return _CALL_PRECEDENCE + + if isinstance(node, (ast.BoolOp, ast.BinOp, ast.UnaryOp)): + return _PRECEDENCES[type(node.op)] + + if isinstance(node, ast.Compare): + # Compare operators have the same precedence. It is enough to check only the + # first operator. + return _PRECEDENCES[type(node.ops[0])] + + return 1_000_000 + + +@dataclasses.dataclass(frozen=True) +class BinOperandRule: + """Syntax rules for operands of BinOp.""" + + # Whether to require wrapping operands by parentheses according to the precedence. + wrap: bool = True + + # Whether to require wrapping operands by parentheses if the operand has the same + # precedence with this operator. + # This is used to control the behavior of non-associative operators. + force: bool = False + + +@dataclasses.dataclass(frozen=True) +class BinOpRule: + """Syntax rules for BinOp.""" + + # Left/middle/right syntaxes to wrap operands. + latex_left: str + latex_middle: str + latex_right: str + + # Operand rules. + operand_left: BinOperandRule = dataclasses.field(default_factory=BinOperandRule) + operand_right: BinOperandRule = dataclasses.field(default_factory=BinOperandRule) + + # Whether to assume the resulting syntax is wrapped by some bracket operators. + # If True, the parent operator can avoid wrapping this operator by parentheses. + is_wrapped: bool = False + + +_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = { + ast.Pow: BinOpRule( + "", + "^{", + "}", + operand_left=BinOperandRule(force=True), + operand_right=BinOperandRule(wrap=False), + ), + ast.Mult: BinOpRule("", " ", ""), + ast.MatMult: BinOpRule("", " ", ""), + ast.Div: BinOpRule( + r"\frac{", + "}{", + "}", + operand_left=BinOperandRule(wrap=False), + operand_right=BinOperandRule(wrap=False), + ), + ast.FloorDiv: BinOpRule( + r"\left\lfloor\frac{", + "}{", + r"}\right\rfloor", + operand_left=BinOperandRule(wrap=False), + operand_right=BinOperandRule(wrap=False), + is_wrapped=True, + ), + ast.Mod: BinOpRule( + "", r" \mathbin{\%} ", "", operand_right=BinOperandRule(force=True) + ), + ast.Add: BinOpRule("", " + ", ""), + ast.Sub: BinOpRule("", " - ", "", operand_right=BinOperandRule(force=True)), + ast.LShift: BinOpRule("", r" \ll ", "", operand_right=BinOperandRule(force=True)), + ast.RShift: BinOpRule("", r" \gg ", "", operand_right=BinOperandRule(force=True)), + ast.BitAnd: BinOpRule("", r" \mathbin{\&} ", ""), + ast.BitXor: BinOpRule("", r" \oplus ", ""), + ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""), +} + +# Typeset for BinOp of sets. +_SET_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = { + **_BIN_OP_RULES, + ast.Sub: BinOpRule( + "", r" \setminus ", "", operand_right=BinOperandRule(force=True) + ), + ast.BitAnd: BinOpRule("", r" \cap ", ""), + ast.BitXor: BinOpRule("", r" \mathbin{\triangle} ", ""), + ast.BitOr: BinOpRule("", r" \cup ", ""), +} + +_UNARY_OPS: dict[type[ast.unaryop], str] = { + ast.Invert: r"\mathord{\sim} ", + ast.UAdd: "+", # Explicitly adds the $+$ operator. + ast.USub: "-", + ast.Not: r"\lnot ", +} + +_COMPARE_OPS: dict[type[ast.cmpop], str] = { + ast.Eq: "=", + ast.Gt: ">", + ast.GtE: r"\ge", + ast.In: r"\in", + ast.Is: r"\equiv", + ast.IsNot: r"\not\equiv", + ast.Lt: "<", + ast.LtE: r"\le", + ast.NotEq: r"\ne", + ast.NotIn: r"\notin", +} + +# Typeset for Compare of sets. +_SET_COMPARE_OPS: dict[type[ast.cmpop], str] = { + **_COMPARE_OPS, + ast.Gt: r"\supset", + ast.GtE: r"\supseteq", + ast.Lt: r"\subset", + ast.LtE: r"\subseteq", +} + +_BOOL_OPS: dict[type[ast.boolop], str] = { + ast.And: r"\land", + ast.Or: r"\lor", +} + + +class ExpressionCodegen(ast.NodeVisitor): + """Codegen for single expressions.""" + + _identifier_converter: identifier_converter.IdentifierConverter + + _bin_op_rules: dict[type[ast.operator], BinOpRule] + _compare_ops: dict[type[ast.cmpop], str] + + def __init__( + self, *, use_math_symbols: bool = False, use_set_symbols: bool = False + ) -> None: + """Initializer. + + Args: + use_math_symbols: Whether to convert identifiers with a math symbol surface + (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). + use_set_symbols: Whether to use set symbols or not. + """ + self._identifier_converter = identifier_converter.IdentifierConverter( + use_math_symbols=use_math_symbols + ) + + self._bin_op_rules = _SET_BIN_OP_RULES if use_set_symbols else _BIN_OP_RULES + self._compare_ops = _SET_COMPARE_OPS if use_set_symbols else _COMPARE_OPS + + def generic_visit(self, node: ast.AST) -> str: + raise exceptions.LatexifyNotSupportedError( + f"Unsupported AST: {type(node).__name__}" + ) + + def visit_Tuple(self, node: ast.Tuple) -> str: + """Visit a Tuple node.""" + elts = [self.visit(elt) for elt in node.elts] + return r"\mathopen{}\left( " + r", ".join(elts) + r" \mathclose{}\right)" + + def visit_List(self, node: ast.List) -> str: + """Visit a List node.""" + elts = [self.visit(elt) for elt in node.elts] + return r"\mathopen{}\left[ " + r", ".join(elts) + r" \mathclose{}\right]" + + def visit_Set(self, node: ast.Set) -> str: + """Visit a Set node.""" + elts = [self.visit(elt) for elt in node.elts] + return r"\mathopen{}\left\{ " + r", ".join(elts) + r" \mathclose{}\right\}" + + def visit_ListComp(self, node: ast.ListComp) -> str: + """Visit a ListComp node.""" + generators = [self.visit(comp) for comp in node.generators] + return ( + r"\mathopen{}\left[ " + + self.visit(node.elt) + + r" \mid " + + ", ".join(generators) + + r" \mathclose{}\right]" + ) + + def visit_SetComp(self, node: ast.SetComp) -> str: + """Visit a SetComp node.""" + generators = [self.visit(comp) for comp in node.generators] + return ( + r"\mathopen{}\left\{ " + + self.visit(node.elt) + + r" \mid " + + ", ".join(generators) + + r" \mathclose{}\right\}" + ) + + def visit_comprehension(self, node: ast.comprehension) -> str: + """Visit a comprehension node.""" + target = rf"{self.visit(node.target)} \in {self.visit(node.iter)}" + + if not node.ifs: + # Returns the source without parenthesis. + return target + + conds = [target] + [self.visit(cond) for cond in node.ifs] + wrapped = [r"\mathopen{}\left( " + s + r" \mathclose{}\right)" for s in conds] + return r" \land ".join(wrapped) + + def _generate_sum_prod(self, node: ast.Call) -> str | None: + """Generates sum/prod expression. + + Args: + node: ast.Call node containing the sum/prod invocation. + + Returns: + Generated LaTeX, or None if the node has unsupported syntax. + """ + if not isinstance(node.args[0], ast.GeneratorExp): + return None + + name = ast_utils.extract_function_name_or_none(node) + assert name in ("fsum", "sum", "prod") + + command = { + "fsum": r"\sum", + "sum": r"\sum", + "prod": r"\prod", + }[name] + + elt, scripts = self._get_sum_prod_info(node.args[0]) + scripts_str = [rf"{command}_{{{lo}}}^{{{up}}}" for lo, up in scripts] + return ( + " ".join(scripts_str) + + rf" \mathopen{{}}\left({{{elt}}}\mathclose{{}}\right)" + ) + + def _generate_matrix(self, node: ast.Call) -> str | None: + """Generates matrix expression. + + Args: + node: ast.Call node containing the ndarray invocation. + + Returns: + Generated LaTeX, or None if the node has unsupported syntax. + """ + + def generate_matrix_from_array(data: list[list[str]]) -> str: + """Helper to generate a bmatrix environment.""" + contents = r" \\ ".join(" & ".join(row) for row in data) + return r"\begin{bmatrix} " + contents + r" \end{bmatrix}" + + arg = node.args[0] + if not isinstance(arg, ast.List) or not arg.elts: + # Not an array or no rows + return None + + row0 = arg.elts[0] + + if not isinstance(row0, ast.List): + # Maybe 1 x N array + return generate_matrix_from_array([[self.visit(x) for x in arg.elts]]) + + if not row0.elts: + # No columns + return None + + ncols = len(row0.elts) + + rows: list[list[str]] = [] + + for row in arg.elts: + if not isinstance(row, ast.List) or len(row.elts) != ncols: + # Length mismatch + return None + + rows.append([self.visit(x) for x in row.elts]) + + return generate_matrix_from_array(rows) + + def visit_Call(self, node: ast.Call) -> str: + """Visit a Call node.""" + func_name = ast_utils.extract_function_name_or_none(node) + + # Special treatments for some functions. + if func_name in ("fsum", "sum", "prod"): + special_latex = self._generate_sum_prod(node) + elif func_name in ("array", "ndarray"): + special_latex = self._generate_matrix(node) + else: + special_latex = None + + if special_latex is not None: + return special_latex + + # Obtains the codegen rule. + rule = constants.BUILTIN_FUNCS.get(func_name) if func_name is not None else None + + if rule is None: + rule = constants.FunctionRule(self.visit(node.func)) + + if rule.is_unary and len(node.args) == 1: + # Unary function. Applies the same wrapping policy with the unary operators. + # NOTE(odashi): + # Factorial "x!" is treated as a special case: it requires both inner/outer + # parentheses for correct interpretation. + precedence = _get_precedence(node) + arg = node.args[0] + force_wrap = isinstance(arg, ast.Call) and ( + func_name == "factorial" + or ast_utils.extract_function_name_or_none(arg) == "factorial" + ) + arg_latex = self._wrap_operand(arg, precedence, force_wrap) + elements = [rule.left, arg_latex, rule.right] + else: + arg_latex = ", ".join(self.visit(arg) for arg in node.args) + if rule.is_wrapped: + elements = [rule.left, arg_latex, rule.right] + else: + elements = [ + rule.left, + r"\mathopen{}\left(", + arg_latex, + r"\mathclose{}\right)", + rule.right, + ] + + return " ".join(x for x in elements if x) + + def visit_Attribute(self, node: ast.Attribute) -> str: + """Visit an Attribute node.""" + vstr = self.visit(node.value) + astr = self._identifier_converter.convert(node.attr)[0] + return vstr + "." + astr + + def visit_Name(self, node: ast.Name) -> str: + """Visit a Name node.""" + return self._identifier_converter.convert(node.id)[0] + + # From Python 3.8 + def visit_Constant(self, node: ast.Constant) -> str: + """Visit a Constant node.""" + return codegen_utils.convert_constant(node.value) + + # Until Python 3.7 + def visit_Num(self, node: ast.Num) -> str: + """Visit a Num node.""" + return codegen_utils.convert_constant(node.n) + + # Until Python 3.7 + def visit_Str(self, node: ast.Str) -> str: + """Visit a Str node.""" + return codegen_utils.convert_constant(node.s) + + # Until Python 3.7 + def visit_Bytes(self, node: ast.Bytes) -> str: + """Visit a Bytes node.""" + return codegen_utils.convert_constant(node.s) + + # Until Python 3.7 + def visit_NameConstant(self, node: ast.NameConstant) -> str: + """Visit a NameConstant node.""" + return codegen_utils.convert_constant(node.value) + + # Until Python 3.7 + def visit_Ellipsis(self, node: ast.Ellipsis) -> str: + """Visit an Ellipsis node.""" + return codegen_utils.convert_constant(...) + + def _wrap_operand( + self, child: ast.expr, parent_prec: int, force_wrap: bool = False + ) -> str: + """Wraps the operand subtree with parentheses. + + Args: + child: Operand subtree. + parent_prec: Precedence of the parent operator. + force_wrap: Whether to wrap the operand or not when the precedence is equal. + + Returns: + LaTeX form of `child`, with or without surrounding parentheses. + """ + latex = self.visit(child) + child_prec = _get_precedence(child) + + if child_prec < parent_prec or force_wrap and child_prec == parent_prec: + return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)" + + return latex + + def _wrap_binop_operand( + self, + child: ast.expr, + parent_prec: int, + operand_rule: BinOperandRule, + ) -> str: + """Wraps the operand subtree of BinOp with parentheses. + + Args: + child: Operand subtree. + parent_prec: Precedence of the parent operator. + operand_rule: Syntax rule of this operand. + + Returns: + LaTeX form of the `child`, with or without surrounding parentheses. + """ + if not operand_rule.wrap: + return self.visit(child) + + if isinstance(child, ast.Call): + child_fn_name = ast_utils.extract_function_name_or_none(child) + rule = ( + constants.BUILTIN_FUNCS.get(child_fn_name) + if child_fn_name is not None + else None + ) + if rule is not None and rule.is_wrapped: + return self.visit(child) + + if not isinstance(child, ast.BinOp): + return self._wrap_operand(child, parent_prec) + + latex = self.visit(child) + + if _BIN_OP_RULES[type(child.op)].is_wrapped: + return latex + + child_prec = _get_precedence(child) + + if child_prec > parent_prec or ( + child_prec == parent_prec and not operand_rule.force + ): + return latex + + return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)" + + def visit_BinOp(self, node: ast.BinOp) -> str: + """Visit a BinOp node.""" + prec = _get_precedence(node) + rule = self._bin_op_rules[type(node.op)] + lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left) + rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right) + return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}" + + def visit_UnaryOp(self, node: ast.UnaryOp) -> str: + """Visit a UnaryOp node.""" + latex = self._wrap_operand(node.operand, _get_precedence(node)) + return _UNARY_OPS[type(node.op)] + latex + + def visit_Compare(self, node: ast.Compare) -> str: + """Visit a Compare node.""" + parent_prec = _get_precedence(node) + lhs = self._wrap_operand(node.left, parent_prec) + ops = [self._compare_ops[type(x)] for x in node.ops] + rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators] + ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)] + return lhs + "".join(ops_rhs) + + def visit_BoolOp(self, node: ast.BoolOp) -> str: + """Visit a BoolOp node.""" + parent_prec = _get_precedence(node) + values = [self._wrap_operand(x, parent_prec) for x in node.values] + op = f" {_BOOL_OPS[type(node.op)]} " + return op.join(values) + + def visit_IfExp(self, node: ast.IfExp) -> str: + """Visit an IfExp node""" + latex = r"\left\{ \begin{array}{ll} " + + current_expr: ast.expr = node + + while isinstance(current_expr, ast.IfExp): + cond_latex = self.visit(current_expr.test) + true_latex = self.visit(current_expr.body) + latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ " + current_expr = current_expr.orelse + + latex += self.visit(current_expr) + return latex + r", & \mathrm{otherwise} \end{array} \right." + + def _get_sum_prod_range(self, node: ast.comprehension) -> tuple[str, str] | None: + """Helper to process range(...) for sum and prod functions. + + Args: + node: comprehension node to be analyzed. + + Returns: + Tuple of following strings: + - lower_rhs + - upper + which are used in _get_sum_prod_info, or None if the analysis failed. + """ + if not ( + isinstance(node.iter, ast.Call) + and isinstance(node.iter.func, ast.Name) + and node.iter.func.id == "range" + ): + return None + + try: + range_info = analyzers.analyze_range(node.iter) + except exceptions.LatexifyError: + return None + + if ( + # Only accepts ascending order with step size 1. + range_info.step_int != 1 + or ( + range_info.start_int is not None + and range_info.stop_int is not None + and range_info.start_int >= range_info.stop_int + ) + ): + return None + + if range_info.start_int is None: + lower_rhs = self.visit(range_info.start) + else: + lower_rhs = str(range_info.start_int) + + if range_info.stop_int is None: + upper = self.visit(analyzers.reduce_stop_parameter(range_info.stop)) + else: + upper = str(range_info.stop_int - 1) + + return lower_rhs, upper + + def _get_sum_prod_info( + self, node: ast.GeneratorExp + ) -> tuple[str, list[tuple[str, str]]]: + r"""Process GeneratorExp for sum and prod functions. + + Args: + node: GeneratorExp node to be analyzed. + + Returns: + Tuple of following strings: + - elt + - scripts + which are used to represent sum/prod operators as follows: + \sum_{scripts[0][0]}^{scripts[0][1]} + \sum_{scripts[1][0]}^{scripts[1][1]} + ... + {elt} + + Raises: + LateixfyError: Unsupported AST is given. + """ + elt = self.visit(node.elt) + + scripts: list[tuple[str, str]] = [] + + for comp in node.generators: + range_args = self._get_sum_prod_range(comp) + + if range_args is not None and not comp.ifs: + target = self.visit(comp.target) + lower_rhs, upper = range_args + lower = f"{target} = {lower_rhs}" + else: + lower = self.visit(comp) # Use a usual comprehension form. + upper = "" + + scripts.append((lower, upper)) + + return elt, scripts + + # Until 3.8 + def visit_Index(self, node: ast.Index) -> str: + """Visit an Index node.""" + return self.visit(node.value) # type: ignore[attr-defined] + + def _convert_nested_subscripts(self, node: ast.Subscript) -> tuple[str, list[str]]: + """Helper function to convert nested subscription. + + This function converts x[i][j][...] to "x" and ["i", "j", ...] + + Args: + node: ast.Subscript node to be converted. + + Returns: + Tuple of following strings: + - The root value of the subscription. + - Sequence of incices. + """ + if isinstance(node.value, ast.Subscript): + value, indices = self._convert_nested_subscripts(node.value) + else: + value = self.visit(node.value) + indices = [] + + indices.append(self.visit(node.slice)) + return value, indices + + def visit_Subscript(self, node: ast.Subscript) -> str: + """Visitor a Subscript node.""" + value, indices = self._convert_nested_subscripts(node) + + # TODO(odashi): + # "[i][j][...]" may be a possible representation as well as "i, j. ..." + indices_str = ", ".join(indices) + + return f"{value}_{{{indices_str}}}" diff --git a/src/latexify/codegen/expression_codegen_test.py b/src/latexify/codegen/expression_codegen_test.py new file mode 100644 index 0000000..5368a13 --- /dev/null +++ b/src/latexify/codegen/expression_codegen_test.py @@ -0,0 +1,909 @@ +"""Tests for latexify.codegen.expression_codegen.""" + +from __future__ import annotations + +import ast + +import pytest + +from latexify import ast_utils, exceptions, test_utils +from latexify.codegen import ExpressionCodegen + + +def test_generic_visit() -> None: + class UnknownNode(ast.AST): + pass + + with pytest.raises( + exceptions.LatexifyNotSupportedError, + match=r"^Unsupported AST: UnknownNode$", + ): + ExpressionCodegen().visit(UnknownNode()) + + +@pytest.mark.parametrize( + "code,latex", + [ + ("()", r"\mathopen{}\left( \mathclose{}\right)"), + ("(x,)", r"\mathopen{}\left( x \mathclose{}\right)"), + ("(x, y)", r"\mathopen{}\left( x, y \mathclose{}\right)"), + ("(x, y, z)", r"\mathopen{}\left( x, y, z \mathclose{}\right)"), + ], +) +def test_visit_tuple(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.Tuple) + assert ExpressionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("[]", r"\mathopen{}\left[ \mathclose{}\right]"), + ("[x]", r"\mathopen{}\left[ x \mathclose{}\right]"), + ("[x, y]", r"\mathopen{}\left[ x, y \mathclose{}\right]"), + ("[x, y, z]", r"\mathopen{}\left[ x, y, z \mathclose{}\right]"), + ], +) +def test_visit_list(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.List) + assert ExpressionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + # TODO(odashi): Support set(). + # ("set()", r"\mathopen{}\left\{ \mathclose{}\right\}"), + ("{x}", r"\mathopen{}\left\{ x \mathclose{}\right\}"), + ("{x, y}", r"\mathopen{}\left\{ x, y \mathclose{}\right\}"), + ("{x, y, z}", r"\mathopen{}\left\{ x, y, z \mathclose{}\right\}"), + ], +) +def test_visit_set(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.Set) + assert ExpressionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("[i for i in n]", r"\mathopen{}\left[ i \mid i \in n \mathclose{}\right]"), + ( + "[i for i in n if i > 0]", + r"\mathopen{}\left[ i \mid" + r" \mathopen{}\left( i \in n \mathclose{}\right)" + r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" + r" \mathclose{}\right]", + ), + ( + "[i for i in n if i > 0 if f(i)]", + r"\mathopen{}\left[ i \mid" + r" \mathopen{}\left( i \in n \mathclose{}\right)" + r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" + r" \land \mathopen{}\left( f \mathopen{}\left(" + r" i \mathclose{}\right) \mathclose{}\right)" + r" \mathclose{}\right]", + ), + ( + "[i for k in n for i in k]", + r"\mathopen{}\left[ i \mid k \in n, i \in k" r" \mathclose{}\right]", + ), + ( + "[i for k in n for i in k if i > 0]", + r"\mathopen{}\left[ i \mid" + r" k \in n," + r" \mathopen{}\left( i \in k \mathclose{}\right)" + r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" + r" \mathclose{}\right]", + ), + ( + "[i for k in n if f(k) for i in k if i > 0]", + r"\mathopen{}\left[ i \mid" + r" \mathopen{}\left( k \in n \mathclose{}\right)" + r" \land \mathopen{}\left( f \mathopen{}\left(" + r" k \mathclose{}\right) \mathclose{}\right)," + r" \mathopen{}\left( i \in k \mathclose{}\right)" + r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" + r" \mathclose{}\right]", + ), + ], +) +def test_visit_listcomp(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.ListComp) + assert ExpressionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("{i for i in n}", r"\mathopen{}\left\{ i \mid i \in n \mathclose{}\right\}"), + ( + "{i for i in n if i > 0}", + r"\mathopen{}\left\{ i \mid" + r" \mathopen{}\left( i \in n \mathclose{}\right)" + r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" + r" \mathclose{}\right\}", + ), + ( + "{i for i in n if i > 0 if f(i)}", + r"\mathopen{}\left\{ i \mid" + r" \mathopen{}\left( i \in n \mathclose{}\right)" + r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" + r" \land \mathopen{}\left( f \mathopen{}\left(" + r" i \mathclose{}\right) \mathclose{}\right)" + r" \mathclose{}\right\}", + ), + ( + "{i for k in n for i in k}", + r"\mathopen{}\left\{ i \mid k \in n, i \in k" r" \mathclose{}\right\}", + ), + ( + "{i for k in n for i in k if i > 0}", + r"\mathopen{}\left\{ i \mid" + r" k \in n," + r" \mathopen{}\left( i \in k \mathclose{}\right)" + r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" + r" \mathclose{}\right\}", + ), + ( + "{i for k in n if f(k) for i in k if i > 0}", + r"\mathopen{}\left\{ i \mid" + r" \mathopen{}\left( k \in n \mathclose{}\right)" + r" \land \mathopen{}\left( f \mathopen{}\left(" + r" k \mathclose{}\right) \mathclose{}\right)," + r" \mathopen{}\left( i \in k \mathclose{}\right)" + r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" + r" \mathclose{}\right\}", + ), + ], +) +def test_visit_setcomp(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.SetComp) + assert ExpressionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("foo(x)", r"\mathrm{foo} \mathopen{}\left( x \mathclose{}\right)"), + ("f(x)", r"f \mathopen{}\left( x \mathclose{}\right)"), + ("f(-x)", r"f \mathopen{}\left( -x \mathclose{}\right)"), + ("f(x + y)", r"f \mathopen{}\left( x + y \mathclose{}\right)"), + ( + "f(f(x))", + r"f \mathopen{}\left(" + r" f \mathopen{}\left( x \mathclose{}\right)" + r" \mathclose{}\right)", + ), + ("f(sqrt(x))", r"f \mathopen{}\left( \sqrt{ x } \mathclose{}\right)"), + ("f(sin(x))", r"f \mathopen{}\left( \sin x \mathclose{}\right)"), + ("f(factorial(x))", r"f \mathopen{}\left( x ! \mathclose{}\right)"), + ("f(x, y)", r"f \mathopen{}\left( x, y \mathclose{}\right)"), + ("sqrt(x)", r"\sqrt{ x }"), + ("sqrt(-x)", r"\sqrt{ -x }"), + ("sqrt(x + y)", r"\sqrt{ x + y }"), + ("sqrt(f(x))", r"\sqrt{ f \mathopen{}\left( x \mathclose{}\right) }"), + ("sqrt(sqrt(x))", r"\sqrt{ \sqrt{ x } }"), + ("sqrt(sin(x))", r"\sqrt{ \sin x }"), + ("sqrt(factorial(x))", r"\sqrt{ x ! }"), + ("sin(x)", r"\sin x"), + ("sin(-x)", r"\sin \mathopen{}\left( -x \mathclose{}\right)"), + ("sin(x + y)", r"\sin \mathopen{}\left( x + y \mathclose{}\right)"), + ("sin(f(x))", r"\sin f \mathopen{}\left( x \mathclose{}\right)"), + ("sin(sqrt(x))", r"\sin \sqrt{ x }"), + ("sin(sin(x))", r"\sin \sin x"), + ("sin(factorial(x))", r"\sin \mathopen{}\left( x ! \mathclose{}\right)"), + ("factorial(x)", r"x !"), + ("factorial(-x)", r"\mathopen{}\left( -x \mathclose{}\right) !"), + ("factorial(x + y)", r"\mathopen{}\left( x + y \mathclose{}\right) !"), + ( + "factorial(f(x))", + r"\mathopen{}\left(" + r" f \mathopen{}\left( x \mathclose{}\right)" + r" \mathclose{}\right) !", + ), + ("factorial(sqrt(x))", r"\mathopen{}\left( \sqrt{ x } \mathclose{}\right) !"), + ("factorial(sin(x))", r"\mathopen{}\left( \sin x \mathclose{}\right) !"), + ("factorial(factorial(x))", r"\mathopen{}\left( x ! \mathclose{}\right) !"), + ], +) +def test_visit_call(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.Call) + assert ExpressionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "src_suffix,dest_suffix", + [ + # No comprehension + ("(x)", r" x"), + ( + "([1, 2])", + r" \mathopen{}\left[ 1, 2 \mathclose{}\right]", + ), + ( + "({1, 2})", + r" \mathopen{}\left\{ 1, 2 \mathclose{}\right\}", + ), + ("(f(x))", r" f \mathopen{}\left( x \mathclose{}\right)"), + # Single comprehension + ("(i for i in x)", r"_{i \in x}^{} \mathopen{}\left({i}\mathclose{}\right)"), + ( + "(i for i in [1, 2])", + r"_{i \in \mathopen{}\left[ 1, 2 \mathclose{}\right]}^{} " + r"\mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in {1, 2})", + r"_{i \in \mathopen{}\left\{ 1, 2 \mathclose{}\right\}}^{}" + r" \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in f(x))", + r"_{i \in f \mathopen{}\left( x \mathclose{}\right)}^{}" + r" \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(n))", + r"_{i = 0}^{n - 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(n + 1))", + r"_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(n + 2))", + r"_{i = 0}^{n + 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + # ast.parse() does not recognize negative integers. + "(i for i in range(n - -1))", + r"_{i = 0}^{n - -1 - 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(n - 1))", + r"_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(n + m))", + r"_{i = 0}^{n + m - 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(n - m))", + r"_{i = 0}^{n - m - 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(3))", + r"_{i = 0}^{2} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(3 + 1))", + r"_{i = 0}^{3} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(3 + 2))", + r"_{i = 0}^{3 + 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(3 - 1))", + r"_{i = 0}^{3 - 2} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + # ast.parse() does not recognize negative integers. + "(i for i in range(3 - -1))", + r"_{i = 0}^{3 - -1 - 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(3 + m))", + r"_{i = 0}^{3 + m - 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(3 - m))", + r"_{i = 0}^{3 - m - 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(n, m))", + r"_{i = n}^{m - 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(1, m))", + r"_{i = 1}^{m - 1} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(n, 3))", + r"_{i = n}^{2} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in range(n, m, k))", + r"_{i \in \mathrm{range} \mathopen{}\left( n, m, k \mathclose{}\right)}^{}" + r" \mathopen{}\left({i}\mathclose{}\right)", + ), + ], +) +def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: + for src_fn, dest_fn in [("fsum", r"\sum"), ("sum", r"\sum"), ("prod", r"\prod")]: + node = ast_utils.parse_expr(src_fn + src_suffix) + assert isinstance(node, ast.Call) + assert ExpressionCodegen().visit(node) == dest_fn + dest_suffix + + +@pytest.mark.parametrize( + "code,latex", + [ + # 2 clauses + ( + "sum(i for y in x for i in y)", + r"\sum_{y \in x}^{} \sum_{i \in y}^{} " + r"\mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "sum(i for y in x for z in y for i in z)", + r"\sum_{y \in x}^{} \sum_{z \in y}^{} \sum_{i \in z}^{} " + r"\mathopen{}\left({i}\mathclose{}\right)", + ), + # 3 clauses + ( + "prod(i for y in x for i in y)", + r"\prod_{y \in x}^{} \prod_{i \in y}^{} " + r"\mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "prod(i for y in x for z in y for i in z)", + r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} " + r"\mathopen{}\left({i}\mathclose{}\right)", + ), + # reduce stop parameter + ( + "sum(i for i in range(n+1))", + r"\sum_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "prod(i for i in range(n-1))", + r"\prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)", + ), + # reduce stop parameter + ( + "sum(i for i in range(n+1))", + r"\sum_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "prod(i for i in range(n-1))", + r"\prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)", + ), + ], +) +def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.Call) + assert ExpressionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "src_suffix,dest_suffix", + [ + ( + "(i for i in x if i < y)", + r"_{\mathopen{}\left( i \in x \mathclose{}\right) " + r"\land \mathopen{}\left( i < y \mathclose{}\right)}^{} " + r"\mathopen{}\left({i}\mathclose{}\right)", + ), + ( + "(i for i in x if i < y if f(i))", + r"_{\mathopen{}\left( i \in x \mathclose{}\right)" + r" \land \mathopen{}\left( i < y \mathclose{}\right)" + r" \land \mathopen{}\left( f \mathopen{}\left(" + r" i \mathclose{}\right) \mathclose{}\right)}^{}" + r" \mathopen{}\left({i}\mathclose{}\right)", + ), + ], +) +def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None: + for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]: + node = ast_utils.parse_expr(src_fn + src_suffix) + assert isinstance(node, ast.Call) + assert ExpressionCodegen().visit(node) == dest_fn + dest_suffix + + +@pytest.mark.parametrize( + "code,latex", + [ + ( + "x if x < y else y", + r"\left\{ \begin{array}{ll}" + r" x, & \mathrm{if} \ x < y \\" + r" y, & \mathrm{otherwise}" + r" \end{array} \right.", + ), + ( + "x if x < y else (y if y < z else z)", + r"\left\{ \begin{array}{ll}" + r" x, & \mathrm{if} \ x < y \\" + r" y, & \mathrm{if} \ y < z \\" + r" z, & \mathrm{otherwise}" + r" \end{array} \right.", + ), + ( + "x if x < y else (y if y < z else (z if z < w else w))", + r"\left\{ \begin{array}{ll}" + r" x, & \mathrm{if} \ x < y \\" + r" y, & \mathrm{if} \ y < z \\" + r" z, & \mathrm{if} \ z < w \\" + r" w, & \mathrm{otherwise}" + r" \end{array} \right.", + ), + ], +) +def test_if_then_else(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.IfExp) + assert ExpressionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + # x op y + ("x**y", r"x^{y}"), + ("x * y", r"x y"), + ("x @ y", r"x y"), + ("x / y", r"\frac{x}{y}"), + ("x // y", r"\left\lfloor\frac{x}{y}\right\rfloor"), + ("x % y", r"x \mathbin{\%} y"), + ("x + y", r"x + y"), + ("x - y", r"x - y"), + ("x << y", r"x \ll y"), + ("x >> y", r"x \gg y"), + ("x & y", r"x \mathbin{\&} y"), + ("x ^ y", r"x \oplus y"), + ("x | y", R"x \mathbin{|} y"), + # (x op y) op z + ("(x**y)**z", r"\mathopen{}\left( x^{y} \mathclose{}\right)^{z}"), + ("(x * y) * z", r"x y z"), + ("(x @ y) @ z", r"x y z"), + ("(x / y) / z", r"\frac{\frac{x}{y}}{z}"), + ( + "(x // y) // z", + r"\left\lfloor\frac{\left\lfloor\frac{x}{y}\right\rfloor}{z}\right\rfloor", + ), + ("(x % y) % z", r"x \mathbin{\%} y \mathbin{\%} z"), + ("(x + y) + z", r"x + y + z"), + ("(x - y) - z", r"x - y - z"), + ("(x << y) << z", r"x \ll y \ll z"), + ("(x >> y) >> z", r"x \gg y \gg z"), + ("(x & y) & z", r"x \mathbin{\&} y \mathbin{\&} z"), + ("(x ^ y) ^ z", r"x \oplus y \oplus z"), + ("(x | y) | z", r"x \mathbin{|} y \mathbin{|} z"), + # x op (y op z) + ("x**(y**z)", r"x^{y^{z}}"), + ("x * (y * z)", r"x y z"), + ("x @ (y @ z)", r"x y z"), + ("x / (y / z)", r"\frac{x}{\frac{y}{z}}"), + ( + "x // (y // z)", + r"\left\lfloor\frac{x}{\left\lfloor\frac{y}{z}\right\rfloor}\right\rfloor", + ), + ( + "x % (y % z)", + r"x \mathbin{\%} \mathopen{}\left( y \mathbin{\%} z \mathclose{}\right)", + ), + ("x + (y + z)", r"x + y + z"), + ("x - (y - z)", r"x - \mathopen{}\left( y - z \mathclose{}\right)"), + ("x << (y << z)", r"x \ll \mathopen{}\left( y \ll z \mathclose{}\right)"), + ("x >> (y >> z)", r"x \gg \mathopen{}\left( y \gg z \mathclose{}\right)"), + ("x & (y & z)", r"x \mathbin{\&} y \mathbin{\&} z"), + ("x ^ (y ^ z)", r"x \oplus y \oplus z"), + ("x | (y | z)", r"x \mathbin{|} y \mathbin{|} z"), + # x OP y op z + ("x**y * z", r"x^{y} z"), + ("x * y + z", r"x y + z"), + ("x @ y + z", r"x y + z"), + ("x / y + z", r"\frac{x}{y} + z"), + ("x // y + z", r"\left\lfloor\frac{x}{y}\right\rfloor + z"), + ("x % y + z", r"x \mathbin{\%} y + z"), + ("x + y << z", r"x + y \ll z"), + ("x - y << z", r"x - y \ll z"), + ("x << y & z", r"x \ll y \mathbin{\&} z"), + ("x >> y & z", r"x \gg y \mathbin{\&} z"), + ("x & y ^ z", r"x \mathbin{\&} y \oplus z"), + ("x ^ y | z", r"x \oplus y \mathbin{|} z"), + # x OP (y op z) + ("x**(y * z)", r"x^{y z}"), + ("x * (y + z)", r"x \mathopen{}\left( y + z \mathclose{}\right)"), + ("x @ (y + z)", r"x \mathopen{}\left( y + z \mathclose{}\right)"), + ("x / (y + z)", r"\frac{x}{y + z}"), + ("x // (y + z)", r"\left\lfloor\frac{x}{y + z}\right\rfloor"), + ("x % (y + z)", r"x \mathbin{\%} \mathopen{}\left( y + z \mathclose{}\right)"), + ("x + (y << z)", r"x + \mathopen{}\left( y \ll z \mathclose{}\right)"), + ("x - (y << z)", r"x - \mathopen{}\left( y \ll z \mathclose{}\right)"), + ( + "x << (y & z)", + r"x \ll \mathopen{}\left( y \mathbin{\&} z \mathclose{}\right)", + ), + ( + "x >> (y & z)", + r"x \gg \mathopen{}\left( y \mathbin{\&} z \mathclose{}\right)", + ), + ( + "x & (y ^ z)", + r"x \mathbin{\&} \mathopen{}\left( y \oplus z \mathclose{}\right)", + ), + ( + "x ^ (y | z)", + r"x \oplus \mathopen{}\left( y \mathbin{|} z \mathclose{}\right)", + ), + # x op y OP z + ("x * y**z", r"x y^{z}"), + ("x + y * z", r"x + y z"), + ("x + y @ z", r"x + y z"), + ("x + y / z", r"x + \frac{y}{z}"), + ("x + y // z", r"x + \left\lfloor\frac{y}{z}\right\rfloor"), + ("x + y % z", r"x + y \mathbin{\%} z"), + ("x << y + z", r"x \ll y + z"), + ("x << y - z", r"x \ll y - z"), + ("x & y << z", r"x \mathbin{\&} y \ll z"), + ("x & y >> z", r"x \mathbin{\&} y \gg z"), + ("x ^ y & z", r"x \oplus y \mathbin{\&} z"), + ("x | y ^ z", r"x \mathbin{|} y \oplus z"), + # (x op y) OP z + ("(x * y)**z", r"\mathopen{}\left( x y \mathclose{}\right)^{z}"), + ("(x + y) * z", r"\mathopen{}\left( x + y \mathclose{}\right) z"), + ("(x + y) @ z", r"\mathopen{}\left( x + y \mathclose{}\right) z"), + ("(x + y) / z", r"\frac{x + y}{z}"), + ("(x + y) // z", r"\left\lfloor\frac{x + y}{z}\right\rfloor"), + ("(x + y) % z", r"\mathopen{}\left( x + y \mathclose{}\right) \mathbin{\%} z"), + ("(x << y) + z", r"\mathopen{}\left( x \ll y \mathclose{}\right) + z"), + ("(x << y) - z", r"\mathopen{}\left( x \ll y \mathclose{}\right) - z"), + ( + "(x & y) << z", + r"\mathopen{}\left( x \mathbin{\&} y \mathclose{}\right) \ll z", + ), + ( + "(x & y) >> z", + r"\mathopen{}\left( x \mathbin{\&} y \mathclose{}\right) \gg z", + ), + ( + "(x ^ y) & z", + r"\mathopen{}\left( x \oplus y \mathclose{}\right) \mathbin{\&} z", + ), + ( + "(x | y) ^ z", + r"\mathopen{}\left( x \mathbin{|} y \mathclose{}\right) \oplus z", + ), + # is_wrapped + ("(x // y)**z", r"\left\lfloor\frac{x}{y}\right\rfloor^{z}"), + # With Call + ("x**f(y)", r"x^{f \mathopen{}\left( y \mathclose{}\right)}"), + ( + "f(x)**y", + r"\mathopen{}\left(" + r" f \mathopen{}\left( x \mathclose{}\right)" + r" \mathclose{}\right)^{y}", + ), + ("x * f(y)", r"x f \mathopen{}\left( y \mathclose{}\right)"), + ("f(x) * y", r"f \mathopen{}\left( x \mathclose{}\right) y"), + ("x / f(y)", r"\frac{x}{f \mathopen{}\left( y \mathclose{}\right)}"), + ("f(x) / y", r"\frac{f \mathopen{}\left( x \mathclose{}\right)}{y}"), + ("x + f(y)", r"x + f \mathopen{}\left( y \mathclose{}\right)"), + ("f(x) + y", r"f \mathopen{}\left( x \mathclose{}\right) + y"), + # With is_wrapped Call + ("sqrt(x) ** y", r"\sqrt{ x }^{y}"), + # With UnaryOp + ("x**-y", r"x^{-y}"), + ("(-x)**y", r"\mathopen{}\left( -x \mathclose{}\right)^{y}"), + ("x * -y", r"x -y"), # TODO(odashi): google/latexify_py#89 + ("-x * y", r"-x y"), + ("x / -y", r"\frac{x}{-y}"), + ("-x / y", r"\frac{-x}{y}"), + ("x + -y", r"x + -y"), + ("-x + y", r"-x + y"), + # With Compare + ("x**(y == z)", r"x^{y = z}"), + ("(x == y)**z", r"\mathopen{}\left( x = y \mathclose{}\right)^{z}"), + ("x * (y == z)", r"x \mathopen{}\left( y = z \mathclose{}\right)"), + ("(x == y) * z", r"\mathopen{}\left( x = y \mathclose{}\right) z"), + ("x / (y == z)", r"\frac{x}{y = z}"), + ("(x == y) / z", r"\frac{x = y}{z}"), + ("x + (y == z)", r"x + \mathopen{}\left( y = z \mathclose{}\right)"), + ("(x == y) + z", r"\mathopen{}\left( x = y \mathclose{}\right) + z"), + # With BoolOp + ("x**(y and z)", r"x^{y \land z}"), + ("(x and y)**z", r"\mathopen{}\left( x \land y \mathclose{}\right)^{z}"), + ("x * (y and z)", r"x \mathopen{}\left( y \land z \mathclose{}\right)"), + ("(x and y) * z", r"\mathopen{}\left( x \land y \mathclose{}\right) z"), + ("x / (y and z)", r"\frac{x}{y \land z}"), + ("(x and y) / z", r"\frac{x \land y}{z}"), + ("x + (y and z)", r"x + \mathopen{}\left( y \land z \mathclose{}\right)"), + ("(x and y) + z", r"\mathopen{}\left( x \land y \mathclose{}\right) + z"), + ], +) +def test_visit_binop(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.BinOp) + assert ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + # With literals + ("+x", r"+x"), + ("-x", r"-x"), + ("~x", r"\mathord{\sim} x"), + ("not x", r"\lnot x"), + # With Call + ("+f(x)", r"+f \mathopen{}\left( x \mathclose{}\right)"), + ("-f(x)", r"-f \mathopen{}\left( x \mathclose{}\right)"), + ("~f(x)", r"\mathord{\sim} f \mathopen{}\left( x \mathclose{}\right)"), + ("not f(x)", r"\lnot f \mathopen{}\left( x \mathclose{}\right)"), + # With BinOp + ("+(x + y)", r"+\mathopen{}\left( x + y \mathclose{}\right)"), + ("-(x + y)", r"-\mathopen{}\left( x + y \mathclose{}\right)"), + ("~(x + y)", r"\mathord{\sim} \mathopen{}\left( x + y \mathclose{}\right)"), + ("not x + y", r"\lnot \mathopen{}\left( x + y \mathclose{}\right)"), + # With Compare + ("+(x == y)", r"+\mathopen{}\left( x = y \mathclose{}\right)"), + ("-(x == y)", r"-\mathopen{}\left( x = y \mathclose{}\right)"), + ("~(x == y)", r"\mathord{\sim} \mathopen{}\left( x = y \mathclose{}\right)"), + ("not x == y", r"\lnot \mathopen{}\left( x = y \mathclose{}\right)"), + # With BoolOp + ("+(x and y)", r"+\mathopen{}\left( x \land y \mathclose{}\right)"), + ("-(x and y)", r"-\mathopen{}\left( x \land y \mathclose{}\right)"), + ( + "~(x and y)", + r"\mathord{\sim} \mathopen{}\left( x \land y \mathclose{}\right)", + ), + ("not (x and y)", r"\lnot \mathopen{}\left( x \land y \mathclose{}\right)"), + ], +) +def test_visit_unaryop(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.UnaryOp) + assert ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + # 1 comparator + ("a == b", "a = b"), + ("a > b", "a > b"), + ("a >= b", r"a \ge b"), + ("a in b", r"a \in b"), + ("a is b", r"a \equiv b"), + ("a is not b", r"a \not\equiv b"), + ("a < b", "a < b"), + ("a <= b", r"a \le b"), + ("a != b", r"a \ne b"), + ("a not in b", r"a \notin b"), + # 2 comparators + ("a == b == c", "a = b = c"), + ("a == b > c", "a = b > c"), + ("a == b >= c", r"a = b \ge c"), + ("a == b < c", "a = b < c"), + ("a == b <= c", r"a = b \le c"), + ("a > b == c", "a > b = c"), + ("a > b > c", "a > b > c"), + ("a > b >= c", r"a > b \ge c"), + ("a >= b == c", r"a \ge b = c"), + ("a >= b > c", r"a \ge b > c"), + ("a >= b >= c", r"a \ge b \ge c"), + ("a < b == c", "a < b = c"), + ("a < b < c", "a < b < c"), + ("a < b <= c", r"a < b \le c"), + ("a <= b == c", r"a \le b = c"), + ("a <= b < c", r"a \le b < c"), + ("a <= b <= c", r"a \le b \le c"), + # With Call + ("a == f(b)", r"a = f \mathopen{}\left( b \mathclose{}\right)"), + ("f(a) == b", r"f \mathopen{}\left( a \mathclose{}\right) = b"), + # With BinOp + ("a == b + c", r"a = b + c"), + ("a + b == c", r"a + b = c"), + # With UnaryOp + ("a == -b", r"a = -b"), + ("-a == b", r"-a = b"), + ("a == (not b)", r"a = \lnot b"), + ("(not a) == b", r"\lnot a = b"), + # With BoolOp + ("a == (b and c)", r"a = \mathopen{}\left( b \land c \mathclose{}\right)"), + ("(a and b) == c", r"\mathopen{}\left( a \land b \mathclose{}\right) = c"), + ], +) +def test_visit_compare(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Compare) + assert ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + # With literals + ("a and b", r"a \land b"), + ("a and b and c", r"a \land b \land c"), + ("a or b", r"a \lor b"), + ("a or b or c", r"a \lor b \lor c"), + ("a or b and c", r"a \lor b \land c"), + ( + "(a or b) and c", + r"\mathopen{}\left( a \lor b \mathclose{}\right) \land c", + ), + ("a and b or c", r"a \land b \lor c"), + ( + "a and (b or c)", + r"a \land \mathopen{}\left( b \lor c \mathclose{}\right)", + ), + # With Call + ("a and f(b)", r"a \land f \mathopen{}\left( b \mathclose{}\right)"), + ("f(a) and b", r"f \mathopen{}\left( a \mathclose{}\right) \land b"), + ("a or f(b)", r"a \lor f \mathopen{}\left( b \mathclose{}\right)"), + ("f(a) or b", r"f \mathopen{}\left( a \mathclose{}\right) \lor b"), + # With BinOp + ("a and b + c", r"a \land b + c"), + ("a + b and c", r"a + b \land c"), + ("a or b + c", r"a \lor b + c"), + ("a + b or c", r"a + b \lor c"), + # With UnaryOp + ("a and not b", r"a \land \lnot b"), + ("not a and b", r"\lnot a \land b"), + ("a or not b", r"a \lor \lnot b"), + ("not a or b", r"\lnot a \lor b"), + # With Compare + ("a and b == c", r"a \land b = c"), + ("a == b and c", r"a = b \land c"), + ("a or b == c", r"a \lor b = c"), + ("a == b or c", r"a = b \lor c"), + ], +) +def test_visit_boolop(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.BoolOp) + assert ExpressionCodegen().visit(tree) == latex + + +@test_utils.require_at_most(7) +@pytest.mark.parametrize( + "code,cls,latex", + [ + ("0", ast.Num, "0"), + ("1", ast.Num, "1"), + ("0.0", ast.Num, "0.0"), + ("1.5", ast.Num, "1.5"), + ("0.0j", ast.Num, "0j"), + ("1.0j", ast.Num, "1j"), + ("1.5j", ast.Num, "1.5j"), + ('"abc"', ast.Str, r'\textrm{"abc"}'), + ('b"abc"', ast.Bytes, r"\textrm{b'abc'}"), + ("None", ast.NameConstant, r"\mathrm{None}"), + ("False", ast.NameConstant, r"\mathrm{False}"), + ("True", ast.NameConstant, r"\mathrm{True}"), + ("...", ast.Ellipsis, r"\cdots"), + ], +) +def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, cls) + assert ExpressionCodegen().visit(tree) == latex + + +@test_utils.require_at_least(8) +@pytest.mark.parametrize( + "code,latex", + [ + ("0", "0"), + ("1", "1"), + ("0.0", "0.0"), + ("1.5", "1.5"), + ("0.0j", "0j"), + ("1.0j", "1j"), + ("1.5j", "1.5j"), + ('"abc"', r'\textrm{"abc"}'), + ('b"abc"', r"\textrm{b'abc'}"), + ("None", r"\mathrm{None}"), + ("False", r"\mathrm{False}"), + ("True", r"\mathrm{True}"), + ("...", r"\cdots"), + ], +) +def test_visit_constant(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Constant) + assert ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("x[0]", "x_{0}"), + ("x[0][1]", "x_{0, 1}"), + ("x[0][1][2]", "x_{0, 1, 2}"), + ("x[foo]", r"x_{\mathrm{foo}}"), + ("x[floor(x)]", r"x_{\mathopen{}\left\lfloor x \mathclose{}\right\rfloor}"), + ], +) +def test_visit_subscript(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Subscript) + assert ExpressionCodegen().visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("a - b", r"a \setminus b"), + ("a & b", r"a \cap b"), + ("a ^ b", r"a \mathbin{\triangle} b"), + ("a | b", r"a \cup b"), + ], +) +def test_visit_binop_use_set_symbols(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.BinOp) + assert ExpressionCodegen(use_set_symbols=True).visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("a < b", r"a \subset b"), + ("a <= b", r"a \subseteq b"), + ("a > b", r"a \supset b"), + ("a >= b", r"a \supseteq b"), + ], +) +def test_visit_compare_use_set_symbols(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Compare) + assert ExpressionCodegen(use_set_symbols=True).visit(tree) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("array(1)", r"\mathrm{array} \mathopen{}\left( 1 \mathclose{}\right)"), + ( + "array([])", + r"\mathrm{array} \mathopen{}\left(" + r" \mathopen{}\left[ \mathclose{}\right]" + r" \mathclose{}\right)", + ), + ("array([1])", r"\begin{bmatrix} 1 \end{bmatrix}"), + ("array([1, 2, 3])", r"\begin{bmatrix} 1 & 2 & 3 \end{bmatrix}"), + ( + "array([[]])", + r"\mathrm{array} \mathopen{}\left(" + r" \mathopen{}\left[ \mathopen{}\left[" + r" \mathclose{}\right] \mathclose{}\right]" + r" \mathclose{}\right)", + ), + ("array([[1]])", r"\begin{bmatrix} 1 \end{bmatrix}"), + ("array([[1], [2], [3]])", r"\begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}"), + ( + "array([[1], [2], [3, 4]])", + r"\mathrm{array} \mathopen{}\left(" + r" \mathopen{}\left[" + r" \mathopen{}\left[ 1 \mathclose{}\right]," + r" \mathopen{}\left[ 2 \mathclose{}\right]," + r" \mathopen{}\left[ 3, 4 \mathclose{}\right]" + r" \mathclose{}\right]" + r" \mathclose{}\right)", + ), + ( + "array([[1, 2], [3, 4], [5, 6]])", + r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}", + ), + # Only checks two cases for ndarray. + ("ndarray(1)", r"\mathrm{ndarray} \mathopen{}\left( 1 \mathclose{}\right)"), + ("ndarray([1])", r"\begin{bmatrix} 1 \end{bmatrix}"), + ], +) +def test_numpy_array(code: str, latex: str) -> None: + tree = ast_utils.parse_expr(code) + assert isinstance(tree, ast.Call) + assert ExpressionCodegen().visit(tree) == latex diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index 419bf6c..c9b01e2 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -1,202 +1,12 @@ -"""Codegen for single function.""" +"""Codegen for single functions.""" from __future__ import annotations import ast -import dataclasses import sys -from typing import Any - -from latexify import analyzers, ast_utils, constants, exceptions -from latexify.codegen import identifier_converter - -# Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes. -# Note that this value affects only the appearance of surrounding parentheses for each -# expression, and does not affect the AST itself. -# See also: -# https://docs.python.org/3/reference/expressions.html#operator-precedence -_PRECEDENCES: dict[type[ast.AST], int] = { - ast.Pow: 120, - ast.UAdd: 110, - ast.USub: 110, - ast.Invert: 110, - ast.Mult: 100, - ast.MatMult: 100, - ast.Div: 100, - ast.FloorDiv: 100, - ast.Mod: 100, - ast.Add: 90, - ast.Sub: 90, - ast.LShift: 80, - ast.RShift: 80, - ast.BitAnd: 70, - ast.BitXor: 60, - ast.BitOr: 50, - ast.In: 40, - ast.NotIn: 40, - ast.Is: 40, - ast.IsNot: 40, - ast.Lt: 40, - ast.LtE: 40, - ast.Gt: 40, - ast.GtE: 40, - ast.NotEq: 40, - ast.Eq: 40, - # NOTE(odashi): - # We assume that the `not` operator has the same precedence with other unary - # operators `+`, `-` and `~`, because the LaTeX counterpart $\lnot$ looks to have a - # high precedence. - # ast.Not: 30, - ast.Not: 110, - ast.And: 20, - ast.Or: 10, -} - -# NOTE(odashi): -# Function invocation is treated as a unary operator with a higher precedence. -# This ensures that the argument with a unary operator is wrapped: -# exp(x) --> \exp x -# exp(-x) --> \exp (-x) -# -exp(x) --> - \exp x -_CALL_PRECEDENCE = _PRECEDENCES[ast.UAdd] + 1 - - -def _get_precedence(node: ast.AST) -> int: - """Obtains the precedence of the subtree. - - Args: - node: Subtree to investigate. - - Returns: - If `node` is a subtree with some operator, returns the precedence of the - operator. Otherwise, returns a number larger enough from other precedences. - """ - if isinstance(node, ast.Call): - return _CALL_PRECEDENCE - - if isinstance(node, (ast.BoolOp, ast.BinOp, ast.UnaryOp)): - return _PRECEDENCES[type(node.op)] - - if isinstance(node, ast.Compare): - # Compare operators have the same precedence. It is enough to check only the - # first operator. - return _PRECEDENCES[type(node.ops[0])] - - return 1_000_000 - - -@dataclasses.dataclass(frozen=True) -class BinOperandRule: - """Syntax rules for operands of BinOp.""" - - # Whether to require wrapping operands by parentheses according to the precedence. - wrap: bool = True - - # Whether to require wrapping operands by parentheses if the operand has the same - # precedence with this operator. - # This is used to control the behavior of non-associative operators. - force: bool = False - - -@dataclasses.dataclass(frozen=True) -class BinOpRule: - """Syntax rules for BinOp.""" - - # Left/middle/right syntaxes to wrap operands. - latex_left: str - latex_middle: str - latex_right: str - - # Operand rules. - operand_left: BinOperandRule = dataclasses.field(default_factory=BinOperandRule) - operand_right: BinOperandRule = dataclasses.field(default_factory=BinOperandRule) - - # Whether to assume the resulting syntax is wrapped by some bracket operators. - # If True, the parent operator can avoid wrapping this operator by parentheses. - is_wrapped: bool = False - - -_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = { - ast.Pow: BinOpRule( - "", - "^{", - "}", - operand_left=BinOperandRule(force=True), - operand_right=BinOperandRule(wrap=False), - ), - ast.Mult: BinOpRule("", " ", ""), - ast.MatMult: BinOpRule("", " ", ""), - ast.Div: BinOpRule( - r"\frac{", - "}{", - "}", - operand_left=BinOperandRule(wrap=False), - operand_right=BinOperandRule(wrap=False), - ), - ast.FloorDiv: BinOpRule( - r"\left\lfloor\frac{", - "}{", - r"}\right\rfloor", - operand_left=BinOperandRule(wrap=False), - operand_right=BinOperandRule(wrap=False), - is_wrapped=True, - ), - ast.Mod: BinOpRule( - "", r" \mathbin{\%} ", "", operand_right=BinOperandRule(force=True) - ), - ast.Add: BinOpRule("", " + ", ""), - ast.Sub: BinOpRule("", " - ", "", operand_right=BinOperandRule(force=True)), - ast.LShift: BinOpRule("", r" \ll ", "", operand_right=BinOperandRule(force=True)), - ast.RShift: BinOpRule("", r" \gg ", "", operand_right=BinOperandRule(force=True)), - ast.BitAnd: BinOpRule("", r" \mathbin{\&} ", ""), - ast.BitXor: BinOpRule("", r" \oplus ", ""), - ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""), -} - -# Typeset for BinOp of sets. -_SET_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = { - **_BIN_OP_RULES, - ast.Sub: BinOpRule( - "", r" \setminus ", "", operand_right=BinOperandRule(force=True) - ), - ast.BitAnd: BinOpRule("", r" \cap ", ""), - ast.BitXor: BinOpRule("", r" \mathbin{\triangle} ", ""), - ast.BitOr: BinOpRule("", r" \cup ", ""), -} - -_UNARY_OPS: dict[type[ast.unaryop], str] = { - ast.Invert: r"\mathord{\sim} ", - ast.UAdd: "+", # Explicitly adds the $+$ operator. - ast.USub: "-", - ast.Not: r"\lnot ", -} - -_COMPARE_OPS: dict[type[ast.cmpop], str] = { - ast.Eq: "=", - ast.Gt: ">", - ast.GtE: r"\ge", - ast.In: r"\in", - ast.Is: r"\equiv", - ast.IsNot: r"\not\equiv", - ast.Lt: "<", - ast.LtE: r"\le", - ast.NotEq: r"\ne", - ast.NotIn: r"\notin", -} - -# Typeset for Compare of sets. -_SET_COMPARE_OPS: dict[type[ast.cmpop], str] = { - **_COMPARE_OPS, - ast.Gt: r"\supset", - ast.GtE: r"\supseteq", - ast.Lt: r"\subset", - ast.LtE: r"\subseteq", -} - -_BOOL_OPS: dict[type[ast.boolop], str] = { - ast.And: r"\land", - ast.Or: r"\lor", -} + +from latexify import ast_utils, exceptions +from latexify.codegen import codegen_utils, expression_codegen, identifier_converter class FunctionCodegen(ast.NodeVisitor): @@ -209,9 +19,6 @@ class FunctionCodegen(ast.NodeVisitor): _identifier_converter: identifier_converter.IdentifierConverter _use_signature: bool - _bin_op_rules: dict[type[ast.operator], BinOpRule] - _compare_ops: dict[type[ast.cmpop], str] - def __init__( self, *, @@ -228,23 +35,25 @@ def __init__( or not. use_set_symbols: Whether to use set symbols or not. """ + self._expression_codegen = expression_codegen.ExpressionCodegen( + use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols + ) self._identifier_converter = identifier_converter.IdentifierConverter( use_math_symbols=use_math_symbols ) self._use_signature = use_signature - self._bin_op_rules = _SET_BIN_OP_RULES if use_set_symbols else _BIN_OP_RULES - self._compare_ops = _SET_COMPARE_OPS if use_set_symbols else _COMPARE_OPS - def generic_visit(self, node: ast.AST) -> str: raise exceptions.LatexifyNotSupportedError( f"Unsupported AST: {type(node).__name__}" ) def visit_Module(self, node: ast.Module) -> str: + """Visit a Module node.""" return self.visit(node.body[0]) def visit_FunctionDef(self, node: ast.FunctionDef) -> str: + """Visit a FunctionDef node.""" # Function name name_str = self._identifier_converter.convert(node.name)[0] @@ -275,7 +84,6 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str: f"Unsupported last statement: {type(return_stmt).__name__}" ) else: - if not isinstance(return_stmt, (ast.Return, ast.If)): raise exceptions.LatexifySyntaxError( f"Unsupported last statement: {type(return_stmt).__name__}" @@ -298,395 +106,21 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> str: return r"\begin{array}{l} " + r" \\ ".join(body_strs) + r" \end{array}" def visit_Assign(self, node: ast.Assign) -> str: - operands: list[str] = [self.visit(t) for t in node.targets] - operands.append(self.visit(node.value)) + """Visit an Assign node.""" + operands: list[str] = [self._expression_codegen.visit(t) for t in node.targets] + operands.append(self._expression_codegen.visit(node.value)) return " = ".join(operands) def visit_Return(self, node: ast.Return) -> str: + """Visit a Return node.""" return ( - self.visit(node.value) + self._expression_codegen.visit(node.value) if node.value is not None - else self._convert_constant(None) - ) - - def visit_Tuple(self, node: ast.Tuple) -> str: - elts = [self.visit(i) for i in node.elts] - return r"\mathopen{}\left( " + r", ".join(elts) + r" \mathclose{}\right)" - - def visit_List(self, node: ast.List) -> str: - elts = [self.visit(i) for i in node.elts] - return r"\mathopen{}\left[ " + r", ".join(elts) + r" \mathclose{}\right]" - - def visit_Set(self, node: ast.Set) -> str: - elts = [self.visit(i) for i in node.elts] - return r"\mathopen{}\left\{ " + r", ".join(elts) + r" \mathclose{}\right\}" - - def visit_ListComp(self, node: ast.ListComp) -> str: - generators = [self.visit(comp) for comp in node.generators] - return ( - r"\mathopen{}\left[ " - + self.visit(node.elt) - + r" \mid " - + ", ".join(generators) - + r" \mathclose{}\right]" - ) - - def visit_SetComp(self, node: ast.SetComp) -> str: - generators = [self.visit(comp) for comp in node.generators] - return ( - r"\mathopen{}\left\{ " - + self.visit(node.elt) - + r" \mid " - + ", ".join(generators) - + r" \mathclose{}\right\}" + else codegen_utils.convert_constant(None) ) - def visit_comprehension(self, node: ast.comprehension) -> str: - target = rf"{self.visit(node.target)} \in {self.visit(node.iter)}" - - if not node.ifs: - # Returns the source without parenthesis. - return target - - conds = [target] + [self.visit(cond) for cond in node.ifs] - wrapped = [r"\mathopen{}\left( " + s + r" \mathclose{}\right)" for s in conds] - return r" \land ".join(wrapped) - - def _generate_sum_prod(self, node: ast.Call) -> str | None: - """Generates sum/prod expression. - - Args: - node: ast.Call node containing the sum/prod invocation. - - Returns: - Generated LaTeX, or None if the node has unsupported syntax. - """ - if not isinstance(node.args[0], ast.GeneratorExp): - return None - - name = ast_utils.extract_function_name_or_none(node) - assert name in ("fsum", "sum", "prod") - - command = { - "fsum": r"\sum", - "sum": r"\sum", - "prod": r"\prod", - }[name] - - elt, scripts = self._get_sum_prod_info(node.args[0]) - scripts_str = [rf"{command}_{{{lo}}}^{{{up}}}" for lo, up in scripts] - return ( - " ".join(scripts_str) - + rf" \mathopen{{}}\left({{{elt}}}\mathclose{{}}\right)" - ) - - def _generate_matrix(self, node: ast.Call) -> str | None: - """Generates matrix expression. - - Args: - node: ast.Call node containing the ndarray invocation. - - Returns: - Generated LaTeX, or None if the node has unsupported syntax. - """ - - def generate_matrix_from_array(data: list[list[str]]) -> str: - """Helper to generate a bmatrix environment.""" - contents = r" \\ ".join(" & ".join(row) for row in data) - return r"\begin{bmatrix} " + contents + r" \end{bmatrix}" - - arg = node.args[0] - if not isinstance(arg, ast.List) or not arg.elts: - # Not an array or no rows - return None - - row0 = arg.elts[0] - - if not isinstance(row0, ast.List): - # Maybe 1 x N array - return generate_matrix_from_array([[self.visit(x) for x in arg.elts]]) - - if not row0.elts: - # No columns - return None - - ncols = len(row0.elts) - - rows: list[list[str]] = [] - - for row in arg.elts: - if not isinstance(row, ast.List) or len(row.elts) != ncols: - # Length mismatch - return None - - rows.append([self.visit(x) for x in row.elts]) - - return generate_matrix_from_array(rows) - - def _generate_zeros(self, node: ast.Call) -> str | None: - """Generates LaTeX for numpy.zeros. - - Args: - node: ast.Call node containing the appropriate method invocation. - - Returns: - Generated LaTeX, or None if the node has unsupported syntax. - """ - name = ast_utils.extract_function_name_or_none(node) - assert name == "zeros" - - if len(node.args) != 1: - return None - - # All args to np.zeros should be numeric. - if isinstance(node.args[0], ast.Tuple): - dims = [ast_utils.extract_int_or_none(x) for x in node.args[0].elts] - if any(x is None for x in dims): - return None - if not dims: - return "0" - if len(dims) == 1: - dims = [1, dims[0]] - - dims_latex = r" \times ".join(str(x) for x in dims) - else: - dim = ast_utils.extract_int_or_none(node.args[0]) - if not isinstance(dim, int): - return None - # 1 x N array of zeros - dims_latex = rf"1 \times {dim}" - - return rf"\mathbf{{0}}^{{{dims_latex}}}" - - def _generate_identity(self, node: ast.Call) -> str | None: - """Generates LaTeX for numpy.identity. - - Args: - node: ast.Call node containing the appropriate method invocation. - - Returns: - Generated LaTeX, or None if the node has unsupported syntax. - """ - name = ast_utils.extract_function_name_or_none(node) - assert name == "identity" - - if len(node.args) != 1: - return None - - ndims = ast_utils.extract_int_or_none(node.args[0]) - if ndims is None: - return None - - return rf"\mathbf{{I}}_{{{ndims}}}" - - def visit_Call(self, node: ast.Call) -> str: - """Visit a call node.""" - func_name = ast_utils.extract_function_name_or_none(node) - - # Special treatments for some functions. - # TODO(odashi): Move these functions to some separate utility. - if func_name in ("fsum", "sum", "prod"): - special_latex = self._generate_sum_prod(node) - elif func_name in ("array", "ndarray"): - special_latex = self._generate_matrix(node) - elif func_name == "zeros": - special_latex = self._generate_zeros(node) - elif func_name == "identity": - special_latex = self._generate_identity(node) - else: - special_latex = None - - if special_latex is not None: - return special_latex - - # Obtains the codegen rule. - rule = constants.BUILTIN_FUNCS.get(func_name) if func_name is not None else None - - if rule is None: - rule = constants.FunctionRule(self.visit(node.func)) - - if rule.is_unary and len(node.args) == 1: - # Unary function. Applies the same wrapping policy with the unary operators. - # NOTE(odashi): - # Factorial "x!" is treated as a special case: it requires both inner/outer - # parentheses for correct interpretation. - precedence = _get_precedence(node) - arg = node.args[0] - force_wrap = isinstance(arg, ast.Call) and ( - func_name == "factorial" - or ast_utils.extract_function_name_or_none(arg) == "factorial" - ) - arg_latex = self._wrap_operand(arg, precedence, force_wrap) - elements = [rule.left, arg_latex, rule.right] - else: - arg_latex = ", ".join(self.visit(arg) for arg in node.args) - if rule.is_wrapped: - elements = [rule.left, arg_latex, rule.right] - else: - elements = [ - rule.left, - r"\mathopen{}\left(", - arg_latex, - r"\mathclose{}\right)", - rule.right, - ] - - return " ".join(x for x in elements if x) - - def visit_Attribute(self, node: ast.Attribute) -> str: - vstr = self.visit(node.value) - astr = self._identifier_converter.convert(node.attr)[0] - return vstr + "." + astr - - def visit_Name(self, node: ast.Name) -> str: - return self._identifier_converter.convert(node.id)[0] - - def _convert_constant(self, value: Any) -> str: - """Helper to convert constant values to LaTeX. - - Args: - value: A constant value. - - Returns: - The LaTeX representation of `value`. - """ - if value is None or isinstance(value, bool): - return r"\mathrm{" + str(value) + "}" - if isinstance(value, (int, float, complex)): - # TODO(odashi): Support other symbols for the imaginary unit than j. - return str(value) - if isinstance(value, str): - return r'\textrm{"' + value + '"}' - if isinstance(value, bytes): - return r"\textrm{" + str(value) + "}" - if value is ...: - return r"\cdots" - raise exceptions.LatexifyNotSupportedError( - f"Unrecognized constant: {type(value).__name__}" - ) - - # From Python 3.8 - def visit_Constant(self, node: ast.Constant) -> str: - return self._convert_constant(node.value) - - # Until Python 3.7 - def visit_Num(self, node: ast.Num) -> str: - return self._convert_constant(node.n) - - # Until Python 3.7 - def visit_Str(self, node: ast.Str) -> str: - return self._convert_constant(node.s) - - # Until Python 3.7 - def visit_Bytes(self, node: ast.Bytes) -> str: - return self._convert_constant(node.s) - - # Until Python 3.7 - def visit_NameConstant(self, node: ast.NameConstant) -> str: - return self._convert_constant(node.value) - - # Until Python 3.7 - def visit_Ellipsis(self, node: ast.Ellipsis) -> str: - return self._convert_constant(...) - - def _wrap_operand( - self, child: ast.expr, parent_prec: int, force_wrap: bool = False - ) -> str: - """Wraps the operand subtree with parentheses. - - Args: - child: Operand subtree. - parent_prec: Precedence of the parent operator. - force_wrap: Whether to wrap the operand or not when the precedence is equal. - - Returns: - LaTeX form of `child`, with or without surrounding parentheses. - """ - latex = self.visit(child) - child_prec = _get_precedence(child) - - if child_prec < parent_prec or force_wrap and child_prec == parent_prec: - return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)" - - return latex - - def _wrap_binop_operand( - self, - child: ast.expr, - parent_prec: int, - operand_rule: BinOperandRule, - ) -> str: - """Wraps the operand subtree of BinOp with parentheses. - - Args: - child: Operand subtree. - parent_prec: Precedence of the parent operator. - operand_rule: Syntax rule of this operand. - - Returns: - LaTeX form of the `child`, with or without surrounding parentheses. - """ - if not operand_rule.wrap: - return self.visit(child) - - if isinstance(child, ast.Call): - child_fn_name = ast_utils.extract_function_name_or_none(child) - rule = ( - constants.BUILTIN_FUNCS.get(child_fn_name) - if child_fn_name is not None - else None - ) - if rule is not None and rule.is_wrapped: - return self.visit(child) - - if not isinstance(child, ast.BinOp): - return self._wrap_operand(child, parent_prec) - - latex = self.visit(child) - - if _BIN_OP_RULES[type(child.op)].is_wrapped: - return latex - - child_prec = _get_precedence(child) - - if child_prec > parent_prec or ( - child_prec == parent_prec and not operand_rule.force - ): - return latex - - return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)" - - def visit_BinOp(self, node: ast.BinOp) -> str: - """Visit a BinOp node.""" - prec = _get_precedence(node) - rule = self._bin_op_rules[type(node.op)] - lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left) - rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right) - return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}" - - def visit_UnaryOp(self, node: ast.UnaryOp) -> str: - """Visit a unary op node.""" - latex = self._wrap_operand(node.operand, _get_precedence(node)) - return _UNARY_OPS[type(node.op)] + latex - - def visit_Compare(self, node: ast.Compare) -> str: - """Visit a compare node.""" - parent_prec = _get_precedence(node) - lhs = self._wrap_operand(node.left, parent_prec) - ops = [self._compare_ops[type(x)] for x in node.ops] - rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators] - ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)] - return lhs + "".join(ops_rhs) - - def visit_BoolOp(self, node: ast.BoolOp) -> str: - """Visit a BoolOp node.""" - parent_prec = _get_precedence(node) - values = [self._wrap_operand(x, parent_prec) for x in node.values] - op = f" {_BOOL_OPS[type(node.op)]} " - return op.join(values) - def visit_If(self, node: ast.If) -> str: - """Visit an if node.""" + """Visit an If node.""" latex = r"\left\{ \begin{array}{ll} " current_stmt: ast.stmt = node @@ -697,7 +131,7 @@ def visit_If(self, node: ast.If) -> str: "Multiple statements are not supported in If nodes." ) - cond_latex = self.visit(current_stmt.test) + cond_latex = self._expression_codegen.visit(current_stmt.test) true_latex = self.visit(current_stmt.body[0]) latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ " current_stmt = current_stmt.orelse[0] @@ -705,23 +139,8 @@ def visit_If(self, node: ast.If) -> str: latex += self.visit(current_stmt) return latex + r", & \mathrm{otherwise} \end{array} \right." - def visit_IfExp(self, node: ast.IfExp) -> str: - """Visit an ifexp node""" - latex = r"\left\{ \begin{array}{ll} " - - current_expr: ast.expr = node - - while isinstance(current_expr, ast.IfExp): - cond_latex = self.visit(current_expr.test) - true_latex = self.visit(current_expr.body) - latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ " - current_expr = current_expr.orelse - - latex += self.visit(current_expr) - return latex + r", & \mathrm{otherwise} \end{array} \right." - def visit_Match(self, node: ast.Match) -> str: - """Visit a match node""" + """Visit a Match node""" if not ( len(node.cases) >= 2 and isinstance(node.cases[-1].pattern, ast.MatchAs) @@ -731,7 +150,7 @@ def visit_Match(self, node: ast.Match) -> str: "Match statement must contain the wildcard." ) - subject_latex = self.visit(node.subject) + subject_latex = self._expression_codegen.visit(node.subject) case_latexes: list[str] = [] for i, case in enumerate(node.cases): @@ -759,169 +178,5 @@ def visit_Match(self, node: ast.Match) -> str: def visit_MatchValue(self, node: ast.MatchValue) -> str: """Visit a MatchValue node""" - latex = self.visit(node.value) + latex = self._expression_codegen.visit(node.value) return " = " + latex - - def _reduce_stop_parameter(self, node: ast.expr) -> ast.expr: - """Adjusts the stop expression of the range. - - This function tries to convert the syntax as follows: - * n + 1 --> n - * n + 2 --> n + 1 - * n - (-1) --> n - * n - 1 --> n - 2 - - Args: - node: The target expression. - - Returns: - Converted expression. - """ - if not ( - isinstance(node, ast.BinOp) and isinstance(node.op, (ast.Add, ast.Sub)) - ): - return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1)) - - # Treatment for Python 3.7. - rhs = ( - ast.Constant(value=node.right.n) - if sys.version_info.minor < 8 and isinstance(node.right, ast.Num) - else node.right - ) - - if not isinstance(rhs, ast.Constant): - return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1)) - - shift = 1 if isinstance(node.op, ast.Add) else -1 - - return ( - node.left - if rhs.value == shift - else ast.BinOp( - left=node.left, op=node.op, right=ast.Constant(value=rhs.value - shift) - ) - ) - - def _get_sum_prod_range(self, node: ast.comprehension) -> tuple[str, str] | None: - """Helper to process range(...) for sum and prod functions. - - Args: - node: comprehension node to be analyzed. - - Returns: - Tuple of following strings: - - lower_rhs - - upper - which are used in _get_sum_prod_info, or None if the analysis failed. - """ - if not ( - isinstance(node.iter, ast.Call) - and isinstance(node.iter.func, ast.Name) - and node.iter.func.id == "range" - ): - return None - - try: - range_info = analyzers.analyze_range(node.iter) - except exceptions.LatexifyError: - return None - - if ( - # Only accepts ascending order with step size 1. - range_info.step_int != 1 - or ( - range_info.start_int is not None - and range_info.stop_int is not None - and range_info.start_int >= range_info.stop_int - ) - ): - return None - - if range_info.start_int is None: - lower_rhs = self.visit(range_info.start) - else: - lower_rhs = str(range_info.start_int) - - if range_info.stop_int is None: - upper = self.visit(self._reduce_stop_parameter(range_info.stop)) - else: - upper = str(range_info.stop_int - 1) - - return lower_rhs, upper - - def _get_sum_prod_info( - self, node: ast.GeneratorExp - ) -> tuple[str, list[tuple[str, str]]]: - r"""Process GeneratorExp for sum and prod functions. - - Args: - node: GeneratorExp node to be analyzed. - - Returns: - Tuple of following strings: - - elt - - scripts - which are used to represent sum/prod operators as follows: - \sum_{scripts[0][0]}^{scripts[0][1]} - \sum_{scripts[1][0]}^{scripts[1][1]} - ... - {elt} - - Raises: - LateixfyError: Unsupported AST is given. - """ - elt = self.visit(node.elt) - - scripts: list[tuple[str, str]] = [] - - for comp in node.generators: - range_args = self._get_sum_prod_range(comp) - - if range_args is not None and not comp.ifs: - target = self.visit(comp.target) - lower_rhs, upper = range_args - lower = f"{target} = {lower_rhs}" - else: - lower = self.visit(comp) # Use a usual comprehension form. - upper = "" - - scripts.append((lower, upper)) - - return elt, scripts - - # Until 3.8 - def visit_Index(self, node: ast.Index) -> str: - """Visitor for the Index nodes.""" - return self.visit(node.value) # type: ignore[attr-defined] - - def _convert_nested_subscripts(self, node: ast.Subscript) -> tuple[str, list[str]]: - """Helper function to convert nested subscription. - - This function converts x[i][j][...] to "x" and ["i", "j", ...] - - Args: - node: ast.Subscript node to be converted. - - Returns: - Tuple of following strings: - - The root value of the subscription. - - Sequence of incices. - """ - if isinstance(node.value, ast.Subscript): - value, indices = self._convert_nested_subscripts(node.value) - else: - value = self.visit(node.value) - indices = [] - - indices.append(self.visit(node.slice)) - return value, indices - - def visit_Subscript(self, node: ast.Subscript) -> str: - """Visitor of the Subscript nodes.""" - value, indices = self._convert_nested_subscripts(node) - - # TODO(odashi): - # "[i][j][...]" may be a possible representation as well as "i, j. ..." - indices_str = ", ".join(indices) - - return f"{value}_{{{indices_str}}}" diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index 02245d2..b8e363a 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -7,7 +7,7 @@ import pytest -from latexify import ast_utils, exceptions, test_utils +from latexify import exceptions from latexify.codegen import function_codegen @@ -78,945 +78,3 @@ def f(x): latex = r"f(x) = x" assert function_codegen.FunctionCodegen().visit(tree) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - ("()", r"\mathopen{}\left( \mathclose{}\right)"), - ("(x,)", r"\mathopen{}\left( x \mathclose{}\right)"), - ("(x, y)", r"\mathopen{}\left( x, y \mathclose{}\right)"), - ("(x, y, z)", r"\mathopen{}\left( x, y, z \mathclose{}\right)"), - ], -) -def test_tuple(code: str, latex: str) -> None: - node = ast_utils.parse_expr(code) - assert isinstance(node, ast.Tuple) - assert function_codegen.FunctionCodegen().visit(node) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - ("[]", r"\mathopen{}\left[ \mathclose{}\right]"), - ("[x]", r"\mathopen{}\left[ x \mathclose{}\right]"), - ("[x, y]", r"\mathopen{}\left[ x, y \mathclose{}\right]"), - ("[x, y, z]", r"\mathopen{}\left[ x, y, z \mathclose{}\right]"), - ], -) -def test_list(code: str, latex: str) -> None: - node = ast_utils.parse_expr(code) - assert isinstance(node, ast.List) - assert function_codegen.FunctionCodegen().visit(node) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - # TODO(odashi): Support set(). - # ("set()", r"\mathopen{}\left\{ \mathclose{}\right\}"), - ("{x}", r"\mathopen{}\left\{ x \mathclose{}\right\}"), - ("{x, y}", r"\mathopen{}\left\{ x, y \mathclose{}\right\}"), - ("{x, y, z}", r"\mathopen{}\left\{ x, y, z \mathclose{}\right\}"), - ], -) -def test_set(code: str, latex: str) -> None: - node = ast_utils.parse_expr(code) - assert isinstance(node, ast.Set) - assert function_codegen.FunctionCodegen().visit(node) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - ("[i for i in n]", r"\mathopen{}\left[ i \mid i \in n \mathclose{}\right]"), - ( - "[i for i in n if i > 0]", - r"\mathopen{}\left[ i \mid" - r" \mathopen{}\left( i \in n \mathclose{}\right)" - r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" - r" \mathclose{}\right]", - ), - ( - "[i for i in n if i > 0 if f(i)]", - r"\mathopen{}\left[ i \mid" - r" \mathopen{}\left( i \in n \mathclose{}\right)" - r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" - r" \land \mathopen{}\left( f \mathopen{}\left(" - r" i \mathclose{}\right) \mathclose{}\right)" - r" \mathclose{}\right]", - ), - ( - "[i for k in n for i in k]", - r"\mathopen{}\left[ i \mid k \in n, i \in k" r" \mathclose{}\right]", - ), - ( - "[i for k in n for i in k if i > 0]", - r"\mathopen{}\left[ i \mid" - r" k \in n," - r" \mathopen{}\left( i \in k \mathclose{}\right)" - r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" - r" \mathclose{}\right]", - ), - ( - "[i for k in n if f(k) for i in k if i > 0]", - r"\mathopen{}\left[ i \mid" - r" \mathopen{}\left( k \in n \mathclose{}\right)" - r" \land \mathopen{}\left( f \mathopen{}\left(" - r" k \mathclose{}\right) \mathclose{}\right)," - r" \mathopen{}\left( i \in k \mathclose{}\right)" - r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" - r" \mathclose{}\right]", - ), - ], -) -def test_visit_listcomp(code: str, latex: str) -> None: - node = ast_utils.parse_expr(code) - assert isinstance(node, ast.ListComp) - assert function_codegen.FunctionCodegen().visit(node) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - ("{i for i in n}", r"\mathopen{}\left\{ i \mid i \in n \mathclose{}\right\}"), - ( - "{i for i in n if i > 0}", - r"\mathopen{}\left\{ i \mid" - r" \mathopen{}\left( i \in n \mathclose{}\right)" - r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" - r" \mathclose{}\right\}", - ), - ( - "{i for i in n if i > 0 if f(i)}", - r"\mathopen{}\left\{ i \mid" - r" \mathopen{}\left( i \in n \mathclose{}\right)" - r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" - r" \land \mathopen{}\left( f \mathopen{}\left(" - r" i \mathclose{}\right) \mathclose{}\right)" - r" \mathclose{}\right\}", - ), - ( - "{i for k in n for i in k}", - r"\mathopen{}\left\{ i \mid k \in n, i \in k" r" \mathclose{}\right\}", - ), - ( - "{i for k in n for i in k if i > 0}", - r"\mathopen{}\left\{ i \mid" - r" k \in n," - r" \mathopen{}\left( i \in k \mathclose{}\right)" - r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" - r" \mathclose{}\right\}", - ), - ( - "{i for k in n if f(k) for i in k if i > 0}", - r"\mathopen{}\left\{ i \mid" - r" \mathopen{}\left( k \in n \mathclose{}\right)" - r" \land \mathopen{}\left( f \mathopen{}\left(" - r" k \mathclose{}\right) \mathclose{}\right)," - r" \mathopen{}\left( i \in k \mathclose{}\right)" - r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" - r" \mathclose{}\right\}", - ), - ], -) -def test_visit_setcomp(code: str, latex: str) -> None: - node = ast_utils.parse_expr(code) - assert isinstance(node, ast.SetComp) - assert function_codegen.FunctionCodegen().visit(node) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - ("foo(x)", r"\mathrm{foo} \mathopen{}\left( x \mathclose{}\right)"), - ("f(x)", r"f \mathopen{}\left( x \mathclose{}\right)"), - ("f(-x)", r"f \mathopen{}\left( -x \mathclose{}\right)"), - ("f(x + y)", r"f \mathopen{}\left( x + y \mathclose{}\right)"), - ( - "f(f(x))", - r"f \mathopen{}\left(" - r" f \mathopen{}\left( x \mathclose{}\right)" - r" \mathclose{}\right)", - ), - ("f(sqrt(x))", r"f \mathopen{}\left( \sqrt{ x } \mathclose{}\right)"), - ("f(sin(x))", r"f \mathopen{}\left( \sin x \mathclose{}\right)"), - ("f(factorial(x))", r"f \mathopen{}\left( x ! \mathclose{}\right)"), - ("f(x, y)", r"f \mathopen{}\left( x, y \mathclose{}\right)"), - ("sqrt(x)", r"\sqrt{ x }"), - ("sqrt(-x)", r"\sqrt{ -x }"), - ("sqrt(x + y)", r"\sqrt{ x + y }"), - ("sqrt(f(x))", r"\sqrt{ f \mathopen{}\left( x \mathclose{}\right) }"), - ("sqrt(sqrt(x))", r"\sqrt{ \sqrt{ x } }"), - ("sqrt(sin(x))", r"\sqrt{ \sin x }"), - ("sqrt(factorial(x))", r"\sqrt{ x ! }"), - ("sin(x)", r"\sin x"), - ("sin(-x)", r"\sin \mathopen{}\left( -x \mathclose{}\right)"), - ("sin(x + y)", r"\sin \mathopen{}\left( x + y \mathclose{}\right)"), - ("sin(f(x))", r"\sin f \mathopen{}\left( x \mathclose{}\right)"), - ("sin(sqrt(x))", r"\sin \sqrt{ x }"), - ("sin(sin(x))", r"\sin \sin x"), - ("sin(factorial(x))", r"\sin \mathopen{}\left( x ! \mathclose{}\right)"), - ("factorial(x)", r"x !"), - ("factorial(-x)", r"\mathopen{}\left( -x \mathclose{}\right) !"), - ("factorial(x + y)", r"\mathopen{}\left( x + y \mathclose{}\right) !"), - ( - "factorial(f(x))", - r"\mathopen{}\left(" - r" f \mathopen{}\left( x \mathclose{}\right)" - r" \mathclose{}\right) !", - ), - ("factorial(sqrt(x))", r"\mathopen{}\left( \sqrt{ x } \mathclose{}\right) !"), - ("factorial(sin(x))", r"\mathopen{}\left( \sin x \mathclose{}\right) !"), - ("factorial(factorial(x))", r"\mathopen{}\left( x ! \mathclose{}\right) !"), - ], -) -def test_visit_call(code: str, latex: str) -> None: - node = ast_utils.parse_expr(code) - assert isinstance(node, ast.Call) - assert function_codegen.FunctionCodegen().visit(node) == latex - - -@pytest.mark.parametrize( - "src_suffix,dest_suffix", - [ - # No comprehension - ("(x)", r" x"), - ( - "([1, 2])", - r" \mathopen{}\left[ 1, 2 \mathclose{}\right]", - ), - ( - "({1, 2})", - r" \mathopen{}\left\{ 1, 2 \mathclose{}\right\}", - ), - ("(f(x))", r" f \mathopen{}\left( x \mathclose{}\right)"), - # Single comprehension - ("(i for i in x)", r"_{i \in x}^{} \mathopen{}\left({i}\mathclose{}\right)"), - ( - "(i for i in [1, 2])", - r"_{i \in \mathopen{}\left[ 1, 2 \mathclose{}\right]}^{} " - r"\mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in {1, 2})", - r"_{i \in \mathopen{}\left\{ 1, 2 \mathclose{}\right\}}^{}" - r" \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in f(x))", - r"_{i \in f \mathopen{}\left( x \mathclose{}\right)}^{}" - r" \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(n))", - r"_{i = 0}^{n - 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(n + 1))", - r"_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(n + 2))", - r"_{i = 0}^{n + 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - # ast.parse() does not recognize negative integers. - "(i for i in range(n - -1))", - r"_{i = 0}^{n - -1 - 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(n - 1))", - r"_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(n + m))", - r"_{i = 0}^{n + m - 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(n - m))", - r"_{i = 0}^{n - m - 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(3))", - r"_{i = 0}^{2} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(3 + 1))", - r"_{i = 0}^{3} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(3 + 2))", - r"_{i = 0}^{3 + 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(3 - 1))", - r"_{i = 0}^{3 - 2} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - # ast.parse() does not recognize negative integers. - "(i for i in range(3 - -1))", - r"_{i = 0}^{3 - -1 - 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(3 + m))", - r"_{i = 0}^{3 + m - 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(3 - m))", - r"_{i = 0}^{3 - m - 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(n, m))", - r"_{i = n}^{m - 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(1, m))", - r"_{i = 1}^{m - 1} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(n, 3))", - r"_{i = n}^{2} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in range(n, m, k))", - r"_{i \in \mathrm{range} \mathopen{}\left( n, m, k \mathclose{}\right)}^{}" - r" \mathopen{}\left({i}\mathclose{}\right)", - ), - ], -) -def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: - for src_fn, dest_fn in [("fsum", r"\sum"), ("sum", r"\sum"), ("prod", r"\prod")]: - node = ast_utils.parse_expr(src_fn + src_suffix) - assert isinstance(node, ast.Call) - assert function_codegen.FunctionCodegen().visit(node) == dest_fn + dest_suffix - - -@pytest.mark.parametrize( - "code,latex", - [ - # 2 clauses - ( - "sum(i for y in x for i in y)", - r"\sum_{y \in x}^{} \sum_{i \in y}^{} " - r"\mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "sum(i for y in x for z in y for i in z)", - r"\sum_{y \in x}^{} \sum_{z \in y}^{} \sum_{i \in z}^{} " - r"\mathopen{}\left({i}\mathclose{}\right)", - ), - # 3 clauses - ( - "prod(i for y in x for i in y)", - r"\prod_{y \in x}^{} \prod_{i \in y}^{} " - r"\mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "prod(i for y in x for z in y for i in z)", - r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} " - r"\mathopen{}\left({i}\mathclose{}\right)", - ), - # reduce stop parameter - ( - "sum(i for i in range(n+1))", - r"\sum_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "prod(i for i in range(n-1))", - r"\prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)", - ), - # reduce stop parameter - ( - "sum(i for i in range(n+1))", - r"\sum_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "prod(i for i in range(n-1))", - r"\prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)", - ), - ], -) -def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None: - node = ast_utils.parse_expr(code) - assert isinstance(node, ast.Call) - assert function_codegen.FunctionCodegen().visit(node) == latex - - -@pytest.mark.parametrize( - "src_suffix,dest_suffix", - [ - ( - "(i for i in x if i < y)", - r"_{\mathopen{}\left( i \in x \mathclose{}\right) " - r"\land \mathopen{}\left( i < y \mathclose{}\right)}^{} " - r"\mathopen{}\left({i}\mathclose{}\right)", - ), - ( - "(i for i in x if i < y if f(i))", - r"_{\mathopen{}\left( i \in x \mathclose{}\right)" - r" \land \mathopen{}\left( i < y \mathclose{}\right)" - r" \land \mathopen{}\left( f \mathopen{}\left(" - r" i \mathclose{}\right) \mathclose{}\right)}^{}" - r" \mathopen{}\left({i}\mathclose{}\right)", - ), - ], -) -def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None: - for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]: - node = ast_utils.parse_expr(src_fn + src_suffix) - assert isinstance(node, ast.Call) - assert function_codegen.FunctionCodegen().visit(node) == dest_fn + dest_suffix - - -@pytest.mark.parametrize( - "code,latex", - [ - ( - "x if x < y else y", - r"\left\{ \begin{array}{ll}" - r" x, & \mathrm{if} \ x < y \\" - r" y, & \mathrm{otherwise}" - r" \end{array} \right.", - ), - ( - "x if x < y else (y if y < z else z)", - r"\left\{ \begin{array}{ll}" - r" x, & \mathrm{if} \ x < y \\" - r" y, & \mathrm{if} \ y < z \\" - r" z, & \mathrm{otherwise}" - r" \end{array} \right.", - ), - ( - "x if x < y else (y if y < z else (z if z < w else w))", - r"\left\{ \begin{array}{ll}" - r" x, & \mathrm{if} \ x < y \\" - r" y, & \mathrm{if} \ y < z \\" - r" z, & \mathrm{if} \ z < w \\" - r" w, & \mathrm{otherwise}" - r" \end{array} \right.", - ), - ], -) -def test_if_then_else(code: str, latex: str) -> None: - node = ast_utils.parse_expr(code) - assert isinstance(node, ast.IfExp) - assert function_codegen.FunctionCodegen().visit(node) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - # x op y - ("x**y", r"x^{y}"), - ("x * y", r"x y"), - ("x @ y", r"x y"), - ("x / y", r"\frac{x}{y}"), - ("x // y", r"\left\lfloor\frac{x}{y}\right\rfloor"), - ("x % y", r"x \mathbin{\%} y"), - ("x + y", r"x + y"), - ("x - y", r"x - y"), - ("x << y", r"x \ll y"), - ("x >> y", r"x \gg y"), - ("x & y", r"x \mathbin{\&} y"), - ("x ^ y", r"x \oplus y"), - ("x | y", R"x \mathbin{|} y"), - # (x op y) op z - ("(x**y)**z", r"\mathopen{}\left( x^{y} \mathclose{}\right)^{z}"), - ("(x * y) * z", r"x y z"), - ("(x @ y) @ z", r"x y z"), - ("(x / y) / z", r"\frac{\frac{x}{y}}{z}"), - ( - "(x // y) // z", - r"\left\lfloor\frac{\left\lfloor\frac{x}{y}\right\rfloor}{z}\right\rfloor", - ), - ("(x % y) % z", r"x \mathbin{\%} y \mathbin{\%} z"), - ("(x + y) + z", r"x + y + z"), - ("(x - y) - z", r"x - y - z"), - ("(x << y) << z", r"x \ll y \ll z"), - ("(x >> y) >> z", r"x \gg y \gg z"), - ("(x & y) & z", r"x \mathbin{\&} y \mathbin{\&} z"), - ("(x ^ y) ^ z", r"x \oplus y \oplus z"), - ("(x | y) | z", r"x \mathbin{|} y \mathbin{|} z"), - # x op (y op z) - ("x**(y**z)", r"x^{y^{z}}"), - ("x * (y * z)", r"x y z"), - ("x @ (y @ z)", r"x y z"), - ("x / (y / z)", r"\frac{x}{\frac{y}{z}}"), - ( - "x // (y // z)", - r"\left\lfloor\frac{x}{\left\lfloor\frac{y}{z}\right\rfloor}\right\rfloor", - ), - ( - "x % (y % z)", - r"x \mathbin{\%} \mathopen{}\left( y \mathbin{\%} z \mathclose{}\right)", - ), - ("x + (y + z)", r"x + y + z"), - ("x - (y - z)", r"x - \mathopen{}\left( y - z \mathclose{}\right)"), - ("x << (y << z)", r"x \ll \mathopen{}\left( y \ll z \mathclose{}\right)"), - ("x >> (y >> z)", r"x \gg \mathopen{}\left( y \gg z \mathclose{}\right)"), - ("x & (y & z)", r"x \mathbin{\&} y \mathbin{\&} z"), - ("x ^ (y ^ z)", r"x \oplus y \oplus z"), - ("x | (y | z)", r"x \mathbin{|} y \mathbin{|} z"), - # x OP y op z - ("x**y * z", r"x^{y} z"), - ("x * y + z", r"x y + z"), - ("x @ y + z", r"x y + z"), - ("x / y + z", r"\frac{x}{y} + z"), - ("x // y + z", r"\left\lfloor\frac{x}{y}\right\rfloor + z"), - ("x % y + z", r"x \mathbin{\%} y + z"), - ("x + y << z", r"x + y \ll z"), - ("x - y << z", r"x - y \ll z"), - ("x << y & z", r"x \ll y \mathbin{\&} z"), - ("x >> y & z", r"x \gg y \mathbin{\&} z"), - ("x & y ^ z", r"x \mathbin{\&} y \oplus z"), - ("x ^ y | z", r"x \oplus y \mathbin{|} z"), - # x OP (y op z) - ("x**(y * z)", r"x^{y z}"), - ("x * (y + z)", r"x \mathopen{}\left( y + z \mathclose{}\right)"), - ("x @ (y + z)", r"x \mathopen{}\left( y + z \mathclose{}\right)"), - ("x / (y + z)", r"\frac{x}{y + z}"), - ("x // (y + z)", r"\left\lfloor\frac{x}{y + z}\right\rfloor"), - ("x % (y + z)", r"x \mathbin{\%} \mathopen{}\left( y + z \mathclose{}\right)"), - ("x + (y << z)", r"x + \mathopen{}\left( y \ll z \mathclose{}\right)"), - ("x - (y << z)", r"x - \mathopen{}\left( y \ll z \mathclose{}\right)"), - ( - "x << (y & z)", - r"x \ll \mathopen{}\left( y \mathbin{\&} z \mathclose{}\right)", - ), - ( - "x >> (y & z)", - r"x \gg \mathopen{}\left( y \mathbin{\&} z \mathclose{}\right)", - ), - ( - "x & (y ^ z)", - r"x \mathbin{\&} \mathopen{}\left( y \oplus z \mathclose{}\right)", - ), - ( - "x ^ (y | z)", - r"x \oplus \mathopen{}\left( y \mathbin{|} z \mathclose{}\right)", - ), - # x op y OP z - ("x * y**z", r"x y^{z}"), - ("x + y * z", r"x + y z"), - ("x + y @ z", r"x + y z"), - ("x + y / z", r"x + \frac{y}{z}"), - ("x + y // z", r"x + \left\lfloor\frac{y}{z}\right\rfloor"), - ("x + y % z", r"x + y \mathbin{\%} z"), - ("x << y + z", r"x \ll y + z"), - ("x << y - z", r"x \ll y - z"), - ("x & y << z", r"x \mathbin{\&} y \ll z"), - ("x & y >> z", r"x \mathbin{\&} y \gg z"), - ("x ^ y & z", r"x \oplus y \mathbin{\&} z"), - ("x | y ^ z", r"x \mathbin{|} y \oplus z"), - # (x op y) OP z - ("(x * y)**z", r"\mathopen{}\left( x y \mathclose{}\right)^{z}"), - ("(x + y) * z", r"\mathopen{}\left( x + y \mathclose{}\right) z"), - ("(x + y) @ z", r"\mathopen{}\left( x + y \mathclose{}\right) z"), - ("(x + y) / z", r"\frac{x + y}{z}"), - ("(x + y) // z", r"\left\lfloor\frac{x + y}{z}\right\rfloor"), - ("(x + y) % z", r"\mathopen{}\left( x + y \mathclose{}\right) \mathbin{\%} z"), - ("(x << y) + z", r"\mathopen{}\left( x \ll y \mathclose{}\right) + z"), - ("(x << y) - z", r"\mathopen{}\left( x \ll y \mathclose{}\right) - z"), - ( - "(x & y) << z", - r"\mathopen{}\left( x \mathbin{\&} y \mathclose{}\right) \ll z", - ), - ( - "(x & y) >> z", - r"\mathopen{}\left( x \mathbin{\&} y \mathclose{}\right) \gg z", - ), - ( - "(x ^ y) & z", - r"\mathopen{}\left( x \oplus y \mathclose{}\right) \mathbin{\&} z", - ), - ( - "(x | y) ^ z", - r"\mathopen{}\left( x \mathbin{|} y \mathclose{}\right) \oplus z", - ), - # is_wrapped - ("(x // y)**z", r"\left\lfloor\frac{x}{y}\right\rfloor^{z}"), - # With Call - ("x**f(y)", r"x^{f \mathopen{}\left( y \mathclose{}\right)}"), - ( - "f(x)**y", - r"\mathopen{}\left(" - r" f \mathopen{}\left( x \mathclose{}\right)" - r" \mathclose{}\right)^{y}", - ), - ("x * f(y)", r"x f \mathopen{}\left( y \mathclose{}\right)"), - ("f(x) * y", r"f \mathopen{}\left( x \mathclose{}\right) y"), - ("x / f(y)", r"\frac{x}{f \mathopen{}\left( y \mathclose{}\right)}"), - ("f(x) / y", r"\frac{f \mathopen{}\left( x \mathclose{}\right)}{y}"), - ("x + f(y)", r"x + f \mathopen{}\left( y \mathclose{}\right)"), - ("f(x) + y", r"f \mathopen{}\left( x \mathclose{}\right) + y"), - # With is_wrapped Call - ("sqrt(x) ** y", r"\sqrt{ x }^{y}"), - # With UnaryOp - ("x**-y", r"x^{-y}"), - ("(-x)**y", r"\mathopen{}\left( -x \mathclose{}\right)^{y}"), - ("x * -y", r"x -y"), # TODO(odashi): google/latexify_py#89 - ("-x * y", r"-x y"), - ("x / -y", r"\frac{x}{-y}"), - ("-x / y", r"\frac{-x}{y}"), - ("x + -y", r"x + -y"), - ("-x + y", r"-x + y"), - # With Compare - ("x**(y == z)", r"x^{y = z}"), - ("(x == y)**z", r"\mathopen{}\left( x = y \mathclose{}\right)^{z}"), - ("x * (y == z)", r"x \mathopen{}\left( y = z \mathclose{}\right)"), - ("(x == y) * z", r"\mathopen{}\left( x = y \mathclose{}\right) z"), - ("x / (y == z)", r"\frac{x}{y = z}"), - ("(x == y) / z", r"\frac{x = y}{z}"), - ("x + (y == z)", r"x + \mathopen{}\left( y = z \mathclose{}\right)"), - ("(x == y) + z", r"\mathopen{}\left( x = y \mathclose{}\right) + z"), - # With BoolOp - ("x**(y and z)", r"x^{y \land z}"), - ("(x and y)**z", r"\mathopen{}\left( x \land y \mathclose{}\right)^{z}"), - ("x * (y and z)", r"x \mathopen{}\left( y \land z \mathclose{}\right)"), - ("(x and y) * z", r"\mathopen{}\left( x \land y \mathclose{}\right) z"), - ("x / (y and z)", r"\frac{x}{y \land z}"), - ("(x and y) / z", r"\frac{x \land y}{z}"), - ("x + (y and z)", r"x + \mathopen{}\left( y \land z \mathclose{}\right)"), - ("(x and y) + z", r"\mathopen{}\left( x \land y \mathclose{}\right) + z"), - ], -) -def test_visit_binop(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.BinOp) - assert function_codegen.FunctionCodegen().visit(tree) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - # With literals - ("+x", r"+x"), - ("-x", r"-x"), - ("~x", r"\mathord{\sim} x"), - ("not x", r"\lnot x"), - # With Call - ("+f(x)", r"+f \mathopen{}\left( x \mathclose{}\right)"), - ("-f(x)", r"-f \mathopen{}\left( x \mathclose{}\right)"), - ("~f(x)", r"\mathord{\sim} f \mathopen{}\left( x \mathclose{}\right)"), - ("not f(x)", r"\lnot f \mathopen{}\left( x \mathclose{}\right)"), - # With BinOp - ("+(x + y)", r"+\mathopen{}\left( x + y \mathclose{}\right)"), - ("-(x + y)", r"-\mathopen{}\left( x + y \mathclose{}\right)"), - ("~(x + y)", r"\mathord{\sim} \mathopen{}\left( x + y \mathclose{}\right)"), - ("not x + y", r"\lnot \mathopen{}\left( x + y \mathclose{}\right)"), - # With Compare - ("+(x == y)", r"+\mathopen{}\left( x = y \mathclose{}\right)"), - ("-(x == y)", r"-\mathopen{}\left( x = y \mathclose{}\right)"), - ("~(x == y)", r"\mathord{\sim} \mathopen{}\left( x = y \mathclose{}\right)"), - ("not x == y", r"\lnot \mathopen{}\left( x = y \mathclose{}\right)"), - # With BoolOp - ("+(x and y)", r"+\mathopen{}\left( x \land y \mathclose{}\right)"), - ("-(x and y)", r"-\mathopen{}\left( x \land y \mathclose{}\right)"), - ( - "~(x and y)", - r"\mathord{\sim} \mathopen{}\left( x \land y \mathclose{}\right)", - ), - ("not (x and y)", r"\lnot \mathopen{}\left( x \land y \mathclose{}\right)"), - ], -) -def test_visit_unaryop(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.UnaryOp) - assert function_codegen.FunctionCodegen().visit(tree) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - # 1 comparator - ("a == b", "a = b"), - ("a > b", "a > b"), - ("a >= b", r"a \ge b"), - ("a in b", r"a \in b"), - ("a is b", r"a \equiv b"), - ("a is not b", r"a \not\equiv b"), - ("a < b", "a < b"), - ("a <= b", r"a \le b"), - ("a != b", r"a \ne b"), - ("a not in b", r"a \notin b"), - # 2 comparators - ("a == b == c", "a = b = c"), - ("a == b > c", "a = b > c"), - ("a == b >= c", r"a = b \ge c"), - ("a == b < c", "a = b < c"), - ("a == b <= c", r"a = b \le c"), - ("a > b == c", "a > b = c"), - ("a > b > c", "a > b > c"), - ("a > b >= c", r"a > b \ge c"), - ("a >= b == c", r"a \ge b = c"), - ("a >= b > c", r"a \ge b > c"), - ("a >= b >= c", r"a \ge b \ge c"), - ("a < b == c", "a < b = c"), - ("a < b < c", "a < b < c"), - ("a < b <= c", r"a < b \le c"), - ("a <= b == c", r"a \le b = c"), - ("a <= b < c", r"a \le b < c"), - ("a <= b <= c", r"a \le b \le c"), - # With Call - ("a == f(b)", r"a = f \mathopen{}\left( b \mathclose{}\right)"), - ("f(a) == b", r"f \mathopen{}\left( a \mathclose{}\right) = b"), - # With BinOp - ("a == b + c", r"a = b + c"), - ("a + b == c", r"a + b = c"), - # With UnaryOp - ("a == -b", r"a = -b"), - ("-a == b", r"-a = b"), - ("a == (not b)", r"a = \lnot b"), - ("(not a) == b", r"\lnot a = b"), - # With BoolOp - ("a == (b and c)", r"a = \mathopen{}\left( b \land c \mathclose{}\right)"), - ("(a and b) == c", r"\mathopen{}\left( a \land b \mathclose{}\right) = c"), - ], -) -def test_visit_compare(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.Compare) - assert function_codegen.FunctionCodegen().visit(tree) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - # With literals - ("a and b", r"a \land b"), - ("a and b and c", r"a \land b \land c"), - ("a or b", r"a \lor b"), - ("a or b or c", r"a \lor b \lor c"), - ("a or b and c", r"a \lor b \land c"), - ( - "(a or b) and c", - r"\mathopen{}\left( a \lor b \mathclose{}\right) \land c", - ), - ("a and b or c", r"a \land b \lor c"), - ( - "a and (b or c)", - r"a \land \mathopen{}\left( b \lor c \mathclose{}\right)", - ), - # With Call - ("a and f(b)", r"a \land f \mathopen{}\left( b \mathclose{}\right)"), - ("f(a) and b", r"f \mathopen{}\left( a \mathclose{}\right) \land b"), - ("a or f(b)", r"a \lor f \mathopen{}\left( b \mathclose{}\right)"), - ("f(a) or b", r"f \mathopen{}\left( a \mathclose{}\right) \lor b"), - # With BinOp - ("a and b + c", r"a \land b + c"), - ("a + b and c", r"a + b \land c"), - ("a or b + c", r"a \lor b + c"), - ("a + b or c", r"a + b \lor c"), - # With UnaryOp - ("a and not b", r"a \land \lnot b"), - ("not a and b", r"\lnot a \land b"), - ("a or not b", r"a \lor \lnot b"), - ("not a or b", r"\lnot a \lor b"), - # With Compare - ("a and b == c", r"a \land b = c"), - ("a == b and c", r"a = b \land c"), - ("a or b == c", r"a \lor b = c"), - ("a == b or c", r"a = b \lor c"), - ], -) -def test_visit_boolop(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.BoolOp) - assert function_codegen.FunctionCodegen().visit(tree) == latex - - -@test_utils.require_at_most(7) -@pytest.mark.parametrize( - "code,cls,latex", - [ - ("0", ast.Num, "0"), - ("1", ast.Num, "1"), - ("0.0", ast.Num, "0.0"), - ("1.5", ast.Num, "1.5"), - ("0.0j", ast.Num, "0j"), - ("1.0j", ast.Num, "1j"), - ("1.5j", ast.Num, "1.5j"), - ('"abc"', ast.Str, r'\textrm{"abc"}'), - ('b"abc"', ast.Bytes, r"\textrm{b'abc'}"), - ("None", ast.NameConstant, r"\mathrm{None}"), - ("False", ast.NameConstant, r"\mathrm{False}"), - ("True", ast.NameConstant, r"\mathrm{True}"), - ("...", ast.Ellipsis, r"\cdots"), - ], -) -def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, cls) - assert function_codegen.FunctionCodegen().visit(tree) == latex - - -@test_utils.require_at_least(8) -@pytest.mark.parametrize( - "code,latex", - [ - ("0", "0"), - ("1", "1"), - ("0.0", "{0.0}"), - ("1.5", "{1.5}"), - ("0.0j", "{0j}"), - ("1.0j", "{1j}"), - ("1.5j", "{1.5j}"), - ('"abc"', r'\textrm{"abc"}'), - ('b"abc"', r"\textrm{b'abc'}"), - ("None", r"\mathrm{None}"), - ("False", r"\mathrm{False}"), - ("True", r"\mathrm{True}"), - ("...", r"{\cdots}"), - ], -) -def test_visit_constant(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.Constant) - - -@pytest.mark.parametrize( - "code,latex", - [ - ("x[0]", "x_{0}"), - ("x[0][1]", "x_{0, 1}"), - ("x[0][1][2]", "x_{0, 1, 2}"), - ("x[foo]", r"x_{\mathrm{foo}}"), - ("x[floor(x)]", r"x_{\mathopen{}\left\lfloor x \mathclose{}\right\rfloor}"), - ], -) -def test_visit_subscript(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.Subscript) - assert function_codegen.FunctionCodegen().visit(tree) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - ("a - b", r"a \setminus b"), - ("a & b", r"a \cap b"), - ("a ^ b", r"a \mathbin{\triangle} b"), - ("a | b", r"a \cup b"), - ], -) -def test_use_set_symbols_binop(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.BinOp) - assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - ("a < b", r"a \subset b"), - ("a <= b", r"a \subseteq b"), - ("a > b", r"a \supset b"), - ("a >= b", r"a \supseteq b"), - ], -) -def test_use_set_symbols_compare(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.Compare) - assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - ("array(1)", r"\mathrm{array} \mathopen{}\left( 1 \mathclose{}\right)"), - ( - "array([])", - r"\mathrm{array} \mathopen{}\left(" - r" \mathopen{}\left[ \mathclose{}\right]" - r" \mathclose{}\right)", - ), - ("array([1])", r"\begin{bmatrix} 1 \end{bmatrix}"), - ("array([1, 2, 3])", r"\begin{bmatrix} 1 & 2 & 3 \end{bmatrix}"), - ( - "array([[]])", - r"\mathrm{array} \mathopen{}\left(" - r" \mathopen{}\left[ \mathopen{}\left[" - r" \mathclose{}\right] \mathclose{}\right]" - r" \mathclose{}\right)", - ), - ("array([[1]])", r"\begin{bmatrix} 1 \end{bmatrix}"), - ("array([[1], [2], [3]])", r"\begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}"), - ( - "array([[1], [2], [3, 4]])", - r"\mathrm{array} \mathopen{}\left(" - r" \mathopen{}\left[" - r" \mathopen{}\left[ 1 \mathclose{}\right]," - r" \mathopen{}\left[ 2 \mathclose{}\right]," - r" \mathopen{}\left[ 3, 4 \mathclose{}\right]" - r" \mathclose{}\right]" - r" \mathclose{}\right)", - ), - ( - "array([[1, 2], [3, 4], [5, 6]])", - r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}", - ), - # Only checks two cases for ndarray. - ("ndarray(1)", r"\mathrm{ndarray} \mathopen{}\left( 1 \mathclose{}\right)"), - ("ndarray([1])", r"\begin{bmatrix} 1 \end{bmatrix}"), - ], -) -def test_numpy_array(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.Call) - assert function_codegen.FunctionCodegen().visit(tree) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - ("zeros(0)", r"\mathbf{0}^{1 \times 0}"), - ("zeros(1)", r"\mathbf{0}^{1 \times 1}"), - ("zeros(2)", r"\mathbf{0}^{1 \times 2}"), - ("zeros(())", r"0"), - ("zeros((0,))", r"\mathbf{0}^{1 \times 0}"), - ("zeros((1,))", r"\mathbf{0}^{1 \times 1}"), - ("zeros((2,))", r"\mathbf{0}^{1 \times 2}"), - ("zeros((0, 0))", r"\mathbf{0}^{0 \times 0}"), - ("zeros((1, 1))", r"\mathbf{0}^{1 \times 1}"), - ("zeros((2, 3))", r"\mathbf{0}^{2 \times 3}"), - ("zeros((0, 0, 0))", r"\mathbf{0}^{0 \times 0 \times 0}"), - ("zeros((1, 1, 1))", r"\mathbf{0}^{1 \times 1 \times 1}"), - ("zeros((2, 3, 5))", r"\mathbf{0}^{2 \times 3 \times 5}"), - # Unsupported - ("zeros()", r"\mathrm{zeros} \mathopen{}\left( \mathclose{}\right)"), - ("zeros(x)", r"\mathrm{zeros} \mathopen{}\left( x \mathclose{}\right)"), - ("zeros(0, x)", r"\mathrm{zeros} \mathopen{}\left( 0, x \mathclose{}\right)"), - ( - "zeros((x,))", - r"\mathrm{zeros} \mathopen{}\left(" - r" \mathopen{}\left( x \mathclose{}\right)" - r" \mathclose{}\right)", - ), - ], -) -def test_zeros(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.Call) - assert function_codegen.FunctionCodegen().visit(tree) == latex - - -@pytest.mark.parametrize( - "code,latex", - [ - ("identity(0)", r"\mathbf{I}_{0}"), - ("identity(1)", r"\mathbf{I}_{1}"), - ("identity(2)", r"\mathbf{I}_{2}"), - # Unsupported - ("identity()", r"\mathrm{identity} \mathopen{}\left( \mathclose{}\right)"), - ("identity(x)", r"\mathrm{identity} \mathopen{}\left( x \mathclose{}\right)"), - ( - "identity(0, x)", - r"\mathrm{identity} \mathopen{}\left( 0, x \mathclose{}\right)", - ), - ], -) -def test_identity(code: str, latex: str) -> None: - tree = ast_utils.parse_expr(code) - assert isinstance(tree, ast.Call) - assert function_codegen.FunctionCodegen().visit(tree) == latex diff --git a/src/latexify/test_utils.py b/src/latexify/test_utils.py index 487ec50..e143a64 100644 --- a/src/latexify/test_utils.py +++ b/src/latexify/test_utils.py @@ -60,8 +60,8 @@ def wrapper(*args, **kwargs): def ast_equal(observed: ast.AST, expected: ast.AST) -> bool: """Checks the equality between two ASTs. - This function checks if `ovserved` contains at least the same subtree with - `expected`. If `ovserved` has some extra branches that `expected` does not cover, + This function checks if `observed` contains at least the same subtree with + `expected`. If `observed` has some extra branches that `expected` does not cover, it is ignored. Args: @@ -75,6 +75,9 @@ def ast_equal(observed: ast.AST, expected: ast.AST) -> bool: assert type(observed) is type(expected) for k, ve in vars(expected).items(): + if k in {"col_offset", "end_col_offset", "end_lineno", "kind", "lineno"}: + continue + vo = getattr(observed, k) # May cause AttributeError. if isinstance(ve, ast.AST): diff --git a/src/latexify/transformers/assignment_reducer.py b/src/latexify/transformers/assignment_reducer.py index 2b7616c..ef35af2 100644 --- a/src/latexify/transformers/assignment_reducer.py +++ b/src/latexify/transformers/assignment_reducer.py @@ -33,7 +33,7 @@ def f(x): # comprehensions or lambdas, which introduces inner scopes. # It may cause some mistakes in the resulting AST. def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: - """Visitor of FunctionDef nodes.""" + """Visit a FunctionDef node.""" # Push stack parent_assignments = self._assignments self._assignments = {} @@ -75,7 +75,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: ) def visit_Name(self, node: ast.Name) -> Any: - """Visitor of Name nodes.""" + """Visit a Name node.""" if self._assignments is not None: return self._assignments.get(node.id, node) diff --git a/src/latexify/transformers/function_expander.py b/src/latexify/transformers/function_expander.py index cedee60..529de43 100644 --- a/src/latexify/transformers/function_expander.py +++ b/src/latexify/transformers/function_expander.py @@ -27,7 +27,7 @@ def __init__(self, functions: set[str]) -> None: self._functions = functions def visit_Call(self, node: ast.Call) -> ast.AST: - """Visitor of Call nodes.""" + """Visit a Call node.""" func_name = ast_utils.extract_function_name_or_none(node) if ( func_name is not None diff --git a/src/latexify/transformers/identifier_replacer.py b/src/latexify/transformers/identifier_replacer.py index cefd271..430e50e 100644 --- a/src/latexify/transformers/identifier_replacer.py +++ b/src/latexify/transformers/identifier_replacer.py @@ -48,8 +48,7 @@ def _replace_args(self, args: list[ast.arg]) -> list[ast.arg]: return [ast.arg(arg=self._mapping.get(a.arg, a.arg)) for a in args] def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: - """Visitor of FunctionDef.""" - + """Visit a FunctionDef node.""" visited = cast(ast.FunctionDef, super().generic_visit(node)) if sys.version_info.minor < 8: @@ -76,7 +75,7 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: ) def visit_Name(self, node: ast.Name) -> ast.Name: - """Visitor of Name.""" + """Visit a Name node.""" return ast.Name( id=self._mapping.get(node.id, node.id), ctx=node.ctx, diff --git a/src/latexify/transformers/prefix_trimmer.py b/src/latexify/transformers/prefix_trimmer.py index 3ed8e76..c04dd59 100644 --- a/src/latexify/transformers/prefix_trimmer.py +++ b/src/latexify/transformers/prefix_trimmer.py @@ -79,6 +79,7 @@ def _make_attribute(self, prefix: tuple[str, ...], name: str) -> ast.expr: return ast_utils.make_attribute(parent, name) def visit_Attribute(self, node: ast.Attribute) -> ast.expr: + """Visit an Attribute node.""" prefix = self._get_prefix(node.value) if prefix is None: return node