-
Notifications
You must be signed in to change notification settings - Fork 29
/
mTRFcrossval.m
399 lines (339 loc) · 14.3 KB
/
mTRFcrossval.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
function [stats,t] = mTRFcrossval(stim,resp,fs,Dir,tmin,tmax,lambda,varargin)
%MTRFCROSSVAL Leave-one-out cross-validation.
% STATS = MTRFCROSSVAL(STIM,RESP,FS,DIR,TMIN,TMAX,LAMBDA) cross validates
% a forward encoding model (stimulus to neural response) or a backward
% decoding model (neural response to stimulus) over multiple trials of
% data as per Crosse et al. (2016). Pass in 1 for DIR to validate a
% forward model, or -1 to validate a backward model. STIM and RESP are
% cell arrays containing corresponding trials of continuous data. FS is a
% scalar specifying the sample rate in Hertz, and TMIN and TMAX are
% scalars specifying the minimum and maximum time lags in milliseconds.
% For backward models, the time lags are automatically reversed. LAMBDA
% is a vector of regularization values to be validated and controls
% overfitting.
%
% MTRFCROSSVAL returns the cross-validation statistics in a structure
% with the following fields:
% 'r' -- correlation coefficient based on Pearson's linear
% correlation coefficient (nfold-by-nlambda-by-yvar)
% 'err' -- prediction error based on the mean squared error
% (nfold-by-nlambda-by-yvar)
%
% MTRFCROSSVAL performs a leave-one-out cross-validation over all trials.
% To achieve a k-fold cross-validation, arrange STIM and RESP in k-by-1
% cell arrays. The number of folds can also be increased by an integer
% factor using the 'split' parameter (see below).
%
% If STIM or RESP contain matrices, it is assumed that the rows
% correspond to observations and the columns to variables, unless
% otherwise stated via the 'dim' parameter (see below). If they contain
% vectors, it is assumed that the first non-singleton dimension
% corresponds to observations. Each trial of STIM and RESP must have the
% same number of observations.
%
% [STATS,T] = MTRFCROSSVAL(...) returns a vector containing the time lags
% used in milliseconds. These data are useful for interpreting the
% results of single-lag models.
%
% [...] = MTRFCROSSVAL(...,'PARAM1',VAL1,'PARAM2',VAL2,...) specifies
% additional parameters and their values. Valid parameters are the
% following:
%
% Parameter Value
% 'dim' A scalar specifying the dimension to work along: pass
% in 1 to work along the columns (default), or 2 to work
% along the rows. Applies to both STIM and RESP.
% 'method' A string specifying the regularization method to use:
% 'ridge' ridge regression (default): suitable
% for multivariate input features
% 'Tikhonov' Tikhonov regularization: dampens fast
% oscillatory components of the solution
% but may cause cross-channel leakage for
% multivariate input features
% 'ols' ordinary least squares: equivalent to
% setting LAMBDA=0 (no regularization)
% 'type' A string specifying the type of model to fit:
% 'multi' use all lags simultaneously to fit a
% multi-lag model (default)
% 'single' use each lag individually to fit
% separate single-lag models
% 'corr' A string specifying the correlation metric to use:
% 'Pearson' Pearson's linear correlation
% coefficient (default): suitable for
% data with a linear relationship
% 'Spearman' Spearman's rank correlation
% coefficient: suitable for data with a
% non-linear relationship
% 'error' A string specifying the error metric to use:
% 'mse' mean square error (default): take the
% square root to convert it to the
% original units of the data (i.e., RMSE)
% 'mae' mean absolute error: more robust to
% outliers than MSE
% 'split' A scalar specifying the number of segments in which to
% split each trial of data when computing the covariance
% matrices. This is useful for reducing memory usage on
% large datasets. By default, the entire trial is used.
% 'window' A scalar specifying the window size over which to
% compute performance in seconds. By default, the entire
% trial or segment is used.
% 'zeropad' A numeric or logical specifying whether to zero-pad the
% outer rows of the design matrix or delete them: pass in
% 1 to zero-pad them (default), or 0 to delete them.
% 'fast' A numeric or logical specifying whether to use the fast
% cross-validation method (requires more memory) or the
% slower method (requires less memory): pass in 1 to use
% the fast method (default), or 0 to use the slower
% method. Note, both methods are numerically equivalent.
% 'verbose' A numeric or logical specifying whether to execute in
% verbose mode: pass in 1 for verbose mode (default), or
% 0 for non-verbose mode.
%
% Notes:
% Each iteration of MTRFCROSSVAL partitions the N trials or segments of
% data into two subsets, fitting a model to N-1 trials (training set) and
% validating it on the left-out trial (validation set). Performance on
% the validation set can be used to optimize hyperparameters (e.g.,
% LAMBDA, TMAX). Once completed, it is recommended to evaluate model
% performance on separate held-out data using the mTRFpredict function.
%
% Discontinuous trials of data should not be concatenated prior to cross-
% validation, as this will introduce artifacts in places where the
% temporal integration window crosses over trial boundaries. Each trial
% of continuous data should be input as a separate cell.
%
% See also CROSSVAL, MTRFPARTITION, MTRFTRAIN, MTRFPREDICT.
%
% mTRF-Toolbox https://github.com/mickcrosse/mTRF-Toolbox
% References:
% [1] Crosse MC, Di Liberto GM, Bednar A, Lalor EC (2016) The
% multivariate temporal response function (mTRF) toolbox: a MATLAB
% toolbox for relating neural signals to continuous stimuli. Front
% Hum Neurosci 10:604.
% [2] Alickovic E, Lunner T, Gustafsson F, Ljung L (2019) A Tutorial
% on Auditory Attention Identification Methods. Front Neurosci
% 13:153.
% Authors: Mick Crosse <crossemj@tcd.ie>
% Giovanni Di Liberto <diliberg@tcd.ie>
% Edmund Lalor <edlalor@tcd.ie>
% Nate Zuk <zukn@tcd.ie>
% Copyright 2014-2024 Lalor Lab, Trinity College Dublin.
% Parse input arguments
arg = parsevarargin(varargin);
% Validate input parameters
validateparamin(fs,Dir,tmin,tmax,lambda)
% Define X and Y variables
if Dir == 1
x = stim; y = resp;
elseif Dir == -1
x = resp; y = stim;
[tmin,tmax] = deal(tmax,tmin);
end
% Format data in cell arrays
[x,xobs,xvar] = formatcells(x,arg.dim,arg.split);
[y,yobs,yvar] = formatcells(y,arg.dim,arg.split);
% Check equal number of observations
if ~isequal(xobs,yobs)
error(['STIM and RESP arguments must have the same number of '...
'observations.'])
end
% Convert time lags to samples
tmin = floor(tmin/1e3*fs*Dir);
tmax = ceil(tmax/1e3*fs*Dir);
lags = tmin:tmax;
arg.window = round(arg.window*fs);
% Compute sampling interval
delta = 1/fs;
% Get dimensions
xvar = unique(xvar);
yvar = unique(yvar);
nreg = numel(lambda);
nfold = numel(x);
switch arg.type
case 'multi'
nvar = xvar*numel(lags)+1;
nlag = 1;
case 'single'
nvar = xvar+1;
nlag = numel(lags);
end
% Truncate output
if ~arg.zeropad
[y,yobs] = truncate(y,tmin,tmax,yobs);
end
if arg.window
nwin = sum(floor(yobs/arg.window));
else
nwin = nfold;
end
% Verbose mode
if arg.verbose
v = verbosemode([],[],nfold);
end
% Compute covariance matrices
if arg.fast
[Cxx,Cxy,folds] = olscovmat(x,y,lags,arg.type,arg.zeropad,arg.verbose);
else
[Cxx,Cxy] = olscovmat(x,y,lags,arg.type,arg.zeropad,arg.verbose);
end
% Verbose mode
if arg.verbose
v = verbosemode(v,0,nfold);
end
% Set up sparse regularization matrix
M = regmat(nvar,arg.method)/delta;
% Initialize variables
r = zeros(nwin,nreg,yvar,nlag);
err = zeros(nwin,nreg,yvar,nlag);
ii = 0;
% Leave-one-out cross-validation
for i = 1:nfold
if arg.window
ii = ii(end)+1:ii(end)+floor(yobs(i)/arg.window);
else
ii = i;
end
if arg.fast % fast method
% Validation set
xlag = folds.xlag{i};
% Training set
Cxxi = Cxx - folds.Cxx{i};
Cxyi = Cxy - folds.Cxy{i};
else % memory-efficient method
% Validation set
xlag = lagGen(x{i},lags,arg.zeropad,1);
% Training set
Cxxi = Cxx - xlag'*xlag;
Cxyi = Cxy - xlag'*y{i};
end
for j = 1:nreg
switch arg.type
case 'multi'
% Fit linear model
w = (Cxxi + lambda(j)*M)\Cxyi;
% Predict output
pred = xlag*w;
% Evaluate performance
[r(ii,j,:),err(ii,j,:)] = mTRFevaluate(y{i},pred,...
'corr',arg.corr,'error',arg.error,...
'window',arg.window);
case 'single'
for k = 1:nlag
% Index lag
idx = [1,xvar*(k-1)+2:xvar*k+1];
% Fit linear model
w = (Cxxi(:,:,k) + lambda(j)*M)\Cxyi(:,:,k);
% Predict output
pred = xlag(:,idx)*w;
% Evaluate performance
[r(ii,j,:,k),err(ii,j,:,k)] = mTRFevaluate(y{i},pred,...
'corr',arg.corr,'error',arg.error,...
'window',arg.window);
end
end
end
% Verbose mode
if arg.verbose
v = verbosemode(v,i,nfold);
end
end
% Format output arguments
stats = struct('r',r,'err',err);
if nargout > 1
t = lags/fs*1e3;
end
% Verbose mode
if arg.verbose
verbosemode(v,i+1,nfold,stats);
end
function v = verbosemode(v,fold,nfold,stats)
%VERBOSEMODE Execute verbose mode.
% V = VERBOSEMODE(V,FOLD,NFOLD,STATS) prints details about the progress
% of the main function to the screen.
if isempty(fold)
v = struct('msg',[],'h',[],'tocs',0);
fprintf('\nTrain on %d folds, validate on 1 fold\n',nfold-1)
elseif fold == 0
fprintf('Training/validating model\n')
v.msg = ['%d/%d [',repmat(' ',1,nfold),']\n'];
v.h = fprintf(v.msg,fold,nfold);
elseif fold <= nfold
if fold == 1 && toc < 0.1
pause(0.1)
end
v.tocs = v.tocs + toc;
fprintf(repmat('\b',1,v.h))
v.msg = ['%d/%d [',repmat('=',1,fold),repmat(' ',1,nfold-fold),'] - ',...
'%.3fs/fold\n'];
v.h = fprintf(v.msg,fold,nfold,v.tocs/fold);
end
if fold < nfold
tic
elseif fold > nfold
rmax = mean(stats.r,1); rmax = max(rmax(:));
emax = mean(stats.err,1); emax = max(emax(:));
fprintf('val_correlation: %.4f - val_error: %.4f\n',rmax,emax)
end
function validateparamin(fs,Dir,tmin,tmax,lambda)
%VALIDATEPARAMIN Validate input parameters.
% VALIDATEPARAMIN(FS,DIR,TMIN,TMAX,LAMBDA) validates the input parameters
% of the main function.
if ~isnumeric(fs) || ~isscalar(fs) || fs <= 0
error('FS argument must be a positive numeric scalar.')
elseif Dir ~= 1 && Dir ~= -1
error('DIR argument must have a value of 1 or -1.')
elseif ~isnumeric([tmin,tmax]) || ~isscalar(tmin) || ~isscalar(tmax)
error('TMIN and TMAX arguments must be numeric scalars.')
elseif tmin > tmax
error('The value of TMIN must be less than that of TMAX.')
elseif ~isnumeric(lambda) || any(lambda < 0)
error('LAMBDA argument must be positive numeric values.')
end
function arg = parsevarargin(varargin)
%PARSEVARARGIN Parse input arguments.
% [PARAM1,PARAM2,...] = PARSEVARARGIN('PARAM1',VAL1,'PARAM2',VAL2,...)
% parses the input arguments of the main function.
% Create parser object
p = inputParser;
% Dimension to work along
errorMsg = 'It must be a positive integer scalar within indexing range.';
validFcn = @(x) assert(x==1||x==2,errorMsg);
addParameter(p,'dim',1,validFcn);
% Regularization method
regOptions = {'ridge','Tikhonov','ols'};
validFcn = @(x) any(validatestring(x,regOptions));
addParameter(p,'method','ridge',validFcn);
% Model type
lagOptions = {'multi','single'};
validFcn = @(x) any(validatestring(x,lagOptions));
addParameter(p,'type','multi',validFcn);
% Correlation metric
corrOptions = {'Pearson','Spearman'};
validFcn = @(x) any(validatestring(x,corrOptions));
addParameter(p,'corr','Pearson',validFcn);
% Error metric
errOptions = {'mse','mae'};
validFcn = @(x) any(validatestring(x,errOptions));
addParameter(p,'error','mse',validFcn);
% Split data
errorMsg = 'It must be a positive integer scalar.';
validFcn = @(x) assert(isnumeric(x)&&isscalar(x),errorMsg);
addParameter(p,'split',1,validFcn);
% Window size
errorMsg = 'It must be a positive numeric scalar within indexing range.';
validFcn = @(x) assert(isnumeric(x)&&isscalar(x),errorMsg);
addParameter(p,'window',0,validFcn);
% Boolean arguments
errorMsg = 'It must be a numeric scalar (0,1) or logical.';
validFcn = @(x) assert(x==0||x==1||islogical(x),errorMsg);
addParameter(p,'zeropad',true,validFcn); % zero-pad design matrix
addParameter(p,'fast',true,validFcn); % fast CV method
addParameter(p,'verbose',true,validFcn); % verbose mode
% Parse input arguments
parse(p,varargin{1,1}{:});
arg = p.Results;
% Redefine partially matched strings
arg.method = validatestring(arg.method,regOptions);
arg.type = validatestring(arg.type,lagOptions);
arg.corr = validatestring(arg.corr,corrOptions);
arg.error = validatestring(arg.error,errOptions);