7
7
import pytest
8
8
9
9
from latexify import ast_utils , exceptions , test_utils
10
- from latexify .codegen import ExpressionCodegen
10
+ from latexify .codegen import expression_codegen
11
11
12
12
13
13
def test_generic_visit () -> None :
@@ -18,7 +18,7 @@ class UnknownNode(ast.AST):
18
18
exceptions .LatexifyNotSupportedError ,
19
19
match = r"^Unsupported AST: UnknownNode$" ,
20
20
):
21
- ExpressionCodegen ().visit (UnknownNode ())
21
+ expression_codegen . ExpressionCodegen ().visit (UnknownNode ())
22
22
23
23
24
24
@pytest .mark .parametrize (
@@ -33,7 +33,7 @@ class UnknownNode(ast.AST):
33
33
def test_visit_tuple (code : str , latex : str ) -> None :
34
34
node = ast_utils .parse_expr (code )
35
35
assert isinstance (node , ast .Tuple )
36
- assert ExpressionCodegen ().visit (node ) == latex
36
+ assert expression_codegen . ExpressionCodegen ().visit (node ) == latex
37
37
38
38
39
39
@pytest .mark .parametrize (
@@ -48,7 +48,7 @@ def test_visit_tuple(code: str, latex: str) -> None:
48
48
def test_visit_list (code : str , latex : str ) -> None :
49
49
node = ast_utils .parse_expr (code )
50
50
assert isinstance (node , ast .List )
51
- assert ExpressionCodegen ().visit (node ) == latex
51
+ assert expression_codegen . ExpressionCodegen ().visit (node ) == latex
52
52
53
53
54
54
@pytest .mark .parametrize (
@@ -64,7 +64,7 @@ def test_visit_list(code: str, latex: str) -> None:
64
64
def test_visit_set (code : str , latex : str ) -> None :
65
65
node = ast_utils .parse_expr (code )
66
66
assert isinstance (node , ast .Set )
67
- assert ExpressionCodegen ().visit (node ) == latex
67
+ assert expression_codegen . ExpressionCodegen ().visit (node ) == latex
68
68
69
69
70
70
@pytest .mark .parametrize (
@@ -114,7 +114,7 @@ def test_visit_set(code: str, latex: str) -> None:
114
114
def test_visit_listcomp (code : str , latex : str ) -> None :
115
115
node = ast_utils .parse_expr (code )
116
116
assert isinstance (node , ast .ListComp )
117
- assert ExpressionCodegen ().visit (node ) == latex
117
+ assert expression_codegen . ExpressionCodegen ().visit (node ) == latex
118
118
119
119
120
120
@pytest .mark .parametrize (
@@ -164,7 +164,7 @@ def test_visit_listcomp(code: str, latex: str) -> None:
164
164
def test_visit_setcomp (code : str , latex : str ) -> None :
165
165
node = ast_utils .parse_expr (code )
166
166
assert isinstance (node , ast .SetComp )
167
- assert ExpressionCodegen ().visit (node ) == latex
167
+ assert expression_codegen . ExpressionCodegen ().visit (node ) == latex
168
168
169
169
170
170
@pytest .mark .parametrize (
@@ -215,7 +215,7 @@ def test_visit_setcomp(code: str, latex: str) -> None:
215
215
def test_visit_call (code : str , latex : str ) -> None :
216
216
node = ast_utils .parse_expr (code )
217
217
assert isinstance (node , ast .Call )
218
- assert ExpressionCodegen ().visit (node ) == latex
218
+ assert expression_codegen . ExpressionCodegen ().visit (node ) == latex
219
219
220
220
221
221
@pytest .mark .parametrize (
@@ -330,7 +330,9 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
330
330
for src_fn , dest_fn in [("fsum" , r"\sum" ), ("sum" , r"\sum" ), ("prod" , r"\prod" )]:
331
331
node = ast_utils .parse_expr (src_fn + src_suffix )
332
332
assert isinstance (node , ast .Call )
333
- assert ExpressionCodegen ().visit (node ) == dest_fn + dest_suffix
333
+ assert (
334
+ expression_codegen .ExpressionCodegen ().visit (node ) == dest_fn + dest_suffix
335
+ )
334
336
335
337
336
338
@pytest .mark .parametrize (
@@ -381,7 +383,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
381
383
def test_visit_call_sum_prod_multiple_comprehension (code : str , latex : str ) -> None :
382
384
node = ast_utils .parse_expr (code )
383
385
assert isinstance (node , ast .Call )
384
- assert ExpressionCodegen ().visit (node ) == latex
386
+ assert expression_codegen . ExpressionCodegen ().visit (node ) == latex
385
387
386
388
387
389
@pytest .mark .parametrize (
@@ -407,7 +409,9 @@ def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
407
409
for src_fn , dest_fn in [("sum" , r"\sum" ), ("prod" , r"\prod" )]:
408
410
node = ast_utils .parse_expr (src_fn + src_suffix )
409
411
assert isinstance (node , ast .Call )
410
- assert ExpressionCodegen ().visit (node ) == dest_fn + dest_suffix
412
+ assert (
413
+ expression_codegen .ExpressionCodegen ().visit (node ) == dest_fn + dest_suffix
414
+ )
411
415
412
416
413
417
@pytest .mark .parametrize (
@@ -442,7 +446,7 @@ def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
442
446
def test_if_then_else (code : str , latex : str ) -> None :
443
447
node = ast_utils .parse_expr (code )
444
448
assert isinstance (node , ast .IfExp )
445
- assert ExpressionCodegen ().visit (node ) == latex
449
+ assert expression_codegen . ExpressionCodegen ().visit (node ) == latex
446
450
447
451
448
452
@pytest .mark .parametrize (
@@ -625,7 +629,7 @@ def test_if_then_else(code: str, latex: str) -> None:
625
629
def test_visit_binop (code : str , latex : str ) -> None :
626
630
tree = ast_utils .parse_expr (code )
627
631
assert isinstance (tree , ast .BinOp )
628
- assert ExpressionCodegen ().visit (tree ) == latex
632
+ assert expression_codegen . ExpressionCodegen ().visit (tree ) == latex
629
633
630
634
631
635
@pytest .mark .parametrize (
@@ -664,7 +668,7 @@ def test_visit_binop(code: str, latex: str) -> None:
664
668
def test_visit_unaryop (code : str , latex : str ) -> None :
665
669
tree = ast_utils .parse_expr (code )
666
670
assert isinstance (tree , ast .UnaryOp )
667
- assert ExpressionCodegen ().visit (tree ) == latex
671
+ assert expression_codegen . ExpressionCodegen ().visit (tree ) == latex
668
672
669
673
670
674
@pytest .mark .parametrize (
@@ -718,7 +722,7 @@ def test_visit_unaryop(code: str, latex: str) -> None:
718
722
def test_visit_compare (code : str , latex : str ) -> None :
719
723
tree = ast_utils .parse_expr (code )
720
724
assert isinstance (tree , ast .Compare )
721
- assert ExpressionCodegen ().visit (tree ) == latex
725
+ assert expression_codegen . ExpressionCodegen ().visit (tree ) == latex
722
726
723
727
724
728
@pytest .mark .parametrize (
@@ -764,7 +768,7 @@ def test_visit_compare(code: str, latex: str) -> None:
764
768
def test_visit_boolop (code : str , latex : str ) -> None :
765
769
tree = ast_utils .parse_expr (code )
766
770
assert isinstance (tree , ast .BoolOp )
767
- assert ExpressionCodegen ().visit (tree ) == latex
771
+ assert expression_codegen . ExpressionCodegen ().visit (tree ) == latex
768
772
769
773
770
774
@test_utils .require_at_most (7 )
@@ -789,7 +793,7 @@ def test_visit_boolop(code: str, latex: str) -> None:
789
793
def test_visit_constant_lagacy (code : str , cls : type [ast .expr ], latex : str ) -> None :
790
794
tree = ast_utils .parse_expr (code )
791
795
assert isinstance (tree , cls )
792
- assert ExpressionCodegen ().visit (tree ) == latex
796
+ assert expression_codegen . ExpressionCodegen ().visit (tree ) == latex
793
797
794
798
795
799
@test_utils .require_at_least (8 )
@@ -814,7 +818,7 @@ def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> No
814
818
def test_visit_constant (code : str , latex : str ) -> None :
815
819
tree = ast_utils .parse_expr (code )
816
820
assert isinstance (tree , ast .Constant )
817
- assert ExpressionCodegen ().visit (tree ) == latex
821
+ assert expression_codegen . ExpressionCodegen ().visit (tree ) == latex
818
822
819
823
820
824
@pytest .mark .parametrize (
@@ -830,7 +834,7 @@ def test_visit_constant(code: str, latex: str) -> None:
830
834
def test_visit_subscript (code : str , latex : str ) -> None :
831
835
tree = ast_utils .parse_expr (code )
832
836
assert isinstance (tree , ast .Subscript )
833
- assert ExpressionCodegen ().visit (tree ) == latex
837
+ assert expression_codegen . ExpressionCodegen ().visit (tree ) == latex
834
838
835
839
836
840
@pytest .mark .parametrize (
@@ -845,7 +849,9 @@ def test_visit_subscript(code: str, latex: str) -> None:
845
849
def test_visit_binop_use_set_symbols (code : str , latex : str ) -> None :
846
850
tree = ast_utils .parse_expr (code )
847
851
assert isinstance (tree , ast .BinOp )
848
- assert ExpressionCodegen (use_set_symbols = True ).visit (tree ) == latex
852
+ assert (
853
+ expression_codegen .ExpressionCodegen (use_set_symbols = True ).visit (tree ) == latex
854
+ )
849
855
850
856
851
857
@pytest .mark .parametrize (
@@ -860,7 +866,9 @@ def test_visit_binop_use_set_symbols(code: str, latex: str) -> None:
860
866
def test_visit_compare_use_set_symbols (code : str , latex : str ) -> None :
861
867
tree = ast_utils .parse_expr (code )
862
868
assert isinstance (tree , ast .Compare )
863
- assert ExpressionCodegen (use_set_symbols = True ).visit (tree ) == latex
869
+ assert (
870
+ expression_codegen .ExpressionCodegen (use_set_symbols = True ).visit (tree ) == latex
871
+ )
864
872
865
873
866
874
@pytest .mark .parametrize (
@@ -906,4 +914,59 @@ def test_visit_compare_use_set_symbols(code: str, latex: str) -> None:
906
914
def test_numpy_array (code : str , latex : str ) -> None :
907
915
tree = ast_utils .parse_expr (code )
908
916
assert isinstance (tree , ast .Call )
909
- assert ExpressionCodegen ().visit (tree ) == latex
917
+ assert expression_codegen .ExpressionCodegen ().visit (tree ) == latex
918
+
919
+
920
+ @pytest .mark .parametrize (
921
+ "code,latex" ,
922
+ [
923
+ ("zeros(0)" , r"\mathbf{0}^{1 \times 0}" ),
924
+ ("zeros(1)" , r"\mathbf{0}^{1 \times 1}" ),
925
+ ("zeros(2)" , r"\mathbf{0}^{1 \times 2}" ),
926
+ ("zeros(())" , r"0" ),
927
+ ("zeros((0,))" , r"\mathbf{0}^{1 \times 0}" ),
928
+ ("zeros((1,))" , r"\mathbf{0}^{1 \times 1}" ),
929
+ ("zeros((2,))" , r"\mathbf{0}^{1 \times 2}" ),
930
+ ("zeros((0, 0))" , r"\mathbf{0}^{0 \times 0}" ),
931
+ ("zeros((1, 1))" , r"\mathbf{0}^{1 \times 1}" ),
932
+ ("zeros((2, 3))" , r"\mathbf{0}^{2 \times 3}" ),
933
+ ("zeros((0, 0, 0))" , r"\mathbf{0}^{0 \times 0 \times 0}" ),
934
+ ("zeros((1, 1, 1))" , r"\mathbf{0}^{1 \times 1 \times 1}" ),
935
+ ("zeros((2, 3, 5))" , r"\mathbf{0}^{2 \times 3 \times 5}" ),
936
+ # Unsupported
937
+ ("zeros()" , r"\mathrm{zeros} \mathopen{}\left( \mathclose{}\right)" ),
938
+ ("zeros(x)" , r"\mathrm{zeros} \mathopen{}\left( x \mathclose{}\right)" ),
939
+ ("zeros(0, x)" , r"\mathrm{zeros} \mathopen{}\left( 0, x \mathclose{}\right)" ),
940
+ (
941
+ "zeros((x,))" ,
942
+ r"\mathrm{zeros} \mathopen{}\left("
943
+ r" \mathopen{}\left( x \mathclose{}\right)"
944
+ r" \mathclose{}\right)" ,
945
+ ),
946
+ ],
947
+ )
948
+ def test_zeros (code : str , latex : str ) -> None :
949
+ tree = ast_utils .parse_expr (code )
950
+ assert isinstance (tree , ast .Call )
951
+ assert expression_codegen .ExpressionCodegen ().visit (tree ) == latex
952
+
953
+
954
+ @pytest .mark .parametrize (
955
+ "code,latex" ,
956
+ [
957
+ ("identity(0)" , r"\mathbf{I}_{0}" ),
958
+ ("identity(1)" , r"\mathbf{I}_{1}" ),
959
+ ("identity(2)" , r"\mathbf{I}_{2}" ),
960
+ # Unsupported
961
+ ("identity()" , r"\mathrm{identity} \mathopen{}\left( \mathclose{}\right)" ),
962
+ ("identity(x)" , r"\mathrm{identity} \mathopen{}\left( x \mathclose{}\right)" ),
963
+ (
964
+ "identity(0, x)" ,
965
+ r"\mathrm{identity} \mathopen{}\left( 0, x \mathclose{}\right)" ,
966
+ ),
967
+ ],
968
+ )
969
+ def test_identity (code : str , latex : str ) -> None :
970
+ tree = ast_utils .parse_expr (code )
971
+ assert isinstance (tree , ast .Call )
972
+ assert expression_codegen .ExpressionCodegen ().visit (tree ) == latex
0 commit comments