-
-
Notifications
You must be signed in to change notification settings - Fork 151
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Blockwise
Op
#1215
base: main
Are you sure you want to change the base?
Add Blockwise
Op
#1215
Conversation
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #1215 +/- ##
==========================================
+ Coverage 75.02% 79.16% +4.14%
==========================================
Files 194 174 -20
Lines 50099 48677 -1422
Branches 12096 10359 -1737
==========================================
+ Hits 37586 38536 +950
+ Misses 10189 7640 -2549
- Partials 2324 2501 +177
|
Don't forget to rebase onto |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great! The next step involves extending the number of gufunc_sig
s we specify and adding the associated tests.
The big, open question is whether or not we can replace Elemwise
with this new Op
. When we demonstrate that this Op
can at least handle all the standard Elemwise
cases, then we'll start exploring this question further, though. In other words, we don't want to start considering all the other changes (e.g. Blockwise.c_code
, Numba/JAX transpilations, etc.) until we've demonstrated good test coverage (both Elemwise
/scalar broadcasting cases and otherwise).
x = Blockwise(op)(*args) | ||
x_fn = aesara.function(args, x) | ||
|
||
x_fn(*arg_vals) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We're going to need to assert
something about this output.
gufunc_sig = ((("m", "n"), ("n", "p")), (("m", "p"),)) | ||
|
||
__props__ = ("gufunc_sig",) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: We'll need to create these kinds of signatures for every applicable Op
.
a57528c
to
f770ada
Compare
f770ada
to
6cda5c3
Compare
What should be the signature for Subtensor Op and Shape Op ? |
If you're talking about constructing symbolic graphs, the signatures are ultimately determined by their |
Yes, got it now |
Hello, @brandonwillard. You can reproduce the error using the following command. |
It looks like I'm guessing Regardless, we shouldn't need new |
0792e8a
to
fdb3045
Compare
877d04d
to
c9ad602
Compare
3ed3497
to
c7b0d10
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added comments for some of the changes we made locally during the meeting.
), | ||
(("n", "m"),), | ||
) | ||
__props__ = ("dtype", "gufunc_sig") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__props__ = ("dtype", "gufunc_sig") | |
__props__ = ("dtype",) |
@@ -3502,7 +3517,8 @@ class AllocDiag(Op): | |||
It does the inverse of `ExtractDiag`. | |||
""" | |||
|
|||
__props__ = ("offset", "axis1", "axis2") | |||
gufunc_sig = (((),), (("m", "m"),)) | |||
__props__ = ("offset", "axis1", "axis2", "gufunc_sig") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__props__ = ("offset", "axis1", "axis2", "gufunc_sig") | |
__props__ = ("offset", "axis1", "axis2",) |
return Apply(self, list(inputs), outputs) | ||
|
||
def __str__(self): | ||
return f"{type(self).__name__}{{op={self.op}}}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return f"{type(self).__name__}{{op={self.op}}}" | |
return f"{type(self).__name__}{{{self.op}, {self.signature}}}" |
# The gradient contains a constant | ||
# res = aesara.tensor.basic.constant( | ||
# np.asarray(var.data), dtype=var.type.dtype | ||
# ) | ||
res = var | ||
|
||
# TODO FIXME: Use dimensions of relevant/appropriate inputs. | ||
# What exactly are those in this case? | ||
nd = inputs[0].type.ndim | ||
|
||
return atleast_Nd(res, n=nd) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# The gradient contains a constant | |
# res = aesara.tensor.basic.constant( | |
# np.asarray(var.data), dtype=var.type.dtype | |
# ) | |
res = var | |
# TODO FIXME: Use dimensions of relevant/appropriate inputs. | |
# What exactly are those in this case? | |
nd = inputs[0].type.ndim | |
return atleast_Nd(res, n=nd) | |
return var |
|
||
__props__ = ("lower", "destructive", "on_error") | ||
gufunc_sig = ((("m", "m"),), (("m", "m"),)) | ||
__props__ = ("lower", "destructive", "on_error", "gufunc_sig") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
__props__ = ("lower", "destructive", "on_error", "gufunc_sig") | |
__props__ = ("lower", "destructive", "on_error",) |
from aesara.tensor.basic import Tri | ||
|
||
blk_op = Blockwise(op=Tri(dtype="float64"), signature=(((), (), ()), (("n", "m"),))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from aesara.tensor.basic import Tri | |
blk_op = Blockwise(op=Tri(dtype="float64"), signature=(((), (), ()), (("n", "m"),))) | |
blk_op = Blockwise(op=Tri(dtype="float64")) |
blk_op = Blockwise(op=Tri(dtype="float64"), signature=(((), (), ()), (("n", "m"),))) | ||
out_dtype, output_shapes, inputs = blk_op.get_output_info(a, b, c) | ||
|
||
assert out_dtype == ["float64"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We need to assert
something about output_shapes
(i.e. make sure they're correct in some way).
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
Inspired by: aesara-devs/aesara#1215 Co-authored-by: Brandon T. Willard <brandonwillard@users.noreply.github.com> Co-authored-by: Purna Chandra Mansingh <purnachandramansingh135@gmail.com> Co-authored-by: Sayam Kumar <sayamkumar049@gmail.com3> Co-authored-by: Kaustubh <ckaustubhm06@gmail.com>
This PR builds off of #757 and closes #695.
To #757 it adds:
get_output_info()
, which is the same asElemwise
get_output_info(), to make all inputs of the same dimension.grad
is computedDifferences with #757:
curr_static_shape
ofcore_inp_grads
use the dimensions from the end.perform()
ofDimShuffle
(which can be removed later)