Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Blockwise Op #1215

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
b6c0e0e
Extract check method from InferShapeTester
brandonwillard Jan 16, 2022
5037958
Add a Blockwise Op
brandonwillard Jan 17, 2022
f8af028
simply printing of Blockwise Op
purna135 Nov 9, 2022
2f0576d
refactor grad_signature
purna135 Nov 9, 2022
2c82292
use Blockwise instead of Elemwise
purna135 Nov 9, 2022
0bb4feb
Fix typing issues in Blockwise._bgrad
brandonwillard Nov 11, 2022
f67c769
Convert inputs to ndarrays in Blockwise.perform
brandonwillard Nov 11, 2022
90f4791
Add a Blockwise Op
brandonwillard Jan 17, 2022
d24f619
fixed py_func result
purna135 Jan 17, 2023
daf2ea2
fixed parameters to atleast_Nd()
purna135 Jan 17, 2023
33ead95
added gufunc_sig to Tri and AllocDiag
purna135 Jan 17, 2023
fdf5af8
added gufunc_sig to Cholesky
purna135 Jan 17, 2023
0072f2d
added test for Blockwise Cholesky
purna135 Jan 17, 2023
ab66791
manage the Ops which support nd inputs
purna135 Jan 17, 2023
b77d50b
Use dispatch for gufunc signature and (partially) implement Subtensor…
brandonwillard Jan 21, 2023
76dc6b4
Add a infer_shape_to_gufunc_sig function
brandonwillard Jan 23, 2023
e1c8892
Fix flat_out_shapes
purna135 Jan 28, 2023
be1b794
Fix AllocDiag and Tri gufunc signatures
brandonwillard Feb 2, 2023
f0eb8f5
Fixed output dtype
purna135 Feb 6, 2023
de3c1ea
Fix a core inputs computation bug and do some refactoring
brandonwillard Feb 19, 2023
45f5eeb
add more tests to test_infer_shape_to_gufunc_sig
purna135 Mar 10, 2023
7f1b99d
fix infer_shape_to_gufunc_sig
purna135 Mar 14, 2023
c7b0d10
add test for Blockwise SolveTriangular
purna135 Mar 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions aesara/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,20 @@ def _get_vector_length_Constant(op: Union[Op, Variable], var: Constant) -> int:
return len(var.data)


def get_gufunc_signature(op, blocked_inputs):
sig = getattr(op, "gufunc_sig", None)

if sig is None:
return _get_gufunc_signature(op, blocked_inputs)

return sig


@singledispatch
def _get_gufunc_signature(op, blocked_inputs):
raise ValueError(f"'{op}' object has no attribute 'gufunc_sig'")


import aesara.tensor.exceptions # noqa
from aesara.gradient import consider_constant, grad, hessian, jacobian # noqa

Expand Down
20 changes: 18 additions & 2 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from aesara.scalar.basic import ScalarConstant, ScalarVariable
from aesara.tensor import (
_as_tensor_variable,
_get_gufunc_signature,
_get_vector_length,
as_tensor_variable,
get_vector_length,
Expand Down Expand Up @@ -986,7 +987,15 @@ def nonzero_values(a):

class Tri(Op):

__props__ = ("dtype",)
gufunc_sig = (
(
(),
(),
(),
),
(("n", "m"),),
)
__props__ = ("dtype", "gufunc_sig")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
__props__ = ("dtype", "gufunc_sig")
__props__ = ("dtype",)


def __init__(self, dtype=None):
if dtype is None:
Expand Down Expand Up @@ -3470,6 +3479,12 @@ def __setstate__(self, state):
self.axis2 = 1


@_get_gufunc_signature.register(ExtractDiag)
def _get_gufunc_signature_ExtractDiag(op, blocked_inputs):
# TODO:
raise NotImplementedError()


extract_diag = ExtractDiag()
# TODO: optimization to insert ExtractDiag with view=True

Expand Down Expand Up @@ -3502,7 +3517,8 @@ class AllocDiag(Op):
It does the inverse of `ExtractDiag`.
"""

__props__ = ("offset", "axis1", "axis2")
gufunc_sig = (((),), (("m", "m"),))
__props__ = ("offset", "axis1", "axis2", "gufunc_sig")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
__props__ = ("offset", "axis1", "axis2", "gufunc_sig")
__props__ = ("offset", "axis1", "axis2",)


def __init__(self, offset=0, axis1=0, axis2=1):
"""
Expand Down
Loading