Skip to content

Commit 85dddda

Browse files
ZibingZhangZibing Zhang
and
Zibing Zhang
authored
Support expanding some functions (#125)
* Impl function expander * Fix error message * fix flake8 errors * file doc * Suggestions * Fix docstring * Fix docstring * rename test methods * slight rephrasing * rm local var * Reformat import * suggestions Co-authored-by: Zibing Zhang <zizhang@hubspot.com>
1 parent a782fab commit 85dddda

File tree

7 files changed

+249
-1
lines changed

7 files changed

+249
-1
lines changed

src/integration_tests/regression_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,37 @@ def solve(a, b):
281281
_check_function(solve, latex)
282282

283283

284+
def test_expand_hypot_function_without_attribute_access() -> None:
285+
from math import hypot
286+
287+
def solve(x, y, z):
288+
return hypot(x, y, z)
289+
290+
latex = r"\mathrm{solve}(x, y, z) = \sqrt{x^{{2}} + y^{{2}} + z^{{2}}}"
291+
_check_function(solve, latex, expand_functions={"hypot"})
292+
293+
294+
def test_expand_hypot_function() -> None:
295+
def solve(x, y, z):
296+
return math.hypot(x, y, z)
297+
298+
latex = r"\mathrm{solve}(x, y, z) = \sqrt{x^{{2}} + y^{{2}} + z^{{2}}}"
299+
_check_function(solve, latex, expand_functions={"hypot"})
300+
301+
302+
def test_expand_nested_function() -> None:
303+
def solve(a, b, x, y):
304+
return math.hypot(math.hypot(a, b), x, y)
305+
306+
latex = (
307+
r"\mathrm{solve}(a, b, x, y) = "
308+
r"\sqrt{"
309+
r"\sqrt{a^{{2}} + b^{{2}}}^{{2}} + "
310+
r"x^{{2}} + y^{{2}}}"
311+
)
312+
_check_function(solve, latex, expand_functions={"hypot"})
313+
314+
284315
def test_docstring_allowed() -> None:
285316
def solve(x):
286317
"""The identity function."""

src/latexify/ast_utils.py

+17
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,20 @@ def extract_int(node: ast.expr) -> int:
104104
raise ValueError(f"Unsupported node to extract int: {type(node).__name__}")
105105

106106
return value
107+
108+
109+
def extract_function_name_or_none(node: ast.Call) -> str | None:
110+
"""Extracts function name from the given Call node.
111+
112+
Args:
113+
node: ast.Call.
114+
115+
Returns:
116+
Extracted function name, or None if not found.
117+
"""
118+
if isinstance(node.func, ast.Name):
119+
return node.func.id
120+
if isinstance(node.func, ast.Attribute):
121+
return node.func.attr
122+
123+
return None

src/latexify/ast_utils_test.py

+34
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,37 @@ def test_extract_int_invalid() -> None:
146146
ast_utils.extract_int(ast_utils.make_constant("123"))
147147
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
148148
ast_utils.extract_int(ast_utils.make_constant(b"123"))
149+
150+
151+
@pytest.mark.parametrize(
152+
"value,expected",
153+
[
154+
(
155+
ast.Call(
156+
func=ast.Name(id="hypot", ctx=ast.Load()),
157+
args=[],
158+
),
159+
"hypot",
160+
),
161+
(
162+
ast.Call(
163+
func=ast.Attribute(
164+
value=ast.Name(id="math", ctx=ast.Load()),
165+
attr="hypot",
166+
ctx=ast.Load(),
167+
),
168+
args=[],
169+
),
170+
"hypot",
171+
),
172+
(
173+
ast.Call(
174+
func=ast.Call(func=ast.Name(id="foo", ctx=ast.Load()), args=[]),
175+
args=[],
176+
),
177+
None,
178+
),
179+
],
180+
)
181+
def test_extract_function_name_or_none(value: ast.Call, expected: str | None) -> None:
182+
assert ast_utils.extract_function_name_or_none(value) == expected

src/latexify/frontend.py

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def get_latex(
1313
fn: Callable[..., Any],
1414
*,
1515
identifiers: dict[str, str] | None = None,
16+
expand_functions: set[str] | None = None,
1617
reduce_assignments: bool = False,
1718
use_math_symbols: bool = False,
1819
use_raw_function_name: bool = False,
@@ -28,6 +29,7 @@ def get_latex(
2829
the replacements.
2930
Both keys and values have to represent valid Python identifiers:
3031
^[A-Za-z_][A-Za-z0-9_]*$
32+
expand_functions: If set, the names of the functions to expand.
3133
reduce_assignments: If True, assignment statements are used to synthesize
3234
the final expression.
3335
use_math_symbols: Whether to convert identifiers with a math symbol surface
@@ -52,6 +54,8 @@ def get_latex(
5254
tree = transformers.IdentifierReplacer(identifiers).visit(tree)
5355
if reduce_assignments:
5456
tree = transformers.AssignmentReducer().visit(tree)
57+
if expand_functions is not None:
58+
tree = transformers.FunctionExpander(expand_functions).visit(tree)
5559

5660
# Generates LaTeX.
5761
return codegen.FunctionCodegen(

src/latexify/transformers/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
"""Package latexify.transformers."""
22

3-
from latexify.transformers import assignment_reducer, identifier_replacer
3+
from latexify.transformers import (
4+
assignment_reducer,
5+
function_expander,
6+
identifier_replacer,
7+
)
48

59
AssignmentReducer = assignment_reducer.AssignmentReducer
10+
FunctionExpander = function_expander.FunctionExpander
611
IdentifierReplacer = identifier_replacer.IdentifierReplacer
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from __future__ import annotations
2+
3+
import ast
4+
import functools
5+
from collections.abc import Callable
6+
7+
from latexify import ast_utils
8+
9+
10+
# TODO(ZibingZhang): handle recursive function expansions
11+
class FunctionExpander(ast.NodeTransformer):
12+
"""NodeTransformer to expand functions.
13+
14+
This class replaces function calls with an expanded form.
15+
16+
Example:
17+
def f(x, y):
18+
return hypot(x, y)
19+
20+
FunctionExpander({"hypot"}) will modify the AST of the function above to below:
21+
22+
def f(x, y):
23+
return sqrt(x**2, y**2)
24+
"""
25+
26+
def __init__(self, functions: set[str]) -> None:
27+
self._functions = functions
28+
29+
def visit_Call(self, node: ast.Call) -> ast.AST:
30+
"""Visitor of Call nodes."""
31+
func_name = ast_utils.extract_function_name_or_none(node)
32+
if (
33+
func_name is not None
34+
and func_name in self._functions
35+
and func_name in _FUNCTION_EXPANDERS
36+
):
37+
return _FUNCTION_EXPANDERS[func_name](self, node)
38+
39+
return node
40+
41+
42+
def _hypot_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST:
43+
if len(node.args) == 0:
44+
return ast_utils.make_constant(0)
45+
46+
args = [
47+
ast.BinOp(function_expander.visit(arg), ast.Pow(), ast_utils.make_constant(2))
48+
for arg in node.args
49+
]
50+
51+
args_reduced = functools.reduce(lambda a, b: ast.BinOp(a, ast.Add(), b), args)
52+
return ast.Call(
53+
func=ast.Name(id="sqrt", ctx=ast.Load()),
54+
args=[args_reduced],
55+
)
56+
57+
58+
_FUNCTION_EXPANDERS: dict[str, Callable[[FunctionExpander, ast.Call], ast.AST]] = {
59+
"hypot": _hypot_expander
60+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Tests for latexify.transformers.function_expander."""
2+
3+
from __future__ import annotations
4+
5+
import ast
6+
import math
7+
8+
from latexify import ast_utils, parser, test_utils
9+
from latexify.transformers.function_expander import FunctionExpander
10+
11+
12+
def _make_ast(args: list[str], body: ast.expr) -> ast.Module:
13+
"""Helper function to generate an AST for f(x).
14+
15+
Args:
16+
args: The arguments passed to the method.
17+
body: The body of the return statement.
18+
19+
Returns:
20+
Generated AST.
21+
"""
22+
return ast.Module(
23+
body=[
24+
ast.FunctionDef(
25+
name="f",
26+
args=ast.arguments(
27+
args=[ast.arg(arg=arg) for arg in args],
28+
kwonlyargs=[],
29+
kw_defaults=[],
30+
defaults=[],
31+
),
32+
body=[ast.Return(body)],
33+
decorator_list=[],
34+
)
35+
],
36+
)
37+
38+
39+
def test_hypot_unchanged_without_attribute_access() -> None:
40+
from math import hypot
41+
42+
def f(x, y):
43+
return hypot(x, y)
44+
45+
expected = _make_ast(
46+
["x", "y"], ast.Call(ast.Name("hypot"), [ast.Name("x"), ast.Name("y")])
47+
)
48+
transformed = FunctionExpander(set()).visit(parser.parse_function(f))
49+
test_utils.assert_ast_equal(transformed, expected)
50+
51+
52+
def test_hypot_unchanged() -> None:
53+
def f(x, y):
54+
return math.hypot(x, y)
55+
56+
expected = _make_ast(
57+
["x", "y"],
58+
ast.Call(
59+
ast.Attribute(ast.Name("math"), "hypot", ast.Load()),
60+
[ast.Name("x"), ast.Name("y")],
61+
),
62+
)
63+
transformed = FunctionExpander(set()).visit(parser.parse_function(f))
64+
test_utils.assert_ast_equal(transformed, expected)
65+
66+
67+
def test_hypot_expanded() -> None:
68+
def f(x, y):
69+
return math.hypot(x, y)
70+
71+
expected = _make_ast(
72+
["x", "y"],
73+
ast.Call(
74+
ast.Name("sqrt"),
75+
[
76+
ast.BinOp(
77+
ast.BinOp(ast.Name("x"), ast.Pow(), ast_utils.make_constant(2)),
78+
ast.Add(),
79+
ast.BinOp(ast.Name("y"), ast.Pow(), ast_utils.make_constant(2)),
80+
)
81+
],
82+
),
83+
)
84+
transformed = FunctionExpander({"hypot"}).visit(parser.parse_function(f))
85+
test_utils.assert_ast_equal(transformed, expected)
86+
87+
88+
def test_hypot_expanded_no_args() -> None:
89+
def f():
90+
return math.hypot()
91+
92+
expected = _make_ast(
93+
[],
94+
ast_utils.make_constant(0),
95+
)
96+
transformed = FunctionExpander({"hypot"}).visit(parser.parse_function(f))
97+
test_utils.assert_ast_equal(transformed, expected)

0 commit comments

Comments
 (0)