Skip to content

Commit 7d8a436

Browse files
author
Yusuke Oda
authored
Prefix trimmer (#138)
* prefix trimmer * add test * add frontend flags
1 parent 4b37366 commit 7d8a436

15 files changed

+276
-27
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,4 @@ jobs:
6868
python -m pip install --upgrade pip
6969
python -m pip install isort
7070
- name: Check
71-
run: python -m isort -v src
71+
run: python -m isort --check src

checks.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ set -eoux pipefail
44
python -m pytest src -vv
55
python -m black --check src
66
python -m pflake8 src
7-
python -m isort -v src
7+
python -m isort --check src

src/integration_tests/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Package integration_tests."""
2+
3+
import pytest
4+
5+
pytest.register_assert_rewrite("integration_tests.utils")

src/latexify/ast_utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,31 @@
77
from typing import Any
88

99

10+
def make_name(id: str) -> ast.Name:
11+
"""Generates a new Name node.
12+
13+
Args:
14+
id: Name of the node.
15+
16+
Returns:
17+
Generated ast.Name.
18+
"""
19+
return ast.Name(id=id, ctx=ast.Load())
20+
21+
22+
def make_attribute(value: ast.Expr, attr: str):
23+
"""Generates a new Attribute node.
24+
25+
Args:
26+
value: Parent value.
27+
attr: Attribute name.
28+
29+
Returns:
30+
Generated ast.Attribute.
31+
"""
32+
return ast.Attribute(value=value, attr=attr, ctx=ast.Load())
33+
34+
1035
def make_constant(value: Any) -> ast.expr:
1136
"""Generates a new Constant node.
1237

src/latexify/ast_utils_test.py

+13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,19 @@
1010
from latexify import ast_utils, test_utils
1111

1212

13+
def test_make_name() -> None:
14+
test_utils.assert_ast_equal(
15+
ast_utils.make_name("foo"), ast.Name(id="foo", ctx=ast.Load())
16+
)
17+
18+
19+
def test_make_attribute() -> None:
20+
test_utils.assert_ast_equal(
21+
ast_utils.make_attribute(ast_utils.make_name("foo"), "bar"),
22+
ast.Attribute(ast.Name(id="foo", ctx=ast.Load()), attr="bar", ctx=ast.Load()),
23+
)
24+
25+
1326
@test_utils.require_at_most(7)
1427
@pytest.mark.parametrize(
1528
"value,expected",

src/latexify/codegen/function_codegen.py

-7
Original file line numberDiff line numberDiff line change
@@ -346,13 +346,6 @@ def visit_Call(self, node: ast.Call) -> str:
346346
# Function signature (possibly an expression).
347347
func_str = self.visit(node.func)
348348

349-
# Removes common prefixes: math.sqrt -> sqrt
350-
# TODO(odashi): This process can be implemented as a NodeTransformer.
351-
for prefix in constants.PREFIXES:
352-
if func_str.startswith(f"{prefix}."):
353-
func_str = func_str[len(prefix) + 1 :]
354-
break
355-
356349
# Obtains wrapper syntax: sqrt -> "\sqrt{" and "}"
357350
lstr, rstr = constants.BUILTIN_FUNCS.get(
358351
func_str,

src/latexify/codegen/function_codegen_test.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def test_visit_setcomp(code: str, latex: str) -> None:
221221
],
222222
)
223223
def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
224-
for src_fn, dest_fn in [("sum", r"\sum"), ("math.prod", r"\prod")]:
224+
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
225225
node = ast.parse(src_fn + src_suffix).body[0].value
226226
assert isinstance(node, ast.Call)
227227
assert FunctionCodegen().visit(node) == dest_fn + dest_suffix
@@ -243,12 +243,12 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
243243
),
244244
# 3 clauses
245245
(
246-
"math.prod(i for y in x for i in y)",
246+
"prod(i for y in x for i in y)",
247247
r"\prod_{y \in x}^{} \prod_{i \in y}^{} "
248248
r"\mathopen{}\left({i}\mathclose{}\right)",
249249
),
250250
(
251-
"math.prod(i for y in x for z in y for i in z)",
251+
"prod(i for y in x for z in y for i in z)",
252252
r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} "
253253
r"\mathopen{}\left({i}\mathclose{}\right)",
254254
),
@@ -258,7 +258,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
258258
r"\sum_{i = {0}}^{{n}} \mathopen{}\left({i}\mathclose{}\right)",
259259
),
260260
(
261-
"math.prod(i for i in range(n-1))",
261+
"prod(i for i in range(n-1))",
262262
r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)",
263263
),
264264
# reduce stop parameter
@@ -267,7 +267,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
267267
r"\sum_{i = {0}}^{{n}} \mathopen{}\left({i}\mathclose{}\right)",
268268
),
269269
(
270-
"math.prod(i for i in range(n-1))",
270+
"prod(i for i in range(n-1))",
271271
r"\prod_{i = {0}}^{{n - {2}}} \mathopen{}\left({i}\mathclose{}\right)",
272272
),
273273
],
@@ -298,7 +298,7 @@ def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> No
298298
],
299299
)
300300
def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
301-
for src_fn, dest_fn in [("sum", r"\sum"), ("math.prod", r"\prod")]:
301+
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
302302
node = ast.parse(src_fn + src_suffix).body[0].value
303303
assert isinstance(node, ast.Call)
304304
assert FunctionCodegen().visit(node) == dest_fn + dest_suffix
@@ -707,7 +707,7 @@ def test_visit_constant(code: str, latex: str) -> None:
707707
("x[0][1]", "{x_{{0}, {1}}}"),
708708
("x[0][1][2]", "{x_{{0}, {1}, {2}}}"),
709709
("x[foo]", "{x_{foo}}"),
710-
("x[math.floor(x)]", r"{x_{\left\lfloor{x}\right\rfloor}}"),
710+
("x[floor(x)]", r"{x_{\left\lfloor{x}\right\rfloor}}"),
711711
],
712712
)
713713
def test_visit_subscript(code: str, latex: str) -> None:

src/latexify/codegen/latex_test.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from __future__ import annotations
44

5-
from latexify.codegen.latex import Latex # Ignores [22-imports] for convenience.
5+
# Ignores [22-imports] for convenience.
6+
from latexify.codegen.latex import Latex
67

78

89
def test_eq() -> None:

src/latexify/config.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ class Config:
1717
and corresponding values are the replacements.
1818
Both keys and values have to represent valid Python identifiers:
1919
^[A-Za-z_][A-Za-z0-9_]*$
20+
prefixes: Prefixes of identifiers to trim. E.g., if "foo.bar" in prefixes, all
21+
identifiers with the form "foo.bar.suffix" will be replaced to "suffix"
2022
reduce_assignments: If True, assignment statements are used to synthesize
2123
the final expression.
2224
use_math_symbols: Whether to convert identifiers with a math symbol surface
@@ -30,6 +32,7 @@ class Config:
3032

3133
expand_functions: set[str] | None
3234
identifiers: dict[str, str] | None
35+
prefixes: set[str] | None
3336
reduce_assignments: bool
3437
use_math_symbols: bool
3538
use_raw_function_name: bool
@@ -70,6 +73,7 @@ def defaults() -> Config:
7073
return Config(
7174
expand_functions=None,
7275
identifiers=None,
76+
prefixes=None,
7377
reduce_assignments=False,
7478
use_math_symbols=False,
7579
use_raw_function_name=False,

src/latexify/constants.py

-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
import enum
66

7-
PREFIXES = ["math", "numpy", "np"]
8-
97

108
class BuiltinFnName(str, enum.Enum):
119
"""Built-in function name."""

src/latexify/frontend.py

+10
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
from latexify import config as cfg
1111
from latexify import exceptions, parser, transformers
1212

13+
# NOTE(odashi):
14+
# These prefixes are trimmed by default.
15+
# This behavior shouldn't be controlled by users in the current implementation because
16+
# some processes expects absense of these prefixes.
17+
_COMMON_PREFIXES = {"math", "numpy", "np"}
18+
1319

1420
# TODO(odashi): move expand_functions to Config.
1521
def get_latex(
@@ -39,6 +45,10 @@ def get_latex(
3945
tree = parser.parse_function(fn)
4046

4147
# Applies AST transformations.
48+
49+
prefixes = _COMMON_PREFIXES | (merged_config.prefixes or set())
50+
tree = transformers.PrefixTrimmer(prefixes).visit(tree)
51+
4252
if merged_config.identifiers is not None:
4353
tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree)
4454
if merged_config.reduce_assignments:

src/latexify/frontend_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,27 @@ def myfn(myvar):
1818
assert frontend.get_latex(myfn, identifiers=identifiers) == latex_with_flag
1919

2020

21+
def test_get_latex_prefixes() -> None:
22+
math = numpy = np = abc = object()
23+
24+
def f(x):
25+
return math.foo + numpy.bar + np.baz + abc.qux + x.y.z.quux
26+
27+
latex_without_flag = r"\mathrm{f}(x) = foo + bar + baz + abc.qux + x.y.z.quux"
28+
latex_with_flag1 = r"\mathrm{f}(x) = foo + bar + baz + qux + x.y.z.quux"
29+
latex_with_flag2 = r"\mathrm{f}(x) = foo + bar + baz + abc.qux + y.z.quux"
30+
latex_with_flag3 = r"\mathrm{f}(x) = foo + bar + baz + abc.qux + z.quux"
31+
latex_with_flag4 = r"\mathrm{f}(x) = foo + bar + baz + qux + quux"
32+
33+
assert frontend.get_latex(f) == latex_without_flag
34+
assert frontend.get_latex(f, prefixes=set()) == latex_without_flag
35+
assert frontend.get_latex(f, prefixes={"abc"}) == latex_with_flag1
36+
assert frontend.get_latex(f, prefixes={"x"}) == latex_with_flag2
37+
assert frontend.get_latex(f, prefixes={"x.y"}) == latex_with_flag3
38+
assert frontend.get_latex(f, prefixes={"abc", "x.y.z"}) == latex_with_flag4
39+
assert frontend.get_latex(f, prefixes={"abc", "x", "x.y.z"}) == latex_with_flag4
40+
41+
2142
def test_get_latex_reduce_assignments() -> None:
2243
def f(x):
2344
y = 3 * x

src/latexify/transformers/__init__.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
"""Package latexify.transformers."""
22

3-
from latexify.transformers import (
4-
assignment_reducer,
5-
function_expander,
6-
identifier_replacer,
7-
)
3+
from latexify.transformers.assignment_reducer import AssignmentReducer
4+
from latexify.transformers.function_expander import FunctionExpander
5+
from latexify.transformers.identifier_replacer import IdentifierReplacer
6+
from latexify.transformers.prefix_trimmer import PrefixTrimmer
87

9-
AssignmentReducer = assignment_reducer.AssignmentReducer
10-
FunctionExpander = function_expander.FunctionExpander
11-
IdentifierReplacer = identifier_replacer.IdentifierReplacer
8+
__all__ = [
9+
AssignmentReducer,
10+
FunctionExpander,
11+
IdentifierReplacer,
12+
PrefixTrimmer,
13+
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""NodeTransformer to trim unnecessary prefixes."""
2+
3+
from __future__ import annotations
4+
5+
import ast
6+
import re
7+
8+
from latexify import ast_utils
9+
10+
_PREFIX_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$")
11+
12+
13+
class PrefixTrimmer(ast.NodeTransformer):
14+
"""NodeTransformer to trim unnecessary prefixes.
15+
16+
This class investigates all Attribute subtrees, and replace them if the prefix of
17+
the attribute matches the given set of prefixes.
18+
Prefix is searched in the manner of leftmost longest matching.
19+
20+
Example:
21+
def f(x):
22+
return math.sqrt(x)
23+
24+
PrefixTrimmer({"math"}) will modify the AST of the function above to below:
25+
26+
def f(x):
27+
return sqrt(x)
28+
"""
29+
30+
_prefixes: list[tuple[str, ...]]
31+
32+
def __init__(self, prefixes: set[str]) -> None:
33+
"""Initializer.
34+
35+
Args:
36+
prefixes: Set of prefixes to be trimmed. Nested prefix is allowed too.
37+
Each value must follow one of the following formats:
38+
- A Python identifier, e.g., "math"
39+
- Python identifiers joined by periods, e.g., "numpy.random"
40+
"""
41+
for p in prefixes:
42+
if not _PREFIX_PATTERN.match(p):
43+
raise ValueError(f"Invalid prefix: {p}")
44+
45+
self._prefixes = [tuple(p.split(".")) for p in prefixes]
46+
47+
def _get_prefix(self, node: ast.expr) -> tuple[str, ...] | None:
48+
"""Helper to obtain nested prefix.
49+
50+
Args:
51+
node: Node to investigate.
52+
53+
Returns:
54+
The prefix tuple, or None if the node has unsupported syntax.
55+
"""
56+
if isinstance(node, ast.Name):
57+
return (node.id,)
58+
59+
if isinstance(node, ast.Attribute):
60+
parent = self._get_prefix(node.value)
61+
return parent + (node.attr,) if parent is not None else None
62+
63+
return None
64+
65+
def _make_attribute(self, prefix: tuple[str, ...], name: str) -> ast.expr:
66+
"""Helper to generate a new Attribute or Name node.
67+
68+
Args:
69+
prefix: List of prefixes.
70+
name: Attribute name.
71+
72+
Returns:
73+
Name node if prefix == (), (possibly nested) Attribute node otherwise.
74+
"""
75+
if not prefix:
76+
return ast_utils.make_name(name)
77+
78+
parent = self._make_attribute(prefix[:-1], prefix[-1])
79+
return ast_utils.make_attribute(parent, name)
80+
81+
def visit_Attribute(self, node: ast.Attribute) -> ast.expr:
82+
prefix = self._get_prefix(node.value)
83+
if prefix is None:
84+
return node
85+
86+
# Performs leftmost longest match.
87+
# NOTE(odashi):
88+
# This implementation is very naive, but would work efficiently as long as the
89+
# number of patterns is small.
90+
matched_length = 0
91+
92+
for p in self._prefixes:
93+
length = min(len(p), len(prefix))
94+
if prefix[:length] == p and length > matched_length:
95+
matched_length = length
96+
97+
return self._make_attribute(prefix[matched_length:], node.attr)

0 commit comments

Comments
 (0)