Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error handling for cross-validated predictions #408

Open
wants to merge 5 commits into
base: devel
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/R-CMD-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,18 @@ jobs:

steps:
- name: Checkout repo
uses: actions/checkout@v2
uses: actions/checkout@v3

- name: Setup R
uses: r-lib/actions/setup-r@v2
with:
r-version: ${{ matrix.config.r }}

- name: Install pandoc
uses: r-lib/actions/setup-pandoc@v1
uses: r-lib/actions/setup-pandoc@v2

- name: Install tinyTeX
uses: r-lib/actions/setup-tinytex@v1
uses: r-lib/actions/setup-tinytex@v2

- name: Install system dependencies
if: runner.os == 'Linux'
Expand All @@ -50,7 +50,7 @@ jobs:

- name: Install package dependencies
run: |
install.packages(c("remotes", "rcmdcheck", "covr", "sessioninfo"))
install.packages(c("remotes", "rcmdcheck", "covr", "sessioninfo", "devtools"))
if(Sys.info()["sysname"] == "Windows") install.packages("igraph", type = "binary")
remotes::install_deps(dependencies = TRUE)
shell: Rscript {0}
Expand Down
46 changes: 24 additions & 22 deletions R/Lrnr_cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,47 +102,45 @@ Lrnr_cv <- R6Class(
predict_fold = function(task, fold_number = "validation", pred_unique_ts = FALSE) {
fold_number <- interpret_fold_number(fold_number)
if (fold_number == "validation") {
# return cross validation predicitons (what Lrnr_cv$predict does, so use that)

# return cross validation predictions (what Lrnr_cv$predict does)
preds <- self$predict(task)

### Time-series addition:
# Each time point gets an unique final prediction
if (pred_unique_ts) {
folds <- task$folds
index_val <- unlist(lapply(folds, function(fold) {
fold$validation_set
}))
index_val <- unlist(lapply(folds, function(fold) fold$validation_set))
preds_unique <- unique(index_val)

if (length(unique(index_val)) != length(index_val)) {
# Average over the same predictions:
preds <- data.table(index_val, preds)

preds <- preds %>%
group_by(index_val) %>%
summarise_all(mean) %>%
select(-1)
}
}
return(preds)
} else if (fold_number == "full") {
# check if we did a fold fit, and use that fit if available
if (self$params$full_fit) {
fold_fit <- self$fit_object$full_fit
} else {
stop("full fit requested, but Lrnr_cv was constructed with full_fit=FALSE")
}
} else {
# use the requested fold fit
fold_number <- as.numeric(fold_number)
if (is.na(fold_number) || !(fold_number > 0)) {
stop("fold_number must be 'full', 'validation', or a positive integer")
if (fold_number == "full") {
# check if we did a fold fit, and use that fit if available
if (self$params$full_fit) {
fold_fit <- self$fit_object$full_fit
} else {
stop("full fit requested, but Lrnr_cv was constructed with full_fit=FALSE")
}
} else {
# use the requested fold fit
fold_number <- as.numeric(fold_number)
if (is.na(fold_number) || !(fold_number > 0)) {
stop("fold_number must be 'full', 'validation', or a positive integer")
}
fold_fit <- self$fit_object$fold_fits[[as.numeric(fold_number)]]
}
fold_fit <- self$fit_object$fold_fits[[as.numeric(fold_number)]]
revere_task <- task$revere_fold_task(fold_number)
preds <- fold_fit$predict(revere_task)
}

revere_task <- task$revere_fold_task(fold_number)
preds <- fold_fit$predict(revere_task)
return(preds)
},
chain_fold = function(task, fold_number = "validation") {
Expand Down Expand Up @@ -334,6 +332,10 @@ Lrnr_cv <- R6Class(
return(fit_object)
},
.predict = function(task) {
if (length(self$training_task$folds) != length(task$folds)) {
stop("Training and prediction tasks have different numbers of folds")
}

folds <- task$folds
fold_fits <- private$.fit_object$fold_fits

Expand Down Expand Up @@ -376,7 +378,7 @@ Lrnr_cv <- R6Class(

# don't convert to vector if learner is stack, as stack won't
if ((ncol(predictions) == 1) && !inherits(self$params$learner, "Stack")) {
predictions <- unlist(predictions)
predictions <- as.numeric(unlist(predictions))
}
return(predictions)
},
Expand Down
1 change: 1 addition & 0 deletions docs/articles/custom_lrnrs.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions docs/articles/intro_sl3.html

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions docs/pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ pkgdown_sha: ~
articles:
custom_lrnrs: custom_lrnrs.html
intro_sl3: intro_sl3.html

last_built: 2023-02-01T18:15Z

urls:
reference: https://tlverse.org/sl3/reference
article: https://tlverse.org/sl3/articles
Expand Down
23 changes: 23 additions & 0 deletions tests/testthat/test-cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,26 @@ if (Sys.info()["sysname"] == "Windows") {
learners <- learners[!(learners == "Lrnr_grfcate")]
lapply(learners, test_loocv_learner, loocv_task)
test_loocv_learner("Lrnr_grfcate", loocv_task, A = "apgar1")


###################### test CV predictions with new tasks ######################
data(mtcars)
mtcars_task <- make_sl3_Task(
data = mtcars[1:10,], outcome = "mpg",
covariates = c( "cyl", "disp", "hp", "drat", "wt"), folds = 3
)
mtcars_task2 <- make_sl3_Task(
data = mtcars[11:30,], outcome = "mpg",
covariates = c( "cyl", "disp", "hp", "drat", "wt")
)
lrnr_cv_glm <- Lrnr_cv$new(Lrnr_glm$new(), full_fit = TRUE)
cv_glm_fit <- lrnr_cv_glm$train(mtcars_task)
expect_error(cv_glm_fit$predict(mtcars_task2))
expect_error(cv_glm_fit$predict_fold(mtcars_task2, "validation"))

mtcars_task3 <- make_sl3_Task(
data = mtcars[11:30,], outcome = "mpg",
covariates = c( "cyl", "disp", "hp", "drat", "wt"), folds = 3
)
expect_equal(length(cv_glm_fit$predict(mtcars_task3)), 20)
expect_equal(length(cv_glm_fit$predict_fold(mtcars_task2, "validation")), 20)