Skip to content

Commit 3ae62c1

Browse files
author
odashi
committed
fix
1 parent 716e87a commit 3ae62c1

File tree

1 file changed

+32
-209
lines changed

1 file changed

+32
-209
lines changed

src/latexify/codegen/expression_codegen.py

+32-209
Original file line numberDiff line numberDiff line change
@@ -3,206 +3,17 @@
33
from __future__ import annotations
44

55
import ast
6-
import dataclasses
7-
8-
from latexify import analyzers, ast_utils, constants, exceptions
9-
from latexify.codegen import codegen_utils, identifier_converter
10-
11-
# Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes.
12-
# Note that this value affects only the appearance of surrounding parentheses for each
13-
# expression, and does not affect the AST itself.
14-
# See also:
15-
# https://docs.python.org/3/reference/expressions.html#operator-precedence
16-
_PRECEDENCES: dict[type[ast.AST], int] = {
17-
ast.Pow: 120,
18-
ast.UAdd: 110,
19-
ast.USub: 110,
20-
ast.Invert: 110,
21-
ast.Mult: 100,
22-
ast.MatMult: 100,
23-
ast.Div: 100,
24-
ast.FloorDiv: 100,
25-
ast.Mod: 100,
26-
ast.Add: 90,
27-
ast.Sub: 90,
28-
ast.LShift: 80,
29-
ast.RShift: 80,
30-
ast.BitAnd: 70,
31-
ast.BitXor: 60,
32-
ast.BitOr: 50,
33-
ast.In: 40,
34-
ast.NotIn: 40,
35-
ast.Is: 40,
36-
ast.IsNot: 40,
37-
ast.Lt: 40,
38-
ast.LtE: 40,
39-
ast.Gt: 40,
40-
ast.GtE: 40,
41-
ast.NotEq: 40,
42-
ast.Eq: 40,
43-
# NOTE(odashi):
44-
# We assume that the `not` operator has the same precedence with other unary
45-
# operators `+`, `-` and `~`, because the LaTeX counterpart $\lnot$ looks to have a
46-
# high precedence.
47-
# ast.Not: 30,
48-
ast.Not: 110,
49-
ast.And: 20,
50-
ast.Or: 10,
51-
}
52-
53-
# NOTE(odashi):
54-
# Function invocation is treated as a unary operator with a higher precedence.
55-
# This ensures that the argument with a unary operator is wrapped:
56-
# exp(x) --> \exp x
57-
# exp(-x) --> \exp (-x)
58-
# -exp(x) --> - \exp x
59-
_CALL_PRECEDENCE = _PRECEDENCES[ast.UAdd] + 1
60-
61-
62-
def _get_precedence(node: ast.AST) -> int:
63-
"""Obtains the precedence of the subtree.
64-
65-
Args:
66-
node: Subtree to investigate.
67-
68-
Returns:
69-
If `node` is a subtree with some operator, returns the precedence of the
70-
operator. Otherwise, returns a number larger enough from other precedences.
71-
"""
72-
if isinstance(node, ast.Call):
73-
return _CALL_PRECEDENCE
74-
75-
if isinstance(node, (ast.BoolOp, ast.BinOp, ast.UnaryOp)):
76-
return _PRECEDENCES[type(node.op)]
77-
78-
if isinstance(node, ast.Compare):
79-
# Compare operators have the same precedence. It is enough to check only the
80-
# first operator.
81-
return _PRECEDENCES[type(node.ops[0])]
82-
83-
return 1_000_000
84-
85-
86-
@dataclasses.dataclass(frozen=True)
87-
class BinOperandRule:
88-
"""Syntax rules for operands of BinOp."""
89-
90-
# Whether to require wrapping operands by parentheses according to the precedence.
91-
wrap: bool = True
92-
93-
# Whether to require wrapping operands by parentheses if the operand has the same
94-
# precedence with this operator.
95-
# This is used to control the behavior of non-associative operators.
96-
force: bool = False
97-
98-
99-
@dataclasses.dataclass(frozen=True)
100-
class BinOpRule:
101-
"""Syntax rules for BinOp."""
102-
103-
# Left/middle/right syntaxes to wrap operands.
104-
latex_left: str
105-
latex_middle: str
106-
latex_right: str
107-
108-
# Operand rules.
109-
operand_left: BinOperandRule = dataclasses.field(default_factory=BinOperandRule)
110-
operand_right: BinOperandRule = dataclasses.field(default_factory=BinOperandRule)
111-
112-
# Whether to assume the resulting syntax is wrapped by some bracket operators.
113-
# If True, the parent operator can avoid wrapping this operator by parentheses.
114-
is_wrapped: bool = False
115-
116-
117-
_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = {
118-
ast.Pow: BinOpRule(
119-
"",
120-
"^{",
121-
"}",
122-
operand_left=BinOperandRule(force=True),
123-
operand_right=BinOperandRule(wrap=False),
124-
),
125-
ast.Mult: BinOpRule("", " ", ""),
126-
ast.MatMult: BinOpRule("", " ", ""),
127-
ast.Div: BinOpRule(
128-
r"\frac{",
129-
"}{",
130-
"}",
131-
operand_left=BinOperandRule(wrap=False),
132-
operand_right=BinOperandRule(wrap=False),
133-
),
134-
ast.FloorDiv: BinOpRule(
135-
r"\left\lfloor\frac{",
136-
"}{",
137-
r"}\right\rfloor",
138-
operand_left=BinOperandRule(wrap=False),
139-
operand_right=BinOperandRule(wrap=False),
140-
is_wrapped=True,
141-
),
142-
ast.Mod: BinOpRule(
143-
"", r" \mathbin{\%} ", "", operand_right=BinOperandRule(force=True)
144-
),
145-
ast.Add: BinOpRule("", " + ", ""),
146-
ast.Sub: BinOpRule("", " - ", "", operand_right=BinOperandRule(force=True)),
147-
ast.LShift: BinOpRule("", r" \ll ", "", operand_right=BinOperandRule(force=True)),
148-
ast.RShift: BinOpRule("", r" \gg ", "", operand_right=BinOperandRule(force=True)),
149-
ast.BitAnd: BinOpRule("", r" \mathbin{\&} ", ""),
150-
ast.BitXor: BinOpRule("", r" \oplus ", ""),
151-
ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""),
152-
}
153-
154-
# Typeset for BinOp of sets.
155-
_SET_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = {
156-
**_BIN_OP_RULES,
157-
ast.Sub: BinOpRule(
158-
"", r" \setminus ", "", operand_right=BinOperandRule(force=True)
159-
),
160-
ast.BitAnd: BinOpRule("", r" \cap ", ""),
161-
ast.BitXor: BinOpRule("", r" \mathbin{\triangle} ", ""),
162-
ast.BitOr: BinOpRule("", r" \cup ", ""),
163-
}
164-
165-
_UNARY_OPS: dict[type[ast.unaryop], str] = {
166-
ast.Invert: r"\mathord{\sim} ",
167-
ast.UAdd: "+", # Explicitly adds the $+$ operator.
168-
ast.USub: "-",
169-
ast.Not: r"\lnot ",
170-
}
171-
172-
_COMPARE_OPS: dict[type[ast.cmpop], str] = {
173-
ast.Eq: "=",
174-
ast.Gt: ">",
175-
ast.GtE: r"\ge",
176-
ast.In: r"\in",
177-
ast.Is: r"\equiv",
178-
ast.IsNot: r"\not\equiv",
179-
ast.Lt: "<",
180-
ast.LtE: r"\le",
181-
ast.NotEq: r"\ne",
182-
ast.NotIn: r"\notin",
183-
}
184-
185-
# Typeset for Compare of sets.
186-
_SET_COMPARE_OPS: dict[type[ast.cmpop], str] = {
187-
**_COMPARE_OPS,
188-
ast.Gt: r"\supset",
189-
ast.GtE: r"\supseteq",
190-
ast.Lt: r"\subset",
191-
ast.LtE: r"\subseteq",
192-
}
193-
194-
_BOOL_OPS: dict[type[ast.boolop], str] = {
195-
ast.And: r"\land",
196-
ast.Or: r"\lor",
197-
}
6+
7+
from latexify import analyzers, ast_utils, exceptions
8+
from latexify.codegen import codegen_utils, expression_rules, identifier_converter
1989

19910

20011
class ExpressionCodegen(ast.NodeVisitor):
20112
"""Codegen for single expressions."""
20213

20314
_identifier_converter: identifier_converter.IdentifierConverter
20415

205-
_bin_op_rules: dict[type[ast.operator], BinOpRule]
16+
_bin_op_rules: dict[type[ast.operator], expression_rules.BinOpRule]
20617
_compare_ops: dict[type[ast.cmpop], str]
20718

20819
def __init__(
@@ -219,8 +30,16 @@ def __init__(
21930
use_math_symbols=use_math_symbols
22031
)
22132

222-
self._bin_op_rules = _SET_BIN_OP_RULES if use_set_symbols else _BIN_OP_RULES
223-
self._compare_ops = _SET_COMPARE_OPS if use_set_symbols else _COMPARE_OPS
33+
self._bin_op_rules = (
34+
expression_rules.SET_BIN_OP_RULES
35+
if use_set_symbols
36+
else expression_rules.BIN_OP_RULES
37+
)
38+
self._compare_ops = (
39+
expression_rules.SET_COMPARE_OPS
40+
if use_set_symbols
41+
else expression_rules.COMPARE_OPS
42+
)
22443

22544
def generic_visit(self, node: ast.AST) -> str:
22645
raise exceptions.LatexifyNotSupportedError(
@@ -420,17 +239,21 @@ def visit_Call(self, node: ast.Call) -> str:
420239
return special_latex
421240

422241
# Obtains the codegen rule.
423-
rule = constants.BUILTIN_FUNCS.get(func_name) if func_name is not None else None
242+
rule = (
243+
expression_rules.BUILTIN_FUNCS.get(func_name)
244+
if func_name is not None
245+
else None
246+
)
424247

425248
if rule is None:
426-
rule = constants.FunctionRule(self.visit(node.func))
249+
rule = expression_rules.FunctionRule(self.visit(node.func))
427250

428251
if rule.is_unary and len(node.args) == 1:
429252
# Unary function. Applies the same wrapping policy with the unary operators.
430253
# NOTE(odashi):
431254
# Factorial "x!" is treated as a special case: it requires both inner/outer
432255
# parentheses for correct interpretation.
433-
precedence = _get_precedence(node)
256+
precedence = expression_rules.get_precedence(node)
434257
arg = node.args[0]
435258
force_wrap = isinstance(arg, ast.Call) and (
436259
func_name == "factorial"
@@ -507,7 +330,7 @@ def _wrap_operand(
507330
LaTeX form of `child`, with or without surrounding parentheses.
508331
"""
509332
latex = self.visit(child)
510-
child_prec = _get_precedence(child)
333+
child_prec = expression_rules.get_precedence(child)
511334

512335
if child_prec < parent_prec or force_wrap and child_prec == parent_prec:
513336
return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"
@@ -518,7 +341,7 @@ def _wrap_binop_operand(
518341
self,
519342
child: ast.expr,
520343
parent_prec: int,
521-
operand_rule: BinOperandRule,
344+
operand_rule: expression_rules.BinOperandRule,
522345
) -> str:
523346
"""Wraps the operand subtree of BinOp with parentheses.
524347
@@ -536,7 +359,7 @@ def _wrap_binop_operand(
536359
if isinstance(child, ast.Call):
537360
child_fn_name = ast_utils.extract_function_name_or_none(child)
538361
rule = (
539-
constants.BUILTIN_FUNCS.get(child_fn_name)
362+
expression_rules.BUILTIN_FUNCS.get(child_fn_name)
540363
if child_fn_name is not None
541364
else None
542365
)
@@ -548,10 +371,10 @@ def _wrap_binop_operand(
548371

549372
latex = self.visit(child)
550373

551-
if _BIN_OP_RULES[type(child.op)].is_wrapped:
374+
if expression_rules.BIN_OP_RULES[type(child.op)].is_wrapped:
552375
return latex
553376

554-
child_prec = _get_precedence(child)
377+
child_prec = expression_rules.get_precedence(child)
555378

556379
if child_prec > parent_prec or (
557380
child_prec == parent_prec and not operand_rule.force
@@ -562,20 +385,20 @@ def _wrap_binop_operand(
562385

563386
def visit_BinOp(self, node: ast.BinOp) -> str:
564387
"""Visit a BinOp node."""
565-
prec = _get_precedence(node)
388+
prec = expression_rules.get_precedence(node)
566389
rule = self._bin_op_rules[type(node.op)]
567390
lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left)
568391
rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right)
569392
return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}"
570393

571394
def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
572395
"""Visit a UnaryOp node."""
573-
latex = self._wrap_operand(node.operand, _get_precedence(node))
574-
return _UNARY_OPS[type(node.op)] + latex
396+
latex = self._wrap_operand(node.operand, expression_rules.get_precedence(node))
397+
return expression_rules.UNARY_OPS[type(node.op)] + latex
575398

576399
def visit_Compare(self, node: ast.Compare) -> str:
577400
"""Visit a Compare node."""
578-
parent_prec = _get_precedence(node)
401+
parent_prec = expression_rules.get_precedence(node)
579402
lhs = self._wrap_operand(node.left, parent_prec)
580403
ops = [self._compare_ops[type(x)] for x in node.ops]
581404
rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators]
@@ -584,9 +407,9 @@ def visit_Compare(self, node: ast.Compare) -> str:
584407

585408
def visit_BoolOp(self, node: ast.BoolOp) -> str:
586409
"""Visit a BoolOp node."""
587-
parent_prec = _get_precedence(node)
410+
parent_prec = expression_rules.get_precedence(node)
588411
values = [self._wrap_operand(x, parent_prec) for x in node.values]
589-
op = f" {_BOOL_OPS[type(node.op)]} "
412+
op = f" {expression_rules.BOOL_OPS[type(node.op)]} "
590413
return op.join(values)
591414

592415
def visit_IfExp(self, node: ast.IfExp) -> str:

0 commit comments

Comments
 (0)