Skip to content

Commit d0f2d98

Browse files
author
Zibing Zhang
committed
readds changes from google#148
2 parents 05383ca + c96e0aa commit d0f2d98

File tree

2 files changed

+142
-22
lines changed

2 files changed

+142
-22
lines changed

src/latexify/codegen/expression_codegen.py

+57
Original file line numberDiff line numberDiff line change
@@ -347,15 +347,72 @@ def generate_matrix_from_array(data: list[list[str]]) -> str:
347347

348348
return generate_matrix_from_array(rows)
349349

350+
def _generate_zeros(self, node: ast.Call) -> str | None:
351+
"""Generates LaTeX for numpy.zeros.
352+
Args:
353+
node: ast.Call node containing the appropriate method invocation.
354+
Returns:
355+
Generated LaTeX, or None if the node has unsupported syntax.
356+
"""
357+
name = ast_utils.extract_function_name_or_none(node)
358+
assert name == "zeros"
359+
360+
if len(node.args) != 1:
361+
return None
362+
363+
# All args to np.zeros should be numeric.
364+
if isinstance(node.args[0], ast.Tuple):
365+
dims = [ast_utils.extract_int_or_none(x) for x in node.args[0].elts]
366+
if any(x is None for x in dims):
367+
return None
368+
if not dims:
369+
return "0"
370+
if len(dims) == 1:
371+
dims = [1, dims[0]]
372+
373+
dims_latex = r" \times ".join(str(x) for x in dims)
374+
else:
375+
dim = ast_utils.extract_int_or_none(node.args[0])
376+
if not isinstance(dim, int):
377+
return None
378+
# 1 x N array of zeros
379+
dims_latex = rf"1 \times {dim}"
380+
381+
return rf"\mathbf{{0}}^{{{dims_latex}}}"
382+
383+
def _generate_identity(self, node: ast.Call) -> str | None:
384+
"""Generates LaTeX for numpy.identity.
385+
Args:
386+
node: ast.Call node containing the appropriate method invocation.
387+
Returns:
388+
Generated LaTeX, or None if the node has unsupported syntax.
389+
"""
390+
name = ast_utils.extract_function_name_or_none(node)
391+
assert name == "identity"
392+
393+
if len(node.args) != 1:
394+
return None
395+
396+
ndims = ast_utils.extract_int_or_none(node.args[0])
397+
if ndims is None:
398+
return None
399+
400+
return rf"\mathbf{{I}}_{{{ndims}}}"
401+
350402
def visit_Call(self, node: ast.Call) -> str:
351403
"""Visit a Call node."""
352404
func_name = ast_utils.extract_function_name_or_none(node)
353405

354406
# Special treatments for some functions.
407+
# TODO(odashi): Move these functions to some separate utility.
355408
if func_name in ("fsum", "sum", "prod"):
356409
special_latex = self._generate_sum_prod(node)
357410
elif func_name in ("array", "ndarray"):
358411
special_latex = self._generate_matrix(node)
412+
elif func_name == "zeros":
413+
special_latex = self._generate_zeros(node)
414+
elif func_name == "identity":
415+
special_latex = self._generate_identity(node)
359416
else:
360417
special_latex = None
361418

src/latexify/codegen/expression_codegen_test.py

+85-22
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88

99
from latexify import ast_utils, exceptions, test_utils
10-
from latexify.codegen import ExpressionCodegen
10+
from latexify.codegen import expression_codegen
1111

1212

1313
def test_generic_visit() -> None:
@@ -18,7 +18,7 @@ class UnknownNode(ast.AST):
1818
exceptions.LatexifyNotSupportedError,
1919
match=r"^Unsupported AST: UnknownNode$",
2020
):
21-
ExpressionCodegen().visit(UnknownNode())
21+
expression_codegen.ExpressionCodegen().visit(UnknownNode())
2222

2323

2424
@pytest.mark.parametrize(
@@ -33,7 +33,7 @@ class UnknownNode(ast.AST):
3333
def test_visit_tuple(code: str, latex: str) -> None:
3434
node = ast_utils.parse_expr(code)
3535
assert isinstance(node, ast.Tuple)
36-
assert ExpressionCodegen().visit(node) == latex
36+
assert expression_codegen.ExpressionCodegen().visit(node) == latex
3737

3838

3939
@pytest.mark.parametrize(
@@ -48,7 +48,7 @@ def test_visit_tuple(code: str, latex: str) -> None:
4848
def test_visit_list(code: str, latex: str) -> None:
4949
node = ast_utils.parse_expr(code)
5050
assert isinstance(node, ast.List)
51-
assert ExpressionCodegen().visit(node) == latex
51+
assert expression_codegen.ExpressionCodegen().visit(node) == latex
5252

5353

5454
@pytest.mark.parametrize(
@@ -64,7 +64,7 @@ def test_visit_list(code: str, latex: str) -> None:
6464
def test_visit_set(code: str, latex: str) -> None:
6565
node = ast_utils.parse_expr(code)
6666
assert isinstance(node, ast.Set)
67-
assert ExpressionCodegen().visit(node) == latex
67+
assert expression_codegen.ExpressionCodegen().visit(node) == latex
6868

6969

7070
@pytest.mark.parametrize(
@@ -114,7 +114,7 @@ def test_visit_set(code: str, latex: str) -> None:
114114
def test_visit_listcomp(code: str, latex: str) -> None:
115115
node = ast_utils.parse_expr(code)
116116
assert isinstance(node, ast.ListComp)
117-
assert ExpressionCodegen().visit(node) == latex
117+
assert expression_codegen.ExpressionCodegen().visit(node) == latex
118118

119119

120120
@pytest.mark.parametrize(
@@ -164,7 +164,7 @@ def test_visit_listcomp(code: str, latex: str) -> None:
164164
def test_visit_setcomp(code: str, latex: str) -> None:
165165
node = ast_utils.parse_expr(code)
166166
assert isinstance(node, ast.SetComp)
167-
assert ExpressionCodegen().visit(node) == latex
167+
assert expression_codegen.ExpressionCodegen().visit(node) == latex
168168

169169

170170
@pytest.mark.parametrize(
@@ -215,7 +215,7 @@ def test_visit_setcomp(code: str, latex: str) -> None:
215215
def test_visit_call(code: str, latex: str) -> None:
216216
node = ast_utils.parse_expr(code)
217217
assert isinstance(node, ast.Call)
218-
assert ExpressionCodegen().visit(node) == latex
218+
assert expression_codegen.ExpressionCodegen().visit(node) == latex
219219

220220

221221
@pytest.mark.parametrize(
@@ -330,7 +330,9 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
330330
for src_fn, dest_fn in [("fsum", r"\sum"), ("sum", r"\sum"), ("prod", r"\prod")]:
331331
node = ast_utils.parse_expr(src_fn + src_suffix)
332332
assert isinstance(node, ast.Call)
333-
assert ExpressionCodegen().visit(node) == dest_fn + dest_suffix
333+
assert (
334+
expression_codegen.ExpressionCodegen().visit(node) == dest_fn + dest_suffix
335+
)
334336

335337

336338
@pytest.mark.parametrize(
@@ -381,7 +383,7 @@ def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
381383
def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None:
382384
node = ast_utils.parse_expr(code)
383385
assert isinstance(node, ast.Call)
384-
assert ExpressionCodegen().visit(node) == latex
386+
assert expression_codegen.ExpressionCodegen().visit(node) == latex
385387

386388

387389
@pytest.mark.parametrize(
@@ -407,7 +409,9 @@ def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
407409
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
408410
node = ast_utils.parse_expr(src_fn + src_suffix)
409411
assert isinstance(node, ast.Call)
410-
assert ExpressionCodegen().visit(node) == dest_fn + dest_suffix
412+
assert (
413+
expression_codegen.ExpressionCodegen().visit(node) == dest_fn + dest_suffix
414+
)
411415

412416

413417
@pytest.mark.parametrize(
@@ -442,7 +446,7 @@ def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
442446
def test_if_then_else(code: str, latex: str) -> None:
443447
node = ast_utils.parse_expr(code)
444448
assert isinstance(node, ast.IfExp)
445-
assert ExpressionCodegen().visit(node) == latex
449+
assert expression_codegen.ExpressionCodegen().visit(node) == latex
446450

447451

448452
@pytest.mark.parametrize(
@@ -625,7 +629,7 @@ def test_if_then_else(code: str, latex: str) -> None:
625629
def test_visit_binop(code: str, latex: str) -> None:
626630
tree = ast_utils.parse_expr(code)
627631
assert isinstance(tree, ast.BinOp)
628-
assert ExpressionCodegen().visit(tree) == latex
632+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
629633

630634

631635
@pytest.mark.parametrize(
@@ -664,7 +668,7 @@ def test_visit_binop(code: str, latex: str) -> None:
664668
def test_visit_unaryop(code: str, latex: str) -> None:
665669
tree = ast_utils.parse_expr(code)
666670
assert isinstance(tree, ast.UnaryOp)
667-
assert ExpressionCodegen().visit(tree) == latex
671+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
668672

669673

670674
@pytest.mark.parametrize(
@@ -718,7 +722,7 @@ def test_visit_unaryop(code: str, latex: str) -> None:
718722
def test_visit_compare(code: str, latex: str) -> None:
719723
tree = ast_utils.parse_expr(code)
720724
assert isinstance(tree, ast.Compare)
721-
assert ExpressionCodegen().visit(tree) == latex
725+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
722726

723727

724728
@pytest.mark.parametrize(
@@ -764,7 +768,7 @@ def test_visit_compare(code: str, latex: str) -> None:
764768
def test_visit_boolop(code: str, latex: str) -> None:
765769
tree = ast_utils.parse_expr(code)
766770
assert isinstance(tree, ast.BoolOp)
767-
assert ExpressionCodegen().visit(tree) == latex
771+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
768772

769773

770774
@test_utils.require_at_most(7)
@@ -789,7 +793,7 @@ def test_visit_boolop(code: str, latex: str) -> None:
789793
def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> None:
790794
tree = ast_utils.parse_expr(code)
791795
assert isinstance(tree, cls)
792-
assert ExpressionCodegen().visit(tree) == latex
796+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
793797

794798

795799
@test_utils.require_at_least(8)
@@ -814,7 +818,7 @@ def test_visit_constant_lagacy(code: str, cls: type[ast.expr], latex: str) -> No
814818
def test_visit_constant(code: str, latex: str) -> None:
815819
tree = ast_utils.parse_expr(code)
816820
assert isinstance(tree, ast.Constant)
817-
assert ExpressionCodegen().visit(tree) == latex
821+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
818822

819823

820824
@pytest.mark.parametrize(
@@ -830,7 +834,7 @@ def test_visit_constant(code: str, latex: str) -> None:
830834
def test_visit_subscript(code: str, latex: str) -> None:
831835
tree = ast_utils.parse_expr(code)
832836
assert isinstance(tree, ast.Subscript)
833-
assert ExpressionCodegen().visit(tree) == latex
837+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
834838

835839

836840
@pytest.mark.parametrize(
@@ -845,7 +849,9 @@ def test_visit_subscript(code: str, latex: str) -> None:
845849
def test_visit_binop_use_set_symbols(code: str, latex: str) -> None:
846850
tree = ast_utils.parse_expr(code)
847851
assert isinstance(tree, ast.BinOp)
848-
assert ExpressionCodegen(use_set_symbols=True).visit(tree) == latex
852+
assert (
853+
expression_codegen.ExpressionCodegen(use_set_symbols=True).visit(tree) == latex
854+
)
849855

850856

851857
@pytest.mark.parametrize(
@@ -860,7 +866,9 @@ def test_visit_binop_use_set_symbols(code: str, latex: str) -> None:
860866
def test_visit_compare_use_set_symbols(code: str, latex: str) -> None:
861867
tree = ast_utils.parse_expr(code)
862868
assert isinstance(tree, ast.Compare)
863-
assert ExpressionCodegen(use_set_symbols=True).visit(tree) == latex
869+
assert (
870+
expression_codegen.ExpressionCodegen(use_set_symbols=True).visit(tree) == latex
871+
)
864872

865873

866874
@pytest.mark.parametrize(
@@ -906,4 +914,59 @@ def test_visit_compare_use_set_symbols(code: str, latex: str) -> None:
906914
def test_numpy_array(code: str, latex: str) -> None:
907915
tree = ast_utils.parse_expr(code)
908916
assert isinstance(tree, ast.Call)
909-
assert ExpressionCodegen().visit(tree) == latex
917+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
918+
919+
920+
@pytest.mark.parametrize(
921+
"code,latex",
922+
[
923+
("zeros(0)", r"\mathbf{0}^{1 \times 0}"),
924+
("zeros(1)", r"\mathbf{0}^{1 \times 1}"),
925+
("zeros(2)", r"\mathbf{0}^{1 \times 2}"),
926+
("zeros(())", r"0"),
927+
("zeros((0,))", r"\mathbf{0}^{1 \times 0}"),
928+
("zeros((1,))", r"\mathbf{0}^{1 \times 1}"),
929+
("zeros((2,))", r"\mathbf{0}^{1 \times 2}"),
930+
("zeros((0, 0))", r"\mathbf{0}^{0 \times 0}"),
931+
("zeros((1, 1))", r"\mathbf{0}^{1 \times 1}"),
932+
("zeros((2, 3))", r"\mathbf{0}^{2 \times 3}"),
933+
("zeros((0, 0, 0))", r"\mathbf{0}^{0 \times 0 \times 0}"),
934+
("zeros((1, 1, 1))", r"\mathbf{0}^{1 \times 1 \times 1}"),
935+
("zeros((2, 3, 5))", r"\mathbf{0}^{2 \times 3 \times 5}"),
936+
# Unsupported
937+
("zeros()", r"\mathrm{zeros} \mathopen{}\left( \mathclose{}\right)"),
938+
("zeros(x)", r"\mathrm{zeros} \mathopen{}\left( x \mathclose{}\right)"),
939+
("zeros(0, x)", r"\mathrm{zeros} \mathopen{}\left( 0, x \mathclose{}\right)"),
940+
(
941+
"zeros((x,))",
942+
r"\mathrm{zeros} \mathopen{}\left("
943+
r" \mathopen{}\left( x \mathclose{}\right)"
944+
r" \mathclose{}\right)",
945+
),
946+
],
947+
)
948+
def test_zeros(code: str, latex: str) -> None:
949+
tree = ast_utils.parse_expr(code)
950+
assert isinstance(tree, ast.Call)
951+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
952+
953+
954+
@pytest.mark.parametrize(
955+
"code,latex",
956+
[
957+
("identity(0)", r"\mathbf{I}_{0}"),
958+
("identity(1)", r"\mathbf{I}_{1}"),
959+
("identity(2)", r"\mathbf{I}_{2}"),
960+
# Unsupported
961+
("identity()", r"\mathrm{identity} \mathopen{}\left( \mathclose{}\right)"),
962+
("identity(x)", r"\mathrm{identity} \mathopen{}\left( x \mathclose{}\right)"),
963+
(
964+
"identity(0, x)",
965+
r"\mathrm{identity} \mathopen{}\left( 0, x \mathclose{}\right)",
966+
),
967+
],
968+
)
969+
def test_identity(code: str, latex: str) -> None:
970+
tree = ast_utils.parse_expr(code)
971+
assert isinstance(tree, ast.Call)
972+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex

0 commit comments

Comments
 (0)