Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Special array representation #148

Merged
merged 3 commits into from
Dec 9, 2022
Merged

Special array representation #148

merged 3 commits into from
Dec 9, 2022

Conversation

LakeBlair
Copy link
Contributor

Overview

Added special array representation for np.zeros and np.identity.

Details

image

References

#83

Blocked by

@LakeBlair LakeBlair requested a review from odashi as a code owner December 4, 2022 22:09
Copy link
Collaborator

@odashi odashi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! It looks this change needs some refactoring to support the feature and error handling completely. Which would you like either:

  • fix them by yourself (in this case I will put comprehensive comments)
  • delegate remaining tasks to me (your contribution is still recorded correctly)

@@ -2,41 +2,20 @@
"cells": [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(action not required) I guess this notebook file is no longer necessary as we could provide a comprehensive examples in Google Colab. I will remove this file later.

@@ -356,6 +356,17 @@ def visit_Call(self, node: ast.Call) -> str:
(default_func_str + r"\mathopen{}\left(", r"\mathclose{}\right)"),
)

if func_name == "zeros":
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • For zeros, We don't need to constrain only a particular number of dimensions.
  • I think these processes don't work if the subtree has unexpected syntax. It usually happens when users gave other functions with the same name. As the AST varies, we basically need complete check of the underlying structure of the given subtree.

@odashi
Copy link
Collaborator

odashi commented Dec 5, 2022

@LakeBlair Let's continue developing this pull request as you preferred in another thread.

Before providing additional comments, please resolve the following issues:

  • Merge main into this branch and resolve all conflicts.
  • run ./check.sh and resolve all errors.

@LakeBlair
Copy link
Contributor Author

@LakeBlair Let's continue developing this pull request as you preferred in another thread.

Before providing additional comments, please resolve the following issues:

  • Merge main into this branch and resolve all conflicts.
  • run ./check.sh and resolve all errors.

Just did both. Check if they look good to you.

@@ -425,7 +425,26 @@ def visit_Call(self, node: ast.Call) -> str:

if special_latex is not None:
return special_latex


# Special treatment for np.zeros
Copy link
Collaborator

@odashi odashi Dec 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current implementation, special treatments for several functions are implemented in separate functions (see L419-L427). They return str | None, and visit_Call attempts early returning if str is returned by the function. The functionality of this pull request should be implemented similarly:

  • The function returns str, the correct LaTeX, if the given ast.Call has a supported syntax.
  • Otherwise the function returns None, then visit_Call falls back to the default behavior.

Comment on lines 431 to 440
str = ""
open_bracket = "{"
close_bracket = "}"
for i, elt in enumerate(node.args[0].elts):
str += "{" + self.visit(elt) + "}"
if i != len(node.args[0].elts) - 1:
str += " " + r"\times" + " "

matrix_str = "0^" + open_bracket + str + close_bracket
return matrix_str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

General suggestions:

  • Requires complete syntax checking.
  • Don't use str as a variable as it is reserved by the builtin type name. Overwriting builtins will confuse the behavior of the code (even if the current code doesn't involve any errors, it will happen in the future). Use latex instead if you need to store some generated strings.
  • Suppress string concatenation as it is significantly expensive. Making a sequence (generator), then join all of them is generally better.
  • I think \mathbf{0} ( $\mathbf{0}$ ) should be used to distinguish that this is not a scalar $0$.
Suggested change
str = ""
open_bracket = "{"
close_bracket = "}"
for i, elt in enumerate(node.args[0].elts):
str += "{" + self.visit(elt) + "}"
if i != len(node.args[0].elts) - 1:
str += " " + r"\times" + " "
matrix_str = "0^" + open_bracket + str + close_bracket
return matrix_str
if len(node.args) != 1:
# fall back to the default.
arg0 = node.args[0]
if not isinstance(arg0, ast.Tuple):
# fall back to the default.
# Tecunically we don't support `zeros(n)` where `n` is a scalar.
dims_latex = r" \times ".join(self.visit(x) for x in arg0.elts)
return fr"\mathbf{{0}}^{{{dims_latex}}}"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@odashi I made some changes to my code.

  • np.zeros and np.identity are now in a new function, treated as special numpy methods.
  • I added some if statements checking for valid ast inputs.
  • I added some parametric tests.

Some more considerations:

  • Some string concatenation is simplified, some could be simplified further.
  • In the case of np.zeros(0), I just represent it as $\mathbf{0}^{{1} \times {0}}$. There might be better ways of expressing it as a special case. Do you have any suggestions?

Let me know what you think.

Comment on lines 444 to 446
str = "{" + self.visit(node.args[0]) + "}"
matrix_str = f"I_{str}"
return matrix_str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same suggestions here.

Suggested change
str = "{" + self.visit(node.args[0]) + "}"
matrix_str = f"I_{str}"
return matrix_str
if len(node.args) != 1:
# fall back to the default.
return fr"\mathbf{I}_{{{self.visit(node.args[0])}}}"

Comment on lines 866 to 875
tree = ast.parse(
textwrap.dedent(
"""
def f(a, b):
return np.zeros((a,b))
"""
)
).body[0]
latex = "f(a, b) = 0^{{a} \\times {b}}"
assert isinstance(tree, ast.FunctionDef)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need complete parsing. It is enough to obtain only ast.Call.

Suggested change
tree = ast.parse(
textwrap.dedent(
"""
def f(a, b):
return np.zeros((a,b))
"""
)
).body[0]
latex = "f(a, b) = 0^{{a} \\times {b}}"
assert isinstance(tree, ast.FunctionDef)
tree = ast_utils.parse_expr("zeros((a, b))")
assert isinstance(tree, ast.Call)

@@ -862,6 +862,34 @@ def test_use_set_symbols_compare(code: str, latex: str) -> None:
assert function_codegen.FunctionCodegen(use_set_symbols=True).visit(tree) == latex


def test_generate_numpy_zeros():
Copy link
Collaborator

@odashi odashi Dec 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please write comprehensive tests to cover every edge case, as in other parametric tests.

assert FunctionCodegen().visit(tree) == latex


def test_generate_numpy_identity():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@odashi odashi added the feature label Dec 7, 2022
@odashi odashi added this to the v0.3 milestone Dec 7, 2022
@LakeBlair LakeBlair force-pushed the special_array branch 2 times, most recently from acf41ab to 19bfd68 Compare December 8, 2022 03:36
Copy link
Collaborator

@odashi odashi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I still think this change requires several refactoring, I will do it by my side.

@odashi odashi merged commit e51a619 into google:main Dec 9, 2022
ZibingZhang pushed a commit to ZibingZhang/latexify_py that referenced this pull request Dec 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants