Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor indexing #116

Merged
merged 5 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 38 additions & 23 deletions pyttb/sptensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import logging
import warnings
from collections.abc import Sequence
from collections.abc import Iterable, Sequence
from typing import Any, Callable, List, Optional, Tuple, Union, cast, overload

import numpy as np
Expand Down Expand Up @@ -620,7 +620,7 @@ def innerprod(
if self.shape != other.shape:
assert False, "Sptensor and tensor must be same shape for innerproduct"
[subsSelf, valsSelf] = self.find()
valsOther = other[subsSelf, "extract"]
valsOther = other[subsSelf.transpose(), "extract"]
return valsOther.transpose().dot(valsSelf)

if isinstance(other, (ttb.ktensor, ttb.ttensor)): # pragma: no cover
Expand Down Expand Up @@ -685,7 +685,7 @@ def is_length_2(x):

if isinstance(B, ttb.tensor):
BB = sptensor.from_data(
self.subs, B[self.subs, "extract"][:, None], self.shape
self.subs, B[self.subs.transpose(), "extract"][:, None], self.shape
)
C = self.logical_and(BB)
return C
Expand Down Expand Up @@ -1053,7 +1053,7 @@ def scale(self, factor: np.ndarray, dims: Union[float, np.ndarray]) -> sptensor:
assert False, "Size mismatch in scale"
return ttb.sptensor.from_data(
self.subs,
self.vals * factor[self.subs[:, dims], "extract"][:, None],
self.vals * factor[self.subs[:, dims].transpose(), "extract"][:, None],
self.shape,
)
if isinstance(factor, ttb.sptensor):
Expand Down Expand Up @@ -1368,9 +1368,9 @@ def __getitem__(self, item):
if (
isinstance(item, np.ndarray)
and len(item.shape) == 2
and item.shape[1] == self.ndims
and item.shape[0] == self.ndims
):
srchsubs = np.array(item)
srchsubs = np.array(item.transpose())

# *** CASE 2b: Linear indexing ***
else:
Expand Down Expand Up @@ -1463,21 +1463,21 @@ def _set_subscripts(self, key, value):
tt_subscheck(newsubs, nargout=False)

# Error check on subscripts
if newsubs.shape[1] < self.ndims:
if newsubs.shape[0] < self.ndims:
assert False, "Invalid subscripts"

# Check for expanding the order
if newsubs.shape[1] > self.ndims:
if newsubs.shape[0] > self.ndims:
newshape = list(self.shape)
# TODO no need for loop, just add correct size
for _ in range(self.ndims, newsubs.shape[1]):
for _ in range(self.ndims, newsubs.shape[0]):
newshape.append(1)
if self.subs.size > 0:
self.subs = np.concatenate(
(
self.subs,
np.ones(
(self.shape[0], newsubs.shape[1] - self.ndims),
(self.shape[0], newsubs.shape[0] - self.ndims),
dtype=int,
),
),
Expand All @@ -1497,7 +1497,7 @@ def _set_subscripts(self, key, value):

# Determine number of nonzeros being inserted.
# (This is determined by number of subscripts)
newnnz = newsubs.shape[0]
newnnz = newsubs.shape[1]

# Error check on size of newvals
if newvals.size == 1:
Expand All @@ -1510,7 +1510,7 @@ def _set_subscripts(self, key, value):
assert False, "Number of subscripts and number of values do not match!"

# Remove duplicates and print warning if any duplicates were removed
newsubs, idx = np.unique(newsubs, axis=0, return_index=True)
newsubs, idx = np.unique(newsubs.transpose(), axis=0, return_index=True)
if newsubs.shape[0] != newnnz:
warnings.warn("Duplicate assignments discarded")

Expand Down Expand Up @@ -1647,6 +1647,8 @@ def _set_subtensor(self, key, value):
newsz.append(self.shape[n])
else:
newsz.append(max([self.shape[n], key[n].stop]))
elif isinstance(key[n], Iterable):
newsz.append(max([self.shape[n], max(key[n]) + 1]))
else:
newsz.append(max([self.shape[n], key[n] + 1]))

Expand All @@ -1660,7 +1662,7 @@ def _set_subtensor(self, key, value):
)
else:
newsz.append(key[n].stop)
elif isinstance(key[n], np.ndarray):
elif isinstance(key[n], (np.ndarray, Iterable)):
newsz.append(max(key[n]) + 1)
else:
newsz.append(key[n] + 1)
Expand All @@ -1671,7 +1673,8 @@ def _set_subtensor(self, key, value):
self.subs = np.append(
self.subs,
np.zeros(
shape=(self.subs.shape[0], len(self.shape) - self.subs.shape[1])
shape=(self.subs.shape[0], len(self.shape) - self.subs.shape[1]),
dtype=int,
),
axis=1,
)
Expand All @@ -1689,15 +1692,19 @@ def _set_subtensor(self, key, value):
if isinstance(value, (int, float)):
# Determine number of dimensions (may be larger than current number)
N = len(key)
keyCopy = np.array(key)
keyCopy = [None] * N
# Figure out how many indices are in each dimension
nssubs = np.zeros((N, 1))
for n in range(0, N):
if isinstance(key[n], slice):
# Generate slice explicitly to determine its length
keyCopy[n] = np.arange(0, self.shape[n])[key[n]]
indicesInN = len(keyCopy[n])
elif isinstance(key[n], Iterable):
keyCopy[n] = key[n]
indicesInN = len(key[n])
else:
keyCopy[n] = key[n]
indicesInN = 1
nssubs[n] = indicesInN

Expand Down Expand Up @@ -1806,7 +1813,7 @@ def __eq__(self, other):
]

# Find where their nonzeros intersect
othervals = other[self.subs, "extract"]
othervals = other[self.subs.transpose(), "extract"]
znzsubs = self.subs[(othervals[:, None] == self.vals).transpose()[0], :]

return sptensor.from_data(
Expand Down Expand Up @@ -1887,7 +1894,7 @@ def __ne__(self, other):
subs1 = np.empty((0, self.subs.shape[1]))
# find entries where x is nonzero but not equal to y
subs2 = self.subs[
self.vals.transpose()[0] != other[self.subs, "extract"], :
self.vals.transpose()[0] != other[self.subs.transpose(), "extract"], :
]
if subs2.size == 0:
subs2 = np.empty((0, self.subs.shape[1]))
Expand Down Expand Up @@ -2002,7 +2009,7 @@ def __mul__(self, other):
)
if isinstance(other, ttb.tensor):
csubs = self.subs
cvals = self.vals * other[csubs, "extract"][:, None]
cvals = self.vals * other[csubs.transpose(), "extract"][:, None]
return ttb.sptensor.from_data(csubs, cvals, self.shape)
if isinstance(other, ttb.ktensor):
csubs = self.subs
Expand Down Expand Up @@ -2124,7 +2131,7 @@ def __le__(self, other):

# self nonzero
subs2 = self.subs[
self.vals.transpose()[0] <= other[self.subs, "extract"], :
self.vals.transpose()[0] <= other[self.subs.transpose(), "extract"], :
]

# assemble
Expand Down Expand Up @@ -2212,7 +2219,9 @@ def __lt__(self, other):
subs1 = subs1[ttb.tt_setdiff_rows(subs1, self.subs), :]

# self nonzero
subs2 = self.subs[self.vals.transpose()[0] < other[self.subs, "extract"], :]
subs2 = self.subs[
self.vals.transpose()[0] < other[self.subs.transpose(), "extract"], :
]

# assemble
subs = np.vstack((subs1, subs2))
Expand Down Expand Up @@ -2267,7 +2276,10 @@ def __ge__(self, other):

# self nonzero
subs2 = self.subs[
(self.vals >= other[self.subs, "extract"][:, None]).transpose()[0], :
(
self.vals >= other[self.subs.transpose(), "extract"][:, None]
).transpose()[0],
:,
]

# assemble
Expand Down Expand Up @@ -2325,7 +2337,10 @@ def __gt__(self, other):

# self and other nonzero
subs2 = self.subs[
(self.vals > other[self.subs, "extract"][:, None]).transpose()[0], :
(
self.vals > other[self.subs.transpose(), "extract"][:, None]
).transpose()[0],
:,
]

# assemble
Expand Down Expand Up @@ -2428,7 +2443,7 @@ def __truediv__(self, other):

if isinstance(other, ttb.tensor):
csubs = self.subs
cvals = self.vals / other[csubs, "extract"][:, None]
cvals = self.vals / other[csubs.transpose(), "extract"][:, None]
return ttb.sptensor.from_data(csubs, cvals, self.shape)
if isinstance(other, ttb.ktensor):
# TODO consider removing epsilon and generating nans consistent with above
Expand Down
62 changes: 52 additions & 10 deletions pyttb/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import logging
from collections.abc import Iterable
from itertools import permutations
from math import factorial
from typing import Any, Callable, List, Optional, Tuple, Union
Expand Down Expand Up @@ -1276,11 +1277,14 @@ def __setitem__(self, key, value):
# Figure out if we are doing a subtensor, a list of subscripts or a list of
# linear indices
access_type = "error"
if self.ndims <= 1:
if isinstance(key, np.ndarray):
access_type = "subscripts"
else:
# TODO pull out this big decision tree into a function
if isinstance(key, (float, int, np.generic, slice)):
access_type = "linear indices"
elif self.ndims <= 1:
if isinstance(key, tuple):
access_type = "subtensor"
elif isinstance(key, np.ndarray):
access_type = "subscripts"
else:
if isinstance(key, np.ndarray):
if len(key.shape) > 1 and key.shape[1] >= self.ndims:
Expand All @@ -1289,10 +1293,14 @@ def __setitem__(self, key, value):
access_type = "linear indices"
elif isinstance(key, tuple):
validSubtensor = [
isinstance(keyElement, (int, slice)) for keyElement in key
isinstance(keyElement, (int, slice, Iterable)) for keyElement in key
]
if np.all(validSubtensor):
access_type = "subtensor"
elif isinstance(key, Iterable):
key = np.array(key)
if len(key.shape) == 1 or key.shape[1] == 1:
access_type = "linear indices"

# Case 1: Rectangular Subtensor
if access_type == "subtensor":
Expand All @@ -1310,10 +1318,14 @@ def __setitem__(self, key, value):

def _set_linear(self, key, value):
idx = key
if (idx > np.prod(self.shape)).any():
if not isinstance(idx, slice) and (idx > np.prod(self.shape)).any():
assert (
False
), "TTB:BadIndex In assignment X[I] = Y, a tensor X cannot be resized"
if isinstance(key, (int, float, np.generic)):
idx = np.array([key])
elif isinstance(key, slice):
idx = np.array(range(np.prod(self.shape))[key])
idx = tt_ind2sub(self.shape, idx)
if idx.shape[0] == 1:
self.data[tuple(idx[0, :])] = value
Expand All @@ -1333,6 +1345,14 @@ def _set_subtensor(self, key, value):
sliceCheck.append(1)
else:
sliceCheck.append(element.stop)
elif isinstance(element, Iterable):
if any(
not isinstance(entry, (float, int, np.generic)) for entry in element
):
raise ValueError(
f"Entries for setitem must be numeric but recieved, {element}"
)
sliceCheck.append(max(element))
else:
sliceCheck.append(element)
bsiz = np.array(sliceCheck)
Expand Down Expand Up @@ -1443,6 +1463,17 @@ def __getitem__(self, item):
-------
:class:`pyttb.tensor` or :class:`numpy.ndarray`
"""
# Case 0: Single Index Linear
if isinstance(item, (int, float, np.generic, slice)):
if isinstance(item, (int, float, np.generic)):
idx = np.array(item)
elif isinstance(item, slice):
idx = np.array(range(np.prod(self.shape))[item])
a = np.squeeze(
self.data[tuple(ttb.tt_ind2sub(self.shape, idx).transpose())]
)
# Todo if row make column?
return ttb.tt_subsubsref(a, idx)
# Case 1: Rectangular Subtensor
if (
isinstance(item, tuple)
Expand Down Expand Up @@ -1484,17 +1515,28 @@ def __getitem__(self, item):
return a

# *** CASE 2a: Subscript indexing ***
if len(item) > 1 and isinstance(item[-1], str) and item[-1] == "extract":
if isinstance(item, np.ndarray) and len(item) > 1:
# Extract array of subscripts
subs = np.array(item)
a = np.squeeze(self.data[tuple(subs)])
# TODO if is row make column?
return ttb.tt_subsubsref(a, subs)
if (
len(item) > 1
and isinstance(item[0], np.ndarray)
and isinstance(item[-1], str)
and item[-1] == "extract"
):
# TODO dry this up
subs = np.array(item[0])
a = np.squeeze(self.data[tuple(subs.transpose())])
a = np.squeeze(self.data[tuple(subs)])
# TODO if is row make column?
return ttb.tt_subsubsref(a, subs)

# Case 2b: Linear Indexing
if len(item) >= 2 and not isinstance(item[-1], str):
if isinstance(item, tuple) and len(item) >= 2 and not isinstance(item[-1], str):
assert False, "Linear indexing requires single input array"
idx = item[0]
idx = np.array(item)
a = np.squeeze(self.data[tuple(ttb.tt_ind2sub(self.shape, idx).transpose())])
# Todo if row make column?
return ttb.tt_subsubsref(a, idx)
Expand Down
Loading