Skip to content

Commit c2ee280

Browse files
committed
Only do reshapes in tensordot when needed
1 parent 3cff4f5 commit c2ee280

File tree

2 files changed

+44
-33
lines changed

2 files changed

+44
-33
lines changed

pytensor/tensor/math.py

Lines changed: 43 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2152,62 +2152,73 @@ def tensordot(
21522152
a = as_tensor_variable(a)
21532153
b = as_tensor_variable(b)
21542154
runtime_shape_a = a.shape
2155-
bcast_a = a.broadcastable
21562155
static_shape_a = a.type.shape
2157-
ndim_a = a.ndim
2156+
ndim_a = a.type.ndim
21582157
runtime_shape_b = b.shape
2159-
bcast_b = b.broadcastable
21602158
static_shape_b = b.type.shape
2161-
ndim_b = b.ndim
2159+
ndim_b = b.type.ndim
21622160
if na != nb:
21632161
raise ValueError(
21642162
"The number of axes supplied for tensordot must be equal for each tensor. "
21652163
f"Got {na} and {nb} respectively."
21662164
)
21672165
axes_a = list(normalize_axis_tuple(axes_a, ndim_a))
21682166
axes_b = list(normalize_axis_tuple(axes_b, ndim_b))
2167+
2168+
# The operation is only valid if the original dimensions match in length
2169+
# The ravelling of the dimensions to coerce the operation into a single dot
2170+
# could mask such errors, so we add an Assert if needed.
21692171
must_assert_runtime = False
2170-
for k in range(na):
2171-
ax_a = axes_a[k]
2172-
ax_b = axes_b[k]
2173-
if (bcast_a[ax_a] != bcast_b[ax_b]) or (
2172+
for ax_a, ax_b in zip(axes_a, axes_b, strict=True):
2173+
if (
21742174
static_shape_a[ax_a] is not None
21752175
and static_shape_b[ax_b] is not None
21762176
and static_shape_a[ax_a] != static_shape_b[ax_b]
21772177
):
21782178
raise ValueError(
2179-
"Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2179+
"Input arrays have inconsistent type shape along the axes "
21802180
"that are to be reduced with tensordot."
21812181
)
21822182
elif static_shape_a[ax_a] is None or static_shape_b[ax_b] is None:
21832183
if must_assert_runtime:
21842184
a = Assert(
21852185
"Input array shape along reduced axes of tensordot are not equal"
2186-
)(a, eq(a.shape[ax_a], b.shape[ax_b]))
2186+
)(a, eq(runtime_shape_a[ax_a], runtime_shape_b[ax_b]))
21872187
must_assert_runtime = True
21882188

2189-
# Move the axes to sum over to the end of "a"
2190-
# and to the front of "b"
2191-
notin = [k for k in range(ndim_a) if k not in axes_a]
2192-
newaxes_a = notin + axes_a
2193-
N2 = 1
2194-
for axis in axes_a:
2195-
N2 *= runtime_shape_a[axis]
2196-
newshape_a = (-1, N2)
2197-
olda = [runtime_shape_a[axis] for axis in notin]
2198-
2199-
notin = [k for k in range(ndim_b) if k not in axes_b]
2200-
newaxes_b = axes_b + notin
2201-
N2 = 1
2202-
for axis in axes_b:
2203-
N2 *= runtime_shape_b[axis]
2204-
newshape_b = (N2, -1)
2205-
oldb = [runtime_shape_b[axis] for axis in notin]
2206-
2207-
at = a.transpose(newaxes_a).reshape(newshape_a)
2208-
bt = b.transpose(newaxes_b).reshape(newshape_b)
2209-
res = _dot(at, bt)
2210-
return res.reshape(olda + oldb)
2189+
# Convert tensordot into a stacked dot product.
2190+
# We stack the summed axes and the non-summed axes of each tensor separately,
2191+
# and place the summed axes at the end of a and the beginning of b
2192+
non_summed_axes_a = [k for k in range(ndim_a) if k not in axes_a]
2193+
non_summed_dims_a = [runtime_shape_a[axis] for axis in non_summed_axes_a]
2194+
transpose_axes_a = non_summed_axes_a + axes_a
2195+
a_needs_reshape = len(non_summed_axes_a) > 1 or len(axes_a) > 1
2196+
2197+
non_summed_axes_b = [k for k in range(ndim_b) if k not in axes_b]
2198+
non_summed_dims_b = [runtime_shape_b[axis] for axis in non_summed_axes_b]
2199+
transpose_axes_b = axes_b + non_summed_axes_b
2200+
b_needs_reshape = len(axes_b) > 1 or len(non_summed_axes_b) > 1
2201+
2202+
# summed_size_a and summed_size_b must be the same,
2203+
# but to facilitate reasoning about useless reshapes we compute both from their shapes
2204+
at = a.transpose(transpose_axes_a)
2205+
if a_needs_reshape:
2206+
non_summed_size_a = variadic_mul(*non_summed_dims_a)
2207+
summed_size_a = variadic_mul(*[runtime_shape_a[axis] for axis in axes_a])
2208+
at = at.reshape((non_summed_size_a, summed_size_a))
2209+
2210+
bt = b.transpose(transpose_axes_b)
2211+
if b_needs_reshape:
2212+
non_summed_size_b = variadic_mul(*non_summed_dims_b)
2213+
summed_size_b = variadic_mul(*[runtime_shape_b[axis] for axis in axes_b])
2214+
bt = bt.reshape((summed_size_b, non_summed_size_b))
2215+
2216+
res = dot(at, bt)
2217+
2218+
if a_needs_reshape or b_needs_reshape:
2219+
res = res.reshape(non_summed_dims_a + non_summed_dims_b)
2220+
2221+
return res
22112222

22122223

22132224
def outer(x, y):

tests/tensor/test_math.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2278,7 +2278,7 @@ def test_type_shape(self):
22782278

22792279
with pytest.raises(
22802280
ValueError,
2281-
match="Input arrays have inconsistent broadcastable pattern or type shape",
2281+
match="Input arrays have inconsistent type shape",
22822282
):
22832283
tensordot(ones(shape=(7, 4)), ones(shape=(7, 4)), axes=1)
22842284

0 commit comments

Comments
 (0)