@@ -2152,62 +2152,73 @@ def tensordot(
21522152 a = as_tensor_variable (a )
21532153 b = as_tensor_variable (b )
21542154 runtime_shape_a = a .shape
2155- bcast_a = a .broadcastable
21562155 static_shape_a = a .type .shape
2157- ndim_a = a .ndim
2156+ ndim_a = a .type . ndim
21582157 runtime_shape_b = b .shape
2159- bcast_b = b .broadcastable
21602158 static_shape_b = b .type .shape
2161- ndim_b = b .ndim
2159+ ndim_b = b .type . ndim
21622160 if na != nb :
21632161 raise ValueError (
21642162 "The number of axes supplied for tensordot must be equal for each tensor. "
21652163 f"Got { na } and { nb } respectively."
21662164 )
21672165 axes_a = list (normalize_axis_tuple (axes_a , ndim_a ))
21682166 axes_b = list (normalize_axis_tuple (axes_b , ndim_b ))
2167+
2168+ # The operation is only valid if the original dimensions match in length
2169+ # The ravelling of the dimensions to coerce the operation into a single dot
2170+ # could mask such errors, so we add an Assert if needed.
21692171 must_assert_runtime = False
2170- for k in range (na ):
2171- ax_a = axes_a [k ]
2172- ax_b = axes_b [k ]
2173- if (bcast_a [ax_a ] != bcast_b [ax_b ]) or (
2172+ for ax_a , ax_b in zip (axes_a , axes_b , strict = True ):
2173+ if (
21742174 static_shape_a [ax_a ] is not None
21752175 and static_shape_b [ax_b ] is not None
21762176 and static_shape_a [ax_a ] != static_shape_b [ax_b ]
21772177 ):
21782178 raise ValueError (
2179- "Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2179+ "Input arrays have inconsistent type shape along the axes "
21802180 "that are to be reduced with tensordot."
21812181 )
21822182 elif static_shape_a [ax_a ] is None or static_shape_b [ax_b ] is None :
21832183 if must_assert_runtime :
21842184 a = Assert (
21852185 "Input array shape along reduced axes of tensordot are not equal"
2186- )(a , eq (a . shape [ax_a ], b . shape [ax_b ]))
2186+ )(a , eq (runtime_shape_a [ax_a ], runtime_shape_b [ax_b ]))
21872187 must_assert_runtime = True
21882188
2189- # Move the axes to sum over to the end of "a"
2190- # and to the front of "b"
2191- notin = [k for k in range (ndim_a ) if k not in axes_a ]
2192- newaxes_a = notin + axes_a
2193- N2 = 1
2194- for axis in axes_a :
2195- N2 *= runtime_shape_a [axis ]
2196- newshape_a = (- 1 , N2 )
2197- olda = [runtime_shape_a [axis ] for axis in notin ]
2198-
2199- notin = [k for k in range (ndim_b ) if k not in axes_b ]
2200- newaxes_b = axes_b + notin
2201- N2 = 1
2202- for axis in axes_b :
2203- N2 *= runtime_shape_b [axis ]
2204- newshape_b = (N2 , - 1 )
2205- oldb = [runtime_shape_b [axis ] for axis in notin ]
2206-
2207- at = a .transpose (newaxes_a ).reshape (newshape_a )
2208- bt = b .transpose (newaxes_b ).reshape (newshape_b )
2209- res = _dot (at , bt )
2210- return res .reshape (olda + oldb )
2189+ # Convert tensordot into a stacked dot product.
2190+ # We stack the summed axes and the non-summed axes of each tensor separately,
2191+ # and place the summed axes at the end of a and the beginning of b
2192+ non_summed_axes_a = [k for k in range (ndim_a ) if k not in axes_a ]
2193+ non_summed_dims_a = [runtime_shape_a [axis ] for axis in non_summed_axes_a ]
2194+ transpose_axes_a = non_summed_axes_a + axes_a
2195+ a_needs_reshape = len (non_summed_axes_a ) > 1 or len (axes_a ) > 1
2196+
2197+ non_summed_axes_b = [k for k in range (ndim_b ) if k not in axes_b ]
2198+ non_summed_dims_b = [runtime_shape_b [axis ] for axis in non_summed_axes_b ]
2199+ transpose_axes_b = axes_b + non_summed_axes_b
2200+ b_needs_reshape = len (axes_b ) > 1 or len (non_summed_axes_b ) > 1
2201+
2202+ # summed_size_a and summed_size_b must be the same,
2203+ # but to facilitate reasoning about useless reshapes we compute both from their shapes
2204+ at = a .transpose (transpose_axes_a )
2205+ if a_needs_reshape :
2206+ non_summed_size_a = variadic_mul (* non_summed_dims_a )
2207+ summed_size_a = variadic_mul (* [runtime_shape_a [axis ] for axis in axes_a ])
2208+ at = at .reshape ((non_summed_size_a , summed_size_a ))
2209+
2210+ bt = b .transpose (transpose_axes_b )
2211+ if b_needs_reshape :
2212+ non_summed_size_b = variadic_mul (* non_summed_dims_b )
2213+ summed_size_b = variadic_mul (* [runtime_shape_b [axis ] for axis in axes_b ])
2214+ bt = bt .reshape ((summed_size_b , non_summed_size_b ))
2215+
2216+ res = dot (at , bt )
2217+
2218+ if a_needs_reshape or b_needs_reshape :
2219+ res = res .reshape (non_summed_dims_a + non_summed_dims_b )
2220+
2221+ return res
22112222
22122223
22132224def outer (x , y ):
0 commit comments