diff --git a/DESCRIPTION b/DESCRIPTION index 523eac2..b244a94 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -29,4 +29,4 @@ Suggests: testthat Encoding: UTF-8 Roxygen: list(markdown = TRUE) -RoxygenNote: 7.2.3 +RoxygenNote: 7.2.3.9000 diff --git a/NEWS.md b/NEWS.md index ace4446..2a61195 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # mlr3batchmark (development version) +* docs: A warning is now given when the loaded mlr3 version differs from the +mlr3 version stored in the trained learners + # mlr3batchmark 0.1.1 * feat: `mlr3batchmark` now depends on package `batchtools` to avoid having to load `batchtools` explicitly. diff --git a/R/reduceResultsBatchmark.R b/R/reduceResultsBatchmark.R index 7f9359c..04e9212 100644 --- a/R/reduceResultsBatchmark.R +++ b/R/reduceResultsBatchmark.R @@ -29,6 +29,8 @@ reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batch tabs = split(tabs, by = "job.name") bmr = mlr3::BenchmarkResult$new() + version_checked = FALSE + for (tab in tabs) { job = batchtools::makeJob(tab$job.id[1L], reg = reg) bmr_tasks = bmr$tasks @@ -60,6 +62,17 @@ reduceResultsBatchmark = function(ids = NULL, store_backends = TRUE, reg = batch } results = batchtools::reduceResultsList(tab$job.id, reg = reg) + + if (!version_checked) { + version_checked = TRUE + if (mlr3::mlr_reflections$package_version != results[[1]]$learner_state$mlr3_version) { + lg$warn(paste(sep = "\n", + "The mlr3 version (%s) from one of the trained learners differs from the currently loaded mlr3 version (%s).", + "This can lead to unexpected behavior and we recommend using the same versions of all mlr3 packages for collecting the results."), + results[[1]]$learner_state$mlr3_version, mlr3::mlr_reflections$package_version) + } + } + rdata = mlr3::ResultData$new(data.table( task = list(task), learner = list(learner), diff --git a/R/zzz.R b/R/zzz.R index 0c21cd1..1693865 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -4,3 +4,8 @@ #' @rawNamespace import(batchtools, except = chunk) #' @importFrom uuid UUIDgenerate "_PACKAGE" + + +.onLoad = function(libname, pkgname) { + assign("lg", lgr::get_logger(pkgname), envir = parent.env(environment())) +} diff --git a/tests/testthat/test_reduceResultsBatchmark.R b/tests/testthat/test_reduceResultsBatchmark.R index f7d27bb..f565764 100644 --- a/tests/testthat/test_reduceResultsBatchmark.R +++ b/tests/testthat/test_reduceResultsBatchmark.R @@ -31,3 +31,19 @@ test_that("reduceResultsBatchmark", { expect_data_table(tab, nrow = 4) expect_set_equal(tab$resampling_id, ids(resamplings)) }) + +test_that("warning is given when mlr3 versions mismatch", { + mlr_reflections = mlr3::mlr_reflections + mlr3_version = mlr_reflections$package_version + reg = makeExperimentRegistry(NA) + batchmark(benchmark_grid(tsk("mtcars"), lrn("regr.featureless"), rsmp("holdout"))) + submitJobs() + waitForJobs() + + on.exit({mlr_reflections$package_version = mlr3_version}, add = TRUE) + + mlr_reflections$package_version = "100.0.0" + + capture.output(reduceResultsBatchmark(reg = reg)) + expect_true(grepl("The mlr3 version", lg$last_event$msg, fixed = TRUE)) +})