@@ -2152,62 +2152,79 @@ def tensordot(
21522152 a = as_tensor_variable (a )
21532153 b = as_tensor_variable (b )
21542154 runtime_shape_a = a .shape
2155- bcast_a = a .broadcastable
21562155 static_shape_a = a .type .shape
2157- ndim_a = a .ndim
2156+ ndim_a = a .type . ndim
21582157 runtime_shape_b = b .shape
2159- bcast_b = b .broadcastable
21602158 static_shape_b = b .type .shape
2161- ndim_b = b .ndim
2159+ ndim_b = b .type . ndim
21622160 if na != nb :
21632161 raise ValueError (
21642162 "The number of axes supplied for tensordot must be equal for each tensor. "
21652163 f"Got { na } and { nb } respectively."
21662164 )
21672165 axes_a = list (normalize_axis_tuple (axes_a , ndim_a ))
21682166 axes_b = list (normalize_axis_tuple (axes_b , ndim_b ))
2167+
2168+ # The operation is only valid if the original dimensions match in length
2169+ # The ravelling of the dimensions to coerce the operation into a single dot
2170+ # could mask such errors, so we add an Assert if needed.
21692171 must_assert_runtime = False
2170- for k in range (na ):
2171- ax_a = axes_a [k ]
2172- ax_b = axes_b [k ]
2173- if (bcast_a [ax_a ] != bcast_b [ax_b ]) or (
2172+ for ax_a , ax_b in zip (axes_a , axes_b , strict = True ):
2173+ if (
21742174 static_shape_a [ax_a ] is not None
21752175 and static_shape_b [ax_b ] is not None
21762176 and static_shape_a [ax_a ] != static_shape_b [ax_b ]
21772177 ):
21782178 raise ValueError (
2179- "Input arrays have inconsistent broadcastable pattern or type shape along the axes "
2179+ "Input arrays have inconsistent type shape along the axes "
21802180 "that are to be reduced with tensordot."
21812181 )
21822182 elif static_shape_a [ax_a ] is None or static_shape_b [ax_b ] is None :
21832183 if must_assert_runtime :
21842184 a = Assert (
21852185 "Input array shape along reduced axes of tensordot are not equal"
2186- )(a , eq (a . shape [ax_a ], b . shape [ax_b ]))
2186+ )(a , eq (runtime_shape_a [ax_a ], runtime_shape_b [ax_b ]))
21872187 must_assert_runtime = True
21882188
2189- # Move the axes to sum over to the end of "a"
2190- # and to the front of "b"
2191- notin = [k for k in range (ndim_a ) if k not in axes_a ]
2192- newaxes_a = notin + axes_a
2193- N2 = 1
2194- for axis in axes_a :
2195- N2 *= runtime_shape_a [axis ]
2196- newshape_a = (- 1 , N2 )
2197- olda = [runtime_shape_a [axis ] for axis in notin ]
2198-
2199- notin = [k for k in range (ndim_b ) if k not in axes_b ]
2200- newaxes_b = axes_b + notin
2201- N2 = 1
2202- for axis in axes_b :
2203- N2 *= runtime_shape_b [axis ]
2204- newshape_b = (N2 , - 1 )
2205- oldb = [runtime_shape_b [axis ] for axis in notin ]
2206-
2207- at = a .transpose (newaxes_a ).reshape (newshape_a )
2208- bt = b .transpose (newaxes_b ).reshape (newshape_b )
2209- res = _dot (at , bt )
2210- return res .reshape (olda + oldb )
2189+ # Convert tensordot into a stacked dot product.
2190+ # We stack the summed axes and the non-summed axes of each tensor separately,
2191+ # and place the summed axes at the end of a and the beginning of b
2192+ non_summed_axes_a = [k for k in range (ndim_a ) if k not in axes_a ]
2193+ non_summed_dims_a = [runtime_shape_a [axis ] for axis in non_summed_axes_a ]
2194+ transpose_axes_a = non_summed_axes_a + axes_a
2195+ # We only need a reshape when we need to combine summed or non-summed dims
2196+ # or introduce a new dimension (expand_dims), when doing a non-scalar outer product (axes = 0)
2197+ a_needs_reshape = (ndim_a != 0 ) and (
2198+ (len (non_summed_axes_a ) > 1 ) or (len (axes_a ) != 1 )
2199+ )
2200+
2201+ non_summed_axes_b = [k for k in range (ndim_b ) if k not in axes_b ]
2202+ non_summed_dims_b = [runtime_shape_b [axis ] for axis in non_summed_axes_b ]
2203+ transpose_axes_b = axes_b + non_summed_axes_b
2204+ b_needs_reshape = (ndim_b != 0 ) and (
2205+ (len (non_summed_axes_b ) > 1 ) or (len (axes_b ) != 1 )
2206+ )
2207+
2208+ # summed_size_a and summed_size_b must be the same,
2209+ # but to facilitate reasoning about useless reshapes we compute both from their shapes
2210+ at = a .transpose (transpose_axes_a )
2211+ if a_needs_reshape :
2212+ non_summed_size_a = variadic_mul (* non_summed_dims_a )
2213+ summed_size_a = variadic_mul (* [runtime_shape_a [axis ] for axis in axes_a ])
2214+ at = at .reshape ((non_summed_size_a , summed_size_a ))
2215+
2216+ bt = b .transpose (transpose_axes_b )
2217+ if b_needs_reshape :
2218+ non_summed_size_b = variadic_mul (* non_summed_dims_b )
2219+ summed_size_b = variadic_mul (* [runtime_shape_b [axis ] for axis in axes_b ])
2220+ bt = bt .reshape ((summed_size_b , non_summed_size_b ))
2221+
2222+ res = dot (at , bt )
2223+
2224+ if a_needs_reshape or b_needs_reshape :
2225+ res = res .reshape (non_summed_dims_a + non_summed_dims_b )
2226+
2227+ return res
22112228
22122229
22132230def outer (x , y ):
0 commit comments