Skip to content

Commit

Permalink
TUCKER_ALS: TTM with negative values is broken in ttensor (sandialabs#62
Browse files Browse the repository at this point in the history
)

* Replace usage in tucker_als
* Update test for tucker_als to ensure result matches expectation
* Add early error handling in ttensor ttm for negative dims
  • Loading branch information
ntjohnson1 committed Mar 11, 2023
1 parent 992772b commit 3b305f5
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pyttb/pyttb_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def tt_dimscheck(

# Fix "minus" case
if np.max(dims) < 0:
# Check that all memebers in range
# Check that all members in range
if not np.all(np.isin(-dims, np.arange(0, N + 1))):
assert False, "Invalid magnitude for negative dims selection"
dims = np.setdiff1d(np.arange(1, N + 1), -dims) - 1
Expand Down
4 changes: 3 additions & 1 deletion pyttb/ttensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,9 @@ def ttm(self, matrix, dims=None, transpose=False):
dims = np.arange(self.ndims)
elif isinstance(dims, list):
dims = np.array(dims)
elif np.isscalar(dims) or isinstance(dims, list):
elif np.isscalar(dims):
if dims < 0:
raise ValueError("Negative dims is currently unsupported, see #62")
dims = np.array([dims])

if not isinstance(matrix, list):
Expand Down
15 changes: 5 additions & 10 deletions pyttb/tucker_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,11 @@ def tucker_als(

# Iterate over all N modes of the tensor
for n in dimorder:
if (
n == 0
): # TODO proposal to change ttm to include_dims and exclude_dims to resolve -0 ambiguity
dims = np.arange(1, tensor.ndims)
Utilde = tensor.ttm(U, dims, True)
else:
Utilde = tensor.ttm(U, -n, True)

# TODO proposal to change ttm to include_dims and exclude_dims to resolve -0 ambiguity
dims = np.arange(0, tensor.ndims)
dims = dims[dims != n]
Utilde = tensor.ttm(U, dims, True)
print(f"Utilde[{n}] = {Utilde}")
# Maximize norm(Utilde x_n W') wrt W and
# maintain orthonormality of W
U[n] = Utilde.nvecs(n, rank[n])
Expand All @@ -140,13 +137,11 @@ def tucker_als(
core = Utilde.ttm(U, n, True)

# Compute fit
# TODO this abs is missing from MATLAB, but I get negative numbers for trivial examples
normresidual = np.sqrt(abs(normX**2 - core.norm() ** 2))
fit = 1 - (normresidual / normX) # fraction explained by model
fitchange = abs(fitold - fit)

if iter % printitn == 0:
print(f" NormX: {normX} Core norm: {core.norm()}")
print(f" Iter {iter}: fit = {fit:e} fitdelta = {fitchange:7.1e}\n")

# Check for convergence
Expand Down
10 changes: 8 additions & 2 deletions tests/test_ttensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,9 +310,15 @@ def test_ttensor_ttm(random_ttensor):

# Negative Tests
big_wrong_size = 123
matrices[0] = np.random.random((big_wrong_size, big_wrong_size))
bad_matrices = matrices.copy()
bad_matrices[0] = np.random.random((big_wrong_size, big_wrong_size))
with pytest.raises(ValueError):
_ = ttensorInstance.ttm(matrices, np.arange(len(matrices)))
_ = ttensorInstance.ttm(bad_matrices, np.arange(len(bad_matrices)))

with pytest.raises(ValueError):
# Negative dims currently broken, ensure we catch early and
# remove once resolved
ttensorInstance.ttm(matrices, -1)


@pytest.mark.indevelopment
Expand Down
1 change: 1 addition & 0 deletions tests/test_tucker_als.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def test_tucker_als_tensor_default_init(capsys, sample_tensor):
(Solution, Uinit, output) = ttb.tucker_als(T, 2)
capsys.readouterr()
assert pytest.approx(output["fit"], 1) == 0
assert np.all(np.isclose(Solution.double(), T.double()))

(Solution, Uinit, output) = ttb.tucker_als(T, 2, init=Uinit)
capsys.readouterr()
Expand Down

0 comments on commit 3b305f5

Please sign in to comment.