Skip to content

Change Dot Op to only accept matrix inputs #1538

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 25, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jul 13, 2025

This builds on top of #1471 removing extra complexity on our representation of matmul. All dots are now on 2d inputs, and the mat-vec, vec-mat and vec-vec can be detected by introspecting the broadcastable pattern of the inputs. This is information that should never be lost, and not having to worry about variants where it doesn't matter makes our lives easier.

This PR also removes scipy_ger, and uses scipy in the perform method of Ger. This is an artifact from the old Theano times where scipy was an optional dependency.

With the changes the whole concept of Dot22 also looses its meaning. We can remove it next and just port the C-implementation to Dot directly

Closes #453
Closes #946


📚 Documentation preview 📚: https://pytensor--1538.org.readthedocs.build/en/1538/

@ricardoV94 ricardoV94 force-pushed the dot_is_2d branch 2 times, most recently from 2d5ed60 to bb45939 Compare July 22, 2025 10:42
@ricardoV94 ricardoV94 changed the title Canonicalize Dot as a matrix-matrix operation Canonicalize dot as a matrix-matrix operation Jul 22, 2025

@pytest.mark.parametrize("inplace", (True, False), ids=["inplace", "no_inplace"])
@pytest.mark.parametrize("n", [2**7, 2**9, 2**13])
def test_ger_benchmark(n, inplace, benchmark):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this because at some point I considered taking away GER in favor of outer multiplication. The benchmark showed this was underperforming

@ricardoV94 ricardoV94 force-pushed the dot_is_2d branch 3 times, most recently from 115c865 to 199396e Compare July 24, 2025 10:16
@ricardoV94 ricardoV94 changed the title Canonicalize dot as a matrix-matrix operation Change Dot Op to only accept matrix inputs Jul 24, 2025
@ricardoV94 ricardoV94 marked this pull request as ready for review July 24, 2025 11:54
Copy link

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors the Dot operation to only accept 2D matrix inputs, simplifying the codebase by removing support for vector inputs and delegating vector operations to helper functions. The changes eliminate the old scipy_ger module and streamline the dot product implementation.

  • Restricts Dot Op to only accept 2D tensors (matrices), removing vector support
  • Removes scipy_ger module and integrates scipy directly into Ger.perform method
  • Updates dot interface function to handle vector promotion to matrices internally

Reviewed Changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
pytensor/tensor/math.py Major refactor of Dot class to only accept 2D inputs, updated dense_dot function
pytensor/tensor/blas.py Updated Ger.perform to use scipy directly, simplified Dot22.perform
pytensor/tensor/blas_scipy.py Removed entire ScipyGer implementation
pytensor/tensor/rewriting/blas_scipy.py Removed scipy-specific BLAS rewrites
pytensor/tensor/rewriting/math.py Updated optimization rules for new Dot constraints
tests/tensor/test_math.py Updated tests to reflect new Dot API and removed vector tests
Comments suppressed due to low confidence (1)

tests/tensor/test_math.py:2011

  • The test is checking that _dot(d1, d2) raises TypeError, but this line should be inside a pytest.raises context manager to properly test the exception.
            _dot(d1, d2)

# Work on transposed system to avoid copying
A = ger_func(alpha, y, x, a=A.T, overwrite_a=self.destructive).T
else:
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output should handle the case when A.size == 0. When A is empty, the code should still copy A if not destructive, but currently it just assigns A directly regardless of the destructive flag.

Suggested change
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
A = ger_func(alpha, x, y, a=A, overwrite_a=self.destructive)
else:
# Handle the case where A.size == 0
if self.destructive:
# No-op for destructive mode
pass
else:
# Create a copy for non-destructive mode
A = A.copy()

Copilot uses AI. Check for mistakes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no point in copying an empty array, you can't store anything in it

constant_zero = assert_op(constant_zero, eq(x.shape[0], y.shape[0]))
return [constant_zero]
):
return [zeros((x.shape[0], y.shape[1]), dtype=node.outputs[0].type.dtype)]
Copy link
Preview

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes both inputs are 2D matrices, but the function doesn't validate the input dimensions. If either x or y has fewer than 2 dimensions, accessing x.shape[0] or y.shape[1] could cause an IndexError.

Copilot uses AI. Check for mistakes.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the whole point of this PR. If you see a Dot it must be with two matrices. make_node validates it

Copy link

codecov bot commented Jul 24, 2025

Codecov Report

Attention: Patch coverage is 95.62044% with 6 lines in your changes missing coverage. Please review.

Project coverage is 81.53%. Comparing base (12213d0) to head (7bba232).
Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/math.py 88.09% 2 Missing and 3 partials ⚠️
pytensor/tensor/rewriting/linalg.py 50.00% 0 Missing and 1 partial ⚠️

❌ Your patch status has failed because the patch coverage (95.62%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1538      +/-   ##
==========================================
+ Coverage   81.49%   81.53%   +0.03%     
==========================================
  Files         232      230       -2     
  Lines       53122    53003     -119     
  Branches     9444     9410      -34     
==========================================
- Hits        43292    43215      -77     
+ Misses       7382     7360      -22     
+ Partials     2448     2428      -20     
Files with missing lines Coverage Δ
pytensor/tensor/basic.py 91.84% <ø> (ø)
pytensor/tensor/blas.py 73.22% <100.00%> (-0.33%) ⬇️
pytensor/tensor/rewriting/blas.py 91.10% <100.00%> (+1.82%) ⬆️
pytensor/tensor/rewriting/math.py 90.30% <100.00%> (+1.02%) ⬆️
pytensor/tensor/rewriting/subtensor_lift.py 90.95% <100.00%> (-0.11%) ⬇️
pytensor/tensor/rewriting/linalg.py 92.06% <50.00%> (-0.02%) ⬇️
pytensor/tensor/math.py 92.87% <88.09%> (+0.09%) ⬆️

... and 4 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.


x, y = node.inputs[0].owner.inputs
if len(clients) != 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we wouldn't rewrite the following graph:

x = dot(a, b).T
y = x + 1
fn = pytensor.function([a, b], [x, y])

?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would. dot is only used by the transpose.

Instead, we wouldn't rewrite

d = dot(x, y)
out1 = d.T
out2 = d+1
function([x, y], [out1, out2])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise we'd be introducing a 2nd dot computation right?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, and then we would need to go and try to avoid that doing something like #1537

So better not to create potential duplication, and if we care that much about lifting it, we need a global rewriter (if you're using clients you're a global rewriter) that reasons about all combinations of dot at once to decides the canonical form + how to make everyone else use that.

A bit like the Solve -> LU factor -> LU Solve, where we have to orchestrate how everyone will use the common LU factor

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

approving with some questions

if core_a_ndim > 2 or core_b_ndim > 2:
# Shouldn't happen, but here just in case
# Check if we have matrix-matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
if not (a_static_shape[-1] == 1 or b_static_shape[-2] == 1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not check the broadcastable flag here (instead of the static shape)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's the same, no reason

if sx[-1] is not None and sy[0] is not None and sx[-1] != sy[0]:
raise ValueError(
f"Incompatible shared dimension for dot product: {sx}, {sy}"
)
sz = sx[:-1] + sy[-1:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
sz = sx[:-1] + sy[-1:]
sz = [*sx[:-1], sy[-1]]

Slightly clearer?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't need the slice, it's sz = (sx[0], sy[1]) now

)

rval = xgrad, ygrad
if xgrad.type.shape != x.type.shape:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like the logic the comment is worried about should be implemented in the make_node, so that these checks here aren't necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how that's related to the grad here


def __str__(self):
return "dot"
return [[xshp[0], yshp[1]]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The output shape in make_node seemed to be worried about batch dims, but infer_shape doesn't have to?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infer_shape takes a well defined node and input shapes. Otherwise every op.infer_shape would be skeptical that it got the right number of shapes, and we don't do that

copy_stack_trace(node.outputs, new_out)
return new_out

z, a, x, y, b = node.inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wish we consistently used capital letters for matrix inputs and lowercase for scalar/vector.

These are consistent with the code for GEMM so it shouldn't be changed. I'm just complaining.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ancient code, don't want to change

if bval == 1: # best case a natural GER
rval = ger(z, a, xv, yv)
new_out = [rval]
elif bval == 0: # GER on zeros_like should be faster than GEMM
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you test this? When b=0 I thought GEMM totally ignores the z terms

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ancient code again.

I know gemv and ger (c) ignores. I have no experience with gemm directly, never showed up in my work, only dot22 and dot22scalar. I don't even know if we are introducing it these days.

one = ptb.as_tensor_variable(np.asarray(1, dtype=x.dtype))
zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# x and y are both vectors so this might qualifies for a GER
# x and y are both vectors so this might qualify as a GER

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixing decade old typos?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just reviewing what I am given man

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah github shows all these as green because of a tab change. Not the best visual diff indicator

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obviously the lesson here is to never change tabs

zero = ptb.as_tensor_variable(np.asarray(0, dtype=x.dtype))
if xb[1] and yb[0]:
# x and y are both vectors so this might qualifies for a GER
xv = x.dimshuffle(0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why dimshuffle vs mT?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is ancient code, if you see the relevant commit you'll notice I just remove one level of indentation by removing the outer check that's enforced by tracks.

@ricardoV94 ricardoV94 merged commit 0c13849 into pymc-devs:main Jul 25, 2025
69 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Do we need dot Op to handle the vector cases? Merge blas_scipy.py and blas.py
2 participants