-
Notifications
You must be signed in to change notification settings - Fork 1
/
cvControl.R
70 lines (68 loc) · 2.25 KB
/
cvControl.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
CVFolds <- function(N, id, Y, cvControl){
# validRows would be a user specified list of row numbers for the validation sets
if(!is.null(cvControl$validRows)) {
return(cvControl$validRows)
}
stratifyCV <- cvControl$stratifyCV
shuffle <- cvControl$shuffle
V <- cvControl$V
if(!stratifyCV) {
if(shuffle) {
if(is.null(id)) {
validRows <- split(sample(1:N), rep(1:V, length=N))
} else {
n.id <- length(unique(id))
id.split <- split(sample(1:n.id), rep(1:V, length=n.id))
validRows <- vector("list", V)
for(v in seq(V)) {
validRows[[v]] <- which(id %in% unique(id)[id.split[[v]]])
}
}
} else {
if(is.null(id)) {
validRows <- split(1:N, rep(1:V, length=N))
} else {
n.id <- length(unique(id))
id.split <- split(1:n.id, rep(1:V, length=n.id))
validRows <- vector("list", V)
for(v in seq(V)) {
validRows[[v]] <- which(id %in% unique(id)[id.split[[v]]])
}
}
}
} else {
if(length(unique(Y)) != 2) {
stop("stratifyCV only implemented for binary Y")
}
if(sum(Y) < V | sum(!Y) < V) {
stop("number of (Y=1) or (Y=0) is less than the number of folds")
}
if(shuffle) {
if(is.null(id)) {
wiY0 <- which(Y == 0)
wiY1 <- which(Y == 1)
rowsY0 <- split(sample(wiY0), rep(1:V, length=length(wiY0)))
rowsY1 <- split(sample(wiY1), rep(1:V, length=length(wiY1)))
validRows <- vector("list", length = V)
names(validRows) <- paste(seq(V))
for(vv in seq(V)) {
validRows[[vv]] <- c(rowsY0[[vv]], rowsY1[[vv]])
}
} else {
stop("stratified sampling with id not currently implemented")
}
} else {
if(is.null(id)) {
within.split <- suppressWarnings(tapply(1:N, INDEX = Y, FUN = split, rep(1:V)))
validRows <- vector("list", length = V)
names(validRows) <- paste(seq(V))
for(vv in seq(V)) {
validRows[[vv]] <- c(within.split[[1]][[vv]], within.split[[2]][[vv]])
}
} else {
stop("stratified sampling with id not currently implemented")
}
}
}
return(validRows)
}