From 9941c2bd87da244c41b7203359a79e112fe4fef3 Mon Sep 17 00:00:00 2001 From: Toby Dylan Hocking Date: Tue, 23 Jan 2024 19:08:52 +0000 Subject: [PATCH] VariableTrainSize respects strata for all train sets --- DESCRIPTION | 2 +- NEWS | 4 ++ R/ResamplingVariableSizeTrainCV.R | 75 +++++++++++++++++----------- man/ResamplingVariableSizeTrainCV.Rd | 3 +- tests/testthat/test-CRAN.R | 23 +++++++-- 5 files changed, 70 insertions(+), 37 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 5e60caf..b07bf74 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,7 +1,7 @@ Package: mlr3resampling Type: Package Title: Resampling Algorithms for 'mlr3' Framework -Version: 2024.1.8 +Version: 2024.1.23 Authors@R: c( person("Toby", "Hocking", email="toby.hocking@r-project.org", diff --git a/NEWS b/NEWS index 85fe607..f63534f 100644 --- a/NEWS +++ b/NEWS @@ -1,3 +1,7 @@ +Changes in version 2024.1.23 + +- ResamplingVariableSizeTrainCV outputs train sets which respect strata. + Changes in version 2024.1.8 - Rename Simulations vignette to ResamplingSameOtherCV. diff --git a/R/ResamplingVariableSizeTrainCV.R b/R/ResamplingVariableSizeTrainCV.R index f781f69..1e5fc72 100644 --- a/R/ResamplingVariableSizeTrainCV.R +++ b/R/ResamplingVariableSizeTrainCV.R @@ -21,55 +21,70 @@ ResamplingVariableSizeTrainCV = R6::R6Class( }, instantiate = function(task) { task = mlr3::assert_task(mlr3::as_task(task)) - reserved.names <- c( - "row_id", "fold", "group", "display_row", - "train.groups", "test.fold", "test.group", "iteration", - "test", "train", "algorithm", "uhash", "nr", "task", "task_id", - "learner", "learner_id", "resampling", "resampling_id", - "prediction") - ## bad.names <- group.name.vec[group.name.vec %in% reserved.names] - ## if(length(bad.names)){ - ## first.bad <- bad.names[1] - ## stop(sprintf("col with role group must not be named %s; please fix by renaming %s col", first.bad, first.bad)) - ## } - ## orig.group.dt <- task$data(cols=group.name.vec) strata <- if(is.null(task$strata)){ data.dt <- task$data() data.table(N=nrow(data.dt), row_id=list(1:nrow(data.dt))) }else task$strata - sample.list <- lapply(strata$row_id, private$.sample, task = task) - folds = private$.combine(sample.list)[order(row_id)] + strata.list <- lapply(strata$row_id, private$.sample, task = task) + folds = private$.combine(strata.list)[order(row_id)] + max.train.vec <- sapply(strata.list, nrow) + small.strat.i <- which.min(max.train.vec) min_train_data <- self$param_set$values[["min_train_data"]] - if(task$nrow <= min_train_data){ - stop(sprintf( - "task$nrow=%d but should be larger than min_train_data=%d", - task$nrow, min_train_data)) - } uniq.folds <- sort(unique(folds$fold)) iteration.dt.list <- list() for(test.fold in uniq.folds){ - is.set.fold <- list( - test=folds[["fold"]] == test.fold) - is.set.fold[["train"]] <- !is.set.fold[["test"]] - i.set.list <- lapply(is.set.fold, which) - max_train_data <- length(i.set.list$train) + train.strata.list <- lapply(strata.list, function(DT)DT[fold != test.fold]) + max_train_data <- nrow(train.strata.list[[small.strat.i]]) + if(max_train_data <= min_train_data){ + stop(sprintf( + "max_train_data=%d (in smallest stratum) but should be larger than min_train_data=%d, please fix by decreasing min_train_data", + max_train_data, min_train_data)) + } log.range.data <- log(c(min_train_data, max_train_data)) seq.args <- c(as.list(log.range.data), list(l=self$param_set$values[["train_sizes"]])) log.train.sizes <- do.call(seq, seq.args) - train.size.vec <- unique(as.integer(round(exp(log.train.sizes)))) + train.size.vec <- as.integer(round(exp(log.train.sizes))) + size.tab <- table(train.size.vec) + if(any(size.tab>1)){ + stop("train sizes not unique, please decrease train_sizes") + } for(seed in 1:self$param_set$values[["random_seeds"]]){ set.seed(seed) - ord.i.vec <- sample(i.set.list$train) + train.seed.list <- lapply(train.strata.list, function(DT)DT[sample(.N)][, `:=`( + row_seed = .I, + prop = .I/.N + )][]) + test.index.vec <- do.call(c, lapply( + strata.list, function(DT)DT[fold == test.fold, row_id])) + train.prop.dt <- train.seed.list[[small.strat.i]][train.size.vec, data.table(prop)] + train.i.list <- lapply(train.seed.list, function(DT)DT[ + train.prop.dt, + .(train.i=lapply(row_seed, function(last)DT$row_id[1:last])), + on="prop", + roll="nearest"]) + train.index.list <- list() + for(train.size.i in seq_along(train.size.vec)){ + strata.index.list <- lapply(train.i.list, function(DT)DT[["train.i"]][[train.size.i]]) + train.index.list[[train.size.i]] <- do.call(c, strata.index.list) + } iteration.dt.list[[paste(test.fold, seed)]] <- data.table( test.fold, seed, - train_size=train.size.vec, - train=lapply(train.size.vec, function(last)ord.i.vec[1:last]), - test=list(i.set.list$test)) + small_stratum_size=train.size.vec, + train_size_i=seq_along(train.size.vec), + train_size=sapply(train.index.list, length), + train=train.index.list, + test=list(test.index.vec)) } } self$instance <- list( - iteration.dt=rbindlist(iteration.dt.list)[, iteration := .I][], + iteration.dt=rbindlist( + iteration.dt.list + )[ + , iteration := .I + ][ + , train_min_size := min(train_size), by=train_size_i + ][], id.dt=folds) self$task_hash = task$hash self$task_nrow = task$nrow diff --git a/man/ResamplingVariableSizeTrainCV.Rd b/man/ResamplingVariableSizeTrainCV.Rd index 4f56325..997a992 100644 --- a/man/ResamplingVariableSizeTrainCV.Rd +++ b/man/ResamplingVariableSizeTrainCV.Rd @@ -48,7 +48,8 @@ orderings of the train set that are considered. For each random order of the train set, the \code{min_train_data} - parameter controls the smallest train set size considered. + parameter controls the size of the smallest stratum in the smallest + train set considered. To determine the other train set sizes, we use an equally spaced grid on the log scale, from \code{min_train_data} to the largest train set diff --git a/tests/testthat/test-CRAN.R b/tests/testthat/test-CRAN.R index dc3c006..dd3790e 100644 --- a/tests/testthat/test-CRAN.R +++ b/tests/testthat/test-CRAN.R @@ -83,15 +83,28 @@ test_that("error for group named test", { }, "col with role group must not be named test; please fix by renaming test col") }) -test_that("error for 10 data", { +test_that("errors and result for 10 train data in small stratum", { size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new() - i10.dt <- data.table(iris)[1:10] - i10.task <- mlr3::TaskClassif$new("i10", i10.dt, target="Species") + size_cv$param_set$values$folds <- 2 + i10.dt <- data.table(iris)[1:70] + i10.task <- mlr3::TaskClassif$new( + "i10", i10.dt, target="Species" + )$set_col_roles("Species",c("target","stratum")) expect_error({ size_cv$instantiate(i10.task) }, - "task$nrow=10 but should be larger than min_train_data=10", + "max_train_data=10 (in smallest stratum) but should be larger than min_train_data=10, please fix by decreasing min_train_data", fixed=TRUE) + size_cv$param_set$values$min_train_data <- 9 + expect_error({ + size_cv$instantiate(i10.task) + }, + "train sizes not unique, please decrease train_sizes", + fixed=TRUE) + size_cv$param_set$values$train_sizes <- 2 + size_cv$instantiate(i10.task) + size.tab <- table(size_cv$instance$iteration.dt[["small_stratum_size"]]) + expect_identical(names(size.tab), c("9","10")) }) test_that("strata respected in all sizes", { @@ -114,7 +127,7 @@ test_that("strata respected in all sizes", { min.row <- min.dt[min.i] train.i <- min.row$train[[1]] strat.tab <- table(istrat.dt[train.i, strat]) - expect_identical(strat.tab, smallest.size.tab) + expect_equal(strat.tab, smallest.size.tab) } })