-
Notifications
You must be signed in to change notification settings - Fork 0
/
cosmo_nfold_partitioner.m
53 lines (46 loc) · 1.29 KB
/
cosmo_nfold_partitioner.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
function partitions = cosmo_nfold_partitioner(chunks)
% generates an n-fold partition scheme
%
% partitions=cosmo_nfold_partitioner(chunks)
%
% Input
% - chunks Px1 chunk indices for P samples. It can also be a
% dataset with field .sa.chunks
%
% Output:
% - partitions A struct with fields .train_indices and .test_indices.
% Each of these is an 1xQ cell for Q partitions, where
% .train_indices{k} and .test_indices{k} contain the
% sample indices for the k-th fold.
%
% Example:
% p=cosmo_nfold_partitioner([1 1 2 2 3 3 3])
% > p = train_indices: {1x3 cell}
% > test_indices: {1x3 cell}
% p.train_indices{1}'
% > [3 4 5 6 7]
% p.test_indices{1}
% > [1 2]
%
% NNO Aug 2013
if isstruct(chunks)
if isfield(chunks,'sa') && isfield(chunks.sa,'chunks')
chunks=chunks.sa.chunks;
else
error('illegal input')
end
end
unq=unique(chunks);
nchunks=numel(unq);
% allocate space for output
train_indices=cell(1,nchunks);
test_indices=cell(1,nchunks);
% >>
for k=1:nchunks
test_msk=unq(k)==chunks;
train_indices{k}=find(~test_msk)';
test_indices{k}=find(test_msk)';
end
% <<
partitions.train_indices=train_indices;
partitions.test_indices=test_indices;