diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 0a75a686..40627bbf 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -29,17 +29,17 @@ Enhancements Bug ~~~ -- +- Fix shape of ``'y_pred_proba'`` output from `mne_icalabel.label_components` API ~~~ -- +- Authors ~~~~~~~ -* +* `Mathieu Scheltienne`_ :doc:`Find out what was new in previous releases ` diff --git a/mne_icalabel/label_components.py b/mne_icalabel/label_components.py index 3f82b304..93c75bd0 100644 --- a/mne_icalabel/label_components.py +++ b/mne_icalabel/label_components.py @@ -56,7 +56,7 @@ def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA, method: str): labels_pred_proba = methods[method](inst, ica) labels_pred = np.argmax(labels_pred_proba, axis=1) labels = [ICLABEL_NUMERICAL_TO_STRING[label] for label in labels_pred] - y_pred_proba = labels_pred_proba[:, labels_pred] + y_pred_proba = labels_pred_proba[np.arange(15), labels_pred] component_dict = { "y_pred_proba": y_pred_proba, diff --git a/mne_icalabel/tests/test_label_components.py b/mne_icalabel/tests/test_label_components.py index a78d2cd8..79a44c15 100644 --- a/mne_icalabel/tests/test_label_components.py +++ b/mne_icalabel/tests/test_label_components.py @@ -21,7 +21,9 @@ def test_label_components(): """Simple test to check that label_components runs without raising.""" labels = label_components(raw, ica, method="iclabel") - assert labels is not None + assert isinstance(labels, dict) + assert labels["y_pred_proba"].ndim == 1 + assert labels["y_pred_proba"].shape[0] == ica.n_components_ def test_label_components_with_wrong_arguments():