Skip to content

Commit

Permalink
Fix typo (#37)
Browse files Browse the repository at this point in the history
* fix typo

* better test
  • Loading branch information
mscheltienne authored May 6, 2022
1 parent 122ceef commit 8892f40
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
4 changes: 3 additions & 1 deletion mne_icalabel/label_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def label_components(inst: Union[BaseRaw, BaseEpochs], ica: ICA, method: str):
labels_pred_proba = methods[method](inst, ica)
labels_pred = np.argmax(labels_pred_proba, axis=1)
labels = [ICLABEL_NUMERICAL_TO_STRING[label] for label in labels_pred]
y_pred_proba = labels_pred_proba[np.arange(15), labels_pred]
assert ica.n_components_ == labels_pred.size # sanity-check
assert ica.n_components_ == labels_pred_proba.shape[0] # sanity-check
y_pred_proba = labels_pred_proba[np.arange(ica.n_components_), labels_pred]

component_dict = {
"y_pred_proba": y_pred_proba,
Expand Down
10 changes: 6 additions & 4 deletions mne_icalabel/tests/test_label_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# preprocess
raw.filter(l_freq=1.0, h_freq=100.0)
raw.set_eeg_reference("average")
# fit ICA
ica = ICA(n_components=15, method="picard")
ica.fit(raw)


@pytest.mark.filterwarnings("ignore::RuntimeWarning")
def test_label_components():
@pytest.mark.parametrize("n_components", (5, 15))
def test_label_components(n_components):
"""Simple test to check that label_components runs without raising."""
ica = ICA(n_components=n_components, method="picard")
ica.fit(raw)
labels = label_components(raw, ica, method="iclabel")
assert isinstance(labels, dict)
assert labels["y_pred_proba"].ndim == 1
Expand All @@ -28,5 +28,7 @@ def test_label_components():

def test_label_components_with_wrong_arguments():
"""Test that wrong arguments raise."""
ica = ICA(n_components=3, method="picard")
ica.fit(raw)
with pytest.raises(ValueError, match="Invalid value for the 'method' parameter"):
label_components(raw, ica, method="101")

0 comments on commit 8892f40

Please sign in to comment.