Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prefix trimming as a node transformer #137

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/integration_tests/regression_test.py
Original file line number Diff line number Diff line change
@@ -259,3 +259,27 @@ def solve(x):

latex = r"\mathrm{solve}(x) = x"
utils.check_function(solve, latex)


def test_prefix_trimmed() -> None:
import math as prefix

def solve(x):
return prefix.sin(x)

latex = r"\mathrm{solve}(x) = \sin{\left({x}\right)}"
utils.check_function(solve, latex, prefixes={"prefix"})


def test_complex_prefix_trimmed() -> None:
class multiple:
prefixes = math

def solve(x):
return multiple.prefixes.sin(x)

latex = (
r"\mathrm{solve}(x) = "
r"\mathrm{prefixes.sin}\mathopen{}\left(x\mathclose{}\right)"
)
utils.check_function(solve, latex, prefixes={"multiple"})
7 changes: 0 additions & 7 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
@@ -346,13 +346,6 @@ def visit_Call(self, node: ast.Call) -> str:
# Function signature (possibly an expression).
func_str = self.visit(node.func)

# Removes common prefixes: math.sqrt -> sqrt
# TODO(odashi): This process can be implemented as a NodeTransformer.
for prefix in constants.PREFIXES:
if func_str.startswith(f"{prefix}."):
func_str = func_str[len(prefix) + 1 :]
break

# Obtains wrapper syntax: sqrt -> "\sqrt{" and "}"
lstr, rstr = constants.BUILTIN_FUNCS.get(
func_str,
14 changes: 7 additions & 7 deletions src/latexify/codegen/function_codegen_test.py
Original file line number Diff line number Diff line change
@@ -221,7 +221,7 @@ def test_visit_setcomp(code: str, latex: str) -> None:
],
)
def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
for src_fn, dest_fn in [("sum", r"\sum"), ("math.prod", r"\prod")]:
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
node = ast.parse(src_fn + src_suffix).body[0].value
assert isinstance(node, ast.Call)
assert FunctionCodegen().visit(node) == dest_fn + dest_suffix
@@ -243,12 +243,12 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
),
# 3 clauses
(
"math.prod(i for y in x for i in y)",
"prod(i for y in x for i in y)",
r"\prod_{y \in x}^{} \prod_{i \in y}^{} "
r"\mathopen{}\left({i}\mathclose{}\right)",
),
(
"math.prod(i for y in x for z in y for i in z)",
"prod(i for y in x for z in y for i in z)",
r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} "
r"\mathopen{}\left({i}\mathclose{}\right)",
),
@@ -258,7 +258,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
r"\sum_{i = {0}}^{{n}} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"math.prod(i for i in range(n-1))",
"prod(i for i in range(n-1))",
r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)",
),
# reduce stop parameter
@@ -267,7 +267,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
r"\sum_{i = {0}}^{{n}} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"math.prod(i for i in range(n-1))",
"prod(i for i in range(n-1))",
r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)",
),
],
@@ -298,7 +298,7 @@ def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> No
],
)
def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
for src_fn, dest_fn in [("sum", r"\sum"), ("math.prod", r"\prod")]:
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
node = ast.parse(src_fn + src_suffix).body[0].value
assert isinstance(node, ast.Call)
assert FunctionCodegen().visit(node) == dest_fn + dest_suffix
@@ -707,7 +707,7 @@ def test_visit_constant(code: str, latex: str) -> None:
("x[0][1]", "{x_{{0}, {1}}}"),
("x[0][1][2]", "{x_{{0}, {1}, {2}}}"),
("x[foo]", "{x_{foo}}"),
("x[math.floor(x)]", r"{x_{\left\lfloor{x}\right\rfloor}}"),
("x[floor(x)]", r"{x_{\left\lfloor{x}\right\rfloor}}"),
],
)
def test_visit_subscript(code: str, latex: str) -> None:
6 changes: 6 additions & 0 deletions src/latexify/config.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,8 @@
import dataclasses
from typing import Any

from latexify import constants


@dataclasses.dataclass(frozen=True)
class Config:
@@ -17,6 +19,8 @@ class Config:
and corresponding values are the replacements.
Both keys and values have to represent valid Python identifiers:
^[A-Za-z_][A-Za-z0-9_]*$
prefixes: If set, the names of prefixes to trim. Defaults to a set of commonly
used modules.
reduce_assignments: If True, assignment statements are used to synthesize
the final expression.
use_math_symbols: Whether to convert identifiers with a math symbol surface
@@ -30,6 +34,7 @@ class Config:

expand_functions: set[str] | None
identifiers: dict[str, str] | None
prefixes: set[str]
reduce_assignments: bool
use_math_symbols: bool
use_raw_function_name: bool
@@ -70,6 +75,7 @@ def defaults() -> Config:
return Config(
expand_functions=None,
identifiers=None,
prefixes=constants.PREFIXES,
reduce_assignments=False,
use_math_symbols=False,
use_raw_function_name=False,
2 changes: 1 addition & 1 deletion src/latexify/constants.py
Original file line number Diff line number Diff line change
@@ -4,7 +4,7 @@

import enum

PREFIXES = ["math", "numpy", "np"]
PREFIXES = {"math", "numpy", "np"}


class BuiltinFnName(str, enum.Enum):
4 changes: 3 additions & 1 deletion src/latexify/frontend.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ def get_latex(
by users.
Returns:
Generatee LaTeX description.
Generated LaTeX description.
Raises:
latexify.exceptions.LatexifyError: Something went wrong during conversion.
@@ -43,6 +43,8 @@ def get_latex(
tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree)
if merged_config.reduce_assignments:
tree = transformers.AssignmentReducer().visit(tree)
if merged_config.prefixes:
tree = transformers.PrefixTrimmer(merged_config.prefixes).visit(tree)
if merged_config.expand_functions is not None:
tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree)

2 changes: 2 additions & 0 deletions src/latexify/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -4,8 +4,10 @@
assignment_reducer,
function_expander,
identifier_replacer,
prefix_trimmer,
)

AssignmentReducer = assignment_reducer.AssignmentReducer
FunctionExpander = function_expander.FunctionExpander
IdentifierReplacer = identifier_replacer.IdentifierReplacer
PrefixTrimmer = prefix_trimmer.PrefixTrimmer
31 changes: 31 additions & 0 deletions src/latexify/transformers/prefix_trimmer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from __future__ import annotations

import ast


class PrefixTrimmer(ast.NodeTransformer):
"""NodeTransformer to trim function prefixes.
Example:
def f(x, y):
return math.hypot(x, y)
PrefixTrimmer({"math"}) will modify the AST of the function above to below:
def f(x, y):
return hypot(x, y)
"""

def __init__(self, prefixes: set[str]) -> None:
self._prefixes = prefixes

def visit_Attribute(self, node: ast.Attribute) -> ast.AST:
"""Visitor of Attribute nodes."""
if issubclass(node.value.__class__, ast.Name):
if node.value.id in self._prefixes:
return ast.Name(id=node.attr, ctx=node.ctx)
if issubclass(node.value.__class__, ast.Attribute):
kwargs = node.__dict__
kwargs["value"] = self.visit_Attribute(node.value)
return ast.Attribute(**kwargs)
return node
48 changes: 48 additions & 0 deletions src/latexify/transformers/prefix_trimmer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import ast
from latexify import test_utils
from latexify.transformers.prefix_trimmer import PrefixTrimmer


def test_not_trimmed():
prefix = ast.Attribute(
value=ast.Name(id="math", ctx=ast.Load()), attr="sin", ctx=ast.Load()
)
trimmed_prefix = PrefixTrimmer(set()).visit_Attribute(prefix)

test_utils.assert_ast_equal(
trimmed_prefix,
ast.Attribute(
value=ast.Name(id="math", ctx=ast.Load()), attr="sin", ctx=ast.Load()
),
)


def test_trim_basic_prefix():
prefix = ast.Attribute(
value=ast.Name(id="math", ctx=ast.Load()), attr="sin", ctx=ast.Load()
)
trimmed_prefix = PrefixTrimmer({"math"}).visit_Attribute(prefix)

test_utils.assert_ast_equal(trimmed_prefix, ast.Name(id="sin", ctx=ast.Load()))


def test_trim_complex_prefix():
prefix = ast.Attribute(
value=ast.Attribute(
value=ast.Name(id="multiple", ctx=ast.Load()),
attr="prefixes",
ctx=ast.Load(),
),
attr="sin",
ctx=ast.Load(),
)
trimmed_prefix = PrefixTrimmer({"multiple"}).visit_Attribute(prefix)

test_utils.assert_ast_equal(
trimmed_prefix,
ast.Attribute(
value=ast.Name(id="prefixes", ctx=ast.Load()),
attr="sin",
ctx=ast.Load(),
),
)