From 465d7d0eb87d5c7c00a0d696f349939ee9dceebe Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sun, 27 Nov 2022 12:08:49 +0000 Subject: [PATCH 1/4] impl --- src/integration_tests/regression_test.py | 21 +++++++++ src/latexify/config.py | 7 ++- src/latexify/constants.py | 2 +- src/latexify/frontend.py | 8 ++-- src/latexify/transformers/__init__.py | 2 + src/latexify/transformers/prefix_trimmer.py | 30 ++++++++++++ .../transformers/prefix_trimmer_test.py | 46 +++++++++++++++++++ 7 files changed, 110 insertions(+), 6 deletions(-) create mode 100644 src/latexify/transformers/prefix_trimmer.py create mode 100644 src/latexify/transformers/prefix_trimmer_test.py diff --git a/src/integration_tests/regression_test.py b/src/integration_tests/regression_test.py index fdbd080..a77be7d 100644 --- a/src/integration_tests/regression_test.py +++ b/src/integration_tests/regression_test.py @@ -330,3 +330,24 @@ def solve(x): latex = r"\mathrm{solve}(x) = x" _check_function(solve, latex) + + +def test_prefix_trimmed() -> None: + import math as prefix + + def solve(x): + return prefix.sin(x) + + latex = r"\mathrm{solve}(x) = \sin{\left({x}\right)}" + _check_function(solve, latex, prefixes={"prefix"}) + + +def test_complex_prefix_trimmed() -> None: + class multiple: + prefixes = math + + def solve(x): + return multiple.prefixes.sin(x) + + latex = r"\mathrm{solve}(x) = \mathrm{prefixes.sin}\mathopen{}\left(x\mathclose{}\right)" + _check_function(solve, latex, prefixes={"multiple"}) diff --git a/src/latexify/config.py b/src/latexify/config.py index fd34ce5..187fa41 100644 --- a/src/latexify/config.py +++ b/src/latexify/config.py @@ -3,9 +3,10 @@ from __future__ import annotations import dataclasses - from typing import Any +from latexify import constants + @dataclasses.dataclass(frozen=True) class Config: @@ -18,6 +19,8 @@ class Config: and corresponding values are the replacements. Both keys and values have to represent valid Python identifiers: ^[A-Za-z_][A-Za-z0-9_]*$ + prefixes: If set, the names of prefixes to trim. Defaults to a set of commonly + used modules. reduce_assignments: If True, assignment statements are used to synthesize the final expression. use_math_symbols: Whether to convert identifiers with a math symbol surface @@ -31,6 +34,7 @@ class Config: expand_functions: set[str] | None identifiers: dict[str, str] | None + prefixes: set[str] reduce_assignments: bool use_math_symbols: bool use_raw_function_name: bool @@ -71,6 +75,7 @@ def defaults() -> Config: return Config( expand_functions=None, identifiers=None, + prefixes=constants.PREFIXES, reduce_assignments=False, use_math_symbols=False, use_raw_function_name=False, diff --git a/src/latexify/constants.py b/src/latexify/constants.py index 8ee7dc2..8284d85 100644 --- a/src/latexify/constants.py +++ b/src/latexify/constants.py @@ -4,7 +4,7 @@ import enum -PREFIXES = ["math", "numpy", "np"] +PREFIXES = {"math", "numpy", "np"} class BuiltinFnName(str, enum.Enum): diff --git a/src/latexify/frontend.py b/src/latexify/frontend.py index e00a4eb..ea4d93e 100644 --- a/src/latexify/frontend.py +++ b/src/latexify/frontend.py @@ -7,10 +7,8 @@ from typing import Any from latexify import codegen -from latexify import exceptions -from latexify import parser -from latexify import transformers from latexify import config as cfg +from latexify import exceptions, parser, transformers # TODO(odashi): move expand_functions to Config. @@ -30,7 +28,7 @@ def get_latex( by users. Returns: - Generatee LaTeX description. + Generated LaTeX description. Raises: latexify.exceptions.LatexifyError: Something went wrong during conversion. @@ -45,6 +43,8 @@ def get_latex( tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree) if merged_config.reduce_assignments: tree = transformers.AssignmentReducer().visit(tree) + if merged_config.prefixes: + tree = transformers.PrefixTrimmer(merged_config.prefixes).visit(tree) if merged_config.expand_functions is not None: tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree) diff --git a/src/latexify/transformers/__init__.py b/src/latexify/transformers/__init__.py index 9dc5c37..6adb195 100644 --- a/src/latexify/transformers/__init__.py +++ b/src/latexify/transformers/__init__.py @@ -4,8 +4,10 @@ assignment_reducer, function_expander, identifier_replacer, + prefix_trimmer, ) AssignmentReducer = assignment_reducer.AssignmentReducer FunctionExpander = function_expander.FunctionExpander IdentifierReplacer = identifier_replacer.IdentifierReplacer +PrefixTrimmer = prefix_trimmer.PrefixTrimmer diff --git a/src/latexify/transformers/prefix_trimmer.py b/src/latexify/transformers/prefix_trimmer.py new file mode 100644 index 0000000..259fac7 --- /dev/null +++ b/src/latexify/transformers/prefix_trimmer.py @@ -0,0 +1,30 @@ +import ast + + +class PrefixTrimmer(ast.NodeTransformer): + """NodeTransformer to trim function prefixes. + + Example: + def f(x, y): + return math.hypot(x, y) + + PrefixTrimmer({"math"}) will modify the AST of the function above to below: + + def f(x, y): + return hypot(x**2, y**2) + """ + + def __init__(self, prefixes: set[str]) -> None: + self._prefixes = prefixes + + def visit_Attribute(self, node: ast.Attribute) -> ast.AST: + """Visitor of Attribute nodes.""" + if issubclass(node.value.__class__, ast.Name): + if node.value.id in self._prefixes: + print("!!!!!!!!!!!!!!!") + return ast.Name(id=node.attr, ctx=node.ctx) + if issubclass(node.value.__class__, ast.Attribute): + kwargs = node.__dict__ + kwargs["value"] = self.visit_Attribute(node.value) + return ast.Attribute(**kwargs) + return node diff --git a/src/latexify/transformers/prefix_trimmer_test.py b/src/latexify/transformers/prefix_trimmer_test.py new file mode 100644 index 0000000..74b4318 --- /dev/null +++ b/src/latexify/transformers/prefix_trimmer_test.py @@ -0,0 +1,46 @@ +import ast +from latexify import test_utils +from latexify.transformers.prefix_trimmer import PrefixTrimmer + + +def test_not_trimmed(): + prefix = ast.Attribute( + value=ast.Name(id="math", ctx=ast.Load()), attr="sin", ctx=ast.Load() + ) + trimmed_prefix = PrefixTrimmer(set()).visit_Attribute(prefix) + + test_utils.assert_ast_equal( + trimmed_prefix, + ast.Attribute( + value=ast.Name(id="math", ctx=ast.Load()), attr="sin", ctx=ast.Load() + ), + ) + + +def test_trim_basic_prefix(): + prefix = ast.Attribute( + value=ast.Name(id="math", ctx=ast.Load()), attr="sin", ctx=ast.Load() + ) + trimmed_prefix = PrefixTrimmer({"math"}).visit_Attribute(prefix) + + test_utils.assert_ast_equal(trimmed_prefix, ast.Name(id="sin", ctx=ast.Load())) + + +def test_trim_complex_prefix(): + prefix = ast.Attribute( + value=ast.Attribute( + value=ast.Name(id="multiple", ctx=ast.Load()), attr="prefixes", ctx=ast.Load() + ), + attr="sin", + ctx=ast.Load(), + ) + trimmed_prefix = PrefixTrimmer({"multiple"}).visit_Attribute(prefix) + + test_utils.assert_ast_equal( + trimmed_prefix, + ast.Attribute( + value=ast.Name(id="prefixes", ctx=ast.Load()), + attr="sin", + ctx=ast.Load(), + ), + ) From 7e59cf21d178d31231844132ddec04877602ae31 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sun, 27 Nov 2022 12:10:42 +0000 Subject: [PATCH 2/4] fix breaking tests --- src/latexify/codegen/function_codegen.py | 7 ------- src/latexify/codegen/function_codegen_test.py | 14 +++++++------- 2 files changed, 7 insertions(+), 14 deletions(-) diff --git a/src/latexify/codegen/function_codegen.py b/src/latexify/codegen/function_codegen.py index e660804..fe21204 100644 --- a/src/latexify/codegen/function_codegen.py +++ b/src/latexify/codegen/function_codegen.py @@ -346,13 +346,6 @@ def visit_Call(self, node: ast.Call) -> str: # Function signature (possibly an expression). func_str = self.visit(node.func) - # Removes common prefixes: math.sqrt -> sqrt - # TODO(odashi): This process can be implemented as a NodeTransformer. - for prefix in constants.PREFIXES: - if func_str.startswith(f"{prefix}."): - func_str = func_str[len(prefix) + 1 :] - break - # Obtains wrapper syntax: sqrt -> "\sqrt{" and "}" lstr, rstr = constants.BUILTIN_FUNCS.get( func_str, diff --git a/src/latexify/codegen/function_codegen_test.py b/src/latexify/codegen/function_codegen_test.py index 793ab62..92e6617 100644 --- a/src/latexify/codegen/function_codegen_test.py +++ b/src/latexify/codegen/function_codegen_test.py @@ -221,7 +221,7 @@ def test_visit_setcomp(code: str, latex: str) -> None: ], ) def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: - for src_fn, dest_fn in [("sum", r"\sum"), ("math.prod", r"\prod")]: + for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]: node = ast.parse(src_fn + src_suffix).body[0].value assert isinstance(node, ast.Call) assert FunctionCodegen().visit(node) == dest_fn + dest_suffix @@ -243,12 +243,12 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: ), # 3 clauses ( - "math.prod(i for y in x for i in y)", + "prod(i for y in x for i in y)", r"\prod_{y \in x}^{} \prod_{i \in y}^{} " r"\mathopen{}\left({i}\mathclose{}\right)", ), ( - "math.prod(i for y in x for z in y for i in z)", + "prod(i for y in x for z in y for i in z)", r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} " r"\mathopen{}\left({i}\mathclose{}\right)", ), @@ -258,7 +258,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: r"\sum_{i = {0}}^{{n}} \mathopen{}\left({i}\mathclose{}\right)", ), ( - "math.prod(i for i in range(n-1))", + "prod(i for i in range(n-1))", r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)", ), # reduce stop parameter @@ -267,7 +267,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: r"\sum_{i = {0}}^{{n}} \mathopen{}\left({i}\mathclose{}\right)", ), ( - "math.prod(i for i in range(n-1))", + "prod(i for i in range(n-1))", r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)", ), ], @@ -298,7 +298,7 @@ def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> No ], ) def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None: - for src_fn, dest_fn in [("sum", r"\sum"), ("math.prod", r"\prod")]: + for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]: node = ast.parse(src_fn + src_suffix).body[0].value assert isinstance(node, ast.Call) assert FunctionCodegen().visit(node) == dest_fn + dest_suffix @@ -707,7 +707,7 @@ def test_visit_constant(code: str, latex: str) -> None: ("x[0][1]", "{x_{{0}, {1}}}"), ("x[0][1][2]", "{x_{{0}, {1}, {2}}}"), ("x[foo]", "{x_{foo}}"), - ("x[math.floor(x)]", r"{x_{\left\lfloor{x}\right\rfloor}}"), + ("x[floor(x)]", r"{x_{\left\lfloor{x}\right\rfloor}}"), ], ) def test_visit_subscript(code: str, latex: str) -> None: From acc03c700b46a5252d6b0e82cfd7b8964b3d9543 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sun, 27 Nov 2022 12:16:35 +0000 Subject: [PATCH 3/4] clean up --- src/latexify/transformers/prefix_trimmer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/latexify/transformers/prefix_trimmer.py b/src/latexify/transformers/prefix_trimmer.py index 259fac7..77f5240 100644 --- a/src/latexify/transformers/prefix_trimmer.py +++ b/src/latexify/transformers/prefix_trimmer.py @@ -11,7 +11,7 @@ def f(x, y): PrefixTrimmer({"math"}) will modify the AST of the function above to below: def f(x, y): - return hypot(x**2, y**2) + return hypot(x, y) """ def __init__(self, prefixes: set[str]) -> None: @@ -21,7 +21,6 @@ def visit_Attribute(self, node: ast.Attribute) -> ast.AST: """Visitor of Attribute nodes.""" if issubclass(node.value.__class__, ast.Name): if node.value.id in self._prefixes: - print("!!!!!!!!!!!!!!!") return ast.Name(id=node.attr, ctx=node.ctx) if issubclass(node.value.__class__, ast.Attribute): kwargs = node.__dict__ From 70e0ae31844cd37b6574469fc79e1c26ba4744c5 Mon Sep 17 00:00:00 2001 From: Zibing Zhang Date: Sun, 27 Nov 2022 12:19:36 +0000 Subject: [PATCH 4/4] fix --- src/integration_tests/regression_test.py | 5 ++++- src/latexify/transformers/prefix_trimmer.py | 2 ++ src/latexify/transformers/prefix_trimmer_test.py | 4 +++- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/integration_tests/regression_test.py b/src/integration_tests/regression_test.py index 62eaa18..f217883 100644 --- a/src/integration_tests/regression_test.py +++ b/src/integration_tests/regression_test.py @@ -278,5 +278,8 @@ class multiple: def solve(x): return multiple.prefixes.sin(x) - latex = r"\mathrm{solve}(x) = \mathrm{prefixes.sin}\mathopen{}\left(x\mathclose{}\right)" + latex = ( + r"\mathrm{solve}(x) = " + r"\mathrm{prefixes.sin}\mathopen{}\left(x\mathclose{}\right)" + ) utils.check_function(solve, latex, prefixes={"multiple"}) diff --git a/src/latexify/transformers/prefix_trimmer.py b/src/latexify/transformers/prefix_trimmer.py index 77f5240..8773b9f 100644 --- a/src/latexify/transformers/prefix_trimmer.py +++ b/src/latexify/transformers/prefix_trimmer.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import ast diff --git a/src/latexify/transformers/prefix_trimmer_test.py b/src/latexify/transformers/prefix_trimmer_test.py index 74b4318..d5d03ca 100644 --- a/src/latexify/transformers/prefix_trimmer_test.py +++ b/src/latexify/transformers/prefix_trimmer_test.py @@ -29,7 +29,9 @@ def test_trim_basic_prefix(): def test_trim_complex_prefix(): prefix = ast.Attribute( value=ast.Attribute( - value=ast.Name(id="multiple", ctx=ast.Load()), attr="prefixes", ctx=ast.Load() + value=ast.Name(id="multiple", ctx=ast.Load()), + attr="prefixes", + ctx=ast.Load(), ), attr="sin", ctx=ast.Load(),