Skip to content

Commit c881cce

Browse files
author
Yusuke Oda
authored
Improve visit_Compare implementation (#61)
* improve visit_Compare * support multiple comparators
1 parent b61ebd5 commit c881cce

File tree

3 files changed

+67
-21
lines changed

3 files changed

+67
-21
lines changed

src/integration_tests/regression_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def sinc(x):
3838

3939
sinc_latex = (
4040
r"\mathrm{sinc}(x) \triangleq \left\{ \begin{array}{ll} 1, & \mathrm{if} \ "
41-
r"x=0 \\ \frac{\sin{\left({x}\right)}}{x}, & \mathrm{otherwise} \end{array}"
41+
r"{x = 0} \\ \frac{\sin{\left({x}\right)}}{x}, & \mathrm{otherwise} \end{array}"
4242
r" \right."
4343
)
4444

src/latexify/latexify_visitor.py

+20-20
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations
44

55
import ast
6+
from typing import ClassVar
67

78
from latexify import constants
89
from latexify import math_symbols
@@ -246,27 +247,26 @@ def _wrap(child):
246247
return reprs[type(node.op)]()
247248
return r"\mathrm{unknown\_binop}(" + _unwrap(lhs) + ", " + _unwrap(rhs) + ")"
248249

249-
def visit_Compare(self, node, action): # pylint: disable=invalid-name
250+
_compare_ops: ClassVar[dict[type[ast.cmpop], str]] = {
251+
ast.Eq: "=",
252+
ast.Gt: ">",
253+
ast.GtE: r"\ge",
254+
ast.In: r"\in",
255+
ast.Is: r"\equiv",
256+
ast.IsNot: r"\not\equiv",
257+
ast.Lt: "<",
258+
ast.LtE: r"\le",
259+
ast.NotEq: r"\ne",
260+
ast.NotIn: r"\notin",
261+
}
262+
263+
def visit_Compare(self, node: ast.Compare, action): # pylint: disable=invalid-name
250264
"""Visit a compare node."""
251-
lstr = self.visit(node.left)
252-
rstr = self.visit(node.comparators[0])
253-
254-
if isinstance(node.ops[0], ast.Eq):
255-
return lstr + "=" + rstr
256-
if isinstance(node.ops[0], ast.Gt):
257-
return lstr + ">" + rstr
258-
if isinstance(node.ops[0], ast.Lt):
259-
return lstr + "<" + rstr
260-
if isinstance(node.ops[0], ast.GtE):
261-
return lstr + r"\ge " + rstr
262-
if isinstance(node.ops[0], ast.LtE):
263-
return lstr + r"\le " + rstr
264-
if isinstance(node.ops[0], ast.NotEq):
265-
return lstr + r"\ne " + rstr
266-
if isinstance(node.ops[0], ast.Is):
267-
return lstr + r"\equiv" + rstr
268-
269-
return r"\mathrm{unknown\_comparator}(" + lstr + ", " + rstr + ")"
265+
lhs = self.visit(node.left)
266+
ops = [self._compare_ops[type(x)] for x in node.ops]
267+
rhs = [self.visit(x) for x in node.comparators]
268+
ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)]
269+
return "{" + lhs + "".join(ops_rhs) + "}"
270270

271271
def visit_BoolOp(self, node, action): # pylint: disable=invalid-name
272272
logic_operator = (

src/latexify/latexify_visitor_test.py

+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
"""Tests for latexify.latexify_visitor."""
2+
3+
import ast
4+
import pytest
5+
6+
from latexify.latexify_visitor import LatexifyVisitor
7+
8+
9+
@pytest.mark.parametrize(
10+
"code,latex",
11+
[
12+
# 1 comparator
13+
("a == b", "{a = b}"),
14+
("a > b", "{a > b}"),
15+
("a >= b", r"{a \ge b}"),
16+
("a in b", r"{a \in b}"),
17+
("a is b", r"{a \equiv b}"),
18+
("a is not b", r"{a \not\equiv b}"),
19+
("a < b", "{a < b}"),
20+
("a <= b", r"{a \le b}"),
21+
("a != b", r"{a \ne b}"),
22+
("a not in b", r"{a \notin b}"),
23+
# 2 comparators
24+
("a == b == c", "{a = b = c}"),
25+
("a == b > c", "{a = b > c}"),
26+
("a == b >= c", r"{a = b \ge c}"),
27+
("a == b < c", "{a = b < c}"),
28+
("a == b <= c", r"{a = b \le c}"),
29+
("a > b == c", "{a > b = c}"),
30+
("a > b > c", "{a > b > c}"),
31+
("a > b >= c", r"{a > b \ge c}"),
32+
("a >= b == c", r"{a \ge b = c}"),
33+
("a >= b > c", r"{a \ge b > c}"),
34+
("a >= b >= c", r"{a \ge b \ge c}"),
35+
("a < b == c", "{a < b = c}"),
36+
("a < b < c", "{a < b < c}"),
37+
("a < b <= c", r"{a < b \le c}"),
38+
("a <= b == c", r"{a \le b = c}"),
39+
("a <= b < c", r"{a \le b < c}"),
40+
("a <= b <= c", r"{a \le b \le c}"),
41+
],
42+
)
43+
def test_visit_compare(code: str, latex: str) -> None:
44+
tree = ast.parse(code).body[0].value
45+
assert isinstance(tree, ast.Compare)
46+
assert LatexifyVisitor().visit(tree) == latex

0 commit comments

Comments
 (0)