-
-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Blockwise
Op
#1215
Open
purna135
wants to merge
23
commits into
aesara-devs:main
Choose a base branch
from
purna135:add_blockwise
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Add Blockwise
Op
#1215
Changes from all commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
b6c0e0e
Extract check method from InferShapeTester
brandonwillard 5037958
Add a Blockwise Op
brandonwillard f8af028
simply printing of Blockwise Op
purna135 2f0576d
refactor grad_signature
purna135 2c82292
use Blockwise instead of Elemwise
purna135 0bb4feb
Fix typing issues in Blockwise._bgrad
brandonwillard f67c769
Convert inputs to ndarrays in Blockwise.perform
brandonwillard 90f4791
Add a Blockwise Op
brandonwillard d24f619
fixed py_func result
purna135 daf2ea2
fixed parameters to atleast_Nd()
purna135 33ead95
added gufunc_sig to Tri and AllocDiag
purna135 fdf5af8
added gufunc_sig to Cholesky
purna135 0072f2d
added test for Blockwise Cholesky
purna135 ab66791
manage the Ops which support nd inputs
purna135 b77d50b
Use dispatch for gufunc signature and (partially) implement Subtensor…
brandonwillard 76dc6b4
Add a infer_shape_to_gufunc_sig function
brandonwillard e1c8892
Fix flat_out_shapes
purna135 be1b794
Fix AllocDiag and Tri gufunc signatures
brandonwillard f0eb8f5
Fixed output dtype
purna135 de3c1ea
Fix a core inputs computation bug and do some refactoring
brandonwillard 45f5eeb
add more tests to test_infer_shape_to_gufunc_sig
purna135 7f1b99d
fix infer_shape_to_gufunc_sig
purna135 c7b0d10
add test for Blockwise SolveTriangular
purna135 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -38,6 +38,7 @@ | |||||
from aesara.scalar.basic import ScalarConstant, ScalarVariable | ||||||
from aesara.tensor import ( | ||||||
_as_tensor_variable, | ||||||
_get_gufunc_signature, | ||||||
_get_vector_length, | ||||||
as_tensor_variable, | ||||||
get_vector_length, | ||||||
|
@@ -986,7 +987,15 @@ def nonzero_values(a): | |||||
|
||||||
class Tri(Op): | ||||||
|
||||||
__props__ = ("dtype",) | ||||||
gufunc_sig = ( | ||||||
( | ||||||
(), | ||||||
(), | ||||||
(), | ||||||
), | ||||||
(("n", "m"),), | ||||||
) | ||||||
__props__ = ("dtype", "gufunc_sig") | ||||||
|
||||||
def __init__(self, dtype=None): | ||||||
if dtype is None: | ||||||
|
@@ -3470,6 +3479,12 @@ def __setstate__(self, state): | |||||
self.axis2 = 1 | ||||||
|
||||||
|
||||||
@_get_gufunc_signature.register(ExtractDiag) | ||||||
def _get_gufunc_signature_ExtractDiag(op, blocked_inputs): | ||||||
# TODO: | ||||||
raise NotImplementedError() | ||||||
|
||||||
|
||||||
extract_diag = ExtractDiag() | ||||||
# TODO: optimization to insert ExtractDiag with view=True | ||||||
|
||||||
|
@@ -3502,7 +3517,8 @@ class AllocDiag(Op): | |||||
It does the inverse of `ExtractDiag`. | ||||||
""" | ||||||
|
||||||
__props__ = ("offset", "axis1", "axis2") | ||||||
gufunc_sig = (((),), (("m", "m"),)) | ||||||
__props__ = ("offset", "axis1", "axis2", "gufunc_sig") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
def __init__(self, offset=0, axis1=0, axis2=1): | ||||||
""" | ||||||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.