diff --git a/src/integration_tests/function_expansion_test.py b/src/integration_tests/function_expansion_test.py index b4a7c8e..5744334 100644 --- a/src/integration_tests/function_expansion_test.py +++ b/src/integration_tests/function_expansion_test.py @@ -3,23 +3,29 @@ from integration_tests import utils -def test_expand_atan2_function() -> None: +def test_atan2() -> None: def solve(x, y): return math.atan2(y, x) - latex = r"\mathrm{solve}(x, y) = \arctan{\left({\frac{y}{x}}\right)}" + latex = ( + r"\mathrm{solve}(x, y) =" + r" \arctan \mathopen{}\left( \frac{y}{x} \mathclose{}\right)" + ) utils.check_function(solve, latex, expand_functions={"atan2"}) -def test_expand_atan2_nested_function() -> None: +def test_atan2_nested() -> None: def solve(x, y): return math.atan2(math.exp(y), math.exp(x)) - latex = r"\mathrm{solve}(x, y) = \arctan{\left({\frac{e^{y}}{e^{x}}}\right)}" + latex = ( + r"\mathrm{solve}(x, y) =" + r" \arctan \mathopen{}\left( \frac{e^{y}}{e^{x}} \mathclose{}\right)" + ) utils.check_function(solve, latex, expand_functions={"atan2", "exp"}) -def test_expand_exp_function() -> None: +def test_exp() -> None: def solve(x): return math.exp(x) @@ -27,7 +33,7 @@ def solve(x): utils.check_function(solve, latex, expand_functions={"exp"}) -def test_expand_exp_nested_function() -> None: +def test_exp_nested() -> None: def solve(x): return math.exp(math.exp(x)) @@ -35,7 +41,7 @@ def solve(x): utils.check_function(solve, latex, expand_functions={"exp"}) -def test_expand_exp2_function() -> None: +def test_exp2() -> None: def solve(x): return math.exp2(x) @@ -43,7 +49,7 @@ def solve(x): utils.check_function(solve, latex, expand_functions={"exp2"}) -def test_expand_exp2_nested_function() -> None: +def test_exp2_nested() -> None: def solve(x): return math.exp2(math.exp2(x)) @@ -51,15 +57,15 @@ def solve(x): utils.check_function(solve, latex, expand_functions={"exp2"}) -def test_expand_expm1_function() -> None: +def test_expm1() -> None: def solve(x): return math.expm1(x) - latex = r"\mathrm{solve}(x) = \exp{\left({x}\right)} - {1}" + latex = r"\mathrm{solve}(x) = \exp x - {1}" utils.check_function(solve, latex, expand_functions={"expm1"}) -def test_expand_expm1_nested_function() -> None: +def test_expm1_nested() -> None: def solve(x, y, z): return math.expm1(math.pow(y, z)) @@ -67,54 +73,54 @@ def solve(x, y, z): utils.check_function(solve, latex, expand_functions={"expm1", "exp", "pow"}) -def test_expand_hypot_function_without_attribute_access() -> None: +def test_hypot_without_attribute() -> None: from math import hypot def solve(x, y, z): return hypot(x, y, z) - latex = r"\mathrm{solve}(x, y, z) = \sqrt{x^{{2}} + y^{{2}} + z^{{2}}}" + latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{{2}} + y^{{2}} + z^{{2}} }" utils.check_function(solve, latex, expand_functions={"hypot"}) -def test_expand_hypot_function() -> None: +def test_hypot() -> None: def solve(x, y, z): return math.hypot(x, y, z) - latex = r"\mathrm{solve}(x, y, z) = \sqrt{x^{{2}} + y^{{2}} + z^{{2}}}" + latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{{2}} + y^{{2}} + z^{{2}} }" utils.check_function(solve, latex, expand_functions={"hypot"}) -def test_expand_hypot_nested_function() -> None: +def test_hypot_nested() -> None: def solve(a, b, x, y): return math.hypot(math.hypot(a, b), x, y) latex = ( - r"\mathrm{solve}(a, b, x, y) = " - r"\sqrt{" - r"\sqrt{a^{{2}} + b^{{2}}}^{{2}} + " - r"x^{{2}} + y^{{2}}}" + r"\mathrm{solve}(a, b, x, y) =" + r" \sqrt{ \sqrt{ a^{{2}} + b^{{2}} }^{{2}} + x^{{2}} + y^{{2}} }" ) utils.check_function(solve, latex, expand_functions={"hypot"}) -def test_expand_log1p_function() -> None: +def test_log1p() -> None: def solve(x): return math.log1p(x) - latex = r"\mathrm{solve}(x) = \log{\left({{1} + x}\right)}" + latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( {1} + x \mathclose{}\right)" utils.check_function(solve, latex, expand_functions={"log1p"}) -def test_expand_log1p_nested_function() -> None: +def test_log1p_nested() -> None: def solve(x): return math.log1p(math.exp(x)) - latex = r"\mathrm{solve}(x) = \log{\left({{1} + e^{x}}\right)}" + latex = ( + r"\mathrm{solve}(x) = \log \mathopen{}\left( {1} + e^{x} \mathclose{}\right)" + ) utils.check_function(solve, latex, expand_functions={"log1p", "exp"}) -def test_expand_pow_nested_function() -> None: +def test_pow_nested() -> None: def solve(w, x, y, z): return math.pow(math.pow(w, x), math.pow(y, z)) @@ -125,7 +131,7 @@ def solve(w, x, y, z): utils.check_function(solve, latex, expand_functions={"pow"}) -def test_expand_pow_function() -> None: +def test_pow() -> None: def solve(x, y): return math.pow(x, y) diff --git a/src/integration_tests/regression_test.py b/src/integration_tests/regression_test.py index 10244d4..5f8385d 100644 --- a/src/integration_tests/regression_test.py +++ b/src/integration_tests/regression_test.py @@ -11,7 +11,7 @@ def test_quadratic_solution() -> None: def solve(a, b, c): return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a) - latex = r"\mathrm{solve}(a, b, c) = \frac{-b + \sqrt{b^{{2}} - {4} a c}}{{2} a}" + latex = r"\mathrm{solve}(a, b, c) = \frac{-b + \sqrt{ b^{{2}} - {4} a c }}{{2} a}" utils.check_function(solve, latex) @@ -26,7 +26,7 @@ def sinc(x): r"\mathrm{sinc}(x) = " r"\left\{ \begin{array}{ll} " r"{1}, & \mathrm{if} \ " - r"{x = {0}} \\ \frac{\sin{\left({x}\right)}}{x}, & \mathrm{otherwise} " + r"{x = {0}} \\ \frac{\sin x}{x}, & \mathrm{otherwise} " r"\end{array} \right." ) utils.check_function(sinc, latex) @@ -201,9 +201,9 @@ def sigmoid(x): sigmoid, ( r"\mathrm{sigmoid}(x) = \left\{ \begin{array}{ll} " - r"\frac{{1}}{{1} + \exp{\left({-x}\right)}}, & " + r"\frac{{1}}{{1} + \exp \mathopen{}\left( -x \mathclose{}\right)}, & " r"\mathrm{if} \ {x > {0}} \\ " - r"\frac{\exp{\left({x}\right)}}{\exp{\left({x}\right)} + {1}}, & " + r"\frac{\exp x}{\exp x + {1}}, & " r"\mathrm{otherwise} " r"\end{array} \right." ), diff --git a/src/latexify/ast_utils.py b/src/latexify/ast_utils.py index 59b48e0..718fc88 100644 --- a/src/latexify/ast_utils.py +++ b/src/latexify/ast_utils.py @@ -31,7 +31,7 @@ def make_name(id: str) -> ast.Name: return ast.Name(id=id, ctx=ast.Load()) -def make_attribute(value: ast.Expr, attr: str): +def make_attribute(value: ast.expr, attr: str): """Generates a new Attribute node. Args: diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index cf83122..c5925ec 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -52,6 +52,14 @@ ast.Or: 10, } +# NOTE(odashi): +# Function invocation is treated as a unary operator with a higher precedence. +# This ensures that the argument with a unary operator is wrapped: +# exp(x) --> \exp x +# exp(-x) --> \exp (-x) +# -exp(x) --> - \exp x +_CALL_PRECEDENCE = _PRECEDENCES[ast.UAdd] + 1 + def _get_precedence(node: ast.AST) -> int: """Obtains the precedence of the subtree. @@ -63,6 +71,9 @@ def _get_precedence(node: ast.AST) -> int: If `node` is a subtree with some operator, returns the precedence of the operator. Otherwise, returns a number larger enough from other precedences. """ + if isinstance(node, ast.Call): + return _CALL_PRECEDENCE + if isinstance(node, (ast.BoolOp, ast.BinOp, ast.UnaryOp)): return _PRECEDENCES[type(node.op)] @@ -289,38 +300,34 @@ def visit_Return(self, node: ast.Return) -> str: def visit_Tuple(self, node: ast.Tuple) -> str: elts = [self.visit(i) for i in node.elts] - return ( - r"\mathopen{}\left( " - + r"\space,\space ".join(elts) - + r"\mathclose{}\right) " - ) + return r"\mathopen{}\left( " + r", ".join(elts) + r" \mathclose{}\right)" def visit_List(self, node: ast.List) -> str: elts = [self.visit(i) for i in node.elts] - return r"\left[ " + r"\space,\space ".join(elts) + r"\right] " + return r"\mathopen{}\left[ " + r", ".join(elts) + r" \mathclose{}\right]" def visit_Set(self, node: ast.Set) -> str: elts = [self.visit(i) for i in node.elts] - return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} " + return r"\mathopen{}\left\{ " + r", ".join(elts) + r" \mathclose{}\right\}" def visit_ListComp(self, node: ast.ListComp) -> str: generators = [self.visit(comp) for comp in node.generators] return ( - r"\left[ " + r"\mathopen{}\left[ " + self.visit(node.elt) + r" \mid " + ", ".join(generators) - + r" \right]" + + r" \mathclose{}\right]" ) def visit_SetComp(self, node: ast.SetComp) -> str: generators = [self.visit(comp) for comp in node.generators] return ( - r"\left\{ " + r"\mathopen{}\left\{ " + self.visit(node.elt) + r" \mid " + ", ".join(generators) - + r" \right\}" + + r" \mathclose{}\right\}" ) def visit_comprehension(self, node: ast.comprehension) -> str: @@ -347,10 +354,16 @@ def _generate_sum_prod(self, node: ast.Call) -> str | None: return None name = ast_utils.extract_function_name_or_none(node) - assert name is not None + assert name in ("fsum", "sum", "prod") + + command = { + "fsum": r"\sum", + "sum": r"\sum", + "prod": r"\prod", + }[name] elt, scripts = self._get_sum_prod_info(node.args[0]) - scripts_str = [rf"\{name}_{{{lo}}}^{{{up}}}" for lo, up in scripts] + scripts_str = [rf"{command}_{{{lo}}}^{{{up}}}" for lo, up in scripts] return ( " ".join(scripts_str) + rf" \mathopen{{}}\left({{{elt}}}\mathclose{{}}\right)" @@ -403,7 +416,7 @@ def visit_Call(self, node: ast.Call) -> str: func_name = ast_utils.extract_function_name_or_none(node) # Special treatments for some functions. - if func_name in ("sum", "prod"): + if func_name in ("fsum", "sum", "prod"): special_latex = self._generate_sum_prod(node) elif func_name in ("array", "ndarray"): special_latex = self._generate_matrix(node) @@ -413,17 +426,38 @@ def visit_Call(self, node: ast.Call) -> str: if special_latex is not None: return special_latex - # Function signature (possibly an expression). - default_func_str = self.visit(node.func) - - # Obtains wrapper syntax: sqrt -> "\sqrt{" and "}" - lstr, rstr = constants.BUILTIN_FUNCS.get( - func_name, - (default_func_str + r"\mathopen{}\left(", r"\mathclose{}\right)"), - ) + # Obtains the codegen rule. + rule = constants.BUILTIN_FUNCS.get(func_name) + if rule is None: + rule = constants.FunctionRule(self.visit(node.func)) + + if rule.is_unary and len(node.args) == 1: + # Unary function. Applies the same wrapping policy with the unary operators. + # NOTE(odashi): + # Factorial "x!" is treated as a special case: it requires both inner/outer + # parentheses for correct interpretation. + precedence = _get_precedence(node) + arg = node.args[0] + force_wrap = isinstance(arg, ast.Call) and ( + func_name == "factorial" + or ast_utils.extract_function_name_or_none(arg) == "factorial" + ) + arg_latex = self._wrap_operand(arg, precedence, force_wrap) + elements = [rule.left, arg_latex, rule.right] + else: + arg_latex = ", ".join(self.visit(arg) for arg in node.args) + if rule.is_wrapped: + elements = [rule.left, arg_latex, rule.right] + else: + elements = [ + rule.left, + r"\mathopen{}\left(", + arg_latex, + r"\mathclose{}\right)", + rule.right, + ] - arg_strs = [self.visit(arg) for arg in node.args] - return lstr + ", ".join(arg_strs) + rstr + return " ".join(x for x in elements if x) def visit_Attribute(self, node: ast.Attribute) -> str: vstr = self.visit(node.value) @@ -481,20 +515,26 @@ def visit_NameConstant(self, node: ast.NameConstant) -> str: def visit_Ellipsis(self, node: ast.Ellipsis) -> str: return self._convert_constant(...) - def _wrap_operand(self, child: ast.expr, parent_prec: int) -> str: + def _wrap_operand( + self, child: ast.expr, parent_prec: int, force_wrap: bool = False + ) -> str: """Wraps the operand subtree with parentheses. Args: child: Operand subtree. parent_prec: Precedence of the parent operator. + force_wrap: Whether to wrap the operand or not when the precedence is equal. Returns: LaTeX form of `child`, with or without surrounding parentheses. """ latex = self.visit(child) - if _get_precedence(child) >= parent_prec: - return latex - return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)" + child_prec = _get_precedence(child) + + if child_prec < parent_prec or force_wrap and child_prec == parent_prec: + return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)" + + return latex def _wrap_binop_operand( self, @@ -515,6 +555,13 @@ def _wrap_binop_operand( if not operand_rule.wrap: return self.visit(child) + if isinstance(child, ast.Call): + rule = constants.BUILTIN_FUNCS.get( + ast_utils.extract_function_name_or_none(child) + ) + if rule is not None and rule.is_wrapped: + return self.visit(child) + if not isinstance(child, ast.BinOp): return self._wrap_operand(child, parent_prec) diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index 3f9b7e4..406b6ca 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -77,41 +77,90 @@ def f(x): @pytest.mark.parametrize( "code,latex", [ - ("[i for i in n]", r"\left[ i \mid i \in n \right]"), + ("()", r"\mathopen{}\left( \mathclose{}\right)"), + ("(x,)", r"\mathopen{}\left( x \mathclose{}\right)"), + ("(x, y)", r"\mathopen{}\left( x, y \mathclose{}\right)"), + ("(x, y, z)", r"\mathopen{}\left( x, y, z \mathclose{}\right)"), + ], +) +def test_tuple(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.Tuple) + assert FunctionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("[]", r"\mathopen{}\left[ \mathclose{}\right]"), + ("[x]", r"\mathopen{}\left[ x \mathclose{}\right]"), + ("[x, y]", r"\mathopen{}\left[ x, y \mathclose{}\right]"), + ("[x, y, z]", r"\mathopen{}\left[ x, y, z \mathclose{}\right]"), + ], +) +def test_list(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.List) + assert FunctionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + # TODO(odashi): Support set(). + # ("set()", r"\mathopen{}\left\{ \mathclose{}\right\}"), + ("{x}", r"\mathopen{}\left\{ x \mathclose{}\right\}"), + ("{x, y}", r"\mathopen{}\left\{ x, y \mathclose{}\right\}"), + ("{x, y, z}", r"\mathopen{}\left\{ x, y, z \mathclose{}\right\}"), + ], +) +def test_set(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.Set) + assert FunctionCodegen().visit(node) == latex + + +@pytest.mark.parametrize( + "code,latex", + [ + ("[i for i in n]", r"\mathopen{}\left[ i \mid i \in n \mathclose{}\right]"), ( "[i for i in n if i > 0]", - r"\left[ i \mid" + r"\mathopen{}\left[ i \mid" r" \mathopen{}\left( i \in n \mathclose{}\right)" r" \land \mathopen{}\left( {i > {0}} \mathclose{}\right)" - r" \right]", + r" \mathclose{}\right]", ), ( "[i for i in n if i > 0 if f(i)]", - r"\left[ i \mid" + r"\mathopen{}\left[ i \mid" r" \mathopen{}\left( i \in n \mathclose{}\right)" r" \land \mathopen{}\left( {i > {0}} \mathclose{}\right)" - r" \land \mathopen{}\left( f\mathopen{}\left(" - r"i\mathclose{}\right) \mathclose{}\right)" - r" \right]", + r" \land \mathopen{}\left( f \mathopen{}\left(" + r" i \mathclose{}\right) \mathclose{}\right)" + r" \mathclose{}\right]", + ), + ( + "[i for k in n for i in k]", + r"\mathopen{}\left[ i \mid k \in n, i \in k" r" \mathclose{}\right]", ), - ("[i for k in n for i in k]", r"\left[ i \mid k \in n, i \in k" r" \right]"), ( "[i for k in n for i in k if i > 0]", - r"\left[ i \mid" + r"\mathopen{}\left[ i \mid" r" k \in n," r" \mathopen{}\left( i \in k \mathclose{}\right)" r" \land \mathopen{}\left( {i > {0}} \mathclose{}\right)" - r" \right]", + r" \mathclose{}\right]", ), ( "[i for k in n if f(k) for i in k if i > 0]", - r"\left[ i \mid" + r"\mathopen{}\left[ i \mid" r" \mathopen{}\left( k \in n \mathclose{}\right)" - r" \land \mathopen{}\left( f\mathopen{}\left(" - r"k\mathclose{}\right) \mathclose{}\right)," + r" \land \mathopen{}\left( f \mathopen{}\left(" + r" k \mathclose{}\right) \mathclose{}\right)," r" \mathopen{}\left( i \in k \mathclose{}\right)" r" \land \mathopen{}\left( {i > {0}} \mathclose{}\right)" - r" \right]", + r" \mathclose{}\right]", ), ], ) @@ -124,41 +173,44 @@ def test_visit_listcomp(code: str, latex: str) -> None: @pytest.mark.parametrize( "code,latex", [ - ("{i for i in n}", r"\left\{ i \mid i \in n \right\}"), + ("{i for i in n}", r"\mathopen{}\left\{ i \mid i \in n \mathclose{}\right\}"), ( "{i for i in n if i > 0}", - r"\left\{ i \mid" + r"\mathopen{}\left\{ i \mid" r" \mathopen{}\left( i \in n \mathclose{}\right)" r" \land \mathopen{}\left( {i > {0}} \mathclose{}\right)" - r" \right\}", + r" \mathclose{}\right\}", ), ( "{i for i in n if i > 0 if f(i)}", - r"\left\{ i \mid" + r"\mathopen{}\left\{ i \mid" r" \mathopen{}\left( i \in n \mathclose{}\right)" r" \land \mathopen{}\left( {i > {0}} \mathclose{}\right)" - r" \land \mathopen{}\left( f\mathopen{}\left(" - r"i\mathclose{}\right) \mathclose{}\right)" - r" \right\}", + r" \land \mathopen{}\left( f \mathopen{}\left(" + r" i \mathclose{}\right) \mathclose{}\right)" + r" \mathclose{}\right\}", + ), + ( + "{i for k in n for i in k}", + r"\mathopen{}\left\{ i \mid k \in n, i \in k" r" \mathclose{}\right\}", ), - ("{i for k in n for i in k}", r"\left\{ i \mid k \in n, i \in k" r" \right\}"), ( "{i for k in n for i in k if i > 0}", - r"\left\{ i \mid" + r"\mathopen{}\left\{ i \mid" r" k \in n," r" \mathopen{}\left( i \in k \mathclose{}\right)" r" \land \mathopen{}\left( {i > {0}} \mathclose{}\right)" - r" \right\}", + r" \mathclose{}\right\}", ), ( "{i for k in n if f(k) for i in k if i > 0}", - r"\left\{ i \mid" + r"\mathopen{}\left\{ i \mid" r" \mathopen{}\left( k \in n \mathclose{}\right)" - r" \land \mathopen{}\left( f\mathopen{}\left(" - r"k\mathclose{}\right) \mathclose{}\right)," + r" \land \mathopen{}\left( f \mathopen{}\left(" + r" k \mathclose{}\right) \mathclose{}\right)," r" \mathopen{}\left( i \in k \mathclose{}\right)" r" \land \mathopen{}\left( {i > {0}} \mathclose{}\right)" - r" \right\}", + r" \mathclose{}\right\}", ), ], ) @@ -168,30 +220,87 @@ def test_visit_setcomp(code: str, latex: str) -> None: assert FunctionCodegen().visit(node) == latex +@pytest.mark.parametrize( + "code,latex", + [ + ("foo(x)", r"\mathrm{foo} \mathopen{}\left( x \mathclose{}\right)"), + ("f(x)", r"f \mathopen{}\left( x \mathclose{}\right)"), + ("f(-x)", r"f \mathopen{}\left( -x \mathclose{}\right)"), + ("f(x + y)", r"f \mathopen{}\left( x + y \mathclose{}\right)"), + ( + "f(f(x))", + r"f \mathopen{}\left(" + r" f \mathopen{}\left( x \mathclose{}\right)" + r" \mathclose{}\right)", + ), + ("f(sqrt(x))", r"f \mathopen{}\left( \sqrt{ x } \mathclose{}\right)"), + ("f(sin(x))", r"f \mathopen{}\left( \sin x \mathclose{}\right)"), + ("f(factorial(x))", r"f \mathopen{}\left( x ! \mathclose{}\right)"), + ("f(x, y)", r"f \mathopen{}\left( x, y \mathclose{}\right)"), + ("sqrt(x)", r"\sqrt{ x }"), + ("sqrt(-x)", r"\sqrt{ -x }"), + ("sqrt(x + y)", r"\sqrt{ x + y }"), + ("sqrt(f(x))", r"\sqrt{ f \mathopen{}\left( x \mathclose{}\right) }"), + ("sqrt(sqrt(x))", r"\sqrt{ \sqrt{ x } }"), + ("sqrt(sin(x))", r"\sqrt{ \sin x }"), + ("sqrt(factorial(x))", r"\sqrt{ x ! }"), + ("sin(x)", r"\sin x"), + ("sin(-x)", r"\sin \mathopen{}\left( -x \mathclose{}\right)"), + ("sin(x + y)", r"\sin \mathopen{}\left( x + y \mathclose{}\right)"), + ("sin(f(x))", r"\sin f \mathopen{}\left( x \mathclose{}\right)"), + ("sin(sqrt(x))", r"\sin \sqrt{ x }"), + ("sin(sin(x))", r"\sin \sin x"), + ("sin(factorial(x))", r"\sin \mathopen{}\left( x ! \mathclose{}\right)"), + ("factorial(x)", r"x !"), + ("factorial(-x)", r"\mathopen{}\left( -x \mathclose{}\right) !"), + ("factorial(x + y)", r"\mathopen{}\left( x + y \mathclose{}\right) !"), + ( + "factorial(f(x))", + r"\mathopen{}\left(" + r" f \mathopen{}\left( x \mathclose{}\right)" + r" \mathclose{}\right) !", + ), + ("factorial(sqrt(x))", r"\mathopen{}\left( \sqrt{ x } \mathclose{}\right) !"), + ("factorial(sin(x))", r"\mathopen{}\left( \sin x \mathclose{}\right) !"), + ("factorial(factorial(x))", r"\mathopen{}\left( x ! \mathclose{}\right) !"), + ], +) +def test_visit_call(code: str, latex: str) -> None: + node = ast_utils.parse_expr(code) + assert isinstance(node, ast.Call) + assert FunctionCodegen().visit(node) == latex + + @pytest.mark.parametrize( "src_suffix,dest_suffix", [ # No comprehension - ("(x)", r" \left({x}\right)"), - ("([1, 2])", r" \left({\left[ {1}\space,\space {2}\right] }\right)"), - ("({1, 2})", r" \left({\left\{ {1}\space,\space {2}\right\} }\right)"), - ("(f(x))", r" \left({f\mathopen{}\left(x\mathclose{}\right)}\right)"), + ("(x)", r" x"), + ( + "([1, 2])", + r" \mathopen{}\left[ {1}, {2} \mathclose{}\right]", + ), + ( + "({1, 2})", + r" \mathopen{}\left\{ {1}, {2} \mathclose{}\right\}", + ), + ("(f(x))", r" f \mathopen{}\left( x \mathclose{}\right)"), # Single comprehension ("(i for i in x)", r"_{i \in x}^{} \mathopen{}\left({i}\mathclose{}\right)"), ( "(i for i in [1, 2])", - r"_{i \in \left[ {1}\space,\space {2}\right] }^{} " + r"_{i \in \mathopen{}\left[ {1}, {2} \mathclose{}\right]}^{} " r"\mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in {1, 2})", - r"_{i \in \left\{ {1}\space,\space {2}\right\} }^{} " - r"\mathopen{}\left({i}\mathclose{}\right)", + r"_{i \in \mathopen{}\left\{ {1}, {2} \mathclose{}\right\}}^{}" + r" \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in f(x))", - r"_{i \in f\mathopen{}\left(x\mathclose{}\right)}^{} " - r"\mathopen{}\left({i}\mathclose{}\right)", + r"_{i \in f \mathopen{}\left( x \mathclose{}\right)}^{}" + r" \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(n))", @@ -215,13 +324,13 @@ def test_visit_setcomp(code: str, latex: str) -> None: ), ( "(i for i in range(n, m, k))", - r"_{i \in \mathrm{range}\mathopen{}\left(n, m, k" - r"\mathclose{}\right)}^{} \mathopen{}\left({i}\mathclose{}\right)", + r"_{i \in \mathrm{range} \mathopen{}\left( n, m, k \mathclose{}\right)}^{}" + r" \mathopen{}\left({i}\mathclose{}\right)", ), ], ) def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: - for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]: + for src_fn, dest_fn in [("fsum", r"\sum"), ("sum", r"\sum"), ("prod", r"\prod")]: node = ast_utils.parse_expr(src_fn + src_suffix) assert isinstance(node, ast.Call) assert FunctionCodegen().visit(node) == dest_fn + dest_suffix @@ -289,10 +398,10 @@ def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> No ), ( "(i for i in x if i < y if f(i))", - r"_{\mathopen{}\left( i \in x \mathclose{}\right) " - r"\land \mathopen{}\left( {i < y} \mathclose{}\right)" - r" \land \mathopen{}\left( f\mathopen{}\left(" - r"i\mathclose{}\right) \mathclose{}\right)}^{}" + r"_{\mathopen{}\left( i \in x \mathclose{}\right)" + r" \land \mathopen{}\left( {i < y} \mathclose{}\right)" + r" \land \mathopen{}\left( f \mathopen{}\left(" + r" i \mathclose{}\right) \mathclose{}\right)}^{}" r" \mathopen{}\left({i}\mathclose{}\right)", ), ], @@ -469,14 +578,21 @@ def test_if_then_else(code: str, latex: str) -> None: # is_wrapped ("(x // y)**z", r"\left\lfloor\frac{x}{y}\right\rfloor^{z}"), # With Call - ("x**f(y)", r"x^{f\mathopen{}\left(y\mathclose{}\right)}"), - ("f(x)**y", r"f\mathopen{}\left(x\mathclose{}\right)^{y}"), - ("x * f(y)", r"x f\mathopen{}\left(y\mathclose{}\right)"), - ("f(x) * y", r"f\mathopen{}\left(x\mathclose{}\right) y"), - ("x / f(y)", r"\frac{x}{f\mathopen{}\left(y\mathclose{}\right)}"), - ("f(x) / y", r"\frac{f\mathopen{}\left(x\mathclose{}\right)}{y}"), - ("x + f(y)", r"x + f\mathopen{}\left(y\mathclose{}\right)"), - ("f(x) + y", r"f\mathopen{}\left(x\mathclose{}\right) + y"), + ("x**f(y)", r"x^{f \mathopen{}\left( y \mathclose{}\right)}"), + ( + "f(x)**y", + r"\mathopen{}\left(" + r" f \mathopen{}\left( x \mathclose{}\right)" + r" \mathclose{}\right)^{y}", + ), + ("x * f(y)", r"x f \mathopen{}\left( y \mathclose{}\right)"), + ("f(x) * y", r"f \mathopen{}\left( x \mathclose{}\right) y"), + ("x / f(y)", r"\frac{x}{f \mathopen{}\left( y \mathclose{}\right)}"), + ("f(x) / y", r"\frac{f \mathopen{}\left( x \mathclose{}\right)}{y}"), + ("x + f(y)", r"x + f \mathopen{}\left( y \mathclose{}\right)"), + ("f(x) + y", r"f \mathopen{}\left( x \mathclose{}\right) + y"), + # With is_wrapped Call + ("sqrt(x) ** y", r"\sqrt{ x }^{y}"), # With UnaryOp ("x**-y", r"x^{-y}"), ("(-x)**y", r"\mathopen{}\left( -x \mathclose{}\right)^{y}"), @@ -521,10 +637,10 @@ def test_visit_binop(code: str, latex: str) -> None: ("~x", r"\mathord{\sim} x"), ("not x", r"\lnot x"), # With Call - ("+f(x)", r"+f\mathopen{}\left(x\mathclose{}\right)"), - ("-f(x)", r"-f\mathopen{}\left(x\mathclose{}\right)"), - ("~f(x)", r"\mathord{\sim} f\mathopen{}\left(x\mathclose{}\right)"), - ("not f(x)", r"\lnot f\mathopen{}\left(x\mathclose{}\right)"), + ("+f(x)", r"+f \mathopen{}\left( x \mathclose{}\right)"), + ("-f(x)", r"-f \mathopen{}\left( x \mathclose{}\right)"), + ("~f(x)", r"\mathord{\sim} f \mathopen{}\left( x \mathclose{}\right)"), + ("not f(x)", r"\lnot f \mathopen{}\left( x \mathclose{}\right)"), # With BinOp ("+(x + y)", r"+\mathopen{}\left( x + y \mathclose{}\right)"), ("-(x + y)", r"-\mathopen{}\left( x + y \mathclose{}\right)"), @@ -584,8 +700,8 @@ def test_visit_unaryop(code: str, latex: str) -> None: ("a <= b < c", r"{a \le b < c}"), ("a <= b <= c", r"{a \le b \le c}"), # With Call - ("a == f(b)", r"{a = f\mathopen{}\left(b\mathclose{}\right)}"), - ("f(a) == b", r"{f\mathopen{}\left(a\mathclose{}\right) = b}"), + ("a == f(b)", r"{a = f \mathopen{}\left( b \mathclose{}\right)}"), + ("f(a) == b", r"{f \mathopen{}\left( a \mathclose{}\right) = b}"), # With BinOp ("a == b + c", r"{a = b + c}"), ("a + b == c", r"{a + b = c}"), @@ -624,10 +740,10 @@ def test_visit_compare(code: str, latex: str) -> None: r"{a \land \mathopen{}\left( {b \lor c} \mathclose{}\right)}", ), # With Call - ("a and f(b)", r"{a \land f\mathopen{}\left(b\mathclose{}\right)}"), - ("f(a) and b", r"{f\mathopen{}\left(a\mathclose{}\right) \land b}"), - ("a or f(b)", r"{a \lor f\mathopen{}\left(b\mathclose{}\right)}"), - ("f(a) or b", r"{f\mathopen{}\left(a\mathclose{}\right) \lor b}"), + ("a and f(b)", r"{a \land f \mathopen{}\left( b \mathclose{}\right)}"), + ("f(a) and b", r"{f \mathopen{}\left( a \mathclose{}\right) \land b}"), + ("a or f(b)", r"{a \lor f \mathopen{}\left( b \mathclose{}\right)}"), + ("f(a) or b", r"{f \mathopen{}\left( a \mathclose{}\right) \lor b}"), # With BinOp ("a and b + c", r"{a \land b + c}"), ("a + b and c", r"{a + b \land c}"), @@ -707,7 +823,7 @@ def test_visit_constant(code: str, latex: str) -> None: ("x[0][1]", "{x_{{0}, {1}}}"), ("x[0][1][2]", "{x_{{0}, {1}, {2}}}"), ("x[foo]", r"{x_{\mathrm{foo}}}"), - ("x[floor(x)]", r"{x_{\left\lfloor{x}\right\rfloor}}"), + ("x[floor(x)]", r"{x_{\mathopen{}\left\lfloor x \mathclose{}\right\rfloor}}"), ], ) def test_visit_subscript(code: str, latex: str) -> None: @@ -749,37 +865,40 @@ def test_use_set_symbols_compare(code: str, latex: str) -> None: @pytest.mark.parametrize( "code,latex", [ - ("array(1)", r"\mathrm{array}\mathopen{}\left({1}\mathclose{}\right)"), + ("array(1)", r"\mathrm{array} \mathopen{}\left( {1} \mathclose{}\right)"), ( "array([])", - r"\mathrm{array}\mathopen{}\left(\left[ \right] \mathclose{}\right)", + r"\mathrm{array} \mathopen{}\left(" + r" \mathopen{}\left[ \mathclose{}\right]" + r" \mathclose{}\right)", ), ("array([1])", r"\begin{bmatrix} {1} \end{bmatrix}"), ("array([1, 2, 3])", r"\begin{bmatrix} {1} & {2} & {3} \end{bmatrix}"), ( "array([[]])", - r"\mathrm{array}\mathopen{}\left(" - r"\left[ \left[ \right] \right] " - r"\mathclose{}\right)", + r"\mathrm{array} \mathopen{}\left(" + r" \mathopen{}\left[ \mathopen{}\left[" + r" \mathclose{}\right] \mathclose{}\right]" + r" \mathclose{}\right)", ), ("array([[1]])", r"\begin{bmatrix} {1} \end{bmatrix}"), ("array([[1], [2], [3]])", r"\begin{bmatrix} {1} \\ {2} \\ {3} \end{bmatrix}"), ( "array([[1], [2], [3, 4]])", - r"\mathrm{array}\mathopen{}\left(" - r"\left[ " - r"\left[ {1}\right] \space,\space " - r"\left[ {2}\right] \space,\space " - r"\left[ {3}\space,\space {4}\right] " - r"\right] " - r"\mathclose{}\right)", + r"\mathrm{array} \mathopen{}\left(" + r" \mathopen{}\left[" + r" \mathopen{}\left[ {1} \mathclose{}\right]," + r" \mathopen{}\left[ {2} \mathclose{}\right]," + r" \mathopen{}\left[ {3}, {4} \mathclose{}\right]" + r" \mathclose{}\right]" + r" \mathclose{}\right)", ), ( "array([[1, 2], [3, 4], [5, 6]])", r"\begin{bmatrix} {1} & {2} \\ {3} & {4} \\ {5} & {6} \end{bmatrix}", ), # Only checks two cases for ndarray. - ("ndarray(1)", r"\mathrm{ndarray}\mathopen{}\left({1}\mathclose{}\right)"), + ("ndarray(1)", r"\mathrm{ndarray} \mathopen{}\left( {1} \mathclose{}\right)"), ("ndarray([1])", r"\begin{bmatrix} {1} \end{bmatrix}"), ], ) diff --git a/src/latexify/constants.py b/src/latexify/constants.py index 7e9e926..4887925 100644 --- a/src/latexify/constants.py +++ b/src/latexify/constants.py @@ -2,6 +2,7 @@ from __future__ import annotations +import dataclasses import enum @@ -42,37 +43,63 @@ class BuiltinFnName(str, enum.Enum): SUM = "sum" -BUILTIN_FUNCS: dict[BuiltinFnName, tuple[str, str]] = { - BuiltinFnName.ABS: (r"\left|{", r"}\right|"), - BuiltinFnName.ACOS: (r"\arccos{\left({", r"}\right)}"), - BuiltinFnName.ACOSH: (r"\mathrm{arccosh}{\left({", r"}\right)}"), - BuiltinFnName.ARCCOS: (r"\arccos{\left({", r"}\right)}"), - BuiltinFnName.ARCCOSH: (r"\mathrm{arccosh}{\left({", r"}\right)}"), - BuiltinFnName.ARCSIN: (r"\arcsin{\left({", r"}\right)}"), - BuiltinFnName.ARCSINH: (r"\mathrm{arcsinh}{\left({", r"}\right)}"), - BuiltinFnName.ARCTAN: (r"\arctan{\left({", r"}\right)}"), - BuiltinFnName.ARCTANH: (r"\mathrm{arctanh}{\left({", r"}\right)}"), - BuiltinFnName.ASIN: (r"\arcsin{\left({", r"}\right)}"), - BuiltinFnName.ASINH: (r"\mathrm{arcsinh}{\left({", r"}\right)}"), - BuiltinFnName.ATAN: (r"\arctan{\left({", r"}\right)}"), - BuiltinFnName.ATANH: (r"\mathrm{arctanh}{\left({", r"}\right)}"), - BuiltinFnName.CEIL: (r"\left\lceil{", r"}\right\rceil"), - BuiltinFnName.COS: (r"\cos{\left({", r"}\right)}"), - BuiltinFnName.COSH: (r"\cosh{\left({", r"}\right)}"), - BuiltinFnName.EXP: (r"\exp{\left({", r"}\right)}"), - BuiltinFnName.FABS: (r"\left|{", r"}\right|"), - BuiltinFnName.FACTORIAL: (r"\left({", r"}\right)!"), - BuiltinFnName.FLOOR: (r"\left\lfloor{", r"}\right\rfloor"), - BuiltinFnName.FSUM: (r"\sum\left({", r"}\right)"), - BuiltinFnName.GAMMA: (r"\Gamma\left({", r"}\right)"), - BuiltinFnName.LOG: (r"\log{\left({", r"}\right)}"), - BuiltinFnName.LOG10: (r"\log_{10}{\left({", r"}\right)}"), - BuiltinFnName.LOG2: (r"\log_{2}{\left({", r"}\right)}"), - BuiltinFnName.PROD: (r"\prod \left({", r"}\right)"), - BuiltinFnName.SIN: (r"\sin{\left({", r"}\right)}"), - BuiltinFnName.SINH: (r"\sinh{\left({", r"}\right)}"), - BuiltinFnName.SQRT: (r"\sqrt{", "}"), - BuiltinFnName.TAN: (r"\tan{\left({", r"}\right)}"), - BuiltinFnName.TANH: (r"\tanh{\left({", r"}\right)}"), - BuiltinFnName.SUM: (r"\sum \left({", r"}\right)"), +@dataclasses.dataclass(frozen=True) +class FunctionRule: + """Codegen rules for functions. + + Attributes: + left: LaTeX expression concatenated to the left-hand side of the arguments. + right: LaTeX expression concatenated to the right-hand side of the arguments. + is_unary: Whether the function is treated as a unary operator or not. + is_wrapped: Whether the resulting syntax is wrapped by brackets or not. + """ + + left: str + right: str = "" + is_unary: bool = False + is_wrapped: bool = False + + +# name => left_syntax, right_syntax, is_wrapped +BUILTIN_FUNCS: dict[BuiltinFnName, FunctionRule] = { + BuiltinFnName.ABS: FunctionRule( + r"\mathropen{}\left|", r"\mathclose{}\right|", is_wrapped=True + ), + BuiltinFnName.ACOS: FunctionRule(r"\arccos", is_unary=True), + BuiltinFnName.ACOSH: FunctionRule(r"\mathrm{arccosh}", is_unary=True), + BuiltinFnName.ARCCOS: FunctionRule(r"\arccos", is_unary=True), + BuiltinFnName.ARCCOSH: FunctionRule(r"\mathrm{arccosh}", is_unary=True), + BuiltinFnName.ARCSIN: FunctionRule(r"\arcsin", is_unary=True), + BuiltinFnName.ARCSINH: FunctionRule(r"\mathrm{arcsinh}", is_unary=True), + BuiltinFnName.ARCTAN: FunctionRule(r"\arctan", is_unary=True), + BuiltinFnName.ARCTANH: FunctionRule(r"\mathrm{arctanh}", is_unary=True), + BuiltinFnName.ASIN: FunctionRule(r"\arcsin", is_unary=True), + BuiltinFnName.ASINH: FunctionRule(r"\mathrm{arcsinh}", is_unary=True), + BuiltinFnName.ATAN: FunctionRule(r"\arctan", is_unary=True), + BuiltinFnName.ATANH: FunctionRule(r"\mathrm{arctanh}", is_unary=True), + BuiltinFnName.CEIL: FunctionRule( + r"\mathopen{}\left\lceil", r"\mathclose{}\right\rceil", is_wrapped=True + ), + BuiltinFnName.COS: FunctionRule(r"\cos", is_unary=True), + BuiltinFnName.COSH: FunctionRule(r"\cosh", is_unary=True), + BuiltinFnName.EXP: FunctionRule(r"\exp", is_unary=True), + BuiltinFnName.FABS: FunctionRule( + r"\mathopen{}\left|", r"\mathclose{}\right|", is_wrapped=True + ), + BuiltinFnName.FACTORIAL: FunctionRule("", "!", is_unary=True), + BuiltinFnName.FLOOR: FunctionRule( + r"\mathopen{}\left\lfloor", r"\mathclose{}\right\rfloor", is_wrapped=True + ), + BuiltinFnName.FSUM: FunctionRule(r"\sum", is_unary=True), + BuiltinFnName.GAMMA: FunctionRule(r"\Gamma"), + BuiltinFnName.LOG: FunctionRule(r"\log", is_unary=True), + BuiltinFnName.LOG10: FunctionRule(r"\log_{10}", is_unary=True), + BuiltinFnName.LOG2: FunctionRule(r"\log_{2}", is_unary=True), + BuiltinFnName.PROD: FunctionRule(r"\prod", is_unary=True), + BuiltinFnName.SIN: FunctionRule(r"\sin", is_unary=True), + BuiltinFnName.SINH: FunctionRule(r"\sinh", is_unary=True), + BuiltinFnName.SQRT: FunctionRule(r"\sqrt{", "}", is_wrapped=True), + BuiltinFnName.SUM: FunctionRule(r"\sum", is_unary=True), + BuiltinFnName.TAN: FunctionRule(r"\tan", is_unary=True), + BuiltinFnName.TANH: FunctionRule(r"\tanh", is_unary=True), } diff --git a/src/latexify/transformers/function_expander.py b/src/latexify/transformers/function_expander.py index d085be2..ccaa440 100644 --- a/src/latexify/transformers/function_expander.py +++ b/src/latexify/transformers/function_expander.py @@ -36,7 +36,17 @@ def visit_Call(self, node: ast.Call) -> ast.AST: ): return _FUNCTION_EXPANDERS[func_name](self, node) - return node + kwargs = { + "func": self.visit(node.func), + "args": [self.visit(x) for x in node.args], + } + + if hasattr(node, "keywords"): + kwargs["keywords"] = [ + ast.keyword(arg=x.arg, value=self.visit(x.value)) for x in node.keywords + ] + + return ast.Call(**kwargs) def _atan2_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST: @@ -86,7 +96,7 @@ def _expm1_expander(function_expander: FunctionExpander, node: ast.Call) -> ast. def _hypot_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST: - if len(node.args) == 0: + if not node.args: return ast_utils.make_constant(0) args = [ diff --git a/src/latexify/transformers/function_expander_test.py b/src/latexify/transformers/function_expander_test.py index deba8dd..a193245 100644 --- a/src/latexify/transformers/function_expander_test.py +++ b/src/latexify/transformers/function_expander_test.py @@ -3,218 +3,239 @@ from __future__ import annotations import ast -import math -from latexify import ast_utils, constants, parser, test_utils +from latexify import ast_utils, test_utils from latexify.transformers.function_expander import FunctionExpander -def _make_ast(args: list[str], body: ast.expr) -> ast.Module: - """Helper function to generate an AST for f(x). - - Args: - args: The arguments passed to the method. - body: The body of the return statement. - - Returns: - Generated AST. - """ - return ast.Module( - body=[ - ast.FunctionDef( - name="f", - args=ast.arguments( - args=[ast.arg(arg=arg) for arg in args], - kwonlyargs=[], - kw_defaults=[], - defaults=[], - ), - body=[ast.Return(body)], - decorator_list=[], - ) - ], +def test_preserve_keywords() -> None: + tree = ast.Call( + func=ast_utils.make_name("f"), + args=[ast_utils.make_name("x")], + keywords=[ast.keyword(arg="y", value=ast_utils.make_constant(0))], + ) + expected = ast.Call( + func=ast_utils.make_name("f"), + args=[ast_utils.make_name("x")], + keywords=[ast.keyword(arg="y", value=ast_utils.make_constant(0))], ) + transformed = FunctionExpander(set()).visit(tree) + test_utils.assert_ast_equal(transformed, expected) + +def test_exp() -> None: + tree = ast.Call( + func=ast_utils.make_name("exp"), + args=[ast_utils.make_name("x")], + ) + expected = ast.BinOp( + left=ast_utils.make_name("e"), + op=ast.Pow(), + right=ast_utils.make_name("x"), + ) + transformed = FunctionExpander({"exp"}).visit(tree) + test_utils.assert_ast_equal(transformed, expected) -def test_atan2_expanded() -> None: - def f(x, y): - return math.atan2(y, x) - expected = _make_ast( - ["x", "y"], - ast.Call( - func=ast.Name(id="atan", ctx=ast.Load()), - args=[ - ast.BinOp( - left=ast.Name(id="y", ctx=ast.Load()), - op=ast.Div(), - right=ast.Name(id="x", ctx=ast.Load()), - ) - ], - ), +def test_exp_unchanged() -> None: + tree = ast.Call( + func=ast_utils.make_name("exp"), + args=[ast_utils.make_name("x")], + ) + expected = ast.Call( + func=ast_utils.make_name("exp"), + args=[ast_utils.make_name("x")], ) - transformed = FunctionExpander({"atan2"}).visit(parser.parse_function(f)) + transformed = FunctionExpander(set()).visit(tree) test_utils.assert_ast_equal(transformed, expected) -def test_exp_expanded() -> None: - def f(x): - return math.exp(x) +def test_exp_with_attribute() -> None: + tree = ast.Call( + func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"), + args=[ast_utils.make_name("x")], + ) + expected = ast.BinOp( + left=ast_utils.make_name("e"), + op=ast.Pow(), + right=ast_utils.make_name("x"), + ) + transformed2 = FunctionExpander({"exp"}).visit(tree) + test_utils.assert_ast_equal(transformed2, expected) - expected = _make_ast( - ["x"], - ast.BinOp( - left=ast.Name(id="e", ctx=ast.Load()), - op=ast.Pow(), - right=ast.Name(id="x", ctx=ast.Load()), - ), + +def test_exp_unchanged_with_attribute() -> None: + tree = ast.Call( + func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"), + args=[ast_utils.make_name("x")], ) - transformed = FunctionExpander({"exp"}).visit(parser.parse_function(f)) + expected = ast.Call( + func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"), + args=[ast_utils.make_name("x")], + ) + transformed = FunctionExpander(set()).visit(tree) test_utils.assert_ast_equal(transformed, expected) -def test_exp2_expanded() -> None: - def f(x): - return math.exp2(x) - - expected = _make_ast( - ["x"], - ast.BinOp( - left=ast_utils.make_constant(2), +def test_exp_nested1() -> None: + tree = ast.Call( + func=ast_utils.make_name("exp"), + args=[ + ast.Call( + func=ast_utils.make_name("exp"), + args=[ast_utils.make_name("x")], + ) + ], + ) + expected = ast.BinOp( + left=ast_utils.make_name("e"), + op=ast.Pow(), + right=ast.BinOp( + left=ast_utils.make_name("e"), op=ast.Pow(), - right=ast.Name(id="x", ctx=ast.Load()), + right=ast_utils.make_name("x"), ), ) - transformed = FunctionExpander({"exp2"}).visit(parser.parse_function(f)) + transformed = FunctionExpander({"exp"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) -def test_expm1_expanded() -> None: - def f(x): - return math.expm1(x) - - expected = _make_ast( - ["x"], - ast.BinOp( - left=ast.Call( - func=ast.Name(id=constants.BuiltinFnName.EXP.value, ctx=ast.Load()), - args=[ast.Name(id="x", ctx=ast.Load())], - ), - op=ast.Sub(), - right=ast_utils.make_constant(1), - ), +def test_exp_nested2() -> None: + tree = ast.Call( + func=ast_utils.make_name("f"), + args=[ + ast.Call( + func=ast_utils.make_name("exp"), + args=[ast_utils.make_name("x")], + ) + ], ) - transformed = FunctionExpander({"expm1"}).visit(parser.parse_function(f)) + expected = ast.Call( + func=ast_utils.make_name("f"), + args=[ + ast.BinOp( + left=ast_utils.make_name("e"), + op=ast.Pow(), + right=ast_utils.make_name("x"), + ) + ], + ) + transformed = FunctionExpander({"exp"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) -def test_hypot_unchanged_without_attribute_access() -> None: - from math import hypot +def test_atan2() -> None: + tree = ast.Call( + func=ast_utils.make_name("atan2"), + args=[ast_utils.make_name("y"), ast_utils.make_name("x")], + ) + expected = ast.Call( + func=ast_utils.make_name("atan"), + args=[ + ast.BinOp( + left=ast_utils.make_name("y"), + op=ast.Div(), + right=ast_utils.make_name("x"), + ) + ], + ) + transformed = FunctionExpander({"atan2"}).visit(tree) + test_utils.assert_ast_equal(transformed, expected) - def f(x, y): - return hypot(x, y) - expected = _make_ast( - ["x", "y"], - ast.Call( - func=ast.Name(id="hypot"), - args=[ast.Name(id="x", ctx=ast.Load()), ast.Name(id="y", ctx=ast.Load())], - ), +def test_exp2() -> None: + tree = ast.Call( + func=ast_utils.make_name("exp2"), + args=[ast_utils.make_name("x")], ) - transformed = FunctionExpander(set()).visit(parser.parse_function(f)) + expected = ast.BinOp( + left=ast_utils.make_constant(2), + op=ast.Pow(), + right=ast_utils.make_name("x"), + ) + transformed = FunctionExpander({"exp2"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) -def test_hypot_unchanged() -> None: - def f(x, y): - return math.hypot(x, y) - - expected = _make_ast( - ["x", "y"], - ast.Call( - func=ast.Attribute( - ast.Name(id="math", ctx=ast.Load()), attr="hypot", ctx=ast.Load() - ), - args=[ast.Name(id="x", ctx=ast.Load()), ast.Name(id="y", ctx=ast.Load())], +def test_expm1() -> None: + tree = ast.Call( + func=ast_utils.make_name("expm1"), + args=[ast_utils.make_name("x")], + ) + expected = ast.BinOp( + left=ast.Call( + func=ast_utils.make_name("exp"), + args=[ast_utils.make_name("x")], ), + op=ast.Sub(), + right=ast_utils.make_constant(1), ) - transformed = FunctionExpander(set()).visit(parser.parse_function(f)) + transformed = FunctionExpander({"expm1"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) -def test_hypot_expanded() -> None: - def f(x, y): - return math.hypot(x, y) - - expected = _make_ast( - ["x", "y"], - ast.Call( - func=ast.Name(id="sqrt", ctx=ast.Load()), - args=[ - ast.BinOp( - left=ast.BinOp( - left=ast.Name(id="x", ctx=ast.Load()), - op=ast.Pow(), - right=ast_utils.make_constant(2), - ), - op=ast.Add(), - right=ast.BinOp( - left=ast.Name(id="y", ctx=ast.Load()), - op=ast.Pow(), - right=ast_utils.make_constant(2), - ), - ) - ], - ), +def test_hypot() -> None: + tree = ast.Call( + func=ast_utils.make_name("hypot"), + args=[ast_utils.make_name("x"), ast_utils.make_name("y")], + ) + expected = ast.Call( + func=ast_utils.make_name("sqrt"), + args=[ + ast.BinOp( + left=ast.BinOp( + left=ast_utils.make_name("x"), + op=ast.Pow(), + right=ast_utils.make_constant(2), + ), + op=ast.Add(), + right=ast.BinOp( + left=ast_utils.make_name("y"), + op=ast.Pow(), + right=ast_utils.make_constant(2), + ), + ) + ], ) - transformed = FunctionExpander({"hypot"}).visit(parser.parse_function(f)) + transformed = FunctionExpander({"hypot"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) -def test_hypot_expanded_no_args() -> None: - def f(): - return math.hypot() - - expected = _make_ast( - [], - ast_utils.make_constant(0), - ) - transformed = FunctionExpander({"hypot"}).visit(parser.parse_function(f)) +def test_hypot_no_args() -> None: + tree = ast.Call(func=ast_utils.make_name("hypot"), args=[]) + expected = ast_utils.make_constant(0) + transformed = FunctionExpander({"hypot"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) -def test_log1p_expanded() -> None: - def f(x): - return math.log1p(x) - - expected = _make_ast( - ["x"], - ast.Call( - func=ast.Name(id=constants.BuiltinFnName.LOG.value, ctx=ast.Load()), - args=[ - ast.BinOp( - left=ast_utils.make_constant(1), - op=ast.Add(), - right=ast.Name(id="x", ctx=ast.Load()), - ) - ], - ), +def test_log1p() -> None: + tree = ast.Call( + func=ast_utils.make_name("log1p"), + args=[ast_utils.make_name("x")], ) - transformed = FunctionExpander({"log1p"}).visit(parser.parse_function(f)) + expected = ast.Call( + func=ast_utils.make_name("log"), + args=[ + ast.BinOp( + left=ast_utils.make_constant(1), + op=ast.Add(), + right=ast_utils.make_name("x"), + ) + ], + ) + transformed = FunctionExpander({"log1p"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) -def test_pow_expanded() -> None: - def f(x, y): - return math.pow(x, y) - - expected = _make_ast( - ["x", "y"], - ast.BinOp( - left=ast.Name(id="x", ctx=ast.Load()), - op=ast.Pow(), - right=ast.Name(id="y", ctx=ast.Load()), - ), +def test_pow() -> None: + tree = ast.Call( + func=ast_utils.make_name("pow"), + args=[ast_utils.make_name("x"), ast_utils.make_name("y")], + ) + expected = ast.BinOp( + left=ast_utils.make_name("x"), + op=ast.Pow(), + right=ast_utils.make_name("y"), ) - transformed = FunctionExpander({"pow"}).visit(parser.parse_function(f)) + transformed = FunctionExpander({"pow"}).visit(tree) test_utils.assert_ast_equal(transformed, expected)