Skip to content

Commit

Permalink
Merge pull request #7 from tdhock/strata-small-train
Browse files Browse the repository at this point in the history
train strata counts
  • Loading branch information
tdhock authored Jan 23, 2024
2 parents 119a8f6 + d95e44a commit b5a0619
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 37 deletions.
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: mlr3resampling
Type: Package
Title: Resampling Algorithms for 'mlr3' Framework
Version: 2024.1.8
Version: 2024.1.23
Authors@R: c(
person("Toby", "Hocking",
email="toby.hocking@r-project.org",
Expand Down Expand Up @@ -67,6 +67,7 @@ Imports:
mlr3misc
Suggests:
animint2,
lgr,
future,
testthat,
knitr,
Expand Down
4 changes: 4 additions & 0 deletions NEWS
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
Changes in version 2024.1.23

- ResamplingVariableSizeTrainCV outputs train sets which respect strata.

Changes in version 2024.1.8

- Rename Simulations vignette to ResamplingSameOtherCV.
Expand Down
75 changes: 45 additions & 30 deletions R/ResamplingVariableSizeTrainCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,70 @@ ResamplingVariableSizeTrainCV = R6::R6Class(
},
instantiate = function(task) {
task = mlr3::assert_task(mlr3::as_task(task))
reserved.names <- c(
"row_id", "fold", "group", "display_row",
"train.groups", "test.fold", "test.group", "iteration",
"test", "train", "algorithm", "uhash", "nr", "task", "task_id",
"learner", "learner_id", "resampling", "resampling_id",
"prediction")
## bad.names <- group.name.vec[group.name.vec %in% reserved.names]
## if(length(bad.names)){
## first.bad <- bad.names[1]
## stop(sprintf("col with role group must not be named %s; please fix by renaming %s col", first.bad, first.bad))
## }
## orig.group.dt <- task$data(cols=group.name.vec)
strata <- if(is.null(task$strata)){
data.dt <- task$data()
data.table(N=nrow(data.dt), row_id=list(1:nrow(data.dt)))
}else task$strata
sample.list <- lapply(strata$row_id, private$.sample, task = task)
folds = private$.combine(sample.list)[order(row_id)]
strata.list <- lapply(strata$row_id, private$.sample, task = task)
folds = private$.combine(strata.list)[order(row_id)]
max.train.vec <- sapply(strata.list, nrow)
small.strat.i <- which.min(max.train.vec)
min_train_data <- self$param_set$values[["min_train_data"]]
if(task$nrow <= min_train_data){
stop(sprintf(
"task$nrow=%d but should be larger than min_train_data=%d",
task$nrow, min_train_data))
}
uniq.folds <- sort(unique(folds$fold))
iteration.dt.list <- list()
for(test.fold in uniq.folds){
is.set.fold <- list(
test=folds[["fold"]] == test.fold)
is.set.fold[["train"]] <- !is.set.fold[["test"]]
i.set.list <- lapply(is.set.fold, which)
max_train_data <- length(i.set.list$train)
train.strata.list <- lapply(strata.list, function(DT)DT[fold != test.fold])
max_train_data <- nrow(train.strata.list[[small.strat.i]])
if(max_train_data <= min_train_data){
stop(sprintf(
"max_train_data=%d (in smallest stratum) but should be larger than min_train_data=%d, please fix by decreasing min_train_data",
max_train_data, min_train_data))
}
log.range.data <- log(c(min_train_data, max_train_data))
seq.args <- c(as.list(log.range.data), list(l=self$param_set$values[["train_sizes"]]))
log.train.sizes <- do.call(seq, seq.args)
train.size.vec <- unique(as.integer(round(exp(log.train.sizes))))
train.size.vec <- as.integer(round(exp(log.train.sizes)))
size.tab <- table(train.size.vec)
if(any(size.tab>1)){
stop("train sizes not unique, please decrease train_sizes")
}
for(seed in 1:self$param_set$values[["random_seeds"]]){
set.seed(seed)
ord.i.vec <- sample(i.set.list$train)
train.seed.list <- lapply(train.strata.list, function(DT)DT[sample(.N)][, `:=`(
row_seed = .I,
prop = .I/.N
)][])
test.index.vec <- do.call(c, lapply(
strata.list, function(DT)DT[fold == test.fold, row_id]))
train.prop.dt <- train.seed.list[[small.strat.i]][train.size.vec, data.table(prop)]
train.i.list <- lapply(train.seed.list, function(DT)DT[
train.prop.dt,
.(train.i=lapply(row_seed, function(last)DT$row_id[1:last])),
on="prop",
roll="nearest"])
train.index.list <- list()
for(train.size.i in seq_along(train.size.vec)){
strata.index.list <- lapply(train.i.list, function(DT)DT[["train.i"]][[train.size.i]])
train.index.list[[train.size.i]] <- do.call(c, strata.index.list)
}
iteration.dt.list[[paste(test.fold, seed)]] <- data.table(
test.fold,
seed,
train_size=train.size.vec,
train=lapply(train.size.vec, function(last)ord.i.vec[1:last]),
test=list(i.set.list$test))
small_stratum_size=train.size.vec,
train_size_i=seq_along(train.size.vec),
train_size=sapply(train.index.list, length),
train=train.index.list,
test=list(test.index.vec))
}
}
self$instance <- list(
iteration.dt=rbindlist(iteration.dt.list)[, iteration := .I][],
iteration.dt=rbindlist(
iteration.dt.list
)[
, iteration := .I
][
, train_min_size := min(train_size), by=train_size_i
][],
id.dt=folds)
self$task_hash = task$hash
self$task_nrow = task$nrow
Expand Down
3 changes: 2 additions & 1 deletion man/ResamplingVariableSizeTrainCV.Rd
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
orderings of the train set that are considered.

For each random order of the train set, the \code{min_train_data}
parameter controls the smallest train set size considered.
parameter controls the size of the smallest stratum in the smallest
train set considered.

To determine the other train set sizes, we use an equally spaced grid
on the log scale, from \code{min_train_data} to the largest train set
Expand Down
47 changes: 43 additions & 4 deletions tests/testthat/test-CRAN.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
library(testthat)
library(data.table)
test_that("resampling error if no group", {
itask <- mlr3::TaskClassif$new("iris", iris, target="Species")
same_other <- mlr3resampling::ResamplingSameOtherCV$new()
Expand Down Expand Up @@ -81,15 +83,52 @@ test_that("error for group named test", {
}, "col with role group must not be named test; please fix by renaming test col")
})

test_that("error for 10 data", {
test_that("errors and result for 10 train data in small stratum", {
size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
i10.dt <- data.table(iris)[1:10]
i10.task <- mlr3::TaskClassif$new("i10", i10.dt, target="Species")
size_cv$param_set$values$folds <- 2
i10.dt <- data.table(iris)[1:70]
i10.task <- mlr3::TaskClassif$new(
"i10", i10.dt, target="Species"
)$set_col_roles("Species",c("target","stratum"))
expect_error({
size_cv$instantiate(i10.task)
},
"task$nrow=10 but should be larger than min_train_data=10",
"max_train_data=10 (in smallest stratum) but should be larger than min_train_data=10, please fix by decreasing min_train_data",
fixed=TRUE)
size_cv$param_set$values$min_train_data <- 9
expect_error({
size_cv$instantiate(i10.task)
},
"train sizes not unique, please decrease train_sizes",
fixed=TRUE)
size_cv$param_set$values$train_sizes <- 2
size_cv$instantiate(i10.task)
size.tab <- table(size_cv$instance$iteration.dt[["small_stratum_size"]])
expect_identical(names(size.tab), c("9","10"))
})

test_that("strata respected in all sizes", {
size_cv <- mlr3resampling::ResamplingVariableSizeTrainCV$new()
size_cv$param_set$values$min_train_data <- 5
size_cv$param_set$values$folds <- 5
N <- 100
imbalance <- 4
strat.vec <- ifelse((1:imbalance)<imbalance, "A","B")
istrat.dt <- data.table(iris[1:N,], strat=factor(rep(strat.vec, l=N)))
smallest.size.tab <- table(
istrat.dt[["strat"]]
)/N*imbalance*size_cv$param_set$values$min_train_data
istrat.task <- mlr3::TaskClassif$new(
"istrat", istrat.dt, target="Species"
)$set_col_roles("strat", "stratum")
size_cv$instantiate(istrat.task)
min.dt <- size_cv$instance$iteration.dt[train_size==min(train_size)]
for(min.i in 1:nrow(min.dt)){
min.row <- min.dt[min.i]
train.i <- min.row$train[[1]]
strat.tab <- table(istrat.dt[train.i, strat])
expect_equal(strat.tab, smallest.size.tab)
}
})

test_that("train set max size 67 for 100 data", {
Expand Down
4 changes: 3 additions & 1 deletion vignettes/ResamplingSameOtherCV.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,10 @@ In the code below, we execute the benchmark experiment (in parallel
using the multisession future plan).

```{r}
if(FALSE){
if(FALSE){#for CRAN.
if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
(reg.bench.result <- mlr3::benchmark(
reg.bench.grid, store_models = TRUE))
```
Expand Down Expand Up @@ -495,6 +496,7 @@ iteration can be parallelized by declaring a future plan.
if(FALSE){
if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
(class.bench.result <- mlr3::benchmark(
class.bench.grid, store_models = TRUE))
```
Expand Down
2 changes: 2 additions & 0 deletions vignettes/ResamplingVariableSizeTrainCV.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ using the multisession future plan).
if(FALSE){
if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
(reg.bench.result <- mlr3::benchmark(
reg.bench.grid, store_models = TRUE))
```
Expand Down Expand Up @@ -500,6 +501,7 @@ defined by our benchmark grid:
if(FALSE){
if(require(future))plan("multisession")
}
if(require(lgr))get_logger("mlr3")$set_threshold("warn")
(class.bench.result <- mlr3::benchmark(
class.bench.grid, store_models = TRUE))
```
Expand Down

0 comments on commit b5a0619

Please sign in to comment.