Skip to content

Commit eebbae7

Browse files
Yusuke Odakshxtij
Yusuke Oda
andauthored
Several refactoring and bugfix. (#145)
* Add basic matrix support * Fix generation function * Fix formatting and minor issues * Fix tests * refactor * fix * add ast_utils.parse_expr * refactor Co-authored-by: Kshitij Sharma <kshitij4sharma@gmail.com>
1 parent fb233b5 commit eebbae7

8 files changed

+564
-334
lines changed

src/integration_tests/function_expansion_test.py

+32-26
Original file line numberDiff line numberDiff line change
@@ -3,118 +3,124 @@
33
from integration_tests import utils
44

55

6-
def test_expand_atan2_function() -> None:
6+
def test_atan2() -> None:
77
def solve(x, y):
88
return math.atan2(y, x)
99

10-
latex = r"\mathrm{solve}(x, y) = \arctan{\left({\frac{y}{x}}\right)}"
10+
latex = (
11+
r"\mathrm{solve}(x, y) ="
12+
r" \arctan \mathopen{}\left( \frac{y}{x} \mathclose{}\right)"
13+
)
1114
utils.check_function(solve, latex, expand_functions={"atan2"})
1215

1316

14-
def test_expand_atan2_nested_function() -> None:
17+
def test_atan2_nested() -> None:
1518
def solve(x, y):
1619
return math.atan2(math.exp(y), math.exp(x))
1720

18-
latex = r"\mathrm{solve}(x, y) = \arctan{\left({\frac{e^{y}}{e^{x}}}\right)}"
21+
latex = (
22+
r"\mathrm{solve}(x, y) ="
23+
r" \arctan \mathopen{}\left( \frac{e^{y}}{e^{x}} \mathclose{}\right)"
24+
)
1925
utils.check_function(solve, latex, expand_functions={"atan2", "exp"})
2026

2127

22-
def test_expand_exp_function() -> None:
28+
def test_exp() -> None:
2329
def solve(x):
2430
return math.exp(x)
2531

2632
latex = r"\mathrm{solve}(x) = e^{x}"
2733
utils.check_function(solve, latex, expand_functions={"exp"})
2834

2935

30-
def test_expand_exp_nested_function() -> None:
36+
def test_exp_nested() -> None:
3137
def solve(x):
3238
return math.exp(math.exp(x))
3339

3440
latex = r"\mathrm{solve}(x) = e^{e^{x}}"
3541
utils.check_function(solve, latex, expand_functions={"exp"})
3642

3743

38-
def test_expand_exp2_function() -> None:
44+
def test_exp2() -> None:
3945
def solve(x):
4046
return math.exp2(x)
4147

4248
latex = r"\mathrm{solve}(x) = {2}^{x}"
4349
utils.check_function(solve, latex, expand_functions={"exp2"})
4450

4551

46-
def test_expand_exp2_nested_function() -> None:
52+
def test_exp2_nested() -> None:
4753
def solve(x):
4854
return math.exp2(math.exp2(x))
4955

5056
latex = r"\mathrm{solve}(x) = {2}^{{2}^{x}}"
5157
utils.check_function(solve, latex, expand_functions={"exp2"})
5258

5359

54-
def test_expand_expm1_function() -> None:
60+
def test_expm1() -> None:
5561
def solve(x):
5662
return math.expm1(x)
5763

58-
latex = r"\mathrm{solve}(x) = \exp{\left({x}\right)} - {1}"
64+
latex = r"\mathrm{solve}(x) = \exp x - {1}"
5965
utils.check_function(solve, latex, expand_functions={"expm1"})
6066

6167

62-
def test_expand_expm1_nested_function() -> None:
68+
def test_expm1_nested() -> None:
6369
def solve(x, y, z):
6470
return math.expm1(math.pow(y, z))
6571

6672
latex = r"\mathrm{solve}(x, y, z) = e^{y^{z}} - {1}"
6773
utils.check_function(solve, latex, expand_functions={"expm1", "exp", "pow"})
6874

6975

70-
def test_expand_hypot_function_without_attribute_access() -> None:
76+
def test_hypot_without_attribute() -> None:
7177
from math import hypot
7278

7379
def solve(x, y, z):
7480
return hypot(x, y, z)
7581

76-
latex = r"\mathrm{solve}(x, y, z) = \sqrt{x^{{2}} + y^{{2}} + z^{{2}}}"
82+
latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{{2}} + y^{{2}} + z^{{2}} }"
7783
utils.check_function(solve, latex, expand_functions={"hypot"})
7884

7985

80-
def test_expand_hypot_function() -> None:
86+
def test_hypot() -> None:
8187
def solve(x, y, z):
8288
return math.hypot(x, y, z)
8389

84-
latex = r"\mathrm{solve}(x, y, z) = \sqrt{x^{{2}} + y^{{2}} + z^{{2}}}"
90+
latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{{2}} + y^{{2}} + z^{{2}} }"
8591
utils.check_function(solve, latex, expand_functions={"hypot"})
8692

8793

88-
def test_expand_hypot_nested_function() -> None:
94+
def test_hypot_nested() -> None:
8995
def solve(a, b, x, y):
9096
return math.hypot(math.hypot(a, b), x, y)
9197

9298
latex = (
93-
r"\mathrm{solve}(a, b, x, y) = "
94-
r"\sqrt{"
95-
r"\sqrt{a^{{2}} + b^{{2}}}^{{2}} + "
96-
r"x^{{2}} + y^{{2}}}"
99+
r"\mathrm{solve}(a, b, x, y) ="
100+
r" \sqrt{ \sqrt{ a^{{2}} + b^{{2}} }^{{2}} + x^{{2}} + y^{{2}} }"
97101
)
98102
utils.check_function(solve, latex, expand_functions={"hypot"})
99103

100104

101-
def test_expand_log1p_function() -> None:
105+
def test_log1p() -> None:
102106
def solve(x):
103107
return math.log1p(x)
104108

105-
latex = r"\mathrm{solve}(x) = \log{\left({{1} + x}\right)}"
109+
latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( {1} + x \mathclose{}\right)"
106110
utils.check_function(solve, latex, expand_functions={"log1p"})
107111

108112

109-
def test_expand_log1p_nested_function() -> None:
113+
def test_log1p_nested() -> None:
110114
def solve(x):
111115
return math.log1p(math.exp(x))
112116

113-
latex = r"\mathrm{solve}(x) = \log{\left({{1} + e^{x}}\right)}"
117+
latex = (
118+
r"\mathrm{solve}(x) = \log \mathopen{}\left( {1} + e^{x} \mathclose{}\right)"
119+
)
114120
utils.check_function(solve, latex, expand_functions={"log1p", "exp"})
115121

116122

117-
def test_expand_pow_nested_function() -> None:
123+
def test_pow_nested() -> None:
118124
def solve(w, x, y, z):
119125
return math.pow(math.pow(w, x), math.pow(y, z))
120126

@@ -125,7 +131,7 @@ def solve(w, x, y, z):
125131
utils.check_function(solve, latex, expand_functions={"pow"})
126132

127133

128-
def test_expand_pow_function() -> None:
134+
def test_pow() -> None:
129135
def solve(x, y):
130136
return math.pow(x, y)
131137

src/integration_tests/regression_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_quadratic_solution() -> None:
1111
def solve(a, b, c):
1212
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)
1313

14-
latex = r"\mathrm{solve}(a, b, c) = \frac{-b + \sqrt{b^{{2}} - {4} a c}}{{2} a}"
14+
latex = r"\mathrm{solve}(a, b, c) = \frac{-b + \sqrt{ b^{{2}} - {4} a c }}{{2} a}"
1515
utils.check_function(solve, latex)
1616

1717

@@ -26,7 +26,7 @@ def sinc(x):
2626
r"\mathrm{sinc}(x) = "
2727
r"\left\{ \begin{array}{ll} "
2828
r"{1}, & \mathrm{if} \ "
29-
r"{x = {0}} \\ \frac{\sin{\left({x}\right)}}{x}, & \mathrm{otherwise} "
29+
r"{x = {0}} \\ \frac{\sin x}{x}, & \mathrm{otherwise} "
3030
r"\end{array} \right."
3131
)
3232
utils.check_function(sinc, latex)
@@ -201,9 +201,9 @@ def sigmoid(x):
201201
sigmoid,
202202
(
203203
r"\mathrm{sigmoid}(x) = \left\{ \begin{array}{ll} "
204-
r"\frac{{1}}{{1} + \exp{\left({-x}\right)}}, & "
204+
r"\frac{{1}}{{1} + \exp \mathopen{}\left( -x \mathclose{}\right)}, & "
205205
r"\mathrm{if} \ {x > {0}} \\ "
206-
r"\frac{\exp{\left({x}\right)}}{\exp{\left({x}\right)} + {1}}, & "
206+
r"\frac{\exp x}{\exp x + {1}}, & "
207207
r"\mathrm{otherwise} "
208208
r"\end{array} \right."
209209
),

src/latexify/ast_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def make_name(id: str) -> ast.Name:
3131
return ast.Name(id=id, ctx=ast.Load())
3232

3333

34-
def make_attribute(value: ast.Expr, attr: str):
34+
def make_attribute(value: ast.expr, attr: str):
3535
"""Generates a new Attribute node.
3636
3737
Args:

src/latexify/codegen/function_codegen.py

+75-28
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@
5252
ast.Or: 10,
5353
}
5454

55+
# NOTE(odashi):
56+
# Function invocation is treated as a unary operator with a higher precedence.
57+
# This ensures that the argument with a unary operator is wrapped:
58+
# exp(x) --> \exp x
59+
# exp(-x) --> \exp (-x)
60+
# -exp(x) --> - \exp x
61+
_CALL_PRECEDENCE = _PRECEDENCES[ast.UAdd] + 1
62+
5563

5664
def _get_precedence(node: ast.AST) -> int:
5765
"""Obtains the precedence of the subtree.
@@ -63,6 +71,9 @@ def _get_precedence(node: ast.AST) -> int:
6371
If `node` is a subtree with some operator, returns the precedence of the
6472
operator. Otherwise, returns a number larger enough from other precedences.
6573
"""
74+
if isinstance(node, ast.Call):
75+
return _CALL_PRECEDENCE
76+
6677
if isinstance(node, (ast.BoolOp, ast.BinOp, ast.UnaryOp)):
6778
return _PRECEDENCES[type(node.op)]
6879

@@ -289,38 +300,34 @@ def visit_Return(self, node: ast.Return) -> str:
289300

290301
def visit_Tuple(self, node: ast.Tuple) -> str:
291302
elts = [self.visit(i) for i in node.elts]
292-
return (
293-
r"\mathopen{}\left( "
294-
+ r"\space,\space ".join(elts)
295-
+ r"\mathclose{}\right) "
296-
)
303+
return r"\mathopen{}\left( " + r", ".join(elts) + r" \mathclose{}\right)"
297304

298305
def visit_List(self, node: ast.List) -> str:
299306
elts = [self.visit(i) for i in node.elts]
300-
return r"\left[ " + r"\space,\space ".join(elts) + r"\right] "
307+
return r"\mathopen{}\left[ " + r", ".join(elts) + r" \mathclose{}\right]"
301308

302309
def visit_Set(self, node: ast.Set) -> str:
303310
elts = [self.visit(i) for i in node.elts]
304-
return r"\left\{ " + r"\space,\space ".join(elts) + r"\right\} "
311+
return r"\mathopen{}\left\{ " + r", ".join(elts) + r" \mathclose{}\right\}"
305312

306313
def visit_ListComp(self, node: ast.ListComp) -> str:
307314
generators = [self.visit(comp) for comp in node.generators]
308315
return (
309-
r"\left[ "
316+
r"\mathopen{}\left[ "
310317
+ self.visit(node.elt)
311318
+ r" \mid "
312319
+ ", ".join(generators)
313-
+ r" \right]"
320+
+ r" \mathclose{}\right]"
314321
)
315322

316323
def visit_SetComp(self, node: ast.SetComp) -> str:
317324
generators = [self.visit(comp) for comp in node.generators]
318325
return (
319-
r"\left\{ "
326+
r"\mathopen{}\left\{ "
320327
+ self.visit(node.elt)
321328
+ r" \mid "
322329
+ ", ".join(generators)
323-
+ r" \right\}"
330+
+ r" \mathclose{}\right\}"
324331
)
325332

326333
def visit_comprehension(self, node: ast.comprehension) -> str:
@@ -347,10 +354,16 @@ def _generate_sum_prod(self, node: ast.Call) -> str | None:
347354
return None
348355

349356
name = ast_utils.extract_function_name_or_none(node)
350-
assert name is not None
357+
assert name in ("fsum", "sum", "prod")
358+
359+
command = {
360+
"fsum": r"\sum",
361+
"sum": r"\sum",
362+
"prod": r"\prod",
363+
}[name]
351364

352365
elt, scripts = self._get_sum_prod_info(node.args[0])
353-
scripts_str = [rf"\{name}_{{{lo}}}^{{{up}}}" for lo, up in scripts]
366+
scripts_str = [rf"{command}_{{{lo}}}^{{{up}}}" for lo, up in scripts]
354367
return (
355368
" ".join(scripts_str)
356369
+ rf" \mathopen{{}}\left({{{elt}}}\mathclose{{}}\right)"
@@ -403,7 +416,7 @@ def visit_Call(self, node: ast.Call) -> str:
403416
func_name = ast_utils.extract_function_name_or_none(node)
404417

405418
# Special treatments for some functions.
406-
if func_name in ("sum", "prod"):
419+
if func_name in ("fsum", "sum", "prod"):
407420
special_latex = self._generate_sum_prod(node)
408421
elif func_name in ("array", "ndarray"):
409422
special_latex = self._generate_matrix(node)
@@ -413,17 +426,38 @@ def visit_Call(self, node: ast.Call) -> str:
413426
if special_latex is not None:
414427
return special_latex
415428

416-
# Function signature (possibly an expression).
417-
default_func_str = self.visit(node.func)
418-
419-
# Obtains wrapper syntax: sqrt -> "\sqrt{" and "}"
420-
lstr, rstr = constants.BUILTIN_FUNCS.get(
421-
func_name,
422-
(default_func_str + r"\mathopen{}\left(", r"\mathclose{}\right)"),
423-
)
429+
# Obtains the codegen rule.
430+
rule = constants.BUILTIN_FUNCS.get(func_name)
431+
if rule is None:
432+
rule = constants.FunctionRule(self.visit(node.func))
433+
434+
if rule.is_unary and len(node.args) == 1:
435+
# Unary function. Applies the same wrapping policy with the unary operators.
436+
# NOTE(odashi):
437+
# Factorial "x!" is treated as a special case: it requires both inner/outer
438+
# parentheses for correct interpretation.
439+
precedence = _get_precedence(node)
440+
arg = node.args[0]
441+
force_wrap = isinstance(arg, ast.Call) and (
442+
func_name == "factorial"
443+
or ast_utils.extract_function_name_or_none(arg) == "factorial"
444+
)
445+
arg_latex = self._wrap_operand(arg, precedence, force_wrap)
446+
elements = [rule.left, arg_latex, rule.right]
447+
else:
448+
arg_latex = ", ".join(self.visit(arg) for arg in node.args)
449+
if rule.is_wrapped:
450+
elements = [rule.left, arg_latex, rule.right]
451+
else:
452+
elements = [
453+
rule.left,
454+
r"\mathopen{}\left(",
455+
arg_latex,
456+
r"\mathclose{}\right)",
457+
rule.right,
458+
]
424459

425-
arg_strs = [self.visit(arg) for arg in node.args]
426-
return lstr + ", ".join(arg_strs) + rstr
460+
return " ".join(x for x in elements if x)
427461

428462
def visit_Attribute(self, node: ast.Attribute) -> str:
429463
vstr = self.visit(node.value)
@@ -481,20 +515,26 @@ def visit_NameConstant(self, node: ast.NameConstant) -> str:
481515
def visit_Ellipsis(self, node: ast.Ellipsis) -> str:
482516
return self._convert_constant(...)
483517

484-
def _wrap_operand(self, child: ast.expr, parent_prec: int) -> str:
518+
def _wrap_operand(
519+
self, child: ast.expr, parent_prec: int, force_wrap: bool = False
520+
) -> str:
485521
"""Wraps the operand subtree with parentheses.
486522
487523
Args:
488524
child: Operand subtree.
489525
parent_prec: Precedence of the parent operator.
526+
force_wrap: Whether to wrap the operand or not when the precedence is equal.
490527
491528
Returns:
492529
LaTeX form of `child`, with or without surrounding parentheses.
493530
"""
494531
latex = self.visit(child)
495-
if _get_precedence(child) >= parent_prec:
496-
return latex
497-
return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"
532+
child_prec = _get_precedence(child)
533+
534+
if child_prec < parent_prec or force_wrap and child_prec == parent_prec:
535+
return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"
536+
537+
return latex
498538

499539
def _wrap_binop_operand(
500540
self,
@@ -515,6 +555,13 @@ def _wrap_binop_operand(
515555
if not operand_rule.wrap:
516556
return self.visit(child)
517557

558+
if isinstance(child, ast.Call):
559+
rule = constants.BUILTIN_FUNCS.get(
560+
ast_utils.extract_function_name_or_none(child)
561+
)
562+
if rule is not None and rule.is_wrapped:
563+
return self.visit(child)
564+
518565
if not isinstance(child, ast.BinOp):
519566
return self._wrap_operand(child, parent_prec)
520567

0 commit comments

Comments
 (0)