@@ -296,7 +296,11 @@ def visit_Assign(self, node: ast.Assign) -> str:
296
296
return " = " .join (operands )
297
297
298
298
def visit_Return (self , node : ast .Return ) -> str :
299
- return self .visit (node .value )
299
+ return (
300
+ self .visit (node .value )
301
+ if node .value is not None
302
+ else self ._convert_constant (None )
303
+ )
300
304
301
305
def visit_Tuple (self , node : ast .Tuple ) -> str :
302
306
elts = [self .visit (i ) for i in node .elts ]
@@ -401,15 +405,16 @@ def generate_matrix_from_array(data: list[list[str]]) -> str:
401
405
402
406
ncols = len (row0 .elts )
403
407
404
- if not all (
405
- isinstance (row , ast .List ) and len (row .elts ) == ncols for row in arg .elts
406
- ):
407
- # Length mismatch
408
- return None
408
+ rows : list [list [str ]] = []
409
409
410
- return generate_matrix_from_array (
411
- [[self .visit (x ) for x in row .elts ] for row in arg .elts ]
412
- )
410
+ for row in arg .elts :
411
+ if not isinstance (row , ast .List ) or len (row .elts ) != ncols :
412
+ # Length mismatch
413
+ return None
414
+
415
+ rows .append ([self .visit (x ) for x in row .elts ])
416
+
417
+ return generate_matrix_from_array (rows )
413
418
414
419
def visit_Call (self , node : ast .Call ) -> str :
415
420
"""Visit a call node."""
@@ -427,7 +432,8 @@ def visit_Call(self, node: ast.Call) -> str:
427
432
return special_latex
428
433
429
434
# Obtains the codegen rule.
430
- rule = constants .BUILTIN_FUNCS .get (func_name )
435
+ rule = constants .BUILTIN_FUNCS .get (func_name ) if func_name is not None else None
436
+
431
437
if rule is None :
432
438
rule = constants .FunctionRule (self .visit (node .func ))
433
439
@@ -556,8 +562,11 @@ def _wrap_binop_operand(
556
562
return self .visit (child )
557
563
558
564
if isinstance (child , ast .Call ):
559
- rule = constants .BUILTIN_FUNCS .get (
560
- ast_utils .extract_function_name_or_none (child )
565
+ child_fn_name = ast_utils .extract_function_name_or_none (child )
566
+ rule = (
567
+ constants .BUILTIN_FUNCS .get (child_fn_name )
568
+ if child_fn_name is not None
569
+ else None
561
570
)
562
571
if rule is not None and rule .is_wrapped :
563
572
return self .visit (child )
@@ -612,30 +621,35 @@ def visit_If(self, node: ast.If) -> str:
612
621
"""Visit an if node."""
613
622
latex = r"\left\{ \begin{array}{ll} "
614
623
615
- while isinstance (node , ast .If ):
616
- if len (node .body ) != 1 or len (node .orelse ) != 1 :
624
+ current_stmt : ast .stmt = node
625
+
626
+ while isinstance (current_stmt , ast .If ):
627
+ if len (current_stmt .body ) != 1 or len (current_stmt .orelse ) != 1 :
617
628
raise exceptions .LatexifySyntaxError (
618
629
"Multiple statements are not supported in If nodes."
619
630
)
620
631
621
- cond_latex = self .visit (node .test )
622
- true_latex = self .visit (node .body [0 ])
632
+ cond_latex = self .visit (current_stmt .test )
633
+ true_latex = self .visit (current_stmt .body [0 ])
623
634
latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ "
624
- node = node .orelse [0 ]
635
+ current_stmt = current_stmt .orelse [0 ]
625
636
626
- latex += self .visit (node )
637
+ latex += self .visit (current_stmt )
627
638
return latex + r", & \mathrm{otherwise} \end{array} \right."
628
639
629
640
def visit_IfExp (self , node : ast .IfExp ) -> str :
630
641
"""Visit an ifexp node"""
631
642
latex = r"\left\{ \begin{array}{ll} "
632
- while isinstance (node , ast .IfExp ):
633
- cond_latex = self .visit (node .test )
634
- true_latex = self .visit (node .body )
643
+
644
+ current_expr : ast .expr = node
645
+
646
+ while isinstance (current_expr , ast .IfExp ):
647
+ cond_latex = self .visit (current_expr .test )
648
+ true_latex = self .visit (current_expr .body )
635
649
latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ "
636
- node = node .orelse
650
+ current_expr = current_expr .orelse
637
651
638
- latex += self .visit (node )
652
+ latex += self .visit (current_expr )
639
653
return latex + r", & \mathrm{otherwise} \end{array} \right."
640
654
641
655
def _reduce_stop_parameter (self , node : ast .expr ) -> ast .expr :
@@ -768,7 +782,7 @@ def _get_sum_prod_info(
768
782
# Until 3.8
769
783
def visit_Index (self , node : ast .Index ) -> str :
770
784
"""Visitor for the Index nodes."""
771
- return self .visit (node .value )
785
+ return self .visit (node .value ) # type: ignore[attr-defined]
772
786
773
787
def _convert_nested_subscripts (self , node : ast .Subscript ) -> tuple [str , list [str ]]:
774
788
"""Helper function to convert nested subscription.
0 commit comments