diff --git a/src/integration_tests/regression_test.py b/src/integration_tests/regression_test.py index 4f76a78..d6541b6 100644 --- a/src/integration_tests/regression_test.py +++ b/src/integration_tests/regression_test.py @@ -24,7 +24,7 @@ def solve(a, b, c): solve_latex = ( - r"\mathrm{solve}(a, b, c) \triangleq " r"\frac{-b + \sqrt{b^{2} - 4ac}}{2a}" + r"\mathrm{solve}(a, b, c) \triangleq " r"\frac{-b + \sqrt{b^{{2}} - {4}ac}}{{2}a}" ) @@ -37,8 +37,8 @@ def sinc(x): sinc_latex = ( - r"\mathrm{sinc}(x) \triangleq \left\{ \begin{array}{ll} 1, & \mathrm{if} \ " - r"{x = 0} \\ \frac{\sin{\left({x}\right)}}{x}, & \mathrm{otherwise} \end{array}" + r"\mathrm{sinc}(x) \triangleq \left\{ \begin{array}{ll} {1}, & \mathrm{if} \ " + r"{x = {0}} \\ \frac{\sin{\left({x}\right)}}{x}, & \mathrm{otherwise} \end{array}" r" \right." ) @@ -56,7 +56,7 @@ def sum_with_limit(n): sum_with_limit_latex = ( - r"\mathrm{sum_with_limit}(n) \triangleq \sum_{i=0}^{n-1} \left({i^{2}}\right)" + r"\mathrm{sum_with_limit}(n) \triangleq \sum_{i=0}^{n-1} \left({i^{{2}}}\right)" ) @@ -66,7 +66,7 @@ def sum_with_limit_two_args(a, n): sum_with_limit_two_args_latex = ( r"\mathrm{sum_with_limit_two_args}(a, n) " - r"\triangleq \sum_{i=a}^{n-1} \left({i^{2}}\right)" + r"\triangleq \sum_{i=a}^{n-1} \left({i^{{2}}}\right)" ) @@ -96,7 +96,7 @@ def test_nested_function(): def nested(x): return 3 * x - assert get_latex(nested) == r"\mathrm{nested}(x) \triangleq 3x" + assert get_latex(nested) == r"\mathrm{nested}(x) \triangleq {3}x" def test_double_nested_function(): @@ -113,14 +113,14 @@ def test_use_raw_function_name(): def foo_bar(): return 42 - assert str(with_latex(foo_bar)) == r"\mathrm{foo_bar}() \triangleq 42" + assert str(with_latex(foo_bar)) == r"\mathrm{foo_bar}() \triangleq {42}" assert ( str(with_latex(foo_bar, use_raw_function_name=True)) - == r"\mathrm{foo\_bar}() \triangleq 42" + == r"\mathrm{foo\_bar}() \triangleq {42}" ) assert ( str(with_latex(use_raw_function_name=True)(foo_bar)) - == r"\mathrm{foo\_bar}() \triangleq 42" + == r"\mathrm{foo\_bar}() \triangleq {42}" ) @@ -129,9 +129,9 @@ def f(x): a = x + x return 3 * a - assert str(with_latex(f)) == r"a \triangleq x + x \\ \mathrm{f}(x) \triangleq 3a" + assert str(with_latex(f)) == r"a \triangleq x + x \\ \mathrm{f}(x) \triangleq {3}a" - latex_with_option = r"\mathrm{f}(x) \triangleq 3\left( x + x \right)" + latex_with_option = r"\mathrm{f}(x) \triangleq {3}\left( x + x \right)" assert str(with_latex(f, reduce_assignments=True)) == latex_with_option assert str(with_latex(reduce_assignments=True)(f)) == latex_with_option @@ -143,12 +143,12 @@ def f(x): return 3 * b assert str(with_latex(f)) == ( - r"a \triangleq x^{2} \\ b \triangleq a + a \\ \mathrm{f}(x) \triangleq 3b" + r"a \triangleq x^{{2}} \\ b \triangleq a + a \\ \mathrm{f}(x) \triangleq {3}b" ) latex_with_option = ( r"\mathrm{f}(x) \triangleq " - r"3\left( \left( x^{2} \right) + \left( x^{2} \right) \right)" + r"{3}\left( \left( x^{{2}} \right) + \left( x^{{2}} \right) \right)" ) assert str(with_latex(f, reduce_assignments=True)) == latex_with_option assert str(with_latex(reduce_assignments=True)(f)) == latex_with_option diff --git a/src/latexify/exceptions.py b/src/latexify/exceptions.py index 89174af..6601f29 100644 --- a/src/latexify/exceptions.py +++ b/src/latexify/exceptions.py @@ -4,12 +4,12 @@ class LatexifyError(Exception): """Base class of all Latexify exceptions. - Subclasses of this exception does not mean incorrect use of the library by the user, - but informs users that Latexify went into something wrong during compiling the given - functions. - These functions are usually captured by the frontend functions (e.g., `with_latex`) + Subclasses of this exception does not mean incorrect use of the library by the user + at the interface level. These exceptions inform users that Latexify went into + something wrong during processing the given functions. + These exceptions are usually captured by the frontend functions (e.g., `with_latex`) to prevent destroying the entire program. - Errors caused by the wrong inputs should raise built-in exceptions. + Errors caused by wrong inputs should raise built-in exceptions. """ ... diff --git a/src/latexify/frontend.py b/src/latexify/frontend.py index e92ef5a..477d314 100644 --- a/src/latexify/frontend.py +++ b/src/latexify/frontend.py @@ -2,16 +2,13 @@ from __future__ import annotations -import ast from collections.abc import Callable -import inspect -import textwrap from typing import Any -import dill - -from latexify import exceptions, latexify_visitor -from latexify.transformers.identifier_replacer import IdentifierReplacer +from latexify import exceptions +from latexify import latexify_visitor +from latexify import parser +from latexify import transformers def get_latex( @@ -44,19 +41,10 @@ def get_latex( Raises: latexify.exceptions.LatexifyError: Something went wrong during conversion. """ - try: - source = inspect.getsource(fn) - except Exception: - # Maybe running on console. - source = dill.source.getsource(fn) - - # Remove extra indentation so that ast.parse runs correctly. - source = textwrap.dedent(source) - - tree = ast.parse(source) + tree = parser.parse_function(fn) if identifiers is not None: - tree = IdentifierReplacer(identifiers).visit(tree) + tree = transformers.IdentifierReplacer(identifiers).visit(tree) visitor = latexify_visitor.LatexifyVisitor( use_math_symbols=use_math_symbols, diff --git a/src/latexify/frontend_test.py b/src/latexify/frontend_test.py index 91a9d30..7ecf196 100644 --- a/src/latexify/frontend_test.py +++ b/src/latexify/frontend_test.py @@ -15,5 +15,5 @@ def very_long_name_function(very_long_name_variable): } assert frontend.get_latex(very_long_name_function, identifiers=identifiers) == ( - r"\mathrm{f}(x) \triangleq 3x" + r"\mathrm{f}(x) \triangleq {3}x" ) diff --git a/src/latexify/latexify_visitor.py b/src/latexify/latexify_visitor.py index 21cfad5..e217383 100644 --- a/src/latexify/latexify_visitor.py +++ b/src/latexify/latexify_visitor.py @@ -3,7 +3,7 @@ from __future__ import annotations import ast -from typing import ClassVar +from typing import Any, ClassVar from latexify import constants from latexify import math_symbols @@ -51,10 +51,10 @@ def generic_visit(self, node, action) -> str: f"Unsupported AST: {type(node).__name__}" ) - def visit_Module(self, node, action): # pylint: disable=invalid-name + def visit_Module(self, node, action): return self.visit(node.body[0], "multi_lines") - def visit_FunctionDef(self, node, action): # pylint: disable=invalid-name + def visit_FunctionDef(self, node, action): name_str = str(node.name) if self._use_raw_function_name: name_str = name_str.replace(r"_", r"\_") @@ -108,22 +108,22 @@ def visit_Assign(self, node, action): else: return rf"{node.targets[0].id} \triangleq {var} \\ " - def visit_Return(self, node, action): # pylint: disable=invalid-name + def visit_Return(self, node, action): return self.visit(node.value) - def visit_Tuple(self, node, action): # pylint: disable=invalid-name + def visit_Tuple(self, node, action): elts = [self.visit(i) for i in node.elts] return r"\left( " + r"\space,\space ".join(elts) + r"\right) " - def visit_List(self, node, action): # pylint: disable=invalid-name + def visit_List(self, node, action): elts = [self.visit(i) for i in node.elts] return r"\left[ " + r"\space,\space ".join(elts) + r"\right] " - def visit_Set(self, node, action): # pylint: disable=invalid-name + def visit_Set(self, node, action): elts = [self.visit(i) for i in node.elts] return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} " - def visit_Call(self, node, action): # pylint: disable=invalid-name + def visit_Call(self, node, action): """Visit a call node.""" def _decorated_lstr_and_arg(node, callee_str, lstr): @@ -171,26 +171,65 @@ def _decorated_lstr_and_arg(node, callee_str, lstr): lstr, arg_str = _decorated_lstr_and_arg(node, callee_str, lstr) return lstr + arg_str + rstr - def visit_Attribute(self, node, action): # pylint: disable=invalid-name + def visit_Attribute(self, node, action): vstr = self.visit(node.value) astr = str(node.attr) return vstr + "." + astr - def visit_Name(self, node, action): # pylint: disable=invalid-name + def visit_Name(self, node, action): if self._reduce_assignments and node.id in self.assign_var.keys(): return self.assign_var[node.id] return self._math_symbol_converter.convert(str(node.id)) - def visit_Constant(self, node, action): # pylint: disable=invalid-name - # for python >= 3.8 - return str(node.n) + def convert_constant(self, value: Any) -> str: + """Helper to convert constant values to LaTeX. - def visit_Num(self, node, action): # pylint: disable=invalid-name - # for python < 3.8 - return str(node.n) + Args: + value: A constant value. + + Returns: + The LaTeX representation of `value`. + """ + if value is None or isinstance(value, bool): + return r"\mathrm{" + str(value) + "}" + if isinstance(value, (int, float, complex)): + return "{" + str(value) + "}" + if isinstance(value, str): + return r'\textrm{"' + value + '"}' + if isinstance(value, bytes): + return r"\textrm{" + str(value) + "}" + if value is ...: + return r"{\cdots}" + raise exceptions.LatexifyNotSupportedError( + f"Unrecognized constant: {type(value).__name__}" + ) + + # From Python 3.8 + def visit_Constant(self, node: ast.Constant, action) -> str: + return self.convert_constant(node.value) + + # Until Python 3.7 + def visit_Num(self, node: ast.Num, action) -> str: + return self.convert_constant(node.n) + + # Until Python 3.7 + def visit_Str(self, node: ast.Str, action) -> str: + return self.convert_constant(node.s) + + # Until Python 3.7 + def visit_Bytes(self, node: ast.Bytes, action) -> str: + return self.convert_constant(node.s) + + # Until Python 3.7 + def visit_NameConstant(self, node: ast.NameConstant, action) -> str: + return self.convert_constant(node.value) + + # Until Python 3.7 + def visit_Ellipsis(self, node: ast.Ellipsis, action) -> str: + return self.convert_constant(...) - def visit_UnaryOp(self, node, action): # pylint: disable=invalid-name + def visit_UnaryOp(self, node, action): """Visit a unary op node.""" def _wrap(child): @@ -211,7 +250,7 @@ def _wrap(child): return reprs[type(node.op)]() return r"\mathrm{unknown\_uniop}(" + self.visit(node.operand) + ")" - def visit_BinOp(self, node, action): # pylint: disable=invalid-name + def visit_BinOp(self, node, action): """Visit a binary op node.""" priority = constants.BIN_OP_PRIORITY @@ -263,7 +302,7 @@ def _wrap(child): ast.NotIn: r"\notin", } - def visit_Compare(self, node: ast.Compare, action): # pylint: disable=invalid-name + def visit_Compare(self, node: ast.Compare, action): """Visit a compare node.""" lhs = self.visit(node.left) ops = [self._compare_ops[type(x)] for x in node.ops] @@ -276,13 +315,13 @@ def visit_Compare(self, node: ast.Compare, action): # pylint: disable=invalid-n ast.Or: r"\lor", } - def visit_BoolOp(self, node: ast.BoolOp, action): # pylint: disable=invalid-name + def visit_BoolOp(self, node: ast.BoolOp, action): """Visit a BoolOp node.""" values = [rf"\left( {self.visit(x)} \right)" for x in node.values] op = f" {self._bool_ops[type(node.op)]} " return "{" + op.join(values) + "}" - def visit_If(self, node, action): # pylint: disable=invalid-name + def visit_If(self, node, action): """Visit an if node.""" latex = r"\left\{ \begin{array}{ll} " @@ -295,7 +334,7 @@ def visit_If(self, node, action): # pylint: disable=invalid-name latex += self.visit(node) return latex + r", & \mathrm{otherwise} \end{array} \right." - def visit_GeneratorExp_set_bounds(self, node): # pylint: disable=invalid-name + def visit_GeneratorExp_set_bounds(self, node): output = self.visit(node.elt) comprehensions = [ self.visit(generator, "set_bounds") for generator in node.generators @@ -344,7 +383,7 @@ def visit_Subscript(self, node: ast.Subscript, action) -> str: return f"{{{value}_{indices_str}}}" - def visit_comprehension_set_bounds(self, node): # pylint: disable=invalid-name + def visit_comprehension_set_bounds(self, node): """Visit a comprehension node, which represents a for clause""" var = self.visit(node.target) if isinstance(node.iter, ast.Call): diff --git a/src/latexify/latexify_visitor_test.py b/src/latexify/latexify_visitor_test.py index ca5a287..258c94f 100644 --- a/src/latexify/latexify_visitor_test.py +++ b/src/latexify/latexify_visitor_test.py @@ -1,10 +1,14 @@ """Tests for latexify.latexify_visitor.""" +from __future__ import annotations + import ast from latexify import exceptions +from latexify import test_utils + import pytest -from latexify.latexify_visitor import LatexifyVisitor +from latexify import latexify_visitor def test_generic_visit() -> None: @@ -15,7 +19,7 @@ class UnknownNode(ast.AST): exceptions.LatexifyNotSupportedError, match=r"^Unsupported AST: UnknownNode$", ): - LatexifyVisitor().visit(UnknownNode()) + latexify_visitor.LatexifyVisitor().visit(UnknownNode()) @pytest.mark.parametrize( @@ -55,7 +59,7 @@ class UnknownNode(ast.AST): def test_visit_compare(code: str, latex: str) -> None: tree = ast.parse(code).body[0].value assert isinstance(tree, ast.Compare) - assert LatexifyVisitor().visit(tree) == latex + assert latexify_visitor.LatexifyVisitor().visit(tree) == latex @pytest.mark.parametrize( @@ -76,15 +80,64 @@ def test_visit_compare(code: str, latex: str) -> None: def test_visit_boolop(code: str, latex: str) -> None: tree = ast.parse(code).body[0].value assert isinstance(tree, ast.BoolOp) - assert LatexifyVisitor().visit(tree) == latex + assert latexify_visitor.LatexifyVisitor().visit(tree) == latex + + +@test_utils.require_at_most(7) +@pytest.mark.parametrize( + "code,cls,latex", + [ + ("0", ast.Num, "{0}"), + ("1", ast.Num, "{1}"), + ("0.0", ast.Num, "{0.0}"), + ("1.5", ast.Num, "{1.5}"), + ("0.0j", ast.Num, "{0j}"), + ("1.0j", ast.Num, "{1j}"), + ("1.5j", ast.Num, "{1.5j}"), + ('"abc"', ast.Str, r'\textrm{"abc"}'), + ('b"abc"', ast.Bytes, r"\textrm{b'abc'}"), + ("None", ast.NameConstant, r"\mathrm{None}"), + ("False", ast.NameConstant, r"\mathrm{False}"), + ("True", ast.NameConstant, r"\mathrm{True}"), + ("...", ast.Ellipsis, r"{\cdots}"), + ], +) +def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> None: + tree = ast.parse(code).body[0].value + assert isinstance(tree, cls) + assert latexify_visitor.LatexifyVisitor().visit(tree) == latex + + +@test_utils.require_at_least(8) +@pytest.mark.parametrize( + "code,latex", + [ + ("0", "{0}"), + ("1", "{1}"), + ("0.0", "{0.0}"), + ("1.5", "{1.5}"), + ("0.0j", "{0j}"), + ("1.0j", "{1j}"), + ("1.5j", "{1.5j}"), + ('"abc"', r'\textrm{"abc"}'), + ('b"abc"', r"\textrm{b'abc'}"), + ("None", r"\mathrm{None}"), + ("False", r"\mathrm{False}"), + ("True", r"\mathrm{True}"), + ("...", r"{\cdots}"), + ], +) +def test_visit_constant(code: str, latex: str) -> None: + tree = ast.parse(code).body[0].value + assert isinstance(tree, ast.Constant) @pytest.mark.parametrize( "code,latex", [ - ("x[0]", "{x_{0}}"), - ("x[0][1]", "{x_{0, 1}}"), - ("x[0][1][2]", "{x_{0, 1, 2}}"), + ("x[0]", "{x_{{0}}}"), + ("x[0][1]", "{x_{{0}, {1}}}"), + ("x[0][1][2]", "{x_{{0}, {1}, {2}}}"), ("x[foo]", "{x_{foo}}"), ("x[math.floor(x)]", r"{x_{\left\lfloor{x}\right\rfloor}}"), ], @@ -92,4 +145,4 @@ def test_visit_boolop(code: str, latex: str) -> None: def test_visit_subscript(code: str, latex: str) -> None: tree = ast.parse(code).body[0].value assert isinstance(tree, ast.Subscript) - assert LatexifyVisitor().visit(tree) == latex + assert latexify_visitor.LatexifyVisitor().visit(tree) == latex diff --git a/src/latexify/parser.py b/src/latexify/parser.py new file mode 100644 index 0000000..295e103 --- /dev/null +++ b/src/latexify/parser.py @@ -0,0 +1,38 @@ +"""Parsing utilities.""" + +from __future__ import annotations + +from collections.abc import Callable +import ast +import inspect +import textwrap +from typing import Any + +import dill + +from latexify import exceptions + + +def parse_function(fn: Callable[..., Any]) -> ast.FunctionDef: + """Parses given function. + + Args: + fn: Target function. + + Returns: + AST tree representing `fn`. + """ + try: + source = inspect.getsource(fn) + except Exception: + # Maybe running on console. + source = dill.source.getsource(fn) + + # Remove extra indentation so that ast.parse runs correctly. + source = textwrap.dedent(source) + + tree = ast.parse(source) + if not tree.body or not isinstance(tree.body[0], ast.FunctionDef): + raise exceptions.LatexifySyntaxError("Not a function.") + + return tree diff --git a/src/latexify/parser_test.py b/src/latexify/parser_test.py new file mode 100644 index 0000000..201acaa --- /dev/null +++ b/src/latexify/parser_test.py @@ -0,0 +1,38 @@ +"""Tests for latexify.parser.""" + +from __future__ import annotations + +import ast + +import pytest + +from latexify import exceptions, parser +from latexify import test_utils + + +def test_parse_function_with_posonlyargs() -> None: + def f(x): + return x + + expected = ast.Module( + body=[ + ast.FunctionDef( + name="f", + args=ast.arguments( + args=[ast.arg(arg="x")], + ), + body=[ast.Return(value=ast.Name(id="x", ctx=ast.Load()))], + ) + ], + ) + + obtained = parser.parse_function(f) + test_utils.assert_ast_equal(obtained, expected) + + +def test_parse_function_with_lambda() -> None: + with pytest.raises(exceptions.LatexifySyntaxError, match=r"^Not a function\.$"): + parser.parse_function(lambda: ()) + with pytest.raises(exceptions.LatexifySyntaxError, match=r"^Not a function\.$"): + x = lambda: () # noqa: E731 + parser.parse_function(x) diff --git a/src/latexify/test_utils.py b/src/latexify/test_utils.py index b45da9b..89b7344 100644 --- a/src/latexify/test_utils.py +++ b/src/latexify/test_utils.py @@ -3,57 +3,137 @@ from __future__ import annotations import ast +from collections.abc import Callable +import functools +import sys from typing import cast -def ast_equal(tree1: ast.AST, tree2: ast.AST) -> bool: +def require_at_least( + minor: int, +) -> Callable[[Callable[..., None]], Callable[..., None]]: + """Require the minimum minor version of Python 3 to run the test. + + Args: + minor: Minimum minor version (inclusive) that the test case supports. + + Returns: + A decorator function to wrap the test case function. + """ + + def decorator(fn: Callable[..., None]) -> Callable[..., None]: + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if sys.version_info.minor < minor: + return + fn(*args, **kwargs) + + return wrapper + + return decorator + + +def require_at_most( + minor: int, +) -> Callable[[Callable[..., None]], Callable[..., None]]: + """Require the maximum minor version of Python 3 to run the test. + + Args: + minor: Maximum minor version (inclusive) that the test case supports. + + Returns: + A decorator function to wrap the test case function. + """ + + def decorator(fn: Callable[..., None]) -> Callable[..., None]: + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if sys.version_info.minor > minor: + return + fn(*args, **kwargs) + + return wrapper + + return decorator + + +def ast_equal(observed: ast.AST, expected: ast.AST) -> bool: """Checks the equality between two ASTs. + This function checks if `ovserved` contains at least the same subtree with + `expected`. If `ovserved` has some extra branches that `expected` does not cover, + it is ignored. + Args: - tree1: An AST to compare. - tree2: Another AST. + observed: An AST to check. + expected: The expected AST. Returns: - True if tree1 and tree2 represent the same AST, False otherwise. + True if observed and expected represent the same AST, False otherwise. """ try: - assert type(tree1) is type(tree2) + assert type(observed) is type(expected) - for k, v1 in vars(tree1).items(): - v2 = getattr(tree2, k) + for k, ve in vars(expected).items(): + vo = getattr(observed, k) # May cause AttributeError. - if isinstance(v1, ast.AST): - assert ast_equal(v1, cast(ast.AST, v2)) - elif isinstance(v1, list): - v2 = cast(list, v2) - assert len(v1) == len(v2) + if isinstance(ve, ast.AST): + assert ast_equal(cast(ast.AST, vo), ve) + elif isinstance(ve, list): + vo = cast(list, vo) + assert len(vo) == len(ve) assert all( - ast_equal(cast(ast.AST, c1), cast(ast.AST, c2)) - for c1, c2 in zip(v1, v2) + ast_equal(cast(ast.AST, co), cast(ast.AST, ce)) + for co, ce in zip(vo, ve) ) else: - assert v1 == v2 + assert vo == ve - except AssertionError: + except (AssertionError, AttributeError): + raise return False return True -def assert_ast_equal(tree1: ast.AST, tree2: ast.AST) -> None: +def assert_ast_equal(observed: ast.AST, expected: ast.AST) -> None: """Asserts the equality between two ASTs. Args: - tree1: An AST to compare. - tree2: Another AST. + observed: An AST to compare. + expected: Another AST. Raises: - AssertionError: tree1 and tree2 represent different ASTs. + AssertionError: observed and expected represent different ASTs. """ - assert ast_equal( - tree1, tree2 - ), f"""\ + if sys.version_info.minor >= 9: + assert ast_equal( + observed, expected + ), f"""\ +AST does not match. +observed={ast.dump(observed, indent=4)} +expected={ast.dump(expected, indent=4)} +""" + else: + assert ast_equal( + observed, expected + ), f"""\ AST does not match. -tree1={ast.dump(tree1, indent=4)} -tree2={ast.dump(tree2, indent=4)} +observed={ast.dump(observed)} +expected={ast.dump(expected)} """ + + +def make_num(value: int) -> ast.expr: + """Helper function to generate a node for number. + + Args: + value: The value of the node. + + Returns: + Generated AST. + """ + if sys.version_info.minor < 8: + return ast.Num(n=value) + else: + return ast.Constant(value=value) diff --git a/src/latexify/transformers/__init__.py b/src/latexify/transformers/__init__.py index 005a0d2..6a4fc25 100644 --- a/src/latexify/transformers/__init__.py +++ b/src/latexify/transformers/__init__.py @@ -1 +1,6 @@ """Package latexify.transformers.""" + +from latexify.transformers import identifier_replacer + + +IdentifierReplacer = identifier_replacer.IdentifierReplacer diff --git a/src/latexify/transformers/identifier_replacer_test.py b/src/latexify/transformers/identifier_replacer_test.py index f555b92..75c16e0 100644 --- a/src/latexify/transformers/identifier_replacer_test.py +++ b/src/latexify/transformers/identifier_replacer_test.py @@ -1,5 +1,7 @@ """Tests for latexify.transformer.identifier_replacer.""" +from __future__ import annotations + import ast import pytest @@ -31,11 +33,53 @@ def test_name_not_replaced() -> None: test_utils.assert_ast_equal(transformed, expected) +@test_utils.require_at_most(7) def test_functiondef() -> None: + # Subtree of: + # @d + # def f(y=b, *, z=c): + # pass + source = ast.FunctionDef( + name="f", + args=ast.arguments( + args=[ast.arg(arg="y")], + kwonlyargs=[ast.arg(arg="z")], + kw_defaults=[ast.Name(id="c", ctx=ast.Load())], + defaults=[ + ast.Name(id="a", ctx=ast.Load()), + ast.Name(id="b", ctx=ast.Load()), + ], + ), + body=[ast.Pass()], + decorator_list=[ast.Name(id="d", ctx=ast.Load())], + ) + + expected = ast.FunctionDef( + name="F", + args=ast.arguments( + args=[ast.arg(arg="Y")], + kwonlyargs=[ast.arg(arg="Z")], + kw_defaults=[ast.Name(id="C", ctx=ast.Load())], + defaults=[ + ast.Name(id="A", ctx=ast.Load()), + ast.Name(id="B", ctx=ast.Load()), + ], + ), + body=[ast.Pass()], + decorator_list=[ast.Name(id="D", ctx=ast.Load())], + ) + + mapping = {x: x.upper() for x in "abcdfyz"} + transformed = IdentifierReplacer(mapping).visit(source) + test_utils.assert_ast_equal(transformed, expected) + + +@test_utils.require_at_least(8) +def test_functiondef_with_posonlyargs() -> None: # Subtree of: # @d # def f(x=a, /, y=b, *, z=c): - # ... + # pass source = ast.FunctionDef( name="f", args=ast.arguments( @@ -48,7 +92,7 @@ def test_functiondef() -> None: ast.Name(id="b", ctx=ast.Load()), ], ), - body=[ast.Ellipsis()], + body=[ast.Pass()], decorator_list=[ast.Name(id="d", ctx=ast.Load())], ) @@ -64,7 +108,7 @@ def test_functiondef() -> None: ast.Name(id="B", ctx=ast.Load()), ], ), - body=[ast.Ellipsis()], + body=[ast.Pass()], decorator_list=[ast.Name(id="D", ctx=ast.Load())], )