Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove subsample default in OrthogonalProcrustesAlignment, improve tests #28

Merged
merged 8 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 20 additions & 5 deletions cebra/data/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# https://github.com/AdaptiveMotorControlLab/CEBRA/LICENSE.md
#
import copy
import warnings
from typing import List, Optional, Union

import joblib
Expand Down Expand Up @@ -62,7 +63,7 @@ class OrthogonalProcrustesAlignment:
Procrustes problem on.
"""

def __init__(self, top_k: int = 5, subsample: int = 500):
def __init__(self, top_k: int = 5, subsample: Optional[int] = None):
self.subsample = subsample
self.top_k = top_k

Expand Down Expand Up @@ -178,14 +179,28 @@ def fit(

# Get the whole data to align and only the selected closest samples
# from the reference data.
X = data[:, None].repeat(5, axis=1).reshape(-1, data.shape[1])
X = data[:, None].repeat(self.top_k, axis=1).reshape(-1, data.shape[1])
Y = ref_data[target_idx].reshape(-1, ref_data.shape[1])

# Augment data and reference data so that same size
if self.subsample is not None:
idc = np.random.choice(len(X), self.subsample)
X = X[idc]
Y = Y[idc]
if self.subsample > len(X):
warnings.warn(
f"The number of datapoints in the dataset ({len(X)}) "
f"should be larger than the 'subsample' "
f"parameter ({self.subsample}). Ignoring subsampling and "
f"computing alignment on the full dataset instead, which will "
f"give better results.")
else:
if self.subsample < 1000:
warnings.warn(
"This function is experimental when the subsample dimension "
"is less than 1000. You can probably use the whole dataset "
"for alignment by setting subsample=None.")

idc = np.random.choice(len(X), self.subsample)
X = X[idc]
Y = Y[idc]

# Compute orthogonal matrix that most closely maps X to Y using the orthogonal Procrustes problem.
self._transform, _ = scipy.linalg.orthogonal_procrustes(X, Y)
Expand Down
47 changes: 20 additions & 27 deletions tests/test_data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ def test_orthogonal_alignment_shapes(ref_data, data, ref_labels, labels):
assert _does_shape_match(data, aligned_embedding)

# Test with non-default parameters
alignment_model = cebra_data_helper.OrthogonalProcrustesAlignment(
top_k=10, subsample=1000)
alignment_model = cebra_data_helper.OrthogonalProcrustesAlignment(top_k=10)

aligned_embedding = alignment_model.fit_transform(ref_data, data,
ref_labels, labels)
assert _does_shape_match(data, aligned_embedding)
assert _does_shape_match(data, aligned_embedding), (data.shape,
aligned_embedding.shape)


@pytest.mark.parametrize("ref_data,data,ref_labels,labels,match",
Expand All @@ -144,8 +144,8 @@ def test_invalid_orthogonal_alignment(ref_data, data, ref_labels, labels,
def test_orthogonal_alignment_without_labels():
random_seed = 2160
np.random.seed(random_seed)
embedding_100_4d = np.random.uniform(0, 1, (100, 4))
embedding_100_4d_2 = np.random.uniform(0, 1, (100, 4))
embedding_100_4d = np.random.uniform(0, 1, (1000, 4))
embedding_100_4d_2 = np.random.uniform(0, 1, (1000, 4))

alignment_model = cebra_data_helper.OrthogonalProcrustesAlignment()

Expand All @@ -156,18 +156,15 @@ def test_orthogonal_alignment_without_labels():
aligned_embedding_without_labels = alignment_model.transform(
embedding_100_4d_2)

assert np.allclose(aligned_embedding,
aligned_embedding_without_labels,
atol=0.1)
assert np.allclose(aligned_embedding, aligned_embedding_without_labels)


def test_orthogonal_alignment():
random_seed = 483
np.random.seed(random_seed)
embedding_100_4d = np.random.uniform(0, 1, (100, 4))
orthogonal_matrix = scipy.stats.ortho_group.rvs(dim=4,
random_state=random_seed)
labels_100_1d = np.random.uniform(0, 1, (100, 1))
@pytest.mark.parametrize("seed", [483, 425, 166, 672, 123])
def test_orthogonal_alignment(seed):
np.random.seed(seed)
embedding_100_4d = np.random.uniform(0, 1, (1000, 4))
orthogonal_matrix = scipy.stats.ortho_group.rvs(dim=4, random_state=seed)
labels_100_1d = np.random.uniform(0, 1, (1000, 1))

alignment_model = cebra_data_helper.OrthogonalProcrustesAlignment()
aligned_embedding = alignment_model.fit_transform(ref_data=embedding_100_4d,
Expand All @@ -176,14 +173,14 @@ def test_orthogonal_alignment():
orthogonal_matrix),
ref_label=labels_100_1d,
label=labels_100_1d)
assert np.allclose(aligned_embedding, embedding_100_4d, atol=0.05)
assert np.allclose(aligned_embedding, embedding_100_4d, atol=0.03)

# and without labels
aligned_embedding = alignment_model.fit_transform(ref_data=embedding_100_4d,
data=np.dot(
embedding_100_4d,
orthogonal_matrix))
assert np.allclose(aligned_embedding, embedding_100_4d, atol=0.05)
assert np.allclose(aligned_embedding, embedding_100_4d, atol=0.03)


def _initialize_embedding_ensembling_data():
Expand Down Expand Up @@ -277,9 +274,7 @@ def test_embeddings_ensembling_without_labels():
embeddings=[embedding_100_4d, embedding_100_4d_2], labels=[None, None])
joint_embedding_without_labels = cebra_data_helper.ensemble_embeddings(
embeddings=[embedding_100_4d, embedding_100_4d_2])
assert np.allclose(joint_embedding,
joint_embedding_without_labels,
atol=0.05)
assert np.allclose(joint_embedding, joint_embedding_without_labels)


@pytest.mark.parametrize("embeddings,labels,n_jobs,match",
Expand All @@ -293,16 +288,14 @@ def test_invalid_embedding_ensembling(embeddings, labels, n_jobs, match):
)


def test_embedding_ensembling():
random_seed = 27
np.random.seed(random_seed)
@pytest.mark.parametrize("seed", [483, 426, 166, 674, 123])
def test_embedding_ensembling(seed):
np.random.seed(seed)
embedding_100_4d = np.random.uniform(0, 1, (100, 4))
labels_100_1d = np.random.uniform(0, 1, (100, 1))
orthogonal_matrix = scipy.stats.ortho_group.rvs(dim=4,
random_state=random_seed)
orthogonal_matrix = scipy.stats.ortho_group.rvs(dim=4, random_state=seed)
orthogonal_matrix_2 = scipy.stats.ortho_group.rvs(dim=4,
random_state=random_seed +
1)
random_state=seed + 1)

embedding_100_4d_2 = np.dot(embedding_100_4d, orthogonal_matrix)
embedding_100_4d_3 = np.dot(embedding_100_4d, orthogonal_matrix_2)
Expand Down