-
Notifications
You must be signed in to change notification settings - Fork 393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for numpy matrices #118
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add sufficient unit tests in function_codegen_test.py
.
) | ||
|
||
arg = node.args[0] | ||
# TODO: Support string being passed to numpy.matrix. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: what is the "string" in this context?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numpy.matrix supports passing in strings like the example below and parsing that into a ndarray
a = np.matrix('1 2; 3 4')
latex = r"\begin{bmatrix}" | ||
if type(arg) == ast.List: | ||
for elt in arg.elts: | ||
if type(elt) == ast.List: | ||
for sub_elt in elt.elts: | ||
latex += self.visit(sub_elt) + r" & " | ||
latex = latex[:-2] + r" \\ " | ||
else: | ||
latex += self.visit(elt) + r" & " | ||
latex = latex[:-3] + r"\end{bmatrix}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the following steps could be more intuitive and much efficient:
- Checks if the arg is List, and all elts are List with the same length. Stops the process if the condition is not satisfied.
-
# This should work without any breakages iff the condition 1 is satisfied. contents = r" \\ ".join( " & ".join(self.visit(x) for x in row.elts) for row in arg.elts ) return r"\begin{bmatrix} " + contents + r" \end{bmatrix}"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking something like this to properly handle row, column and mxn matrices correctly. What do you think?
if isinstance(arg, ast.List):
# handle row vectors
if all(isinstance(x, ast.Constant) for x in arg.elts):
contents = " & ".join([self.visit(x) for x in arg.elts])
return r"\begin{bmatrix} " + contents + r" \end{bmatrix}"
# handle column vector
if all(isinstance(x, ast.List) for x in arg.elts) and all(len(x.elts) == 1 for x in arg.elts):
contents = r" \\ ".join([self.visit(x.elts[0]) for x in arg.elts])
return r"\begin{bmatrix} " + contents + r" \end{bmatrix}"
# handle matrices
elif all(isinstance(elt, ast.List) and len(elt.elts) == len(arg.elts) for elt in arg.elts):
contents = r" \\ ".join(
" & ".join(self.visit(x) for x in row.elts)
for row in arg.elts
)
return r"\begin{bmatrix} " + contents + r" \end{bmatrix}"
return None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks still wrong I think:
- 1st condition looks too strict because it introduced another constraint that the member is only
Constant
. - 2nd condition is not necessary if we handle
$M \times N$ matrix. - 3rd condition looks to handle only square matrices.
I tried to complete it, and noticed that it requries more detailed checks. How about this:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(btw, I think the above examples can be used in unit tests as well.)
@kshxtij Hi, I'd appreciate it if you could proceed this pull request! |
I'll get this finished over the weekend, sorry it's been so delayed. Just very busy with deadlines and exams! |
if not row0.elts: | ||
return None | ||
|
||
nCols = len(row0.elts) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nCols is not allowed under the naming convention.
https://google.github.io/styleguide/pyguide.html#3164-guidelines-derived-from-guidos-recommendations
nCols = len(row0.elts) | |
ncols = len(row0.elts) |
arg = node.args[0] | ||
if not isinstance(arg, ast.List) or not arg.elts: | ||
return None | ||
|
||
row0 = arg.elts[0] | ||
|
||
if not isinstance(row0, ast.List): | ||
return self._generate_ndarray([self.visit(x) for x in arg.elts]) | ||
|
||
if not row0.elts: | ||
return None | ||
|
||
nCols = len(row0.elts) | ||
|
||
if not all( | ||
isinstance(row, ast.List) and len(row.elts) == nCols for row in arg.elts | ||
): | ||
return None | ||
|
||
return self._generate_ndarray( | ||
[[self.visit(x) for x in row.elts] for row in arg.elts] | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block must be placed on _generate_ndarray, and if it returns None we need to fall back to the default behavior.
@kshxtij No problem! I just commented more. Could you fix CI errors as well? |
@kshxtij Also I'd appreciate it if you add every test case I commented above. |
Reminder: we will close this pull request tomorrow because of more than 1 week inactivity. If you are still working on this, please update the pull request. |
Closes this pull request, I'll take over the remaining work. |
Overview
Adds support for np.ndarray and np.array
Details
References
Blocked by
NA