Skip to content

Commit

Permalink
Merge pull request #13 from sebffischer/main
Browse files Browse the repository at this point in the history
Leanify package, add resamplings to dictionary
  • Loading branch information
tdhock authored May 14, 2024
2 parents 60f639d + 53e7015 commit d1181f9
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 5 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Description: A supervised learning algorithm inputs a train set,
test accuracy for each group; other is usually somewhat less accurate
than same; other can be just as bad as featureless baseline when the
groups have different patterns).
For more information,
For more information,
<https://tdhock.github.io/blog/2023/R-gen-new-subsets/>
describes the method in depth.
How many train samples are required to get accurate predictions on a
Expand Down
4 changes: 2 additions & 2 deletions R/ResamplingSameOtherCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ ResamplingSameOtherCV = R6::R6Class(
}
reserved.names <- c(
"row_id", "fold", "subset", "display_row",
"train.subsets", "test.fold", "test.subset", "iteration",
"train.subsets", "test.fold", "test.subset", "iteration",
"test", "train", "algorithm", "uhash", "nr", "task", "task_id",
"learner", "learner_id", "resampling", "resampling_id",
"prediction")
Expand Down Expand Up @@ -106,7 +106,7 @@ ResamplingSameOtherCV = R6::R6Class(
rows="fold",
display_row=min(display_row),
display_end=max(display_row)
), by=.(subset, fold)])
), by=.(subset, fold)])
self$instance <- list(
iteration.dt=iteration.dt,
id.dt=id.fold.subsets[order(row_id)],
Expand Down
3 changes: 1 addition & 2 deletions R/ResamplingVariableSizeTrainCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ ResamplingVariableSizeTrainCV = R6::R6Class(
instantiate = function(task) {
task = mlr3::assert_task(mlr3::as_task(task))
strata <- if(is.null(task$strata)){
data.dt <- task$data()
data.table(N=nrow(data.dt), row_id=list(1:nrow(data.dt)))
data.table(N=task$nrow, row_id=list(seq_len(task$nrow)))
}else task$strata
strata.list <- lapply(strata$row_id, private$.sample, task = task)
folds = private$.combine(strata.list)[order(row_id)]
Expand Down
16 changes: 16 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
register_mlr3 = function() {
mlr_resamplings = utils::getFromNamespace("mlr_resamplings", ns = "mlr3")
mlr_resamplings$add("same_other_sizes_cv", ResamplingSameOtherSizesCV)
}

.onLoad = function(libname, pkgname) { # nolint
# Configure Logger:
assign("lg", lgr::get_logger("mlr3"), envir = parent.env(environment()))
if (Sys.getenv("IN_PKGDOWN") == "true") {
lg$set_threshold("warn") # nolint
}

mlr3misc::register_namespace_callback(pkgname, "mlr3", register_mlr3)
}

mlr3misc::leanify_package()

0 comments on commit d1181f9

Please sign in to comment.