From d0a4a1d03fa4800586af9a8ff872eb97c26e4455 Mon Sep 17 00:00:00 2001 From: Jeremy Coyle Date: Mon, 22 Jan 2024 14:12:37 -0800 Subject: [PATCH 1/2] fix subset covariates to support out of order covariates. Covariates passed to learner should match the order of the covariate params, not the order of the covariates in the task --- R/Lrnr_base.R | 32 ++++++++----------------- tests/testthat/test-subset_covariates.R | 11 +++++++-- 2 files changed, 19 insertions(+), 24 deletions(-) diff --git a/R/Lrnr_base.R b/R/Lrnr_base.R index 101b8eae..7846c158 100644 --- a/R/Lrnr_base.R +++ b/R/Lrnr_base.R @@ -56,6 +56,13 @@ Lrnr_base <- R6Class( if (length(delta_idx) > 0) { delta_missing <- task_covs_missing[delta_idx] task_covs_missing <- task_covs_missing[-delta_idx] + + delta_missing_data <- matrix(0, nrow(task$data), length(delta_idx)) + colnames(delta_missing_data) <- delta_missing + cols <- task$add_columns(data.table(delta_missing_data)) + + } else{ + cols <- task$column_names } # error when task is missing covariates @@ -68,29 +75,10 @@ Lrnr_base <- R6Class( ) } - # subset task covariates to only includes those in learner covariates - covs_subset <- intersect(task_covs, learner_covs) - - # return updated task - if (length(delta_idx) == 0) { - # re-order the covariate subset to match order of learner covariates - ordered_covs_subset <- covs_subset[match(covs_subset, learner_covs)] - return(task$next_in_chain(covariates = ordered_covs_subset)) - } else { - # incorporate missingness indicators in task covariates subset & sort - covs_subset_delta <- c(covs_subset, delta_missing) - ord_covs <- covs_subset_delta[match(covs_subset_delta, learner_covs)] - - # incorporate missingness indicators in task data - delta_missing_data <- matrix(0, nrow(task$data), length(delta_idx)) - colnames(delta_missing_data) <- delta_missing - cols <- task$add_columns(data.table(delta_missing_data)) - - return(task$next_in_chain( - covariates = ord_covs, + return(task$next_in_chain( + covariates = learner_covs, column_names = cols - )) - } + )) } else { return(task) } diff --git a/tests/testthat/test-subset_covariates.R b/tests/testthat/test-subset_covariates.R index 8e47600f..52df3f75 100644 --- a/tests/testthat/test-subset_covariates.R +++ b/tests/testthat/test-subset_covariates.R @@ -31,6 +31,14 @@ full_preds <- glm_fit_pre_subset$predict(task) training_preds <- glm_fit_pre_subset$predict() test_that("extra covariates in prediction set get dropped correctly", expect_equal(full_preds, training_preds)) + +shuffled_subset <- sample(covariate_subset) +task_pre_subset_shuffled <- sl3_Task$new(mtcars, covariates = shuffled_subset, outcome = outcome) +# debugonce(glm_fit_pre_subset$subset_covariates) +shuffled_preds <- glm_fit_pre_subset$predict(task_pre_subset_shuffled) +test_that("covariates out of order prediction set get shuffled correctly", expect_equal(full_preds, shuffled_preds)) + + task_train <- sl3_Task$new(mtcars, covariates = covariates, outcome = outcome) task_predict <- sl3_Task$new(mtcars, covariates = covariate_subset, outcome = outcome) glm_fit <- lrnr_glm$train(task_train) @@ -47,11 +55,10 @@ task_missing_data <- suppressWarnings( sl3_Task$new(missing_data, covariates = covs, outcome = Y) ) -lrnr_glm <- make_learner(Lrnr_glm_fast, name = "test") +lrnr_glm <- make_learner(Lrnr_glm_fast) glm_fit <- lrnr_glm$train(task_missing_data) task_complete_data <- sl3_Task$new(mtcars, covariates = covs, outcome = Y) - test_that("missingness indicators in prediction task works", { expect_vector(glm_fit$predict(task_complete_data)) }) From 5462dd6b4edde048d489a9c61218895bde11710a Mon Sep 17 00:00:00 2001 From: Jeremy Coyle Date: Mon, 22 Jan 2024 14:18:31 -0800 Subject: [PATCH 2/2] don't warn on covariate param --- R/utils.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/utils.R b/R/utils.R index 8003df15..b49c4431 100644 --- a/R/utils.R +++ b/R/utils.R @@ -105,10 +105,12 @@ call_with_args <- function(fun, args, other_valid = list(), keep_all = FALSE, # subset arguments to pass args <- args[which(names(args) %in% all_valid)] + # don't warn on covariate param + invalid <- setdiff(invalid, "covariates") # return warnings when dropping arguments if (!silent & length(invalid) > 0) { message(sprintf( - "Learner called function %s with unknown args: %s. These will be dropped.\nCheck the params supported by this learner.", + "Learner called function %s with unknown args: %s. These will be dropped.\nCheck the params supported by this learner.\n", as.character(substitute(fun)), paste(invalid, collapse = ", ") )) }