-
Notifications
You must be signed in to change notification settings - Fork 136
Change Dot Op to only accept matrix inputs #1538
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
2d5ed60
to
bb45939
Compare
|
||
@pytest.mark.parametrize("inplace", (True, False), ids=["inplace", "no_inplace"]) | ||
@pytest.mark.parametrize("n", [2**7, 2**9, 2**13]) | ||
def test_ger_benchmark(n, inplace, benchmark): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added this because at some point I considered taking away GER in favor of outer multiplication. The benchmark showed this was underperforming
115c865
to
199396e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors the Dot operation to only accept 2D matrix inputs, simplifying the codebase by removing support for vector inputs and delegating vector operations to helper functions. The changes eliminate the old scipy_ger module and streamline the dot product implementation.
- Restricts Dot Op to only accept 2D tensors (matrices), removing vector support
- Removes scipy_ger module and integrates scipy directly into Ger.perform method
- Updates dot interface function to handle vector promotion to matrices internally
Reviewed Changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 2 comments.
Show a summary per file
File | Description |
---|---|
pytensor/tensor/math.py | Major refactor of Dot class to only accept 2D inputs, updated dense_dot function |
pytensor/tensor/blas.py | Updated Ger.perform to use scipy directly, simplified Dot22.perform |
pytensor/tensor/blas_scipy.py | Removed entire ScipyGer implementation |
pytensor/tensor/rewriting/blas_scipy.py | Removed scipy-specific BLAS rewrites |
pytensor/tensor/rewriting/math.py | Updated optimization rules for new Dot constraints |
tests/tensor/test_math.py | Updated tests to reflect new Dot API and removed vector tests |
Comments suppressed due to low confidence (1)
tests/tensor/test_math.py:2011
- The test is checking that _dot(d1, d2) raises TypeError, but this line should be inside a pytest.raises context manager to properly test the exception.
_dot(d1, d2)
# Work on transposed system to avoid copying | ||
A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T | ||
else: | ||
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output should handle the case when A.size == 0. When A is empty, the code should still copy A if not destructive, but currently it just assigns A directly regardless of the destructive flag.
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) | |
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive) | |
else: | |
# Handle the case where A.size == 0 | |
if self.destructive: | |
# No-op for destructive mode | |
pass | |
else: | |
# Create a copy for non-destructive mode | |
A = A.copy() |
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no point in copying an empty array, you can't store anything in it
constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0])) | ||
return [constant_zero] | ||
): | ||
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes both inputs are 2D matrices, but the function doesn't validate the input dimensions. If either x or y has fewer than 2 dimensions, accessing x.shape[0] or y.shape[1] could cause an IndexError.
Copilot uses AI. Check for mistakes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's the whole point of this PR. If you see a Dot it must be with two matrices. make_node
validates it
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (95.62%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1538 +/- ##
==========================================
+ Coverage 81.49% 81.53% +0.03%
==========================================
Files 232 230 -2
Lines 53122 53003 -119
Branches 9444 9410 -34
==========================================
- Hits 43292 43215 -77
+ Misses 7382 7360 -22
+ Partials 2448 2428 -20
🚀 New features to boost your workflow:
|
|
||
x, y = node.inputs[0].owner.inputs | ||
if len(clients) != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So we wouldn't rewrite the following graph:
x = dot(a, b).T
y = x + 1
fn = pytensor.function([a, b], [x, y])
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We would. dot is only used by the transpose.
Instead, we wouldn't rewrite
d = dot(x, y)
out1 = d.T
out2 = d+1
function([x, y], [out1, out2])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Otherwise we'd be introducing a 2nd dot computation right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, and then we would need to go and try to avoid that doing something like #1537
So better not to create potential duplication, and if we care that much about lifting it, we need a global rewriter (if you're using clients you're a global rewriter) that reasons about all combinations of dot at once to decides the canonical form + how to make everyone else use that.
A bit like the Solve -> LU factor -> LU Solve, where we have to orchestrate how everyone will use the common LU factor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
approving with some questions
if core_a_ndim > 2 or core_b_ndim > 2: | ||
# Shouldn't happen, but here just in case | ||
# Check if we have matrix-matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n) | ||
if not (a_static_shape[-1] == 1 or b_static_shape[-2] == 1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not check the broadcastable flag here (instead of the static shape)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's the same, no reason
pytensor/tensor/math.py
Outdated
if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]: | ||
raise ValueError( | ||
f"Incompatible shared dimension for dot product: {sx}, {sy}" | ||
) | ||
sz = sx[:-1] + sy[-1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sz = sx[:-1] + sy[-1:] | |
sz = [*sx[:-1], sy[-1]] |
Slightly clearer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't need the slice, it's sz = (sx[0], sy[1])
now
) | ||
|
||
rval = xgrad, ygrad | ||
if xgrad.type.shape != x.type.shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems like the logic the comment is worried about should be implemented in the make_node
, so that these checks here aren't necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see how that's related to the grad here
|
||
def __str__(self): | ||
return "dot" | ||
return [[xshp[0], yshp[1]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The output shape in make_node
seemed to be worried about batch dims, but infer_shape
doesn't have to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
infer_shape
takes a well defined node
and input shapes. Otherwise every op.infer_shape would be skeptical that it got the right number of shapes, and we don't do that
copy_stack_trace(node.outputs, new_out) | ||
return new_out | ||
|
||
z, a, x, y, b = node.inputs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wish we consistently used capital letters for matrix inputs and lowercase for scalar/vector.
These are consistent with the code for GEMM so it shouldn't be changed. I'm just complaining.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ancient code, don't want to change
if bval == 1: # best case a natural GER | ||
rval = ger(z, a, xv, yv) | ||
new_out = [rval] | ||
elif bval == 0: # GER on zeros_like should be faster than GEMM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
did you test this? When b=0
I thought GEMM totally ignores the z
terms
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ancient code again.
I know gemv and ger (c) ignores. I have no experience with gemm directly, never showed up in my work, only dot22 and dot22scalar. I don't even know if we are introducing it these days.
one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype)) | ||
zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype)) | ||
if xb[1] and yb[0]: | ||
# x and y are both vectors so this might qualifies for a GER |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# x and y are both vectors so this might qualifies for a GER | |
# x and y are both vectors so this might qualify as a GER |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixing decade old typos?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm just reviewing what I am given man
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah github shows all these as green because of a tab change. Not the best visual diff indicator
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Obviously the lesson here is to never change tabs
zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype)) | ||
if xb[1] and yb[0]: | ||
# x and y are both vectors so this might qualifies for a GER | ||
xv = x.dimshuffle(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why dimshuffle
vs mT
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is ancient code, if you see the relevant commit you'll notice I just remove one level of indentation by removing the outer check that's enforced by tracks.
This builds on top of #1471 removing extra complexity on our representation of matmul. All dots are now on 2d inputs, and the mat-vec, vec-mat and vec-vec can be detected by introspecting the broadcastable pattern of the inputs. This is information that should never be lost, and not having to worry about variants where it doesn't matter makes our lives easier.
This PR also removes scipy_ger, and uses scipy in the perform method of Ger. This is an artifact from the old Theano times where scipy was an optional dependency.
With the changes the whole concept of Dot22 also looses its meaning. We can remove it next and just port the C-implementation to Dot directly
Closes #453
Closes #946
📚 Documentation preview 📚: https://pytensor--1538.org.readthedocs.build/en/1538/