3
3
from __future__ import annotations
4
4
5
5
import ast
6
- import dataclasses
7
-
8
- from latexify import analyzers , ast_utils , constants , exceptions
9
- from latexify .codegen import codegen_utils , identifier_converter
10
-
11
- # Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes.
12
- # Note that this value affects only the appearance of surrounding parentheses for each
13
- # expression, and does not affect the AST itself.
14
- # See also:
15
- # https://docs.python.org/3/reference/expressions.html#operator-precedence
16
- _PRECEDENCES : dict [type [ast .AST ], int ] = {
17
- ast .Pow : 120 ,
18
- ast .UAdd : 110 ,
19
- ast .USub : 110 ,
20
- ast .Invert : 110 ,
21
- ast .Mult : 100 ,
22
- ast .MatMult : 100 ,
23
- ast .Div : 100 ,
24
- ast .FloorDiv : 100 ,
25
- ast .Mod : 100 ,
26
- ast .Add : 90 ,
27
- ast .Sub : 90 ,
28
- ast .LShift : 80 ,
29
- ast .RShift : 80 ,
30
- ast .BitAnd : 70 ,
31
- ast .BitXor : 60 ,
32
- ast .BitOr : 50 ,
33
- ast .In : 40 ,
34
- ast .NotIn : 40 ,
35
- ast .Is : 40 ,
36
- ast .IsNot : 40 ,
37
- ast .Lt : 40 ,
38
- ast .LtE : 40 ,
39
- ast .Gt : 40 ,
40
- ast .GtE : 40 ,
41
- ast .NotEq : 40 ,
42
- ast .Eq : 40 ,
43
- # NOTE(odashi):
44
- # We assume that the `not` operator has the same precedence with other unary
45
- # operators `+`, `-` and `~`, because the LaTeX counterpart $\lnot$ looks to have a
46
- # high precedence.
47
- # ast.Not: 30,
48
- ast .Not : 110 ,
49
- ast .And : 20 ,
50
- ast .Or : 10 ,
51
- }
52
-
53
- # NOTE(odashi):
54
- # Function invocation is treated as a unary operator with a higher precedence.
55
- # This ensures that the argument with a unary operator is wrapped:
56
- # exp(x) --> \exp x
57
- # exp(-x) --> \exp (-x)
58
- # -exp(x) --> - \exp x
59
- _CALL_PRECEDENCE = _PRECEDENCES [ast .UAdd ] + 1
60
-
61
-
62
- def _get_precedence (node : ast .AST ) -> int :
63
- """Obtains the precedence of the subtree.
64
-
65
- Args:
66
- node: Subtree to investigate.
67
-
68
- Returns:
69
- If `node` is a subtree with some operator, returns the precedence of the
70
- operator. Otherwise, returns a number larger enough from other precedences.
71
- """
72
- if isinstance (node , ast .Call ):
73
- return _CALL_PRECEDENCE
74
-
75
- if isinstance (node , (ast .BoolOp , ast .BinOp , ast .UnaryOp )):
76
- return _PRECEDENCES [type (node .op )]
77
-
78
- if isinstance (node , ast .Compare ):
79
- # Compare operators have the same precedence. It is enough to check only the
80
- # first operator.
81
- return _PRECEDENCES [type (node .ops [0 ])]
82
-
83
- return 1_000_000
84
-
85
-
86
- @dataclasses .dataclass (frozen = True )
87
- class BinOperandRule :
88
- """Syntax rules for operands of BinOp."""
89
-
90
- # Whether to require wrapping operands by parentheses according to the precedence.
91
- wrap : bool = True
92
-
93
- # Whether to require wrapping operands by parentheses if the operand has the same
94
- # precedence with this operator.
95
- # This is used to control the behavior of non-associative operators.
96
- force : bool = False
97
-
98
-
99
- @dataclasses .dataclass (frozen = True )
100
- class BinOpRule :
101
- """Syntax rules for BinOp."""
102
-
103
- # Left/middle/right syntaxes to wrap operands.
104
- latex_left : str
105
- latex_middle : str
106
- latex_right : str
107
-
108
- # Operand rules.
109
- operand_left : BinOperandRule = dataclasses .field (default_factory = BinOperandRule )
110
- operand_right : BinOperandRule = dataclasses .field (default_factory = BinOperandRule )
111
-
112
- # Whether to assume the resulting syntax is wrapped by some bracket operators.
113
- # If True, the parent operator can avoid wrapping this operator by parentheses.
114
- is_wrapped : bool = False
115
-
116
-
117
- _BIN_OP_RULES : dict [type [ast .operator ], BinOpRule ] = {
118
- ast .Pow : BinOpRule (
119
- "" ,
120
- "^{" ,
121
- "}" ,
122
- operand_left = BinOperandRule (force = True ),
123
- operand_right = BinOperandRule (wrap = False ),
124
- ),
125
- ast .Mult : BinOpRule ("" , " " , "" ),
126
- ast .MatMult : BinOpRule ("" , " " , "" ),
127
- ast .Div : BinOpRule (
128
- r"\frac{" ,
129
- "}{" ,
130
- "}" ,
131
- operand_left = BinOperandRule (wrap = False ),
132
- operand_right = BinOperandRule (wrap = False ),
133
- ),
134
- ast .FloorDiv : BinOpRule (
135
- r"\left\lfloor\frac{" ,
136
- "}{" ,
137
- r"}\right\rfloor" ,
138
- operand_left = BinOperandRule (wrap = False ),
139
- operand_right = BinOperandRule (wrap = False ),
140
- is_wrapped = True ,
141
- ),
142
- ast .Mod : BinOpRule (
143
- "" , r" \mathbin{\%} " , "" , operand_right = BinOperandRule (force = True )
144
- ),
145
- ast .Add : BinOpRule ("" , " + " , "" ),
146
- ast .Sub : BinOpRule ("" , " - " , "" , operand_right = BinOperandRule (force = True )),
147
- ast .LShift : BinOpRule ("" , r" \ll " , "" , operand_right = BinOperandRule (force = True )),
148
- ast .RShift : BinOpRule ("" , r" \gg " , "" , operand_right = BinOperandRule (force = True )),
149
- ast .BitAnd : BinOpRule ("" , r" \mathbin{\&} " , "" ),
150
- ast .BitXor : BinOpRule ("" , r" \oplus " , "" ),
151
- ast .BitOr : BinOpRule ("" , r" \mathbin{|} " , "" ),
152
- }
153
-
154
- # Typeset for BinOp of sets.
155
- _SET_BIN_OP_RULES : dict [type [ast .operator ], BinOpRule ] = {
156
- ** _BIN_OP_RULES ,
157
- ast .Sub : BinOpRule (
158
- "" , r" \setminus " , "" , operand_right = BinOperandRule (force = True )
159
- ),
160
- ast .BitAnd : BinOpRule ("" , r" \cap " , "" ),
161
- ast .BitXor : BinOpRule ("" , r" \mathbin{\triangle} " , "" ),
162
- ast .BitOr : BinOpRule ("" , r" \cup " , "" ),
163
- }
164
-
165
- _UNARY_OPS : dict [type [ast .unaryop ], str ] = {
166
- ast .Invert : r"\mathord{\sim} " ,
167
- ast .UAdd : "+" , # Explicitly adds the $+$ operator.
168
- ast .USub : "-" ,
169
- ast .Not : r"\lnot " ,
170
- }
171
-
172
- _COMPARE_OPS : dict [type [ast .cmpop ], str ] = {
173
- ast .Eq : "=" ,
174
- ast .Gt : ">" ,
175
- ast .GtE : r"\ge" ,
176
- ast .In : r"\in" ,
177
- ast .Is : r"\equiv" ,
178
- ast .IsNot : r"\not\equiv" ,
179
- ast .Lt : "<" ,
180
- ast .LtE : r"\le" ,
181
- ast .NotEq : r"\ne" ,
182
- ast .NotIn : r"\notin" ,
183
- }
184
-
185
- # Typeset for Compare of sets.
186
- _SET_COMPARE_OPS : dict [type [ast .cmpop ], str ] = {
187
- ** _COMPARE_OPS ,
188
- ast .Gt : r"\supset" ,
189
- ast .GtE : r"\supseteq" ,
190
- ast .Lt : r"\subset" ,
191
- ast .LtE : r"\subseteq" ,
192
- }
193
-
194
- _BOOL_OPS : dict [type [ast .boolop ], str ] = {
195
- ast .And : r"\land" ,
196
- ast .Or : r"\lor" ,
197
- }
6
+
7
+ from latexify import analyzers , ast_utils , exceptions
8
+ from latexify .codegen import codegen_utils , expression_rules , identifier_converter
198
9
199
10
200
11
class ExpressionCodegen (ast .NodeVisitor ):
201
12
"""Codegen for single expressions."""
202
13
203
14
_identifier_converter : identifier_converter .IdentifierConverter
204
15
205
- _bin_op_rules : dict [type [ast .operator ], BinOpRule ]
16
+ _bin_op_rules : dict [type [ast .operator ], expression_rules . BinOpRule ]
206
17
_compare_ops : dict [type [ast .cmpop ], str ]
207
18
208
19
def __init__ (
@@ -219,8 +30,16 @@ def __init__(
219
30
use_math_symbols = use_math_symbols
220
31
)
221
32
222
- self ._bin_op_rules = _SET_BIN_OP_RULES if use_set_symbols else _BIN_OP_RULES
223
- self ._compare_ops = _SET_COMPARE_OPS if use_set_symbols else _COMPARE_OPS
33
+ self ._bin_op_rules = (
34
+ expression_rules .SET_BIN_OP_RULES
35
+ if use_set_symbols
36
+ else expression_rules .BIN_OP_RULES
37
+ )
38
+ self ._compare_ops = (
39
+ expression_rules .SET_COMPARE_OPS
40
+ if use_set_symbols
41
+ else expression_rules .COMPARE_OPS
42
+ )
224
43
225
44
def generic_visit (self , node : ast .AST ) -> str :
226
45
raise exceptions .LatexifyNotSupportedError (
@@ -420,17 +239,21 @@ def visit_Call(self, node: ast.Call) -> str:
420
239
return special_latex
421
240
422
241
# Obtains the codegen rule.
423
- rule = constants .BUILTIN_FUNCS .get (func_name ) if func_name is not None else None
242
+ rule = (
243
+ expression_rules .BUILTIN_FUNCS .get (func_name )
244
+ if func_name is not None
245
+ else None
246
+ )
424
247
425
248
if rule is None :
426
- rule = constants .FunctionRule (self .visit (node .func ))
249
+ rule = expression_rules .FunctionRule (self .visit (node .func ))
427
250
428
251
if rule .is_unary and len (node .args ) == 1 :
429
252
# Unary function. Applies the same wrapping policy with the unary operators.
430
253
# NOTE(odashi):
431
254
# Factorial "x!" is treated as a special case: it requires both inner/outer
432
255
# parentheses for correct interpretation.
433
- precedence = _get_precedence (node )
256
+ precedence = expression_rules . get_precedence (node )
434
257
arg = node .args [0 ]
435
258
force_wrap = isinstance (arg , ast .Call ) and (
436
259
func_name == "factorial"
@@ -507,7 +330,7 @@ def _wrap_operand(
507
330
LaTeX form of `child`, with or without surrounding parentheses.
508
331
"""
509
332
latex = self .visit (child )
510
- child_prec = _get_precedence (child )
333
+ child_prec = expression_rules . get_precedence (child )
511
334
512
335
if child_prec < parent_prec or force_wrap and child_prec == parent_prec :
513
336
return rf"\mathopen{{}}\left( { latex } \mathclose{{}}\right)"
@@ -518,7 +341,7 @@ def _wrap_binop_operand(
518
341
self ,
519
342
child : ast .expr ,
520
343
parent_prec : int ,
521
- operand_rule : BinOperandRule ,
344
+ operand_rule : expression_rules . BinOperandRule ,
522
345
) -> str :
523
346
"""Wraps the operand subtree of BinOp with parentheses.
524
347
@@ -536,7 +359,7 @@ def _wrap_binop_operand(
536
359
if isinstance (child , ast .Call ):
537
360
child_fn_name = ast_utils .extract_function_name_or_none (child )
538
361
rule = (
539
- constants .BUILTIN_FUNCS .get (child_fn_name )
362
+ expression_rules .BUILTIN_FUNCS .get (child_fn_name )
540
363
if child_fn_name is not None
541
364
else None
542
365
)
@@ -548,10 +371,10 @@ def _wrap_binop_operand(
548
371
549
372
latex = self .visit (child )
550
373
551
- if _BIN_OP_RULES [type (child .op )].is_wrapped :
374
+ if expression_rules . BIN_OP_RULES [type (child .op )].is_wrapped :
552
375
return latex
553
376
554
- child_prec = _get_precedence (child )
377
+ child_prec = expression_rules . get_precedence (child )
555
378
556
379
if child_prec > parent_prec or (
557
380
child_prec == parent_prec and not operand_rule .force
@@ -562,20 +385,20 @@ def _wrap_binop_operand(
562
385
563
386
def visit_BinOp (self , node : ast .BinOp ) -> str :
564
387
"""Visit a BinOp node."""
565
- prec = _get_precedence (node )
388
+ prec = expression_rules . get_precedence (node )
566
389
rule = self ._bin_op_rules [type (node .op )]
567
390
lhs = self ._wrap_binop_operand (node .left , prec , rule .operand_left )
568
391
rhs = self ._wrap_binop_operand (node .right , prec , rule .operand_right )
569
392
return f"{ rule .latex_left } { lhs } { rule .latex_middle } { rhs } { rule .latex_right } "
570
393
571
394
def visit_UnaryOp (self , node : ast .UnaryOp ) -> str :
572
395
"""Visit a UnaryOp node."""
573
- latex = self ._wrap_operand (node .operand , _get_precedence (node ))
574
- return _UNARY_OPS [type (node .op )] + latex
396
+ latex = self ._wrap_operand (node .operand , expression_rules . get_precedence (node ))
397
+ return expression_rules . UNARY_OPS [type (node .op )] + latex
575
398
576
399
def visit_Compare (self , node : ast .Compare ) -> str :
577
400
"""Visit a Compare node."""
578
- parent_prec = _get_precedence (node )
401
+ parent_prec = expression_rules . get_precedence (node )
579
402
lhs = self ._wrap_operand (node .left , parent_prec )
580
403
ops = [self ._compare_ops [type (x )] for x in node .ops ]
581
404
rhs = [self ._wrap_operand (x , parent_prec ) for x in node .comparators ]
@@ -584,9 +407,9 @@ def visit_Compare(self, node: ast.Compare) -> str:
584
407
585
408
def visit_BoolOp (self , node : ast .BoolOp ) -> str :
586
409
"""Visit a BoolOp node."""
587
- parent_prec = _get_precedence (node )
410
+ parent_prec = expression_rules . get_precedence (node )
588
411
values = [self ._wrap_operand (x , parent_prec ) for x in node .values ]
589
- op = f" { _BOOL_OPS [type (node .op )]} "
412
+ op = f" { expression_rules . BOOL_OPS [type (node .op )]} "
590
413
return op .join (values )
591
414
592
415
def visit_IfExp (self , node : ast .IfExp ) -> str :
0 commit comments