Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix consistency labels ordering and simplify #87

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions cebra/integrations/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,10 @@ def __init__(
self._define_ax(axis)
scores = self._check_array(scores)
# Check the values dimensions
if scores.ndim > 2:
if scores.ndim >= 2:
raise ValueError(
f"Invalid scores dimensions, expect 1D, got {scores.ndim}D.")

self.labels = self._compute_labels(scores,
pairs=pairs,
datasets=datasets)
Expand Down Expand Up @@ -609,7 +610,7 @@ def _compute_labels(
"got either both or one of them set to None.")
else:
datasets = self._check_array(datasets)
pairs = self._check_array(pairs)
pairs = self.pairs = self._check_array(pairs)

if len(pairs.shape) == 2:
compared_items = list(sorted(set(pairs[:, 0])))
Expand Down Expand Up @@ -651,12 +652,26 @@ def _to_heatmap_format(

values = np.concatenate(values)

pairs = self.pairs

if pairs.ndim == 3:
pairs = pairs[0]

assert len(pairs) == len(values), (self.pairs.shape, len(values))
score_dict = {tuple(pair): value for pair, value in zip(pairs, values)}

if self.labels is None:
n_grid = self.score

heatmap_values = np.zeros((len(self.labels), len(self.labels)))

heatmap_values[:] = float("nan")
heatmap_values[np.eye(len(self.labels)) == 0] = values
for i, label_i in enumerate(self.labels):
for j, label_j in enumerate(self.labels):
if i == j:
heatmap_values[i, j] = float("nan")
else:
heatmap_values[i, j] = score_dict[label_i, label_j]

return np.minimum(heatmap_values * 100, 99)

Expand Down
84 changes: 30 additions & 54 deletions cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _consistency_datasets(
Returns:
A list of scores obtained between embeddings from different datasets (first element),
a list of pairs of IDs corresponding to the scores (second element), and a list of the
datasets (third element).
dataset IDs (third element).

"""
if labels is None:
Expand Down Expand Up @@ -217,7 +217,7 @@ def _consistency_datasets(
pairs = np.array(pairs)[between_dataset]
scores = _average_scores(np.array(scores)[between_dataset], pairs)

return (scores, pairs, datasets)
return (scores, pairs, np.array(dataset_ids))


def _average_scores(scores: Union[npt.NDArray, list], pairs: Union[npt.NDArray,
Expand Down Expand Up @@ -246,61 +246,34 @@ def _average_scores(scores: Union[npt.NDArray, list], pairs: Union[npt.NDArray,

def _consistency_runs(
embeddings: List[Union[npt.NDArray, torch.Tensor]],
dataset_ids: Optional[List[Union[int, str, float]]],
) -> Tuple[npt.NDArray, npt.NDArray, npt.NDArray]:
"""Compute consistency between embeddings coming from the same dataset.

If no `dataset_ids` is provided, then the embeddings are considered to be coming from the
same dataset and consequently not realigned.

For both modes (``between=runs`` or ``between=datasets``), if no `dataset_ids` is provided
(default value is ``None``), then the embeddings are considered individually and the consistency
is computed for possible pairs.

Args:
embeddings: List of embedding matrices.
dataset_ids: List of dataset ID associated to each embedding. Multiple embeddings can be
associated to the same dataset.

Returns:
A list of lists of scores obtained between embeddings of the same dataset (first element),
a list of lists of pairs of ids of the embeddings of the same datasets that were compared
(second element), they are identified with :py:class:`numpy.int` from 0 to the number of
embeddings for the dataset, and a list of the datasets (third element).
embeddings for the dataset, and a list of the unique IDs (third element).
"""
# we consider all embeddings as the same dataset
if dataset_ids is None:
datasets = np.array(["unique"])
dataset_ids = ["unique" for i in range(len(embeddings))]
else:
datasets = np.array(sorted(set(dataset_ids)))

within_dataset_scores = []
within_dataset_pairs = []
for dataset in datasets:
# get all embeddings for `dataset`
dataset_embeddings = [
embeddings[i]
for i, dataset_id in enumerate(dataset_ids)
if dataset_id == dataset
]
if len(dataset_embeddings) <= 1:
raise ValueError(
f"Invalid number of embeddings for dataset {dataset}, expect at least 2 embeddings "
f"to be able to compare them, got {len(dataset_embeddings)}")
score, pairs = _consistency_scores(embeddings=dataset_embeddings,
datasets=np.arange(
len(dataset_embeddings)))
within_dataset_scores.append(score)
within_dataset_pairs.append(pairs)
# NOTE(celia): The number of samples of the embeddings should be the same for all as there is
# no realignment, the number of output dimensions can vary between the embeddings we compare.
if not all(embeddings[0].shape[0] == embeddings[i].shape[0]
for i in range(1, len(embeddings))):
raise ValueError(
f"Invalid embeddings, all embeddings should be the same shape to be compared in a between-runs way."
f"If your embeddings are coming from different models, you can use between-datasets"
)

scores = np.array(within_dataset_scores)
pairs = np.array(within_dataset_pairs)
run_ids = np.arange(len(embeddings))
scores, pairs = _consistency_scores(embeddings=embeddings, datasets=run_ids)

return (
_average_scores(scores, pairs),
pairs,
datasets,
np.array(pairs),
np.array(run_ids),
)


Expand Down Expand Up @@ -328,15 +301,17 @@ def consistency_score(
trained on the **same dataset**. *Consistency between datasets* means the consistency between embeddings
obtained from models trained on **different datasets**, such as different animals, sessions, etc.
num_discretization_bins: Number of values for the digitalized common labels. The discretized labels are used
for embedding alignment. Also see the ``n_bins`` argument in
for embedding alignment. Also see the ``n_bins`` argument in
:py:mod:`cebra.integrations.sklearn.helpers.align_embeddings` for more information on how this
parameter is used internally. This argument is only used if ``labels``
is not ``None``, alignment between datasets is used (``between = "datasets"``), and the given labels
are continuous and not already discrete.

Returns:
The list of scores computed between the embeddings (first returns), the list of pairs corresponding
to each computed score (second returns) and the list of datasets present in the comparison (third returns).
to each computed score (second returns) and the list of id of the entities present in the comparison,
either different datasets in the between-datasets comparison or runs in the between-runs comparison
(third returns).

Example:

Expand All @@ -346,13 +321,13 @@ def consistency_score(
>>> embedding2 = np.random.uniform(0, 1, (1000, 8))
>>> labels1 = np.random.uniform(0, 1, (1000, ))
>>> labels2 = np.random.uniform(0, 1, (1000, ))
>>> # Between-runs, with dataset IDs (optional)
>>> scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
... dataset_ids=["achilles", "achilles"],
>>> # Between-runs consistency
>>> scores, pairs, ids_runs = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
... between="runs")
>>> # Between-datasets consistency, by aligning on the labels
>>> scores, pairs, datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
>>> scores, pairs, ids_datasets = cebra.sklearn.metrics.consistency_score(embeddings=[embedding1, embedding2],
... labels=[labels1, labels2],
... dataset_ids=["achilles", "buddy"],
... between="datasets")

"""
Expand All @@ -369,12 +344,13 @@ def consistency_score(
if labels is not None:
raise ValueError(
f"No labels should be provided for between-runs consistency.")
scores, pairs, datasets = _consistency_runs(
embeddings=embeddings,
dataset_ids=dataset_ids,
)
if dataset_ids is not None:
raise ValueError(
f"No dataset ID should be provided for between-runs consistency."
f"All embeddings should be computed on the same dataset.")
scores, pairs, ids = _consistency_runs(embeddings=embeddings,)
elif between == "datasets":
scores, pairs, datasets = _consistency_datasets(
scores, pairs, ids = _consistency_datasets(
embeddings=embeddings,
dataset_ids=dataset_ids,
labels=labels,
Expand All @@ -383,4 +359,4 @@ def consistency_score(
raise NotImplementedError(
f"Invalid comparison, got between={between}, expects either datasets or runs."
)
return scores.squeeze(), pairs.squeeze(), datasets
return scores.squeeze(), pairs.squeeze(), ids
39 changes: 20 additions & 19 deletions docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -904,47 +904,48 @@ We first create the embeddings to compare: we use two different datasets of data
.. testcode::

n_runs = 3
dataset_ids = ["session1", "session2"]

cebra_model = CEBRA(model_architecture='offset10-model',
batch_size=512,
output_dimension=32,
max_iterations=5,
time_offsets=10)

embeddings, dataset_ids, labels = [], [], []
embeddings_runs = []
embeddings_datasets, ids, labels = [], [], []
for i in range(n_runs):
embeddings.append(cebra_model.fit_transform(neural_session1, continuous_label1))
dataset_ids.append("session1")
labels.append(continuous_label1[:, 0])
embeddings_runs.append(cebra_model.fit_transform(neural_session1, continuous_label1))

embeddings.append(cebra_model.fit_transform(neural_session2, continuous_label2))
dataset_ids.append("session2")
labels.append(continuous_label2[:, 0])
labels.append(continuous_label1[:, 0])
embeddings_datasets.append(embeddings_runs[-1])

n_datasets = len(set(dataset_ids))
embeddings_datasets.append(cebra_model.fit_transform(neural_session2, continuous_label2))
labels.append(continuous_label2[:, 0])

n_datasets = len(dataset_ids)

To get the :py:func:`~.consistency_score` on the set of embeddings that we just generated:

.. testcode::

# Between-runs, with dataset IDs (optional)
scores_runs, pairs_runs, datasets_runs = cebra.sklearn.metrics.consistency_score(embeddings=embeddings,
dataset_ids=dataset_ids,
between="runs")
# Between-runs
scores_runs, pairs_runs, ids_runs = cebra.sklearn.metrics.consistency_score(embeddings=embeddings_runs,
between="runs")
assert scores_runs.shape == (n_runs**2 - n_runs, )
assert pairs_runs.shape == (n_datasets, n_runs*n_datasets, 2)
assert datasets_runs.shape == (n_datasets, )
assert pairs_runs.shape == (n_runs**2 - n_runs, 2)
assert ids_runs.shape == (n_runs, )

# Between-datasets, by aligning on the labels
(scores_datasets,
pairs_datasets,
datasets_datasets) = cebra.sklearn.metrics.consistency_score(embeddings=embeddings,
ids_datasets) = cebra.sklearn.metrics.consistency_score(embeddings=embeddings_datasets,
labels=labels,
dataset_ids=dataset_ids,
between="datasets")
assert scores_datasets.shape == (n_datasets**2 - n_datasets, )
assert pairs_datasets.shape == (n_runs*(n_runs*n_datasets), 2)
assert datasets_datasets.shape == (n_datasets, )
assert pairs_datasets.shape == (n_datasets**2 - n_datasets, 2)
assert ids_datasets.shape == (n_datasets, )

.. admonition:: See API docs
:class: dropdown
Expand All @@ -961,8 +962,8 @@ You can then display the resulting scores using :py:func:`~.plot_consistency`.
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)

ax1 = cebra.plot_consistency(scores_runs, pairs_runs, datasets_runs, vmin=0, vmax=100, ax=ax1, title="Between-runs consistencies")
ax2 = cebra.plot_consistency(scores_datasets, pairs_datasets, datasets_datasets, vmin=0, vmax=100, ax=ax2, title="Between-subjects consistencies")
ax1 = cebra.plot_consistency(scores_runs, pairs_runs, ids_runs, vmin=0, vmax=100, ax=ax1, title="Between-runs consistencies")
ax2 = cebra.plot_consistency(scores_datasets, pairs_datasets, ids_runs, vmin=0, vmax=100, ax=ax2, title="Between-subjects consistencies")


.. figure:: docs-imgs/consistency-score.png
Expand Down
Loading
Loading