Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for numpy NDarray method. Added corresponding tests #135

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 50 additions & 3 deletions examples/examples.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import math\n",
"import latexify"
"import latexify\n",
"import numpy as np"
]
},
{
Expand Down Expand Up @@ -165,11 +166,57 @@
"\n",
"solve"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\displaystyle \\mathrm{numpy}(a) = \\begin{bmatrix} {1} & {2} & {3} & {4} \\\\ {5} & {6} & {7} & {8} \\\\ {9} & {10} & {11} & {12} \\end{bmatrix} $$"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The function name "numpy" is weird as it may overrides the package name (in some cases). Something else (e.g., myarray) is preferred.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@odashi I renamed function into myNdarray.

],
"text/plain": [
"<latexify.frontend.LatexifiedFunction at 0x7f76c00f9990>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@latexify.function\n",
"def numpy(a):\n",
" return np.ndarray([1,2,3,4],[5,6,7,8],[9,10,11,12])\n",
"\n",
"numpy"
]
}
],
"metadata": {
"kernelspec": {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove unnecessary metadata.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed kernelspec from metadata.

"display_name": "Python 3.10.8 64-bit",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python"
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"vscode": {
"interpreter": {
"hash": "8a94588eda9d64d9e9a351ab8144e55b1fabf5113b54e67dd26a8c27df0381b3"
}
}
},
"nbformat": 4,
Expand Down
10 changes: 10 additions & 0 deletions src/integration_tests/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import math
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This import requires additional dev dependency in pyproject.toml.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Odashi, I added numpy version in pyproject.toml, let me know if the version is suitable for the project.

from collections.abc import Callable
from typing import Any

Expand Down Expand Up @@ -330,3 +331,12 @@ def solve(x):

latex = r"\mathrm{solve}(x) = x"
_check_function(solve, latex)


def test_generate_ndarray() -> None:
def solve(a):
"""A 2x3 numpy matrix"""
return np.ndarray([1, 2, 3], [4, 5, 6])

latex = "\\mathrm{solve}(a) = \\begin{bmatrix} {1} & {2} & {3} \\\\ {4} & {5} & {6} \\end{bmatrix}"
_check_function(solve, latex)
12 changes: 12 additions & 0 deletions src/latexify/codegen/function_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,18 @@ def visit_Call(self, node: ast.Call) -> str:
+ rf" \mathopen{{}}\left({{{elt}}}\mathclose{{}}\right)"
)

# Render NDarray method for numpy
if func_str == "ndarray":
# construct matrix
matrix_str = r"\begin{bmatrix} "
# iterate over rows
for row in node.args:
for col in row.elts:
matrix_str += self.visit(col) + r" & "
matrix_str = matrix_str[:-2] + r" \\ "
matrix_str = matrix_str[:-3] + r"\end{bmatrix}"
return matrix_str

arg_strs = [self.visit(arg) for arg in node.args]
return lstr + ", ".join(arg_strs) + rstr

Expand Down
17 changes: 16 additions & 1 deletion src/latexify/codegen/function_codegen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import ast
import textwrap

import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused.

Suggested change
import numpy as np

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed numpy import from function_codegen_test.py

import pytest

from latexify import exceptions, test_utils
Expand Down Expand Up @@ -744,3 +744,18 @@ def test_use_set_symbols_compare(code: str, latex: str) -> None:
tree = ast.parse(code).body[0].value
assert isinstance(tree, ast.Compare)
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex


def test_generate_ndarrays():
tree = ast.parse(
textwrap.dedent(
"""
def numpy(a):
return np.ndarray([1,2,3],[4,5,6])
"""
)
).body[0]

latex = "\\mathrm{numpy}(a) = \\begin{bmatrix} {1} & {2} & {3} \\\\ {4} & {5} & {6} \\end{bmatrix}"
assert isinstance(tree, ast.FunctionDef)
assert FunctionCodegen().visit(tree) == latex