forked from nilearn/nilearn
-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot_haxby_multiclass.py
97 lines (75 loc) · 3.6 KB
/
plot_haxby_multiclass.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
The haxby dataset: different multi-class strategies
=======================================================
We compare one vs all and one vs one multi-class strategies: the overall
cross-validated accuracy and the confusion matrix.
"""
# Import matplotlib for plotting
from matplotlib import pyplot as plt
### Load Haxby dataset ########################################################
from nilearn import datasets
import numpy as np
dataset_files = datasets.fetch_haxby_simple()
# fmri_data and mask are copied to break any reference to the original object
y, session = np.loadtxt(dataset_files.session_target).astype("int").T
conditions = np.recfromtxt(dataset_files.conditions_target)['f0']
# Remove the rest condition, it is not very interesting
non_rest = conditions != 'rest'
conditions = conditions[non_rest]
y = y[non_rest]
session = session[non_rest]
# Get the labels of the numerical conditions represented by the vector y
unique_conditions, order = np.unique(conditions, return_index=True)
# Sort the conditions by the order of appearance
unique_conditions = unique_conditions[np.argsort(order)]
### Loading step ##############################################################
from nilearn.input_data import NiftiMasker
# For decoding, standardizing is often very important
nifti_masker = NiftiMasker(mask=dataset_files.mask, standardize=True,
sessions=session, smoothing_fwhm=4,
memory="nilearn_cache", memory_level=1)
X = nifti_masker.fit_transform(dataset_files.func)
X = X[non_rest]
### Predictor #################################################################
### Define the prediction function to be used.
# Here we use a Support Vector Classification, with a linear kernel
from sklearn.svm import SVC
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.multiclass import OneVsOneClassifier, OneVsRestClassifier
from sklearn.pipeline import Pipeline
svc_ovo = OneVsOneClassifier(Pipeline([
('anova', SelectKBest(f_classif, k=500)),
('svc', SVC(kernel='linear'))
]))
svc_ova = OneVsRestClassifier(Pipeline([
('anova', SelectKBest(f_classif, k=500)),
('svc', SVC(kernel='linear'))
]))
### Cross-validation scores ###################################################
from sklearn.cross_validation import cross_val_score
cv_scores_ovo = cross_val_score(svc_ovo, X, y, cv=5, verbose=True)
cv_scores_ova = cross_val_score(svc_ova, X, y, cv=5, verbose=True)
print 79 * "_"
print 'OvO', cv_scores_ovo.mean()
print 'OvA', cv_scores_ova.mean()
plt.figure(figsize=(4, 3))
plt.boxplot([cv_scores_ova, cv_scores_ovo])
plt.xticks([1, 2], ['One vs All', 'One vs One'])
plt.title('Prediction: accuracy score')
### Plot a confusion matrix ###################################################
# Fit on the the first 10 sessions and plot a confusion matrix on the
# last 2 sessions
from sklearn.metrics import confusion_matrix
svc_ovo.fit(X[session < 10], y[session < 10])
y_pred_ovo = svc_ovo.predict(X[session >= 10])
plt.matshow(confusion_matrix(y_pred_ovo, y[session >= 10]))
plt.title('Confusion matrix: One vs One')
plt.xticks(np.arange(len(unique_conditions)), unique_conditions)
plt.yticks(np.arange(len(unique_conditions)), unique_conditions)
svc_ova.fit(X[session < 10], y[session < 10])
y_pred_ova = svc_ova.predict(X[session >= 10])
plt.matshow(confusion_matrix(y_pred_ova, y[session >= 10]))
plt.title('Confusion matrix: One vs All')
plt.xticks(np.arange(len(unique_conditions)), unique_conditions)
plt.yticks(np.arange(len(unique_conditions)), unique_conditions)
plt.show()