Skip to content

Commit

Permalink
VariableTrainSize respects strata for all train sets
Browse files Browse the repository at this point in the history
  • Loading branch information
tdhock committed Jan 23, 2024
1 parent 3cd9119 commit 9941c2b
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 37 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: mlr3resampling
Type: Package
Title: Resampling Algorithms for 'mlr3' Framework
Version: 2024.1.8
Version: 2024.1.23
Authors@R: c(
person("Toby", "Hocking",
email="toby.hocking@r-project.org",
Expand Down
4 changes: 4 additions & 0 deletions NEWS
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
Changes in version 2024.1.23

- ResamplingVariableSizeTrainCV outputs train sets which respect strata.

Changes in version 2024.1.8

- Rename Simulations vignette to ResamplingSameOtherCV.
Expand Down
75 changes: 45 additions & 30 deletions R/ResamplingVariableSizeTrainCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,70 @@ ResamplingVariableSizeTrainCV = R6::R6Class(
},
instantiate = function(task) {
task = mlr3::assert_task(mlr3::as_task(task))
reserved.names <- c(
"row_id", "fold", "group", "display_row",
"train.groups", "test.fold", "test.group", "iteration",
"test", "train", "algorithm", "uhash", "nr", "task", "task_id",
"learner", "learner_id", "resampling", "resampling_id",
"prediction")
## bad.names <- group.name.vec[group.name.vec %in% reserved.names]
## if(length(bad.names)){
## first.bad <- bad.names[1]
## stop(sprintf("col with role group must not be named %s; please fix by renaming %s col", first.bad, first.bad))
## }
## orig.group.dt <- task$data(cols=group.name.vec)
strata <- if(is.null(task$strata)){
data.dt <- task$data()
data.table(N=nrow(data.dt), row_id=list(1:nrow(data.dt)))
}else task$strata
sample.list <- lapply(strata$row_id, private$.sample, task = task)
folds = private$.combine(sample.list)[order(row_id)]
strata.list <- lapply(strata$row_id, private$.sample, task = task)
folds = private$.combine(strata.list)[order(row_id)]
max.train.vec <- sapply(strata.list, nrow)
small.strat.i <- which.min(max.train.vec)
min_train_data <- self$param_set$values[["min_train_data"]]
if(task$nrow <= min_train_data){
stop(sprintf(
"task$nrow=%d but should be larger than min_train_data=%d",
task$nrow, min_train_data))
}
uniq.folds <- sort(unique(folds$fold))
iteration.dt.list <- list()
for(test.fold in uniq.folds){
is.set.fold <- list(
test=folds[["fold"]] == test.fold)
is.set.fold[["train"]] <- !is.set.fold[["test"]]
i.set.list <- lapply(is.set.fold, which)
max_train_data <- length(i.set.list$train)
train.strata.list <- lapply(strata.list, function(DT)DT[fold != test.fold])
max_train_data <- nrow(train.strata.list[[small.strat.i]])
if(max_train_data <= min_train_data){
stop(sprintf(
"max_train_data=%d (in smallest stratum) but should be larger than min_train_data=%d, please fix by decreasing min_train_data",
max_train_data, min_train_data))
}
log.range.data <- log(c(min_train_data, max_train_data))
seq.args <- c(as.list(log.range.data), list(l=self$param_set$values[["train_sizes"]]))
log.train.sizes <- do.call(seq, seq.args)
train.size.vec <- unique(as.integer(round(exp(log.train.sizes))))
train.size.vec <- as.integer(round(exp(log.train.sizes)))
size.tab <- table(train.size.vec)
if(any(size.tab>1)){
stop("train sizes not unique, please decrease train_sizes")
}
for(seed in 1:self$param_set$values[["random_seeds"]]){
set.seed(seed)
ord.i.vec <- sample(i.set.list$train)
train.seed.list <- lapply(train.strata.list, function(DT)DT[sample(.N)][, `:=`(
row_seed = .I,
prop = .I/.N
)][])
test.index.vec <- do.call(c, lapply(
strata.list, function(DT)DT[fold == test.fold, row_id]))
train.prop.dt <- train.seed.list[[small.strat.i]][train.size.vec, data.table(prop)]
train.i.list <- lapply(train.seed.list, function(DT)DT[
train.prop.dt,
.(train.i=lapply(row_seed, function(last)DT$row_id[1:last])),
on="prop",
roll="nearest"])
train.index.list <- list()
for(train.size.i in seq_along(train.size.vec)){
strata.index.list <- lapply(train.i.list, function(DT)DT[["train.i"]][[train.size.i]])
train.index.list[[train.size.i]] <- do.call(c, strata.index.list)
}
iteration.dt.list[[paste(test.fold, seed)]] <- data.table(
test.fold,
seed,
train_size=train.size.vec,
train=lapply(train.size.vec, function(last)ord.i.vec[1:last]),
test=list(i.set.list$test))
small_stratum_size=train.size.vec,
train_size_i=seq_along(train.size.vec),
train_size=sapply(train.index.list, length),
train=train.index.list,
test=list(test.index.vec))
}
}
self$instance <- list(
iteration.dt=rbindlist(iteration.dt.list)[, iteration := .I][],
iteration.dt=rbindlist(
iteration.dt.list
)[
, iteration := .I
][
, train_min_size := min(train_size), by=train_size_i
][],
id.dt=folds)
self$task_hash = task$hash
self$task_nrow = task$nrow
Expand Down
3 changes: 2 additions & 1 deletion man/ResamplingVariableSizeTrainCV.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
orderings of the train set that are considered.

For each random order of the train set, the \code{min_train_data}
parameter controls the smallest train set size considered.
parameter controls the size of the smallest stratum in the smallest
train set considered.

To determine the other train set sizes, we use an equally spaced grid
on the log scale, from \code{min_train_data} to the largest train set
Expand Down
23 changes: 18 additions & 5 deletions tests/testthat/test-CRAN.R
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,28 @@ test_that("error for group named test", {
}, "col with role group must not be named test; please fix by renaming test col")
})

test_that("error for 10 data", {
test_that("errors and result for 10 train data in small stratum", {
size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
i10.dt <- data.table(iris)[1:10]
i10.task <- mlr3::TaskClassif$new("i10", i10.dt, target="Species")
size_cv$param_set$values$folds <- 2
i10.dt <- data.table(iris)[1:70]
i10.task <- mlr3::TaskClassif$new(
"i10", i10.dt, target="Species"
)$set_col_roles("Species",c("target","stratum"))
expect_error({
size_cv$instantiate(i10.task)
},
"task$nrow=10 but should be larger than min_train_data=10",
"max_train_data=10 (in smallest stratum) but should be larger than min_train_data=10, please fix by decreasing min_train_data",
fixed=TRUE)
size_cv$param_set$values$min_train_data <- 9
expect_error({
size_cv$instantiate(i10.task)
},
"train sizes not unique, please decrease train_sizes",
fixed=TRUE)
size_cv$param_set$values$train_sizes <- 2
size_cv$instantiate(i10.task)
size.tab <- table(size_cv$instance$iteration.dt[["small_stratum_size"]])
expect_identical(names(size.tab), c("9","10"))
})

test_that("strata respected in all sizes", {
Expand All @@ -114,7 +127,7 @@ test_that("strata respected in all sizes", {
min.row <- min.dt[min.i]
train.i <- min.row$train[[1]]
strat.tab <- table(istrat.dt[train.i, strat])
expect_identical(strat.tab, smallest.size.tab)
expect_equal(strat.tab, smallest.size.tab)
}
})

Expand Down

0 comments on commit 9941c2b

Please sign in to comment.