52
52
ast .Or : 10 ,
53
53
}
54
54
55
+ # NOTE(odashi):
56
+ # Function invocation is treated as a unary operator with a higher precedence.
57
+ # This ensures that the argument with a unary operator is wrapped:
58
+ # exp(x) --> \exp x
59
+ # exp(-x) --> \exp (-x)
60
+ # -exp(x) --> - \exp x
61
+ _CALL_PRECEDENCE = _PRECEDENCES [ast .UAdd ] + 1
62
+
55
63
56
64
def _get_precedence (node : ast .AST ) -> int :
57
65
"""Obtains the precedence of the subtree.
@@ -63,6 +71,9 @@ def _get_precedence(node: ast.AST) -> int:
63
71
If `node` is a subtree with some operator, returns the precedence of the
64
72
operator. Otherwise, returns a number larger enough from other precedences.
65
73
"""
74
+ if isinstance (node , ast .Call ):
75
+ return _CALL_PRECEDENCE
76
+
66
77
if isinstance (node , (ast .BoolOp , ast .BinOp , ast .UnaryOp )):
67
78
return _PRECEDENCES [type (node .op )]
68
79
@@ -289,38 +300,34 @@ def visit_Return(self, node: ast.Return) -> str:
289
300
290
301
def visit_Tuple (self , node : ast .Tuple ) -> str :
291
302
elts = [self .visit (i ) for i in node .elts ]
292
- return (
293
- r"\mathopen{}\left( "
294
- + r"\space,\space " .join (elts )
295
- + r"\mathclose{}\right) "
296
- )
303
+ return r"\mathopen{}\left( " + r", " .join (elts ) + r" \mathclose{}\right)"
297
304
298
305
def visit_List (self , node : ast .List ) -> str :
299
306
elts = [self .visit (i ) for i in node .elts ]
300
- return r"\left[ " + r"\space,\space " .join (elts ) + r"\ right] "
307
+ return r"\mathopen{}\ left[ " + r", " .join (elts ) + r" \mathclose{}\ right]"
301
308
302
309
def visit_Set (self , node : ast .Set ) -> str :
303
310
elts = [self .visit (i ) for i in node .elts ]
304
- return r"\left\{ " + r"\space,\space " .join (elts ) + r"\ right\} "
311
+ return r"\mathopen{}\ left\{ " + r", " .join (elts ) + r" \mathclose{}\ right\}"
305
312
306
313
def visit_ListComp (self , node : ast .ListComp ) -> str :
307
314
generators = [self .visit (comp ) for comp in node .generators ]
308
315
return (
309
- r"\left[ "
316
+ r"\mathopen{}\ left[ "
310
317
+ self .visit (node .elt )
311
318
+ r" \mid "
312
319
+ ", " .join (generators )
313
- + r" \right]"
320
+ + r" \mathclose{}\ right]"
314
321
)
315
322
316
323
def visit_SetComp (self , node : ast .SetComp ) -> str :
317
324
generators = [self .visit (comp ) for comp in node .generators ]
318
325
return (
319
- r"\left\{ "
326
+ r"\mathopen{}\ left\{ "
320
327
+ self .visit (node .elt )
321
328
+ r" \mid "
322
329
+ ", " .join (generators )
323
- + r" \right\}"
330
+ + r" \mathclose{}\ right\}"
324
331
)
325
332
326
333
def visit_comprehension (self , node : ast .comprehension ) -> str :
@@ -347,10 +354,16 @@ def _generate_sum_prod(self, node: ast.Call) -> str | None:
347
354
return None
348
355
349
356
name = ast_utils .extract_function_name_or_none (node )
350
- assert name is not None
357
+ assert name in ("fsum" , "sum" , "prod" )
358
+
359
+ command = {
360
+ "fsum" : r"\sum" ,
361
+ "sum" : r"\sum" ,
362
+ "prod" : r"\prod" ,
363
+ }[name ]
351
364
352
365
elt , scripts = self ._get_sum_prod_info (node .args [0 ])
353
- scripts_str = [rf"\ { name } _{{{ lo } }}^{{{ up } }}" for lo , up in scripts ]
366
+ scripts_str = [rf"{ command } _{{{ lo } }}^{{{ up } }}" for lo , up in scripts ]
354
367
return (
355
368
" " .join (scripts_str )
356
369
+ rf" \mathopen{{}}\left({{{ elt } }}\mathclose{{}}\right)"
@@ -403,7 +416,7 @@ def visit_Call(self, node: ast.Call) -> str:
403
416
func_name = ast_utils .extract_function_name_or_none (node )
404
417
405
418
# Special treatments for some functions.
406
- if func_name in ("sum" , "prod" ):
419
+ if func_name in ("fsum" , " sum" , "prod" ):
407
420
special_latex = self ._generate_sum_prod (node )
408
421
elif func_name in ("array" , "ndarray" ):
409
422
special_latex = self ._generate_matrix (node )
@@ -413,17 +426,38 @@ def visit_Call(self, node: ast.Call) -> str:
413
426
if special_latex is not None :
414
427
return special_latex
415
428
416
- # Function signature (possibly an expression).
417
- default_func_str = self .visit (node .func )
418
-
419
- # Obtains wrapper syntax: sqrt -> "\sqrt{" and "}"
420
- lstr , rstr = constants .BUILTIN_FUNCS .get (
421
- func_name ,
422
- (default_func_str + r"\mathopen{}\left(" , r"\mathclose{}\right)" ),
423
- )
429
+ # Obtains the codegen rule.
430
+ rule = constants .BUILTIN_FUNCS .get (func_name )
431
+ if rule is None :
432
+ rule = constants .FunctionRule (self .visit (node .func ))
433
+
434
+ if rule .is_unary and len (node .args ) == 1 :
435
+ # Unary function. Applies the same wrapping policy with the unary operators.
436
+ # NOTE(odashi):
437
+ # Factorial "x!" is treated as a special case: it requires both inner/outer
438
+ # parentheses for correct interpretation.
439
+ precedence = _get_precedence (node )
440
+ arg = node .args [0 ]
441
+ force_wrap = isinstance (arg , ast .Call ) and (
442
+ func_name == "factorial"
443
+ or ast_utils .extract_function_name_or_none (arg ) == "factorial"
444
+ )
445
+ arg_latex = self ._wrap_operand (arg , precedence , force_wrap )
446
+ elements = [rule .left , arg_latex , rule .right ]
447
+ else :
448
+ arg_latex = ", " .join (self .visit (arg ) for arg in node .args )
449
+ if rule .is_wrapped :
450
+ elements = [rule .left , arg_latex , rule .right ]
451
+ else :
452
+ elements = [
453
+ rule .left ,
454
+ r"\mathopen{}\left(" ,
455
+ arg_latex ,
456
+ r"\mathclose{}\right)" ,
457
+ rule .right ,
458
+ ]
424
459
425
- arg_strs = [self .visit (arg ) for arg in node .args ]
426
- return lstr + ", " .join (arg_strs ) + rstr
460
+ return " " .join (x for x in elements if x )
427
461
428
462
def visit_Attribute (self , node : ast .Attribute ) -> str :
429
463
vstr = self .visit (node .value )
@@ -481,20 +515,26 @@ def visit_NameConstant(self, node: ast.NameConstant) -> str:
481
515
def visit_Ellipsis (self , node : ast .Ellipsis ) -> str :
482
516
return self ._convert_constant (...)
483
517
484
- def _wrap_operand (self , child : ast .expr , parent_prec : int ) -> str :
518
+ def _wrap_operand (
519
+ self , child : ast .expr , parent_prec : int , force_wrap : bool = False
520
+ ) -> str :
485
521
"""Wraps the operand subtree with parentheses.
486
522
487
523
Args:
488
524
child: Operand subtree.
489
525
parent_prec: Precedence of the parent operator.
526
+ force_wrap: Whether to wrap the operand or not when the precedence is equal.
490
527
491
528
Returns:
492
529
LaTeX form of `child`, with or without surrounding parentheses.
493
530
"""
494
531
latex = self .visit (child )
495
- if _get_precedence (child ) >= parent_prec :
496
- return latex
497
- return rf"\mathopen{{}}\left( { latex } \mathclose{{}}\right)"
532
+ child_prec = _get_precedence (child )
533
+
534
+ if child_prec < parent_prec or force_wrap and child_prec == parent_prec :
535
+ return rf"\mathopen{{}}\left( { latex } \mathclose{{}}\right)"
536
+
537
+ return latex
498
538
499
539
def _wrap_binop_operand (
500
540
self ,
@@ -515,6 +555,13 @@ def _wrap_binop_operand(
515
555
if not operand_rule .wrap :
516
556
return self .visit (child )
517
557
558
+ if isinstance (child , ast .Call ):
559
+ rule = constants .BUILTIN_FUNCS .get (
560
+ ast_utils .extract_function_name_or_none (child )
561
+ )
562
+ if rule is not None and rule .is_wrapped :
563
+ return self .visit (child )
564
+
518
565
if not isinstance (child , ast .BinOp ):
519
566
return self ._wrap_operand (child , parent_prec )
520
567
0 commit comments