Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ResamplingVariableSizeCV #3

Merged
merged 11 commits into from
Dec 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: mlr3resampling
Type: Package
Title: Resampling Algorithms for 'mlr3' Framework
Version: 2023.12.23
Version: 2023.12.28
Authors@R: c(
person("Toby", "Hocking",
email="toby.hocking@r-project.org",
Expand Down Expand Up @@ -52,6 +52,9 @@ Description: A supervised learning algorithm inputs a train set,
For more information,
<https://tdhock.github.io/blog/2023/R-gen-new-subsets/>
describes the method in depth.
How many train samples are required to get accurate predictions on a
test set? Cross-validation can be used to answer this question, with
variable size train sets.
License: GPL-3
URL: https://github.com/tdhock/mlr3resampling
BugReports: https://github.com/tdhock/mlr3resampling/issues
Expand Down
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
import(R6, checkmate, data.table, mlr3, mlr3misc, paradox)
export(ResamplingSameOtherCV, score)
export(ResamplingSameOtherCV, score, ResamplingVariableSizeTrainCV)

5 changes: 5 additions & 0 deletions NEWS
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
Changes in version 2023.12.28

- Rename Simulations vignette to ResamplingSameOtherCV.
- New ResamplingVariableSizeTrainCV class and vignette.

Changes in version 2023.12.23

- To get data set names in Simulations vignette, use task data names instead of learner$state$data_prototype.
Expand Down
83 changes: 83 additions & 0 deletions R/ResamplingBase.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
ResamplingBase = R6::R6Class(
"Resampling",
public = list(
id = NULL,
label = NULL,
param_set = NULL,
instance = NULL,
task_hash = NA_character_,
task_nrow = NA_integer_,
duplicated_ids = NULL,
man = NULL,
initialize = function(id, param_set = ps(), duplicated_ids = FALSE, label = NA_character_, man = NA_character_) {
self$id = checkmate::assert_string(id, min.chars = 1L)
self$label = checkmate::assert_string(label, na.ok = TRUE)
self$param_set = paradox::assert_param_set(param_set)
self$duplicated_ids = checkmate::assert_flag(duplicated_ids)
self$man = checkmate::assert_string(man, na.ok = TRUE)
},
format = function(...) {
sprintf("<%s>", class(self)[1L])
},
print = function(...) {
cat(
format(self),
if (is.null(self$label) || is.na(self$label))
"" else paste0(": ", self$label)
)
cat("\n* Iterations:", self$iters)
cat("\n* Instantiated:", self$is_instantiated)
cat("\n* Parameters:\n")
str(self$param_set$values)
},
help = function() {
self$man
},
train_set = function(i) {
self$instance$iteration.dt$train[[i]]
},
test_set = function(i) {
self$instance$iteration.dt$test[[i]]
}
),
active = list(
iters = function(rhs) {
nrow(self$instance$iteration.dt)
},
is_instantiated = function(rhs) {
!is.null(self$instance)
},
hash = function(rhs) {
if (!self$is_instantiated) {
return(NA_character_)
}
mlr3misc::calculate_hash(list(
class(self),
self$id,
self$param_set$values,
self$instance))
}
),
private = list(
.sample = function(ids, ...) {
data.table(
row_id = ids,
fold = sample(
seq(0, length(ids)-1) %%
as.integer(self$param_set$values$folds) + 1L
),
key = "fold"
)
},
.combine = function(instances) {
rbindlist(instances, use.names = TRUE)
},
deep_clone = function(name, value) {
switch(name,
"instance" = copy(value),
"param_set" = value$clone(deep = TRUE),
value
)
}
)
)
119 changes: 13 additions & 106 deletions R/ResamplingSameOther.R → R/ResamplingSameOtherCV.R
Original file line number Diff line number Diff line change
@@ -1,33 +1,17 @@
ResamplingSameOther = R6::R6Class(
"Resampling",
ResamplingSameOtherCV = R6::R6Class(
"ResamplingSameOtherCV",
inherit=ResamplingBase,
public = list(
id = NULL,
label = NULL,
param_set = NULL,
instance = NULL,
task_hash = NA_character_,
task_nrow = NA_integer_,
duplicated_ids = NULL,
man = NULL,
initialize = function(id, param_set = ps(), duplicated_ids = FALSE, label = NA_character_, man = NA_character_) {
self$id = checkmate::assert_string(id, min.chars = 1L)
self$label = checkmate::assert_string(label, na.ok = TRUE)
self$param_set = paradox::assert_param_set(param_set)
self$duplicated_ids = checkmate::assert_flag(duplicated_ids)
self$man = checkmate::assert_string(man, na.ok = TRUE)
},
format = function(...) {
sprintf("<%s>", class(self)[1L])
},
print = function(...) {
cat(format(self), if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label))
cat("\n* Iterations:", self$iters)
cat("\n* Instantiated:", self$is_instantiated)
cat("\n* Parameters:\n")
str(self$param_set$values)
},
help = function() {
self$man
initialize = function() {
ps = paradox::ps(
folds = paradox::p_int(2L, tags = "required")
)
ps$values = list(folds = 3L)
super$initialize(
id = "same_other_cv",
param_set = ps,
label = "Same versus Other Cross-Validation",
man = "ResamplingSameOtherCV")
},
instantiate = function(task) {
task = mlr3::assert_task(mlr3::as_task(task))
Expand Down Expand Up @@ -125,83 +109,6 @@ ResamplingSameOther = R6::R6Class(
self$task_hash = task$hash
self$task_nrow = task$nrow
invisible(self)
},
train_set = function(i) {
self$instance$iteration.dt$train[[i]]
},
test_set = function(i) {
self$instance$iteration.dt$test[[i]]
}
),
active = list(
is_instantiated = function(rhs) {
!is.null(self$instance)
},
hash = function(rhs) {
if (!self$is_instantiated) {
return(NA_character_)
}
mlr3misc::calculate_hash(list(class(self), self$id, self$param_set$values, self$instance))
}
)
)

ResamplingSameOtherCV = R6::R6Class(
"ResamplingSameOtherCV",
inherit = ResamplingSameOther,
public = list(
initialize = function() {
ps = paradox::ps(
folds = paradox::p_int(2L, tags = "required")
)
ps$values = list(folds = 3L)
super$initialize(
id = "same_other_cv",
param_set = ps,
label = "Cross-Validation",
man = "ResamplingSameOtherCV")
}
),
active = list(
iters = function(rhs) {
nrow(self$instance$iteration.dt)
}
),
private = list(
.sample = function(ids, ...) {
data.table(
row_id = ids,
fold = sample(
seq(0, length(ids)-1) %%
as.integer(self$param_set$values$folds) + 1L
),
key = "fold"
)
},
.combine = function(instances) {
rbindlist(instances, use.names = TRUE)
},
deep_clone = function(name, value) {
switch(name,
"instance" = copy(value),
"param_set" = value$clone(deep = TRUE),
value
)
}
)
)

score <- function(bench.result, ...){
algorithm <- learner_id <- NULL
## Above to avoid CRAN NOTE.
bench.score <- bench.result$score(...)
out.dt.list <- list()
for(score.i in 1:nrow(bench.score)){
bench.row <- bench.score[score.i]
it.dt <- bench.row$resampling[[1]]$instance$iteration.dt
out.dt.list[[score.i]] <- it.dt[
bench.row, on="iteration"
][, algorithm := sub(".*[.]", "", learner_id)]
}
rbindlist(out.dt.list)
}
80 changes: 80 additions & 0 deletions R/ResamplingVariableSizeTrainCV.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
ResamplingVariableSizeTrainCV = R6::R6Class(
"ResamplingVariableSizeTrainCV",
inherit=ResamplingBase,
public = list(
initialize = function() {
ps = paradox::ps(
folds = paradox::p_int(2L, tags = "required"),
min_train_data=paradox::p_int(1L, tags = "required"),
random_seeds=paradox::p_int(1L, tags = "required"),
train_sizes = paradox::p_int(2L, tags = "required"))
ps$values = list(
folds = 3L,
min_train_data=10L,
random_seeds=3L,
train_sizes=5L)
super$initialize(
id = "variable_size_train_cv",
param_set = ps,
label = "Cross-Validation with variable size train sets",
man = "ResamplingVariableSizeTrainCV")
},
instantiate = function(task) {
task = mlr3::assert_task(mlr3::as_task(task))
reserved.names <- c(
"row_id", "fold", "group", "display_row",
"train.groups", "test.fold", "test.group", "iteration",
"test", "train", "algorithm", "uhash", "nr", "task", "task_id",
"learner", "learner_id", "resampling", "resampling_id",
"prediction")
## bad.names <- group.name.vec[group.name.vec %in% reserved.names]
## if(length(bad.names)){
## first.bad <- bad.names[1]
## stop(sprintf("col with role group must not be named %s; please fix by renaming %s col", first.bad, first.bad))
## }
## orig.group.dt <- task$data(cols=group.name.vec)
strata <- if(is.null(task$strata)){
data.dt <- task$data()
data.table(N=nrow(data.dt), row_id=list(1:nrow(data.dt)))
}else task$strata
folds = private$.combine(
lapply(strata$row_id, private$.sample, task = task)
)[order(row_id)]
min_train_data <- self$param_set$values[["min_train_data"]]
if(task$nrow <= min_train_data){
stop(sprintf(
"task$nrow=%d but should be larger than min_train_data=%d",
task$nrow, min_train_data))
}
uniq.folds <- sort(unique(folds$fold))
iteration.dt.list <- list()
for(test.fold in uniq.folds){
is.set.fold <- list(
test=folds[["fold"]] == test.fold)
is.set.fold[["train"]] <- !is.set.fold[["test"]]
i.set.list <- lapply(is.set.fold, which)
max_train_data <- length(i.set.list$train)
log.range.data <- log(c(min_train_data, max_train_data))
seq.args <- c(as.list(log.range.data), list(l=self$param_set$values[["train_sizes"]]))
log.train.sizes <- do.call(seq, seq.args)
train.size.vec <- unique(as.integer(round(exp(log.train.sizes))))
for(seed in 1:self$param_set$values[["random_seeds"]]){
set.seed(seed)
ord.i.vec <- sample(i.set.list$train)
iteration.dt.list[[paste(test.fold, seed)]] <- data.table(
test.fold,
seed,
train_size=train.size.vec,
train=lapply(train.size.vec, function(last)ord.i.vec[1:last]),
test=list(i.set.list$test))
}
}
self$instance <- list(
iteration.dt=rbindlist(iteration.dt.list)[, iteration := .I][],
id.dt=folds)
self$task_hash = task$hash
self$task_nrow = task$nrow
invisible(self)
}
)
)
14 changes: 14 additions & 0 deletions R/score.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
score <- function(bench.result, ...){
algorithm <- learner_id <- NULL
## Above to avoid CRAN NOTE.
bench.score <- bench.result$score(...)
out.dt.list <- list()
for(score.i in 1:nrow(bench.score)){
bench.row <- bench.score[score.i]
it.dt <- bench.row$resampling[[1]]$instance$iteration.dt
out.dt.list[[score.i]] <- it.dt[
bench.row, on="iteration"
][, algorithm := sub(".*[.]", "", learner_id)]
}
rbindlist(out.dt.list)
}
32 changes: 15 additions & 17 deletions man/ResamplingSameOther.Rd → man/ResamplingSameOtherCV.Rd
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
\name{ResamplingSameOther}
\alias{ResamplingSameOther}
\name{ResamplingSameOtherCV}
\alias{ResamplingSameOtherCV}
\title{Resampling for comparing training on same or other groups}
\description{
\code{\link{ResamplingSameOther}} is the abstract base class for
\code{\link{ResamplingSameOtherCV}},
which defines how a task is partitioned for
resampling, for example in
\code{\link[mlr3:resample]{resample()}} or
\code{\link[mlr3:benchmark]{benchmark()}}.

Resampling objects can be instantiated on a
\code{\link[mlr3:Task]{Task}},
which should define at least one group variable.

After instantiation, sets can be accessed via
\verb{$train_set(i)} and
\verb{$test_set(i)}, respectively.
\code{\link{ResamplingSameOtherCV}}
defines how a task is partitioned for
resampling, for example in
\code{\link[mlr3:resample]{resample()}} or
\code{\link[mlr3:benchmark]{benchmark()}}.

Resampling objects can be instantiated on a
\code{\link[mlr3:Task]{Task}},
which should define at least one group variable.

After instantiation, sets can be accessed via
\verb{$train_set(i)} and
\verb{$test_set(i)}, respectively.
}
\details{
A supervised learning algorithm inputs a train set, and outputs a
Expand Down Expand Up @@ -50,7 +48,7 @@ each combination of the values of the stratification variables forms a stratum.
The grouping variable is assumed to be discrete,
and must be stored in the \link{Task} with column role \code{"group"}.

Then number of cross-validation folds K should be defined as the
The number of cross-validation folds K should be defined as the
\code{fold} parameter.

In each group, there will be about an equal number of observations
Expand Down
Loading