-
Notifications
You must be signed in to change notification settings - Fork 393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Special array representation #148
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! It looks this change needs some refactoring to support the feature and error handling completely. Which would you like either:
- fix them by yourself (in this case I will put comprehensive comments)
- delegate remaining tasks to me (your contribution is still recorded correctly)
examples/examples.ipynb
Outdated
@@ -2,41 +2,20 @@ | |||
"cells": [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(action not required) I guess this notebook file is no longer necessary as we could provide a comprehensive examples in Google Colab. I will remove this file later.
@@ -356,6 +356,17 @@ def visit_Call(self, node: ast.Call) -> str: | |||
(default_func_str + r"\mathopen{}\left(", r"\mathclose{}\right)"), | |||
) | |||
|
|||
if func_name == "zeros": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- For
zeros
, We don't need to constrain only a particular number of dimensions. - I think these processes don't work if the subtree has unexpected syntax. It usually happens when users gave other functions with the same name. As the AST varies, we basically need complete check of the underlying structure of the given subtree.
@LakeBlair Let's continue developing this pull request as you preferred in another thread. Before providing additional comments, please resolve the following issues:
|
Just did both. Check if they look good to you. |
@@ -425,7 +425,26 @@ def visit_Call(self, node: ast.Call) -> str: | |||
|
|||
if special_latex is not None: | |||
return special_latex | |||
|
|||
|
|||
# Special treatment for np.zeros |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the current implementation, special treatments for several functions are implemented in separate functions (see L419-L427). They return str | None
, and visit_Call
attempts early returning if str
is returned by the function. The functionality of this pull request should be implemented similarly:
- The function returns
str
, the correct LaTeX, if the givenast.Call
has a supported syntax. - Otherwise the function returns
None
, thenvisit_Call
falls back to the default behavior.
str = "" | ||
open_bracket = "{" | ||
close_bracket = "}" | ||
for i, elt in enumerate(node.args[0].elts): | ||
str += "{" + self.visit(elt) + "}" | ||
if i != len(node.args[0].elts) - 1: | ||
str += " " + r"\times" + " " | ||
|
||
matrix_str = "0^" + open_bracket + str + close_bracket | ||
return matrix_str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
General suggestions:
- Requires complete syntax checking.
- Don't use
str
as a variable as it is reserved by the builtin type name. Overwriting builtins will confuse the behavior of the code (even if the current code doesn't involve any errors, it will happen in the future). Uselatex
instead if you need to store some generated strings. - Suppress string concatenation as it is significantly expensive. Making a sequence (generator), then join all of them is generally better.
- I think
\mathbf{0}
($\mathbf{0}$ ) should be used to distinguish that this is not a scalar$0$ .
str = "" | |
open_bracket = "{" | |
close_bracket = "}" | |
for i, elt in enumerate(node.args[0].elts): | |
str += "{" + self.visit(elt) + "}" | |
if i != len(node.args[0].elts) - 1: | |
str += " " + r"\times" + " " | |
matrix_str = "0^" + open_bracket + str + close_bracket | |
return matrix_str | |
if len(node.args) != 1: | |
# fall back to the default. | |
arg0 = node.args[0] | |
if not isinstance(arg0, ast.Tuple): | |
# fall back to the default. | |
# Tecunically we don't support `zeros(n)` where `n` is a scalar. | |
dims_latex = r" \times ".join(self.visit(x) for x in arg0.elts) | |
return fr"\mathbf{{0}}^{{{dims_latex}}}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@odashi I made some changes to my code.
- np.zeros and np.identity are now in a new function, treated as special numpy methods.
- I added some if statements checking for valid ast inputs.
- I added some parametric tests.
Some more considerations:
- Some string concatenation is simplified, some could be simplified further.
- In the case of np.zeros(0), I just represent it as
$\mathbf{0}^{{1} \times {0}}$ . There might be better ways of expressing it as a special case. Do you have any suggestions?
Let me know what you think.
str = "{" + self.visit(node.args[0]) + "}" | ||
matrix_str = f"I_{str}" | ||
return matrix_str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same suggestions here.
str = "{" + self.visit(node.args[0]) + "}" | |
matrix_str = f"I_{str}" | |
return matrix_str | |
if len(node.args) != 1: | |
# fall back to the default. | |
return fr"\mathbf{I}_{{{self.visit(node.args[0])}}}" |
tree = ast.parse( | ||
textwrap.dedent( | ||
""" | ||
def f(a, b): | ||
return np.zeros((a,b)) | ||
""" | ||
) | ||
).body[0] | ||
latex = "f(a, b) = 0^{{a} \\times {b}}" | ||
assert isinstance(tree, ast.FunctionDef) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need complete parsing. It is enough to obtain only ast.Call
.
tree = ast.parse( | |
textwrap.dedent( | |
""" | |
def f(a, b): | |
return np.zeros((a,b)) | |
""" | |
) | |
).body[0] | |
latex = "f(a, b) = 0^{{a} \\times {b}}" | |
assert isinstance(tree, ast.FunctionDef) | |
tree = ast_utils.parse_expr("zeros((a, b))") | |
assert isinstance(tree, ast.Call) |
@@ -862,6 +862,34 @@ def test_use_set_symbols_compare(code: str, latex: str) -> None: | |||
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex | |||
|
|||
|
|||
def test_generate_numpy_zeros(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please write comprehensive tests to cover every edge case, as in other parametric tests.
assert FunctionCodegen().visit(tree) == latex | ||
|
||
|
||
def test_generate_numpy_identity(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
acf41ab
to
19bfd68
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I still think this change requires several refactoring, I will do it by my side.
Overview
Added special array representation for np.zeros and np.identity.
Details
References
#83
Blocked by