@@ -1712,7 +1712,7 @@ def __mul__(self, other):
1712
1712
-------
1713
1713
:class:`pyttb.sptensor`
1714
1714
"""
1715
- if isinstance (other , (float ,int )):
1715
+ if isinstance (other , (float , int , np . number )):
1716
1716
return ttb .sptensor .from_data (self .subs , self .vals * other , self .shape )
1717
1717
1718
1718
if isinstance (other , (ttb .sptensor ,ttb .tensor ,ttb .ktensor )) and self .shape != other .shape :
@@ -1754,7 +1754,7 @@ def __rmul__(self, other):
1754
1754
-------
1755
1755
:class:`pyttb.sptensor`
1756
1756
"""
1757
- if isinstance (other , (float ,int )):
1757
+ if isinstance (other , (float , int , np . number )):
1758
1758
return self .__mul__ (other )
1759
1759
else :
1760
1760
assert False , "This object cannot be multiplied by sptensor"
@@ -2173,15 +2173,14 @@ def __repr__(self): # pragma: no cover
2173
2173
2174
2174
__str__ = __repr__
2175
2175
2176
- def ttm (self , matrices , mode , dims = None , transpose = False ):
2176
+ def ttm (self , matrices , dims = None , transpose = False ):
2177
2177
"""
2178
2178
Sparse tensor times matrix.
2179
2179
2180
2180
Parameters
2181
2181
----------
2182
2182
matrices: A matrix or list of matrices
2183
- mode:
2184
- dims:
2183
+ dims: :class:`Numpy.ndarray`, int
2185
2184
transpose: Transpose matrices to be multiplied
2186
2185
2187
2186
Returns
@@ -2190,10 +2189,15 @@ def ttm(self, matrices, mode, dims=None, transpose=False):
2190
2189
"""
2191
2190
if dims is None :
2192
2191
dims = np .arange (self .ndims )
2192
+ elif isinstance (dims , list ):
2193
+ dims = np .array (dims )
2194
+ elif np .isscalar (dims ) or isinstance (dims , list ):
2195
+ dims = np .array ([dims ])
2196
+
2193
2197
# Handle list of matrices
2194
2198
if isinstance (matrices , list ):
2195
2199
# Check dimensions are valid
2196
- [dims , vidx ] = tt_dimscheck (mode , self .ndims , len (matrices ))
2200
+ [dims , vidx ] = tt_dimscheck (dims , self .ndims , len (matrices ))
2197
2201
# Calculate individual products
2198
2202
Y = self .ttm (matrices [vidx [0 ]], dims [0 ], transpose = transpose )
2199
2203
for i in range (1 , dims .size ):
@@ -2208,33 +2212,34 @@ def ttm(self, matrices, mode, dims=None, transpose=False):
2208
2212
if transpose :
2209
2213
matrices = matrices .transpose ()
2210
2214
2211
- # Check mode
2212
- if not np .isscalar (mode ) or mode < 0 or mode > self .ndims - 1 :
2213
- assert False , "Mode must be in [0, ndims)"
2215
+ # Ensure this is the terminal single dimension case
2216
+ if not (dims .size == 1 and np .isin (dims , np .arange (self .ndims ))):
2217
+ assert False , "dims must contain values in [0,self.dims)"
2218
+ dims = dims [0 ]
2214
2219
2215
2220
# Compute the product
2216
2221
2217
2222
# Check that sizes match
2218
- if self .shape [mode ] != matrices .shape [1 ]:
2223
+ if self .shape [dims ] != matrices .shape [1 ]:
2219
2224
assert False , "Matrix shape doesn't match tensor shape"
2220
2225
2221
2226
# Compute the new size
2222
2227
siz = np .array (self .shape )
2223
- siz [mode ] = matrices .shape [0 ]
2228
+ siz [dims ] = matrices .shape [0 ]
2224
2229
2225
2230
# Compute self[mode]'
2226
- Xnt = ttb .tt_to_sparse_matrix (self , mode , True )
2231
+ Xnt = ttb .tt_to_sparse_matrix (self , dims , True )
2227
2232
2228
2233
# Reshape puts the reshaped things after the unchanged modes, transpose then puts it in front
2229
2234
idx = 0
2230
2235
2231
2236
# Convert to sparse matrix and do multiplication; generally result is sparse
2232
2237
Z = Xnt .dot (matrices .transpose ())
2233
2238
2234
- # Rearrange back into sparse tensor of original shape
2235
- Ynt = ttb .tt_from_sparse_matrix (Z , self . shape , mode , idx )
2239
+ # Rearrange back into sparse tensor of correct shape
2240
+ Ynt = ttb .tt_from_sparse_matrix (Z , siz , dims , idx )
2236
2241
2237
- if Z .nnz <= 0.5 * np .prod (siz ):
2242
+ if not isinstance ( Z , np . ndarray ) and Z .nnz <= 0.5 * np .prod (siz ):
2238
2243
return Ynt
2239
2244
else :
2240
2245
# TODO evaluate performance loss by casting into sptensor then tensor. I assume minimal since we are already
0 commit comments