Skip to content

Commit

Permalink
Merge pull request #29 from tdhock/reduceResultsList.fun
Browse files Browse the repository at this point in the history
add reduceResultsList.fun arg to reduceResultsBatchmark
  • Loading branch information
sebffischer authored Apr 9, 2024
2 parents 1aed0c0 + e762b84 commit 5b2b081
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 6 deletions.
6 changes: 4 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ Authors@R: c(
person("Marc", "Becker", , "marcbecker@posteo.de", role = c("cre", "aut"),
comment = c(ORCID = "0000-0002-8115-0400")),
person("Michel", "Lang", , "michellang@gmail.com", role = "aut",
comment = c(ORCID = "0000-0001-9754-0393"))
comment = c(ORCID = "0000-0001-9754-0393")),
person("Toby", "Hocking", role="ctb",
comment = c(ORCID="0000-0002-3146-0865"))
)
Description: Extends the 'mlr3' package with a connector to the package
'batchtools'. This allows to run large-scale benchmark experiments on
Expand All @@ -29,4 +31,4 @@ Suggests:
testthat
Encoding: UTF-8
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3.9000
RoxygenNote: 7.3.1
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# mlr3batchmark (development version)

* feat: `reduceResultsBatchmark` gains argument `fun` which is passed on to `batchtools::reduceResultsList`, useful for deleting model data to avoid running out of memory, https://github.com/mlr-org/mlr3batchmark/issues/18 Thanks to Toby Dylan Hocking @tdhock for the PR.
* docs: A warning is now given when the loaded mlr3 version differs from the
mlr3 version stored in the trained learners

Expand Down
4 changes: 2 additions & 2 deletions R/reduceResultsBatchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#'
#' @return [mlr3::BenchmarkResult].
#' @export
reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batchtools::getDefaultRegistry()) { # nolint
reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batchtools::getDefaultRegistry(), fun=NULL) { # nolint
if (is.null(ids)) {
ids = batchtools::findDone(ids, reg = reg)
} else {
Expand Down Expand Up @@ -61,7 +61,7 @@ reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batch
learner = get_export(needle, reg)
}

results = batchtools::reduceResultsList(tab$job.id, reg = reg)
results = batchtools::reduceResultsList(tab$job.id, reg = reg, fun = fun)

if (!version_checked) {
version_checked = TRUE
Expand Down
5 changes: 5 additions & 0 deletions man/mlr3batchmark-package.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion man/reduceResultsBatchmark.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion tests/testthat/test_reduceResultsBatchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ test_that("reduceResultsBatchmark", {
)

reg = batchtools::makeExperimentRegistry(NA)
batchmark(design, reg = reg)
batchmark(design, reg = reg, store_models = TRUE)
batchtools::submitJobs(reg = reg)
batchtools::waitForJobs(reg = reg)

Expand All @@ -30,6 +30,17 @@ test_that("reduceResultsBatchmark", {
tab = bmr$resamplings
expect_data_table(tab, nrow = 4)
expect_set_equal(tab$resampling_id, ids(resamplings))

rpart_model = function(b){
b$score()[learner_id == "classif.rpart"]$learner[[1]]$model
}
expect_is(rpart_model(bmr), "rpart")
no_models = reduceResultsBatchmark(reg = reg, fun = function(L){
L$learner_state$model = NULL
L
})
expect_null(rpart_model(no_models))

})

test_that("warning is given when mlr3 versions mismatch", {
Expand Down

0 comments on commit 5b2b081

Please sign in to comment.