Skip to content

Commit 37d47c9

Browse files
authored
Numpy transpose (#181)
1 parent b3ae7fa commit 37d47c9

File tree

3 files changed

+56
-9
lines changed

3 files changed

+56
-9
lines changed

src/latexify/codegen/expression_codegen.py

+23
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,27 @@ def _generate_identity(self, node: ast.Call) -> str | None:
218218

219219
return rf"\mathbf{{I}}_{{{ndims}}}"
220220

221+
def _generate_transpose(self, node: ast.Call) -> str | None:
222+
"""Generates LaTeX for numpy.transpose.
223+
Args:
224+
node: ast.Call node containing the appropriate method invocation.
225+
Returns:
226+
Generated LaTeX, or None if the node has unsupported syntax.
227+
Raises:
228+
LatexifyError: Unsupported argument type given.
229+
"""
230+
name = ast_utils.extract_function_name_or_none(node)
231+
assert name == "transpose"
232+
233+
if len(node.args) != 1:
234+
return None
235+
236+
func_arg = node.args[0]
237+
if isinstance(func_arg, ast.Name):
238+
return rf"\mathbf{{{func_arg.id}}}^\intercal"
239+
else:
240+
return None
241+
221242
def visit_Call(self, node: ast.Call) -> str:
222243
"""Visit a Call node."""
223244
func_name = ast_utils.extract_function_name_or_none(node)
@@ -232,6 +253,8 @@ def visit_Call(self, node: ast.Call) -> str:
232253
special_latex = self._generate_zeros(node)
233254
elif func_name == "identity":
234255
special_latex = self._generate_identity(node)
256+
elif func_name == "transpose":
257+
special_latex = self._generate_transpose(node)
235258
else:
236259
special_latex = None
237260

src/latexify/codegen/expression_codegen_test.py

+21
Original file line numberDiff line numberDiff line change
@@ -970,3 +970,24 @@ def test_identity(code: str, latex: str) -> None:
970970
tree = ast_utils.parse_expr(code)
971971
assert isinstance(tree, ast.Call)
972972
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
973+
974+
975+
@pytest.mark.parametrize(
976+
"code,latex",
977+
[
978+
("transpose(A)", r"\mathbf{A}^\intercal"),
979+
("transpose(b)", r"\mathbf{b}^\intercal"),
980+
# Unsupported
981+
("transpose()", r"\mathrm{transpose} \mathopen{}\left( \mathclose{}\right)"),
982+
("transpose(2)", r"\mathrm{transpose} \mathopen{}\left( 2 \mathclose{}\right)"),
983+
(
984+
"transpose(a, (1, 0))",
985+
r"\mathrm{transpose} \mathopen{}\left( a, "
986+
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
987+
),
988+
],
989+
)
990+
def test_transpose(code: str, latex: str) -> None:
991+
tree = ast_utils.parse_expr(code)
992+
assert isinstance(tree, ast.Call)
993+
assert expression_codegen.ExpressionCodegen().visit(tree) == latex

src/latexify/frontend.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ def algorithmic(
2424

2525
def algorithmic(
2626
fn: Callable[..., Any] | None = None, **kwargs: Any
27-
) -> ipython_wrappers.LatexifiedAlgorithm | Callable[
28-
[Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm
29-
]:
27+
) -> (
28+
ipython_wrappers.LatexifiedAlgorithm
29+
| Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm]
30+
):
3031
"""Attach LaTeX pretty-printing to the given function.
3132
3233
This function works with or without specifying the target function as the
@@ -67,9 +68,10 @@ def function(
6768

6869
def function(
6970
fn: Callable[..., Any] | None = None, **kwargs: Any
70-
) -> ipython_wrappers.LatexifiedFunction | Callable[
71-
[Callable[..., Any]], ipython_wrappers.LatexifiedFunction
72-
]:
71+
) -> (
72+
ipython_wrappers.LatexifiedFunction
73+
| Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]
74+
):
7375
"""Attach LaTeX pretty-printing to the given function.
7476
7577
This function works with or without specifying the target function as the positional
@@ -110,9 +112,10 @@ def expression(
110112

111113
def expression(
112114
fn: Callable[..., Any] | None = None, **kwargs: Any
113-
) -> ipython_wrappers.LatexifiedFunction | Callable[
114-
[Callable[..., Any]], ipython_wrappers.LatexifiedFunction
115-
]:
115+
) -> (
116+
ipython_wrappers.LatexifiedFunction
117+
| Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]
118+
):
116119
"""Attach LaTeX pretty-printing to the given function.
117120
118121
This function is a shortcut for `latexify.function` with the default parameter

0 commit comments

Comments
 (0)