Skip to content

Commit eb99634

Browse files
authored
Ttensor implementation (#51)
* TENSOR: Fix slices ref shen return value isn't scalar or vector. #41 * TTENSOR: Add tensor creation (partial support of core tensor types) and display * SPTENSOR: Add numpy scalar type for multiplication filter. * TTENSOR: Double, full, isequal, mtimes, ndims, size, uminus, uplus, and partial innerprod. * TTENSOR: TTV (finishes innerprod), mttkrp, and norm * TTENSOR: TTM, permute and minor cleanup. * TTENSOR: Reconstruct * TTENSOR: Nvecs * SPTENSOR: * Fix argument mismatch for ttm (modes s.b. dims) * Fix ttm for rectangular matrices * Make error message consitent with tensor TENSOR: * Fix error message * TTENSOR: Improve test coverage and corresponding bug fixes discovered.
1 parent eade612 commit eb99634

File tree

6 files changed

+940
-31
lines changed

6 files changed

+940
-31
lines changed

pyttb/sptensor.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1712,7 +1712,7 @@ def __mul__(self, other):
17121712
-------
17131713
:class:`pyttb.sptensor`
17141714
"""
1715-
if isinstance(other, (float,int)):
1715+
if isinstance(other, (float, int, np.number)):
17161716
return ttb.sptensor.from_data(self.subs, self.vals*other, self.shape)
17171717

17181718
if isinstance(other, (ttb.sptensor,ttb.tensor,ttb.ktensor)) and self.shape != other.shape:
@@ -1754,7 +1754,7 @@ def __rmul__(self, other):
17541754
-------
17551755
:class:`pyttb.sptensor`
17561756
"""
1757-
if isinstance(other, (float,int)):
1757+
if isinstance(other, (float, int, np.number)):
17581758
return self.__mul__(other)
17591759
else:
17601760
assert False, "This object cannot be multiplied by sptensor"
@@ -2173,15 +2173,14 @@ def __repr__(self): # pragma: no cover
21732173

21742174
__str__ = __repr__
21752175

2176-
def ttm(self, matrices, mode, dims=None, transpose=False):
2176+
def ttm(self, matrices, dims=None, transpose=False):
21772177
"""
21782178
Sparse tensor times matrix.
21792179
21802180
Parameters
21812181
----------
21822182
matrices: A matrix or list of matrices
2183-
mode:
2184-
dims:
2183+
dims: :class:`Numpy.ndarray`, int
21852184
transpose: Transpose matrices to be multiplied
21862185
21872186
Returns
@@ -2190,10 +2189,15 @@ def ttm(self, matrices, mode, dims=None, transpose=False):
21902189
"""
21912190
if dims is None:
21922191
dims = np.arange(self.ndims)
2192+
elif isinstance(dims, list):
2193+
dims = np.array(dims)
2194+
elif np.isscalar(dims) or isinstance(dims, list):
2195+
dims = np.array([dims])
2196+
21932197
# Handle list of matrices
21942198
if isinstance(matrices, list):
21952199
# Check dimensions are valid
2196-
[dims, vidx] = tt_dimscheck(mode, self.ndims, len(matrices))
2200+
[dims, vidx] = tt_dimscheck(dims, self.ndims, len(matrices))
21972201
# Calculate individual products
21982202
Y = self.ttm(matrices[vidx[0]], dims[0], transpose=transpose)
21992203
for i in range(1, dims.size):
@@ -2208,33 +2212,34 @@ def ttm(self, matrices, mode, dims=None, transpose=False):
22082212
if transpose:
22092213
matrices = matrices.transpose()
22102214

2211-
# Check mode
2212-
if not np.isscalar(mode) or mode < 0 or mode > self.ndims-1:
2213-
assert False, "Mode must be in [0, ndims)"
2215+
# Ensure this is the terminal single dimension case
2216+
if not (dims.size == 1 and np.isin(dims, np.arange(self.ndims))):
2217+
assert False, "dims must contain values in [0,self.dims)"
2218+
dims = dims[0]
22142219

22152220
# Compute the product
22162221

22172222
# Check that sizes match
2218-
if self.shape[mode] != matrices.shape[1]:
2223+
if self.shape[dims] != matrices.shape[1]:
22192224
assert False, "Matrix shape doesn't match tensor shape"
22202225

22212226
# Compute the new size
22222227
siz = np.array(self.shape)
2223-
siz[mode] = matrices.shape[0]
2228+
siz[dims] = matrices.shape[0]
22242229

22252230
# Compute self[mode]'
2226-
Xnt = ttb.tt_to_sparse_matrix(self, mode, True)
2231+
Xnt = ttb.tt_to_sparse_matrix(self, dims, True)
22272232

22282233
# Reshape puts the reshaped things after the unchanged modes, transpose then puts it in front
22292234
idx = 0
22302235

22312236
# Convert to sparse matrix and do multiplication; generally result is sparse
22322237
Z = Xnt.dot(matrices.transpose())
22332238

2234-
# Rearrange back into sparse tensor of original shape
2235-
Ynt = ttb.tt_from_sparse_matrix(Z, self.shape, mode, idx)
2239+
# Rearrange back into sparse tensor of correct shape
2240+
Ynt = ttb.tt_from_sparse_matrix(Z, siz, dims, idx)
22362241

2237-
if Z.nnz <= 0.5 * np.prod(siz):
2242+
if not isinstance(Z, np.ndarray) and Z.nnz <= 0.5 * np.prod(siz):
22382243
return Ynt
22392244
else:
22402245
# TODO evaluate performance loss by casting into sptensor then tensor. I assume minimal since we are already

pyttb/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -921,7 +921,7 @@ def ttm(self, matrix, dims=None, transpose=False):
921921
assert False, "matrix must be of type numpy.ndarray"
922922

923923
if not (dims.size == 1 and np.isin(dims, np.arange(self.ndims))):
924-
assert False, "dims must contain values in [0,self.dims]"
924+
assert False, "dims must contain values in [0,self.dims)"
925925

926926
# old version (ver=0)
927927
shape = np.array(self.shape)

0 commit comments

Comments
 (0)