From f7e25d5f1d2f6a5daa94edd563746c16dc221845 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Thu, 15 Aug 2024 10:42:11 -0700 Subject: [PATCH 01/31] small change to predict checks --- R/predict.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/predict.R b/R/predict.R index 327bd80ef..4b61d97d7 100644 --- a/R/predict.R +++ b/R/predict.R @@ -198,7 +198,8 @@ check_pred_type <- function(object, type, ...) { regression = "numeric", classification = "class", "censored regression" = "time", - rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'.")) + "quantile regression" = "quantile", + rlang::abort("`type` should be 'regression', 'censored regression', 'quantile regression', or 'classification'.")) } if (!(type %in% pred_types)) rlang::abort( From 43d19187b7d9c264619b88d780f87487fb528202 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 9 Sep 2024 12:37:03 -0700 Subject: [PATCH 02/31] add vctrs for quantiles and test, refactor *_rq_preds --- NAMESPACE | 9 +++ R/aaa_quantiles.R | 106 ++++++++++++++++++++++++---- man/vec_quantiles.Rd | 20 ++++++ tests/testthat/test-vec_quantiles.R | 25 +++++++ 4 files changed, 145 insertions(+), 15 deletions(-) create mode 100644 man/vec_quantiles.Rd create mode 100644 tests/testthat/test-vec_quantiles.R diff --git a/NAMESPACE b/NAMESPACE index d48f33586..fe34c25dd 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -36,6 +36,7 @@ S3method(extract_spec_parsnip,model_fit) S3method(fit,model_spec) S3method(fit_xy,gen_additive_mod) S3method(fit_xy,model_spec) +S3method(format,vctrs_quantiles) S3method(glance,model_fit) S3method(has_multi_predict,default) S3method(has_multi_predict,model_fit) @@ -54,6 +55,7 @@ S3method(multi_predict_args,default) S3method(multi_predict_args,model_fit) S3method(multi_predict_args,workflow) S3method(nullmodel,default) +S3method(obj_print_footer,vctrs_quantiles) S3method(predict,"_elnet") S3method(predict,"_glmnetfit") S3method(predict,"_lognet") @@ -172,6 +174,8 @@ S3method(update,svm_rbf) S3method(varying_args,model_spec) S3method(varying_args,recipe) S3method(varying_args,step) +S3method(vec_ptype_abbr,vctrs_quantiles) +S3method(vec_ptype_full,vctrs_quantiles) export("%>%") export(.censoring_weights_graf) export(.check_glmnet_penalty_fit) @@ -280,6 +284,7 @@ export(new_model_spec) export(null_model) export(null_value) export(nullmodel) +export(obj_print_footer) export(parsnip_addin) export(pls) export(poisson_reg) @@ -350,6 +355,7 @@ export(update_model_info_file) export(update_spec) export(varying) export(varying_args) +export(vec_quantiles) export(xgb_predict) export(xgb_train) import(rlang) @@ -396,6 +402,8 @@ importFrom(purrr,map) importFrom(purrr,map_chr) importFrom(purrr,map_dbl) importFrom(purrr,map_lgl) +importFrom(rlang,"!!!") +importFrom(rlang,is_double) importFrom(stats,.checkMFClasses) importFrom(stats,.getXlevels) importFrom(stats,as.formula) @@ -426,5 +434,6 @@ importFrom(utils,globalVariables) importFrom(utils,head) importFrom(utils,methods) importFrom(utils,stack) +importFrom(vctrs,obj_print_footer) importFrom(vctrs,vec_size) importFrom(vctrs,vec_unique) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index a2920757e..c3371ec40 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -21,23 +21,99 @@ check_quantile_level <- function(x, object, call) { x } -# Assumes the columns have the same order as quantile_level -restructure_rq_pred <- function(x, object) { - num_quantiles <- NCOL(x) - if ( num_quantiles == 1L ){ - x <- matrix(x, ncol = 1) + +# ------------------------------------------------------------------------- +# A column vector of quantiles with an attribute + +#' @export +vec_ptype_abbr.vctrs_quantiles <- function(x, ...) "qntls" + +#' @export +vec_ptype_full.vctrs_quantiles <- function(x, ...) "quantiles" + +#' @importFrom rlang is_double !!! +new_vec_quantiles <- function(values = list(), quantile_levels = double()) { + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) + vctrs::new_vctr( + values, quantile_levels = quantile_levels, class = "vctrs_quantiles" + ) +} + + +#' Create a vector containing sets of quantiles +#' +#' @param values A matrix of values. Each column should correspond to one of +#' the quantile levels. +#' @param quantile_levels A vector of probabilities corresponding to `values`. +#' +#' @export +#' @return A vector of values associated with the quantile levels. +#' +#' @examples +#' v <- vec_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +#' +#' # Access the underlying information +#' attr(v, "quantile_levels") +#' vctrs::vec_data(v) +vec_quantiles <- function(values, quantile_levels = double()) { + check_vec_quantiles_inputs(values, quantile_levels) + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) + num_lvls <- length(quantile_levels) + + if (ncol(values) != num_lvls) { + cli::cli_abort( + "The number of columns in {.arg values} must be equal to the length of + {.arg quantile_levels}." + ) } - n <- nrow(x) + values <- lapply(vctrs::vec_chop(values), drop) + new_vec_quantiles(values, quantile_levels) +} +check_vec_quantiles_inputs <- function(values, levels) { + if (!is.matrix(values)) { + cls <- class(values)[1] + cli::cli_abort("{.arg values} must be a {.cls matrix} not a {.cls {cls}}.") + } + purrr::walk(levels, + ~ check_number_decimal(.x, min = 0, max = 1, arg = "quantile_levels") + ) + if (is.unsorted(levels)) { + cli::cli_abort("{.arg quantile_levels} must be sorted in increasing order.") + } + invisible(NULL) +} + +#' @export +format.vctrs_quantiles <- function(x, ...) { + quantile_levels <- attr(x, "levels") + if (length(quantile_levels) == 1L) { + x <- unlist(x) + out <- round(x, 3L) + out[is.na(x)] <- NA + } else { + rng <- sapply(x, range) + out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]") + out[is.na(rng[1, ]) | is.na(rng[2, ])] <- NA + } + out +} + +#' @importFrom vctrs obj_print_footer +#' @export +vctrs::obj_print_footer + +#' @export +obj_print_footer.vctrs_quantiles <- function(x, ...) { + lvls <- attr(x, "quantile_levels") + cat("# Quantile levels: ", format(lvls, digits = 3), "\n", sep = " ") +} + +restructure_rq_pred <- function(x, object) { + if (!is.matrix(x)) x <- as.matrix(x) + n_pred_quantiles <- ncol(x) + # TODO check p = length(quantile_level) quantile_level <- object$spec$quantile_level - res <- - tibble::tibble( - .pred_quantile = as.vector(x), - .quantile_level = rep(quantile_level, each = n), - .row = rep(1:n, num_quantiles)) - res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"]) - res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val))) - res$.row <- NULL - res + tibble::tibble(.pred_quantile = vec_quantiles(x, quantile_level)) } diff --git a/man/vec_quantiles.Rd b/man/vec_quantiles.Rd new file mode 100644 index 000000000..aeda8f8fa --- /dev/null +++ b/man/vec_quantiles.Rd @@ -0,0 +1,20 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/aaa_quantiles.R +\name{vec_quantiles} +\alias{vec_quantiles} +\title{A vector containing sets of quantiles} +\usage{ +vec_quantiles(values, quantile_levels = double()) +} +\arguments{ +\item{values}{A matrix of values. Each column should correspond to one of +the quantile levels.} + +\item{quantile_levels}{A vector of probabilities corresponding to \code{values}.} +} +\description{ +A vector containing sets of quantiles +} +\examples{ +vec_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +} diff --git a/tests/testthat/test-vec_quantiles.R b/tests/testthat/test-vec_quantiles.R new file mode 100644 index 000000000..758859ae9 --- /dev/null +++ b/tests/testthat/test-vec_quantiles.R @@ -0,0 +1,25 @@ +test_that("vec_quantiles error types", { + expect_error(vec_quantiles(1:10, 1:4 / 5), "matrix") + expect_error( + vec_quantiles(matrix(1:20, 5), -1:4 / 5), + "`quantile_levels` must be a number between 0 and 1" + ) + expect_error( + vec_quantiles(matrix(1:20, 5), 1:5 / 6), + "The number of columns in `values` must be equal to" + ) + expect_error( + vec_quantiles(matrix(1:20, 5), 4:1 / 5), + "must be sorted in increasing order" + ) +}) + +test_that("vec_quantiles outputs", { + v <- vec_quantiles(matrix(1:20, 5), 1:4 / 5) + expect_s3_class(v, "vctrs_quantiles") + expect_identical(attr(v, "quantile_levels"), 1:4 / 5) + expect_identical( + vctrs::vec_data(v), + lapply(vctrs::vec_chop(matrix(1:20, 5)), drop) + ) +}) From 728c0462e45bf6b398f22dad0583cc8b54425029 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 9 Sep 2024 12:57:28 -0700 Subject: [PATCH 03/31] revise tests --- R/aaa_quantiles.R | 1 + man/reexports.Rd | 5 +++- man/vec_quantiles.Rd | 13 ++++++++--- tests/testthat/test-linear_reg_quantreg.R | 28 +++++++++++++---------- 4 files changed, 31 insertions(+), 16 deletions(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index c3371ec40..b4ac3bd0b 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -111,6 +111,7 @@ obj_print_footer.vctrs_quantiles <- function(x, ...) { restructure_rq_pred <- function(x, object) { if (!is.matrix(x)) x <- as.matrix(x) + rownames(x) <- NULL n_pred_quantiles <- ncol(x) # TODO check p = length(quantile_level) quantile_level <- object$spec$quantile_level diff --git a/man/reexports.Rd b/man/reexports.Rd index f87bde459..f051744e2 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -1,8 +1,9 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/reexports.R, R/varying.R +% Please edit documentation in R/aaa_quantiles.R, R/reexports.R, R/varying.R \docType{import} \name{reexports} \alias{reexports} +\alias{obj_print_footer} \alias{autoplot} \alias{\%>\%} \alias{fit} @@ -34,5 +35,7 @@ below to see their documentation. \item{hardhat}{\code{\link[hardhat:hardhat-extract]{extract_fit_engine}}, \code{\link[hardhat:hardhat-extract]{extract_fit_time}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_dials}}, \code{\link[hardhat:hardhat-extract]{extract_parameter_set_dials}}, \code{\link[hardhat:hardhat-extract]{extract_spec_parsnip}}, \code{\link[hardhat]{frequency_weights}}, \code{\link[hardhat]{importance_weights}}, \code{\link[hardhat]{tune}}} \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} + + \item{vctrs}{\code{\link[vctrs:obj_print]{obj_print_footer}}} }} diff --git a/man/vec_quantiles.Rd b/man/vec_quantiles.Rd index aeda8f8fa..94d8488fb 100644 --- a/man/vec_quantiles.Rd +++ b/man/vec_quantiles.Rd @@ -2,7 +2,7 @@ % Please edit documentation in R/aaa_quantiles.R \name{vec_quantiles} \alias{vec_quantiles} -\title{A vector containing sets of quantiles} +\title{Create a vector containing sets of quantiles} \usage{ vec_quantiles(values, quantile_levels = double()) } @@ -12,9 +12,16 @@ the quantile levels.} \item{quantile_levels}{A vector of probabilities corresponding to \code{values}.} } +\value{ +A vector of values associated with the quantile levels. +} \description{ -A vector containing sets of quantiles +Create a vector containing sets of quantiles } \examples{ -vec_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +v <- vec_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) + +# Access the underlying information +attr(v, "quantile_levels") +vctrs::vec_data(v) } diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index 97d310bbf..9378b58d7 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -24,9 +24,10 @@ test_that('linear quantile regression via quantreg - single quantile', { expect_true(nrow(one_quant_pred) == nrow(sac_test)) expect_named(one_quant_pred, ".pred_quantile") expect_true(is.list(one_quant_pred[[1]])) - expect_s3_class(one_quant_pred$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame")) - expect_named(one_quant_pred$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level")) - expect_true(nrow(one_quant_pred$.pred_quantile[[1]]) == 1L) + expect_s3_class(one_quant_pred$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list")) + expect_identical(class(one_quant_pred$.pred_quantile[[1]]), "numeric") + expect_true(length(one_quant_pred$.pred_quantile[[1]]) == 1L) + expect_identical(attr(one_quant_pred$.pred_quantile, "quantile_levels"), .5) ### @@ -34,9 +35,10 @@ test_that('linear quantile regression via quantreg - single quantile', { expect_true(nrow(one_quant_one_row) == 1L) expect_named(one_quant_one_row, ".pred_quantile") expect_true(is.list(one_quant_one_row[[1]])) - expect_s3_class(one_quant_one_row$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame")) - expect_named(one_quant_one_row$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level")) - expect_true(nrow(one_quant_one_row$.pred_quantile[[1]]) == 1L) + expect_s3_class(one_quant_one_row$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list")) + expect_identical(class(one_quant_one_row$.pred_quantile[[1]]), "numeric") + expect_true(length(one_quant_one_row$.pred_quantile[[1]]) == 1L) + expect_identical(attr(one_quant_pred$.pred_quantile, "quantile_levels"), .5) }) test_that('linear quantile regression via quantreg - multiple quantiles', { @@ -65,9 +67,10 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(nrow(ten_quant_pred) == nrow(sac_test)) expect_named(ten_quant_pred, ".pred_quantile") expect_true(is.list(ten_quant_pred[[1]])) - expect_s3_class(ten_quant_pred$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame")) - expect_named(ten_quant_pred$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level")) - expect_true(nrow(ten_quant_pred$.pred_quantile[[1]]) == 10L) + expect_s3_class(ten_quant_pred$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list")) + expect_identical(class(ten_quant_pred$.pred_quantile[[1]]), "numeric") + expect_true(length(ten_quant_pred$.pred_quantile[[1]]) == 10L) + expect_identical(attr(ten_quant_pred$.pred_quantile, "quantile_levels"), (0:9)/9) ### @@ -75,9 +78,10 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(nrow(ten_quant_one_row) == 1L) expect_named(ten_quant_one_row, ".pred_quantile") expect_true(is.list(ten_quant_one_row[[1]])) - expect_s3_class(ten_quant_one_row$.pred_quantile[[1]], c("tbl_df", "tbl", "data.frame")) - expect_named(ten_quant_one_row$.pred_quantile[[1]], c(".pred_quantile", ".quantile_level")) - expect_true(nrow(ten_quant_one_row$.pred_quantile[[1]]) == 10L) + expect_s3_class(ten_quant_one_row$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list")) + expect_identical(class(ten_quant_one_row$.pred_quantile[[1]]), "numeric") + expect_true(length(ten_quant_one_row$.pred_quantile[[1]]) == 10L) + expect_identical(attr(ten_quant_one_row$.pred_quantile, "quantile_levels"), (0:9)/9) }) From 32ea877f7f8c71a0b9873e7755ab634ef835de39 Mon Sep 17 00:00:00 2001 From: Daniel McDonald Date: Mon, 9 Sep 2024 13:26:28 -0700 Subject: [PATCH 04/31] Apply some of the suggestions from code review Co-authored-by: Simon P. Couch --- R/aaa_quantiles.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index b4ac3bd0b..1134b0e0e 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -31,7 +31,6 @@ vec_ptype_abbr.vctrs_quantiles <- function(x, ...) "qntls" #' @export vec_ptype_full.vctrs_quantiles <- function(x, ...) "quantiles" -#' @importFrom rlang is_double !!! new_vec_quantiles <- function(values = list(), quantile_levels = double()) { quantile_levels <- vctrs::vec_cast(quantile_levels, double()) vctrs::new_vctr( @@ -72,8 +71,7 @@ vec_quantiles <- function(values, quantile_levels = double()) { check_vec_quantiles_inputs <- function(values, levels) { if (!is.matrix(values)) { - cls <- class(values)[1] - cli::cli_abort("{.arg values} must be a {.cls matrix} not a {.cls {cls}}.") + cli::cli_abort("{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.") } purrr::walk(levels, ~ check_number_decimal(.x, min = 0, max = 1, arg = "quantile_levels") @@ -110,7 +108,9 @@ obj_print_footer.vctrs_quantiles <- function(x, ...) { } restructure_rq_pred <- function(x, object) { - if (!is.matrix(x)) x <- as.matrix(x) + if (!is.matrix(x)) { + x <- as.matrix(x) + } rownames(x) <- NULL n_pred_quantiles <- ncol(x) # TODO check p = length(quantile_level) From cc1f8de29cee1c28c861481d5df24ba176389f07 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 9 Sep 2024 13:32:43 -0700 Subject: [PATCH 05/31] rename tests on suggestion from code review --- ...st-vec_quantiles.R => test-aaa_quantiles.R} | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) rename tests/testthat/{test-vec_quantiles.R => test-aaa_quantiles.R} (53%) diff --git a/tests/testthat/test-vec_quantiles.R b/tests/testthat/test-aaa_quantiles.R similarity index 53% rename from tests/testthat/test-vec_quantiles.R rename to tests/testthat/test-aaa_quantiles.R index 758859ae9..4c96311d3 100644 --- a/tests/testthat/test-vec_quantiles.R +++ b/tests/testthat/test-aaa_quantiles.R @@ -1,16 +1,16 @@ test_that("vec_quantiles error types", { expect_error(vec_quantiles(1:10, 1:4 / 5), "matrix") - expect_error( - vec_quantiles(matrix(1:20, 5), -1:4 / 5), - "`quantile_levels` must be a number between 0 and 1" + expect_snapshot( + error = TRUE, + vec_quantiles(matrix(1:20, 5), -1:4 / 5) ) - expect_error( - vec_quantiles(matrix(1:20, 5), 1:5 / 6), - "The number of columns in `values` must be equal to" + expect_snapshot( + error = TRUE, + vec_quantiles(matrix(1:20, 5), 1:5 / 6) ) - expect_error( - vec_quantiles(matrix(1:20, 5), 4:1 / 5), - "must be sorted in increasing order" + expect_snapshot( + error = TRUE, + vec_quantiles(matrix(1:20, 5), 4:1 / 5) ) }) From 1d27996056a3f1d43993ddc9dbd3405963e8ffdc Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 9 Sep 2024 13:33:05 -0700 Subject: [PATCH 06/31] export missing funs from vctrs for formatting --- NAMESPACE | 6 ++++-- R/aaa_quantiles.R | 9 +++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index fe34c25dd..3fbdf3377 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -355,6 +355,8 @@ export(update_model_info_file) export(update_spec) export(varying) export(varying_args) +export(vec_ptype_abbr) +export(vec_ptype_full) export(vec_quantiles) export(xgb_predict) export(xgb_train) @@ -402,8 +404,6 @@ importFrom(purrr,map) importFrom(purrr,map_chr) importFrom(purrr,map_dbl) importFrom(purrr,map_lgl) -importFrom(rlang,"!!!") -importFrom(rlang,is_double) importFrom(stats,.checkMFClasses) importFrom(stats,.getXlevels) importFrom(stats,as.formula) @@ -435,5 +435,7 @@ importFrom(utils,head) importFrom(utils,methods) importFrom(utils,stack) importFrom(vctrs,obj_print_footer) +importFrom(vctrs,vec_ptype_abbr) +importFrom(vctrs,vec_ptype_full) importFrom(vctrs,vec_size) importFrom(vctrs,vec_unique) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 1134b0e0e..b1c784506 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -25,6 +25,15 @@ check_quantile_level <- function(x, object, call) { # ------------------------------------------------------------------------- # A column vector of quantiles with an attribute +#' @importFrom vctrs vec_ptype_abbr +#' @export +vctrs::vec_ptype_abbr + +#' @importFrom vctrs vec_ptype_full +#' @export +vctrs::vec_ptype_full + + #' @export vec_ptype_abbr.vctrs_quantiles <- function(x, ...) "qntls" From 4a996c2e380c55c7c8cf8bf196528e9ba9344e63 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 9 Sep 2024 13:33:36 -0700 Subject: [PATCH 07/31] convert errors to snapshot tests --- tests/testthat/_snaps/aaa_quantiles.md | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/testthat/_snaps/aaa_quantiles.md diff --git a/tests/testthat/_snaps/aaa_quantiles.md b/tests/testthat/_snaps/aaa_quantiles.md new file mode 100644 index 000000000..0ebdb7aac --- /dev/null +++ b/tests/testthat/_snaps/aaa_quantiles.md @@ -0,0 +1,26 @@ +# vec_quantiles error types + + Code + vec_quantiles(matrix(1:20, 5), -1:4 / 5) + Condition + Error in `map()`: + i In index: 1. + Caused by error in `.f()`: + ! `quantile_levels` must be a number between 0 and 1, not the number -0.2. + +--- + + Code + vec_quantiles(matrix(1:20, 5), 1:5 / 6) + Condition + Error in `vec_quantiles()`: + ! The number of columns in `values` must be equal to the length of `quantile_levels`. + +--- + + Code + vec_quantiles(matrix(1:20, 5), 4:1 / 5) + Condition + Error in `check_vec_quantiles_inputs()`: + ! `quantile_levels` must be sorted in increasing order. + From f03bcc3953c9a220f7d9357efc5787deb04efc2d Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 9 Sep 2024 13:37:09 -0700 Subject: [PATCH 08/31] pass call through input check --- R/aaa_quantiles.R | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index b1c784506..caf85b79f 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -78,15 +78,21 @@ vec_quantiles <- function(values, quantile_levels = double()) { new_vec_quantiles(values, quantile_levels) } -check_vec_quantiles_inputs <- function(values, levels) { +check_vec_quantiles_inputs <- function(values, levels, call = caller_env()) { if (!is.matrix(values)) { - cli::cli_abort("{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.") + cli::cli_abort( + "{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.", + call = call + ) } purrr::walk(levels, - ~ check_number_decimal(.x, min = 0, max = 1, arg = "quantile_levels") + ~ check_number_decimal(.x, min = 0, max = 1, arg = "quantile_levels", call = call) ) if (is.unsorted(levels)) { - cli::cli_abort("{.arg quantile_levels} must be sorted in increasing order.") + cli::cli_abort( + "{.arg quantile_levels} must be sorted in increasing order.", + call = call + ) } invisible(NULL) } From 73e43e9fd641cc7de1cc0795b12d2248d6f8c16d Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 9 Sep 2024 13:38:53 -0700 Subject: [PATCH 09/31] update snapshots for caller_env --- tests/testthat/_snaps/aaa_quantiles.md | 12 ++++++++++-- tests/testthat/test-aaa_quantiles.R | 5 ++++- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/testthat/_snaps/aaa_quantiles.md b/tests/testthat/_snaps/aaa_quantiles.md index 0ebdb7aac..a12261789 100644 --- a/tests/testthat/_snaps/aaa_quantiles.md +++ b/tests/testthat/_snaps/aaa_quantiles.md @@ -1,11 +1,19 @@ # vec_quantiles error types + Code + vec_quantiles(1:10, 1:4 / 5) + Condition + Error in `vec_quantiles()`: + ! `values` must be a , not an integer vector. + +--- + Code vec_quantiles(matrix(1:20, 5), -1:4 / 5) Condition Error in `map()`: i In index: 1. - Caused by error in `.f()`: + Caused by error in `vec_quantiles()`: ! `quantile_levels` must be a number between 0 and 1, not the number -0.2. --- @@ -21,6 +29,6 @@ Code vec_quantiles(matrix(1:20, 5), 4:1 / 5) Condition - Error in `check_vec_quantiles_inputs()`: + Error in `vec_quantiles()`: ! `quantile_levels` must be sorted in increasing order. diff --git a/tests/testthat/test-aaa_quantiles.R b/tests/testthat/test-aaa_quantiles.R index 4c96311d3..670147bad 100644 --- a/tests/testthat/test-aaa_quantiles.R +++ b/tests/testthat/test-aaa_quantiles.R @@ -1,5 +1,8 @@ test_that("vec_quantiles error types", { - expect_error(vec_quantiles(1:10, 1:4 / 5), "matrix") + expect_snapshot( + error = TRUE, + vec_quantiles(1:10, 1:4 / 5) + ) expect_snapshot( error = TRUE, vec_quantiles(matrix(1:20, 5), -1:4 / 5) From 7ca367e6dcf2a69f37f8d7227825165b39222647 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Mon, 9 Sep 2024 16:49:43 -0700 Subject: [PATCH 10/31] rename to parsnip_quantiles, add format snapshot tests --- NAMESPACE | 10 +-- R/aaa_quantiles.R | 26 +++---- ...{vec_quantiles.Rd => parsnip_quantiles.Rd} | 8 +-- man/reexports.Rd | 4 +- tests/testthat/_snaps/aaa_quantiles.md | 69 ++++++++++++++++--- tests/testthat/test-aaa_quantiles.R | 31 ++++++--- tests/testthat/test-linear_reg_quantreg.R | 25 +++++-- 7 files changed, 128 insertions(+), 45 deletions(-) rename man/{vec_quantiles.Rd => parsnip_quantiles.Rd} (77%) diff --git a/NAMESPACE b/NAMESPACE index 3fbdf3377..456e74935 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -36,7 +36,7 @@ S3method(extract_spec_parsnip,model_fit) S3method(fit,model_spec) S3method(fit_xy,gen_additive_mod) S3method(fit_xy,model_spec) -S3method(format,vctrs_quantiles) +S3method(format,parsnip_quantiles) S3method(glance,model_fit) S3method(has_multi_predict,default) S3method(has_multi_predict,model_fit) @@ -55,7 +55,7 @@ S3method(multi_predict_args,default) S3method(multi_predict_args,model_fit) S3method(multi_predict_args,workflow) S3method(nullmodel,default) -S3method(obj_print_footer,vctrs_quantiles) +S3method(obj_print_footer,parsnip_quantiles) S3method(predict,"_elnet") S3method(predict,"_glmnetfit") S3method(predict,"_lognet") @@ -174,8 +174,8 @@ S3method(update,svm_rbf) S3method(varying_args,model_spec) S3method(varying_args,recipe) S3method(varying_args,step) -S3method(vec_ptype_abbr,vctrs_quantiles) -S3method(vec_ptype_full,vctrs_quantiles) +S3method(vec_ptype_abbr,parsnip_quantiles) +S3method(vec_ptype_full,parsnip_quantiles) export("%>%") export(.censoring_weights_graf) export(.check_glmnet_penalty_fit) @@ -286,6 +286,7 @@ export(null_value) export(nullmodel) export(obj_print_footer) export(parsnip_addin) +export(parsnip_quantiles) export(pls) export(poisson_reg) export(pred_value_template) @@ -357,7 +358,6 @@ export(varying) export(varying_args) export(vec_ptype_abbr) export(vec_ptype_full) -export(vec_quantiles) export(xgb_predict) export(xgb_train) import(rlang) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index caf85b79f..3c1415ac3 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -35,15 +35,15 @@ vctrs::vec_ptype_full #' @export -vec_ptype_abbr.vctrs_quantiles <- function(x, ...) "qntls" +vec_ptype_abbr.parsnip_quantiles <- function(x, ...) "qntls" #' @export -vec_ptype_full.vctrs_quantiles <- function(x, ...) "quantiles" +vec_ptype_full.parsnip_quantiles <- function(x, ...) "quantiles" -new_vec_quantiles <- function(values = list(), quantile_levels = double()) { +new_parsnip_quantiles <- function(values = list(), quantile_levels = double()) { quantile_levels <- vctrs::vec_cast(quantile_levels, double()) vctrs::new_vctr( - values, quantile_levels = quantile_levels, class = "vctrs_quantiles" + values, quantile_levels = quantile_levels, class = "parsnip_quantiles" ) } @@ -58,13 +58,13 @@ new_vec_quantiles <- function(values = list(), quantile_levels = double()) { #' @return A vector of values associated with the quantile levels. #' #' @examples -#' v <- vec_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +#' v <- parsnip_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) #' #' # Access the underlying information #' attr(v, "quantile_levels") #' vctrs::vec_data(v) -vec_quantiles <- function(values, quantile_levels = double()) { - check_vec_quantiles_inputs(values, quantile_levels) +parsnip_quantiles <- function(values, quantile_levels = double()) { + check_parsnip_quantiles_inputs(values, quantile_levels) quantile_levels <- vctrs::vec_cast(quantile_levels, double()) num_lvls <- length(quantile_levels) @@ -75,10 +75,10 @@ vec_quantiles <- function(values, quantile_levels = double()) { ) } values <- lapply(vctrs::vec_chop(values), drop) - new_vec_quantiles(values, quantile_levels) + new_parsnip_quantiles(values, quantile_levels) } -check_vec_quantiles_inputs <- function(values, levels, call = caller_env()) { +check_parsnip_quantiles_inputs <- function(values, levels, call = caller_env()) { if (!is.matrix(values)) { cli::cli_abort( "{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.", @@ -98,8 +98,8 @@ check_vec_quantiles_inputs <- function(values, levels, call = caller_env()) { } #' @export -format.vctrs_quantiles <- function(x, ...) { - quantile_levels <- attr(x, "levels") +format.parsnip_quantiles <- function(x, ...) { + quantile_levels <- attr(x, "quantile_levels") if (length(quantile_levels) == 1L) { x <- unlist(x) out <- round(x, 3L) @@ -117,7 +117,7 @@ format.vctrs_quantiles <- function(x, ...) { vctrs::obj_print_footer #' @export -obj_print_footer.vctrs_quantiles <- function(x, ...) { +obj_print_footer.parsnip_quantiles <- function(x, ...) { lvls <- attr(x, "quantile_levels") cat("# Quantile levels: ", format(lvls, digits = 3), "\n", sep = " ") } @@ -130,6 +130,6 @@ restructure_rq_pred <- function(x, object) { n_pred_quantiles <- ncol(x) # TODO check p = length(quantile_level) quantile_level <- object$spec$quantile_level - tibble::tibble(.pred_quantile = vec_quantiles(x, quantile_level)) + tibble::tibble(.pred_quantile = parsnip_quantiles(x, quantile_level)) } diff --git a/man/vec_quantiles.Rd b/man/parsnip_quantiles.Rd similarity index 77% rename from man/vec_quantiles.Rd rename to man/parsnip_quantiles.Rd index 94d8488fb..cbd859db4 100644 --- a/man/vec_quantiles.Rd +++ b/man/parsnip_quantiles.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/aaa_quantiles.R -\name{vec_quantiles} -\alias{vec_quantiles} +\name{parsnip_quantiles} +\alias{parsnip_quantiles} \title{Create a vector containing sets of quantiles} \usage{ -vec_quantiles(values, quantile_levels = double()) +parsnip_quantiles(values, quantile_levels = double()) } \arguments{ \item{values}{A matrix of values. Each column should correspond to one of @@ -19,7 +19,7 @@ A vector of values associated with the quantile levels. Create a vector containing sets of quantiles } \examples{ -v <- vec_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +v <- parsnip_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) # Access the underlying information attr(v, "quantile_levels") diff --git a/man/reexports.Rd b/man/reexports.Rd index f051744e2..13baaa850 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -3,6 +3,8 @@ \docType{import} \name{reexports} \alias{reexports} +\alias{vec_ptype_abbr} +\alias{vec_ptype_full} \alias{obj_print_footer} \alias{autoplot} \alias{\%>\%} @@ -36,6 +38,6 @@ below to see their documentation. \item{magrittr}{\code{\link[magrittr:pipe]{\%>\%}}} - \item{vctrs}{\code{\link[vctrs:obj_print]{obj_print_footer}}} + \item{vctrs}{\code{\link[vctrs:obj_print]{obj_print_footer}}, \code{\link[vctrs:vec_ptype_full]{vec_ptype_abbr}}, \code{\link[vctrs]{vec_ptype_full}}} }} diff --git a/tests/testthat/_snaps/aaa_quantiles.md b/tests/testthat/_snaps/aaa_quantiles.md index a12261789..1a659ce0c 100644 --- a/tests/testthat/_snaps/aaa_quantiles.md +++ b/tests/testthat/_snaps/aaa_quantiles.md @@ -1,34 +1,85 @@ -# vec_quantiles error types +# parsnip_quantiles error types Code - vec_quantiles(1:10, 1:4 / 5) + parsnip_quantiles(1:10, 1:4 / 5) Condition - Error in `vec_quantiles()`: + Error in `parsnip_quantiles()`: ! `values` must be a , not an integer vector. --- Code - vec_quantiles(matrix(1:20, 5), -1:4 / 5) + parsnip_quantiles(matrix(1:20, 5), -1:4 / 5) Condition Error in `map()`: i In index: 1. - Caused by error in `vec_quantiles()`: + Caused by error in `parsnip_quantiles()`: ! `quantile_levels` must be a number between 0 and 1, not the number -0.2. --- Code - vec_quantiles(matrix(1:20, 5), 1:5 / 6) + parsnip_quantiles(matrix(1:20, 5), 1:5 / 6) Condition - Error in `vec_quantiles()`: + Error in `parsnip_quantiles()`: ! The number of columns in `values` must be equal to the length of `quantile_levels`. --- Code - vec_quantiles(matrix(1:20, 5), 4:1 / 5) + parsnip_quantiles(matrix(1:20, 5), 4:1 / 5) Condition - Error in `vec_quantiles()`: + Error in `parsnip_quantiles()`: ! `quantile_levels` must be sorted in increasing order. +# parsnip_quantiles formatting + + Code + print(v) + Output + + [1] [1, 16] [2, 17] [3, 18] [4, 19] [5, 20] + # Quantile levels: 0.2 0.4 0.6 0.8 + +--- + + Code + print(parsnip_quantiles(matrix(1:18, 9), c(1 / 3, 2 / 3))) + Output + + [1] [1, 10] [2, 11] [3, 12] [4, 13] [5, 14] [6, 15] [7, 16] [8, 17] [9, 18] + # Quantile levels: 0.333 0.667 + +--- + + Code + print(parsnip_quantiles(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(0.2, + 0.8))) + Output + + [1] [0.01, 0.598] [0.206, 0.794] [0.402, 0.99] + # Quantile levels: 0.2 0.8 + +--- + + Code + print(tibble(qntls = v)) + Output + # A tibble: 5 x 1 + qntls + + 1 [1, 16] + 2 [2, 17] + 3 [3, 18] + 4 [4, 19] + 5 [5, 20] + +--- + + Code + print(parsnip_quantiles(m, 1:4 / 5)) + Output + + [1] [1, 16] [3, 18] [5, 20] + # Quantile levels: 0.2 0.4 0.6 0.8 + diff --git a/tests/testthat/test-aaa_quantiles.R b/tests/testthat/test-aaa_quantiles.R index 670147bad..acf8f6ea4 100644 --- a/tests/testthat/test-aaa_quantiles.R +++ b/tests/testthat/test-aaa_quantiles.R @@ -1,28 +1,43 @@ -test_that("vec_quantiles error types", { +test_that("parsnip_quantiles error types", { expect_snapshot( error = TRUE, - vec_quantiles(1:10, 1:4 / 5) + parsnip_quantiles(1:10, 1:4 / 5) ) expect_snapshot( error = TRUE, - vec_quantiles(matrix(1:20, 5), -1:4 / 5) + parsnip_quantiles(matrix(1:20, 5), -1:4 / 5) ) expect_snapshot( error = TRUE, - vec_quantiles(matrix(1:20, 5), 1:5 / 6) + parsnip_quantiles(matrix(1:20, 5), 1:5 / 6) ) expect_snapshot( error = TRUE, - vec_quantiles(matrix(1:20, 5), 4:1 / 5) + parsnip_quantiles(matrix(1:20, 5), 4:1 / 5) ) }) -test_that("vec_quantiles outputs", { - v <- vec_quantiles(matrix(1:20, 5), 1:4 / 5) - expect_s3_class(v, "vctrs_quantiles") +test_that("parsnip_quantiles outputs", { + v <- parsnip_quantiles(matrix(1:20, 5), 1:4 / 5) + expect_s3_class(v, "parsnip_quantiles") expect_identical(attr(v, "quantile_levels"), 1:4 / 5) expect_identical( vctrs::vec_data(v), lapply(vctrs::vec_chop(matrix(1:20, 5)), drop) ) }) + +test_that("parsnip_quantiles formatting", { + v <- parsnip_quantiles(matrix(1:20, 5), 1:4 / 5) + expect_snapshot(print(v)) + expect_snapshot(print(parsnip_quantiles(matrix(1:18, 9), c(1/3, 2/3)))) + expect_snapshot(print( + parsnip_quantiles(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(.2, .8)) + )) + expect_snapshot(print(tibble(qntls = v))) + m <- matrix(1:20, 5) + m[2, 3] <- NA + m[4, 2] <- NA + expect_snapshot(print(parsnip_quantiles(m, 1:4 / 5))) + +}) diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index 9378b58d7..91ff51fb6 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -24,7 +24,10 @@ test_that('linear quantile regression via quantreg - single quantile', { expect_true(nrow(one_quant_pred) == nrow(sac_test)) expect_named(one_quant_pred, ".pred_quantile") expect_true(is.list(one_quant_pred[[1]])) - expect_s3_class(one_quant_pred$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list")) + expect_s3_class( + one_quant_pred$.pred_quantile[1], + c("parsnip_quantiles", "vctrs_vctr", "list") + ) expect_identical(class(one_quant_pred$.pred_quantile[[1]]), "numeric") expect_true(length(one_quant_pred$.pred_quantile[[1]]) == 1L) expect_identical(attr(one_quant_pred$.pred_quantile, "quantile_levels"), .5) @@ -35,7 +38,10 @@ test_that('linear quantile regression via quantreg - single quantile', { expect_true(nrow(one_quant_one_row) == 1L) expect_named(one_quant_one_row, ".pred_quantile") expect_true(is.list(one_quant_one_row[[1]])) - expect_s3_class(one_quant_one_row$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list")) + expect_s3_class( + one_quant_one_row$.pred_quantile[1], + c("parsnip_quantiles", "vctrs_vctr", "list") + ) expect_identical(class(one_quant_one_row$.pred_quantile[[1]]), "numeric") expect_true(length(one_quant_one_row$.pred_quantile[[1]]) == 1L) expect_identical(attr(one_quant_pred$.pred_quantile, "quantile_levels"), .5) @@ -67,7 +73,10 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(nrow(ten_quant_pred) == nrow(sac_test)) expect_named(ten_quant_pred, ".pred_quantile") expect_true(is.list(ten_quant_pred[[1]])) - expect_s3_class(ten_quant_pred$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list")) + expect_s3_class( + ten_quant_pred$.pred_quantile[1], + c("parsnip_quantiles", "vctrs_vctr", "list") + ) expect_identical(class(ten_quant_pred$.pred_quantile[[1]]), "numeric") expect_true(length(ten_quant_pred$.pred_quantile[[1]]) == 10L) expect_identical(attr(ten_quant_pred$.pred_quantile, "quantile_levels"), (0:9)/9) @@ -78,10 +87,16 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(nrow(ten_quant_one_row) == 1L) expect_named(ten_quant_one_row, ".pred_quantile") expect_true(is.list(ten_quant_one_row[[1]])) - expect_s3_class(ten_quant_one_row$.pred_quantile[1], c("vctrs_quantiles", "vctrs_vctr", "list")) + expect_s3_class( + ten_quant_one_row$.pred_quantile[1], + c("parsnip_quantiles", "vctrs_vctr", "list") + ) expect_identical(class(ten_quant_one_row$.pred_quantile[[1]]), "numeric") expect_true(length(ten_quant_one_row$.pred_quantile[[1]]) == 10L) - expect_identical(attr(ten_quant_one_row$.pred_quantile, "quantile_levels"), (0:9)/9) + expect_identical( + attr(ten_quant_one_row$.pred_quantile, "quantile_levels"), + (0:9)/9 + ) }) From 49cc02eb362748d75d86be5b9e55a109bea73760 Mon Sep 17 00:00:00 2001 From: Daniel McDonald Date: Tue, 10 Sep 2024 06:51:23 -0700 Subject: [PATCH 11/31] Apply suggestions from @topepo Co-authored-by: Max Kuhn --- R/aaa_quantiles.R | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 3c1415ac3..3afb0fc07 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -105,7 +105,7 @@ format.parsnip_quantiles <- function(x, ...) { out <- round(x, 3L) out[is.na(x)] <- NA } else { - rng <- sapply(x, range) + rng <- sapply(x, range, na.rm = TRUE) out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]") out[is.na(rng[1, ]) | is.na(rng[2, ])] <- NA } @@ -117,9 +117,9 @@ format.parsnip_quantiles <- function(x, ...) { vctrs::obj_print_footer #' @export -obj_print_footer.parsnip_quantiles <- function(x, ...) { +obj_print_footer.parsnip_quantiles <- function(x, digits = 3, ...) { lvls <- attr(x, "quantile_levels") - cat("# Quantile levels: ", format(lvls, digits = 3), "\n", sep = " ") + cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ") } restructure_rq_pred <- function(x, object) { From 3ff693055438e9553a31574235e8abb534d89339 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 10 Sep 2024 07:11:13 -0700 Subject: [PATCH 12/31] rename parsnip_quantiles to quantile_pred --- NAMESPACE | 10 +++---- ...{parsnip_quantiles.Rd => quantile_pred.Rd} | 8 +++--- tests/testthat/test-aaa_quantiles.R | 26 +++++++++---------- tests/testthat/test-linear_reg_quantreg.R | 8 +++--- 4 files changed, 26 insertions(+), 26 deletions(-) rename man/{parsnip_quantiles.Rd => quantile_pred.Rd} (77%) diff --git a/NAMESPACE b/NAMESPACE index 456e74935..5d40f4fb9 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -36,7 +36,7 @@ S3method(extract_spec_parsnip,model_fit) S3method(fit,model_spec) S3method(fit_xy,gen_additive_mod) S3method(fit_xy,model_spec) -S3method(format,parsnip_quantiles) +S3method(format,quantile_pred) S3method(glance,model_fit) S3method(has_multi_predict,default) S3method(has_multi_predict,model_fit) @@ -55,7 +55,7 @@ S3method(multi_predict_args,default) S3method(multi_predict_args,model_fit) S3method(multi_predict_args,workflow) S3method(nullmodel,default) -S3method(obj_print_footer,parsnip_quantiles) +S3method(obj_print_footer,quantile_pred) S3method(predict,"_elnet") S3method(predict,"_glmnetfit") S3method(predict,"_lognet") @@ -174,8 +174,8 @@ S3method(update,svm_rbf) S3method(varying_args,model_spec) S3method(varying_args,recipe) S3method(varying_args,step) -S3method(vec_ptype_abbr,parsnip_quantiles) -S3method(vec_ptype_full,parsnip_quantiles) +S3method(vec_ptype_abbr,quantile_pred) +S3method(vec_ptype_full,quantile_pred) export("%>%") export(.censoring_weights_graf) export(.check_glmnet_penalty_fit) @@ -286,7 +286,6 @@ export(null_value) export(nullmodel) export(obj_print_footer) export(parsnip_addin) -export(parsnip_quantiles) export(pls) export(poisson_reg) export(pred_value_template) @@ -313,6 +312,7 @@ export(prepare_data) export(print_model_spec) export(prompt_missing_implementation) export(proportional_hazards) +export(quantile_pred) export(rand_forest) export(repair_call) export(req_pkgs) diff --git a/man/parsnip_quantiles.Rd b/man/quantile_pred.Rd similarity index 77% rename from man/parsnip_quantiles.Rd rename to man/quantile_pred.Rd index cbd859db4..bf8f20c03 100644 --- a/man/parsnip_quantiles.Rd +++ b/man/quantile_pred.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand % Please edit documentation in R/aaa_quantiles.R -\name{parsnip_quantiles} -\alias{parsnip_quantiles} +\name{quantile_pred} +\alias{quantile_pred} \title{Create a vector containing sets of quantiles} \usage{ -parsnip_quantiles(values, quantile_levels = double()) +quantile_pred(values, quantile_levels = double()) } \arguments{ \item{values}{A matrix of values. Each column should correspond to one of @@ -19,7 +19,7 @@ A vector of values associated with the quantile levels. Create a vector containing sets of quantiles } \examples{ -v <- parsnip_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +v <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) # Access the underlying information attr(v, "quantile_levels") diff --git a/tests/testthat/test-aaa_quantiles.R b/tests/testthat/test-aaa_quantiles.R index acf8f6ea4..c3c46a8f1 100644 --- a/tests/testthat/test-aaa_quantiles.R +++ b/tests/testthat/test-aaa_quantiles.R @@ -1,25 +1,25 @@ -test_that("parsnip_quantiles error types", { +test_that("quantile_pred error types", { expect_snapshot( error = TRUE, - parsnip_quantiles(1:10, 1:4 / 5) + quantile_pred(1:10, 1:4 / 5) ) expect_snapshot( error = TRUE, - parsnip_quantiles(matrix(1:20, 5), -1:4 / 5) + quantile_pred(matrix(1:20, 5), -1:4 / 5) ) expect_snapshot( error = TRUE, - parsnip_quantiles(matrix(1:20, 5), 1:5 / 6) + quantile_pred(matrix(1:20, 5), 1:5 / 6) ) expect_snapshot( error = TRUE, - parsnip_quantiles(matrix(1:20, 5), 4:1 / 5) + quantile_pred(matrix(1:20, 5), 4:1 / 5) ) }) -test_that("parsnip_quantiles outputs", { - v <- parsnip_quantiles(matrix(1:20, 5), 1:4 / 5) - expect_s3_class(v, "parsnip_quantiles") +test_that("quantile_pred outputs", { + v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) + expect_s3_class(v, "quantile_pred") expect_identical(attr(v, "quantile_levels"), 1:4 / 5) expect_identical( vctrs::vec_data(v), @@ -27,17 +27,17 @@ test_that("parsnip_quantiles outputs", { ) }) -test_that("parsnip_quantiles formatting", { - v <- parsnip_quantiles(matrix(1:20, 5), 1:4 / 5) +test_that("quantile_pred formatting", { + v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) expect_snapshot(print(v)) - expect_snapshot(print(parsnip_quantiles(matrix(1:18, 9), c(1/3, 2/3)))) + expect_snapshot(print(quantile_pred(matrix(1:18, 9), c(1/3, 2/3)))) expect_snapshot(print( - parsnip_quantiles(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(.2, .8)) + quantile_pred(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(.2, .8)) )) expect_snapshot(print(tibble(qntls = v))) m <- matrix(1:20, 5) m[2, 3] <- NA m[4, 2] <- NA - expect_snapshot(print(parsnip_quantiles(m, 1:4 / 5))) + expect_snapshot(print(quantile_pred(m, 1:4 / 5))) }) diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index 91ff51fb6..9c3ac41d7 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -26,7 +26,7 @@ test_that('linear quantile regression via quantreg - single quantile', { expect_true(is.list(one_quant_pred[[1]])) expect_s3_class( one_quant_pred$.pred_quantile[1], - c("parsnip_quantiles", "vctrs_vctr", "list") + c("quantile_pred", "vctrs_vctr", "list") ) expect_identical(class(one_quant_pred$.pred_quantile[[1]]), "numeric") expect_true(length(one_quant_pred$.pred_quantile[[1]]) == 1L) @@ -40,7 +40,7 @@ test_that('linear quantile regression via quantreg - single quantile', { expect_true(is.list(one_quant_one_row[[1]])) expect_s3_class( one_quant_one_row$.pred_quantile[1], - c("parsnip_quantiles", "vctrs_vctr", "list") + c("quantile_pred", "vctrs_vctr", "list") ) expect_identical(class(one_quant_one_row$.pred_quantile[[1]]), "numeric") expect_true(length(one_quant_one_row$.pred_quantile[[1]]) == 1L) @@ -75,7 +75,7 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(is.list(ten_quant_pred[[1]])) expect_s3_class( ten_quant_pred$.pred_quantile[1], - c("parsnip_quantiles", "vctrs_vctr", "list") + c("quantile_pred", "vctrs_vctr", "list") ) expect_identical(class(ten_quant_pred$.pred_quantile[[1]]), "numeric") expect_true(length(ten_quant_pred$.pred_quantile[[1]]) == 10L) @@ -89,7 +89,7 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(is.list(ten_quant_one_row[[1]])) expect_s3_class( ten_quant_one_row$.pred_quantile[1], - c("parsnip_quantiles", "vctrs_vctr", "list") + c("quantile_pred", "vctrs_vctr", "list") ) expect_identical(class(ten_quant_one_row$.pred_quantile[[1]]), "numeric") expect_true(length(ten_quant_one_row$.pred_quantile[[1]]) == 10L) From 8e601c5be8a1f0b822b2e2c81963131aa7f21e32 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 10 Sep 2024 07:11:48 -0700 Subject: [PATCH 13/31] rename parsnip_quantiles to quantile_pred and add vector probability check --- R/aaa_quantiles.R | 54 +++++++++++++++++++++++++++-------------------- 1 file changed, 31 insertions(+), 23 deletions(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 3afb0fc07..0c0189942 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -12,12 +12,7 @@ check_quantile_level <- function(x, object, call) { x <- sort(unique(x)) # TODO we need better vectorization here, otherwise we get things like: # "Error during wrapup: i In index: 2." in the traceback. - res <- - purrr::map(x, - ~ check_number_decimal(.x, min = 0, max = 1, - arg = "quantile_level", call = call, - allow_infinite = FALSE) - ) + check_vector_probability(x, arg = "quantile_level", call = call) x } @@ -35,15 +30,15 @@ vctrs::vec_ptype_full #' @export -vec_ptype_abbr.parsnip_quantiles <- function(x, ...) "qntls" +vec_ptype_abbr.quantile_pred <- function(x, ...) "qntls" #' @export -vec_ptype_full.parsnip_quantiles <- function(x, ...) "quantiles" +vec_ptype_full.quantile_pred <- function(x, ...) "quantiles" -new_parsnip_quantiles <- function(values = list(), quantile_levels = double()) { +new_quantile_pred <- function(values = list(), quantile_levels = double()) { quantile_levels <- vctrs::vec_cast(quantile_levels, double()) vctrs::new_vctr( - values, quantile_levels = quantile_levels, class = "parsnip_quantiles" + values, quantile_levels = quantile_levels, class = "quantile_pred" ) } @@ -58,13 +53,13 @@ new_parsnip_quantiles <- function(values = list(), quantile_levels = double()) { #' @return A vector of values associated with the quantile levels. #' #' @examples -#' v <- parsnip_quantiles(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +#' v <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) #' #' # Access the underlying information #' attr(v, "quantile_levels") #' vctrs::vec_data(v) -parsnip_quantiles <- function(values, quantile_levels = double()) { - check_parsnip_quantiles_inputs(values, quantile_levels) +quantile_pred <- function(values, quantile_levels = double()) { + check_quantile_pred_inputs(values, quantile_levels) quantile_levels <- vctrs::vec_cast(quantile_levels, double()) num_lvls <- length(quantile_levels) @@ -75,19 +70,18 @@ parsnip_quantiles <- function(values, quantile_levels = double()) { ) } values <- lapply(vctrs::vec_chop(values), drop) - new_parsnip_quantiles(values, quantile_levels) + new_quantile_pred(values, quantile_levels) } -check_parsnip_quantiles_inputs <- function(values, levels, call = caller_env()) { +check_quantile_pred_inputs <- function(values, levels, call = caller_env()) { if (!is.matrix(values)) { cli::cli_abort( "{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.", call = call ) } - purrr::walk(levels, - ~ check_number_decimal(.x, min = 0, max = 1, arg = "quantile_levels", call = call) - ) + check_vector_probability(values, arg = "quantile_levels", call = call) + if (is.unsorted(levels)) { cli::cli_abort( "{.arg quantile_levels} must be sorted in increasing order.", @@ -98,16 +92,16 @@ check_parsnip_quantiles_inputs <- function(values, levels, call = caller_env()) } #' @export -format.parsnip_quantiles <- function(x, ...) { +format.quantile_pred <- function(x, ...) { quantile_levels <- attr(x, "quantile_levels") if (length(quantile_levels) == 1L) { x <- unlist(x) out <- round(x, 3L) - out[is.na(x)] <- NA + out[is.na(x)] <- NA_character_ } else { rng <- sapply(x, range, na.rm = TRUE) out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]") - out[is.na(rng[1, ]) | is.na(rng[2, ])] <- NA + out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_ } out } @@ -117,11 +111,25 @@ format.parsnip_quantiles <- function(x, ...) { vctrs::obj_print_footer #' @export -obj_print_footer.parsnip_quantiles <- function(x, digits = 3, ...) { +obj_print_footer.quantile_pred <- function(x, digits = 3, ...) { lvls <- attr(x, "quantile_levels") cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ") } +check_vector_probability <- function(x, ..., + allow_na = FALSE, + allow_null = FALSE, + arg = caller_arg(x), + call = caller_env()) { + purrr::walk(x, ~ check_number_decimal( + .x, min = 0, max = 1, + arg = arg, call = call, + allow_na = allow_na, + allow_null = allow_null, + allow_infinite = FALSE + )) +} + restructure_rq_pred <- function(x, object) { if (!is.matrix(x)) { x <- as.matrix(x) @@ -130,6 +138,6 @@ restructure_rq_pred <- function(x, object) { n_pred_quantiles <- ncol(x) # TODO check p = length(quantile_level) quantile_level <- object$spec$quantile_level - tibble::tibble(.pred_quantile = parsnip_quantiles(x, quantile_level)) + tibble::tibble(.pred_quantile = quantile_pred(x, quantile_level)) } From f4c90ca064d2bc7c3d6d30dd48eacac5989a2ad4 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 10 Sep 2024 07:25:28 -0700 Subject: [PATCH 14/31] fix: two bugs introduced earlier --- R/aaa_quantiles.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 0c0189942..879eac9fc 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -80,7 +80,7 @@ check_quantile_pred_inputs <- function(values, levels, call = caller_env()) { call = call ) } - check_vector_probability(values, arg = "quantile_levels", call = call) + check_vector_probability(levels, arg = "quantile_levels", call = call) if (is.unsorted(levels)) { cli::cli_abort( @@ -97,7 +97,7 @@ format.quantile_pred <- function(x, ...) { if (length(quantile_levels) == 1L) { x <- unlist(x) out <- round(x, 3L) - out[is.na(x)] <- NA_character_ + out[is.na(x)] <- NA_real_ } else { rng <- sapply(x, range, na.rm = TRUE) out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]") From 13b6010dea3245eaded0aa6986410095da6ee41e Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 10 Sep 2024 07:25:46 -0700 Subject: [PATCH 15/31] add formatting tests for single quantile --- tests/testthat/_snaps/aaa_quantiles.md | 61 +++++++++++++++++++------- tests/testthat/test-aaa_quantiles.R | 8 ++++ 2 files changed, 54 insertions(+), 15 deletions(-) diff --git a/tests/testthat/_snaps/aaa_quantiles.md b/tests/testthat/_snaps/aaa_quantiles.md index 1a659ce0c..f8db2b326 100644 --- a/tests/testthat/_snaps/aaa_quantiles.md +++ b/tests/testthat/_snaps/aaa_quantiles.md @@ -1,38 +1,38 @@ -# parsnip_quantiles error types +# quantile_pred error types Code - parsnip_quantiles(1:10, 1:4 / 5) + quantile_pred(1:10, 1:4 / 5) Condition - Error in `parsnip_quantiles()`: + Error in `quantile_pred()`: ! `values` must be a , not an integer vector. --- Code - parsnip_quantiles(matrix(1:20, 5), -1:4 / 5) + quantile_pred(matrix(1:20, 5), -1:4 / 5) Condition Error in `map()`: i In index: 1. - Caused by error in `parsnip_quantiles()`: + Caused by error in `quantile_pred()`: ! `quantile_levels` must be a number between 0 and 1, not the number -0.2. --- Code - parsnip_quantiles(matrix(1:20, 5), 1:5 / 6) + quantile_pred(matrix(1:20, 5), 1:5 / 6) Condition - Error in `parsnip_quantiles()`: + Error in `quantile_pred()`: ! The number of columns in `values` must be equal to the length of `quantile_levels`. --- Code - parsnip_quantiles(matrix(1:20, 5), 4:1 / 5) + quantile_pred(matrix(1:20, 5), 4:1 / 5) Condition - Error in `parsnip_quantiles()`: + Error in `quantile_pred()`: ! `quantile_levels` must be sorted in increasing order. -# parsnip_quantiles formatting +# quantile_pred formatting Code print(v) @@ -44,7 +44,7 @@ --- Code - print(parsnip_quantiles(matrix(1:18, 9), c(1 / 3, 2 / 3))) + print(quantile_pred(matrix(1:18, 9), c(1 / 3, 2 / 3))) Output [1] [1, 10] [2, 11] [3, 12] [4, 13] [5, 14] [6, 15] [7, 16] [8, 17] [9, 18] @@ -53,8 +53,7 @@ --- Code - print(parsnip_quantiles(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(0.2, - 0.8))) + print(quantile_pred(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(0.2, 0.8))) Output [1] [0.01, 0.598] [0.206, 0.794] [0.402, 0.99] @@ -77,9 +76,41 @@ --- Code - print(parsnip_quantiles(m, 1:4 / 5)) + print(quantile_pred(m, 1:4 / 5)) Output - [1] [1, 16] [3, 18] [5, 20] + [1] [1, 16] [2, 17] [3, 18] [4, 19] [5, 20] # Quantile levels: 0.2 0.4 0.6 0.8 +--- + + Code + print(one_quantile) + Output + + [1] 1 2 3 4 5 + # Quantile levels: 0.556 + +--- + + Code + print(tibble(qntls = one_quantile)) + Output + # A tibble: 5 x 1 + qntls + + 1 1 + 2 2 + 3 3 + 4 4 + 5 5 + +--- + + Code + print(quantile_pred(m, 5 / 9)) + Output + + [1] 1 NA 3 4 5 + # Quantile levels: 0.556 + diff --git a/tests/testthat/test-aaa_quantiles.R b/tests/testthat/test-aaa_quantiles.R index c3c46a8f1..da19f0819 100644 --- a/tests/testthat/test-aaa_quantiles.R +++ b/tests/testthat/test-aaa_quantiles.R @@ -28,6 +28,7 @@ test_that("quantile_pred outputs", { }) test_that("quantile_pred formatting", { + # multiple quantiles v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) expect_snapshot(print(v)) expect_snapshot(print(quantile_pred(matrix(1:18, 9), c(1/3, 2/3)))) @@ -40,4 +41,11 @@ test_that("quantile_pred formatting", { m[4, 2] <- NA expect_snapshot(print(quantile_pred(m, 1:4 / 5))) + # single quantile + m <- matrix(1:5) + one_quantile <- quantile_pred(m, 5/9) + expect_snapshot(print(one_quantile)) + expect_snapshot(print(tibble(qntls = one_quantile))) + m[2] <- NA + expect_snapshot(print(quantile_pred(m, 5/9))) }) From f3ac33e61f28a292c714f82a1fcc8b07fb5239dd Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 10 Sep 2024 07:33:24 -0700 Subject: [PATCH 16/31] replace walk with a loop to avoid "Error in map()" --- R/aaa_quantiles.R | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 879eac9fc..fc029da63 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -121,13 +121,15 @@ check_vector_probability <- function(x, ..., allow_null = FALSE, arg = caller_arg(x), call = caller_env()) { - purrr::walk(x, ~ check_number_decimal( - .x, min = 0, max = 1, - arg = arg, call = call, - allow_na = allow_na, - allow_null = allow_null, - allow_infinite = FALSE - )) + for (d in x) { + check_number_decimal( + d, min = 0, max = 1, + arg = arg, call = call, + allow_na = allow_na, + allow_null = allow_null, + allow_infinite = FALSE + ) + } } restructure_rq_pred <- function(x, object) { From 7ffcb3886d99ed0da659fafca285c4b7b516d207 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 10 Sep 2024 09:00:26 -0700 Subject: [PATCH 17/31] remove row/col names --- R/aaa_quantiles.R | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index fc029da63..80a772874 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -57,7 +57,7 @@ new_quantile_pred <- function(values = list(), quantile_levels = double()) { #' #' # Access the underlying information #' attr(v, "quantile_levels") -#' vctrs::vec_data(v) +#' unclass(v) quantile_pred <- function(values, quantile_levels = double()) { check_quantile_pred_inputs(values, quantile_levels) quantile_levels <- vctrs::vec_cast(quantile_levels, double()) @@ -69,6 +69,8 @@ quantile_pred <- function(values, quantile_levels = double()) { {.arg quantile_levels}." ) } + rownames(values) <- NULL + colnames(values) <- NULL values <- lapply(vctrs::vec_chop(values), drop) new_quantile_pred(values, quantile_levels) } From 90655c96f9e93ece3baa3a4ed2b91293c0a4c2c5 Mon Sep 17 00:00:00 2001 From: "Daniel J. McDonald" Date: Tue, 10 Sep 2024 14:57:58 -0700 Subject: [PATCH 18/31] adjust quantile_pred format --- NAMESPACE | 1 + R/aaa_quantiles.R | 20 ++++++++++++- tests/testthat/_snaps/aaa_quantiles.md | 40 ++++++++++++-------------- 3 files changed, 39 insertions(+), 22 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 5d40f4fb9..f830f0d3d 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -41,6 +41,7 @@ S3method(glance,model_fit) S3method(has_multi_predict,default) S3method(has_multi_predict,model_fit) S3method(has_multi_predict,workflow) +S3method(median,quantile_pred) S3method(multi_predict,"_C5.0") S3method(multi_predict,"_earth") S3method(multi_predict,"_elnet") diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 80a772874..7b301ec1c 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -30,7 +30,10 @@ vctrs::vec_ptype_full #' @export -vec_ptype_abbr.quantile_pred <- function(x, ...) "qntls" +vec_ptype_abbr.quantile_pred <- function(x, ...) { + n_lvls <- length(attr(x, "quantile_levels")) + cli::format_inline("qtl{?s}({n_lvls})") +} #' @export vec_ptype_full.quantile_pred <- function(x, ...) "quantiles" @@ -104,6 +107,8 @@ format.quantile_pred <- function(x, ...) { rng <- sapply(x, range, na.rm = TRUE) out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]") out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_ + m <- median(x) + out <- paste0("[", round(m, 3L), "]") } out } @@ -134,6 +139,19 @@ check_vector_probability <- function(x, ..., } } +#' @export +median.quantile_pred <- function(x, ...) { + lvls <- attr(x, "quantile_levels") + loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps)) + if (any(loc_median)) { + return(map_dbl(x, ~ .x[min(which(loc_median))])) + } + if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) { + return(rep(NA, vctrs::vec_size(x))) + } + map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y) +} + restructure_rq_pred <- function(x, object) { if (!is.matrix(x)) { x <- as.matrix(x) diff --git a/tests/testthat/_snaps/aaa_quantiles.md b/tests/testthat/_snaps/aaa_quantiles.md index f8db2b326..b06554a58 100644 --- a/tests/testthat/_snaps/aaa_quantiles.md +++ b/tests/testthat/_snaps/aaa_quantiles.md @@ -11,9 +11,7 @@ Code quantile_pred(matrix(1:20, 5), -1:4 / 5) Condition - Error in `map()`: - i In index: 1. - Caused by error in `quantile_pred()`: + Error in `quantile_pred()`: ! `quantile_levels` must be a number between 0 and 1, not the number -0.2. --- @@ -38,7 +36,7 @@ print(v) Output - [1] [1, 16] [2, 17] [3, 18] [4, 19] [5, 20] + [1] [8.5] [9.5] [10.5] [11.5] [12.5] # Quantile levels: 0.2 0.4 0.6 0.8 --- @@ -47,7 +45,7 @@ print(quantile_pred(matrix(1:18, 9), c(1 / 3, 2 / 3))) Output - [1] [1, 10] [2, 11] [3, 12] [4, 13] [5, 14] [6, 15] [7, 16] [8, 17] [9, 18] + [1] [5.5] [6.5] [7.5] [8.5] [9.5] [10.5] [11.5] [12.5] [13.5] # Quantile levels: 0.333 0.667 --- @@ -56,7 +54,7 @@ print(quantile_pred(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(0.2, 0.8))) Output - [1] [0.01, 0.598] [0.206, 0.794] [0.402, 0.99] + [1] [0.304] [0.5] [0.696] # Quantile levels: 0.2 0.8 --- @@ -65,13 +63,13 @@ print(tibble(qntls = v)) Output # A tibble: 5 x 1 - qntls - - 1 [1, 16] - 2 [2, 17] - 3 [3, 18] - 4 [4, 19] - 5 [5, 20] + qntls + + 1 [8.5] + 2 [9.5] + 3 [10.5] + 4 [11.5] + 5 [12.5] --- @@ -79,7 +77,7 @@ print(quantile_pred(m, 1:4 / 5)) Output - [1] [1, 16] [2, 17] [3, 18] [4, 19] [5, 20] + [1] [8.5] [9.5] [10.5] [11.5] [12.5] # Quantile levels: 0.2 0.4 0.6 0.8 --- @@ -97,13 +95,13 @@ print(tibble(qntls = one_quantile)) Output # A tibble: 5 x 1 - qntls - - 1 1 - 2 2 - 3 3 - 4 4 - 5 5 + qntls + + 1 1 + 2 2 + 3 3 + 4 4 + 5 5 --- From e8feed3aaf28309bf005bd1e9302ec2885e1febd Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 10 Sep 2024 20:15:15 -0400 Subject: [PATCH 19/31] as_tibble method --- NAMESPACE | 1 + R/aaa_quantiles.R | 15 +++++++++++++++ man/quantile_pred.Rd | 5 ++++- tests/testthat/test-linear_reg_quantreg.R | 20 ++++++++++++++++++++ 4 files changed, 40 insertions(+), 1 deletion(-) diff --git a/NAMESPACE b/NAMESPACE index f830f0d3d..6633c1c60 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,7 @@ S3method(.censoring_weights_graf,default) S3method(.censoring_weights_graf,model_fit) +S3method(as_tibble,quantile_pred) S3method(augment,model_fit) S3method(autoplot,glmnet) S3method(autoplot,model_fit) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 7b301ec1c..7fb3790bb 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -61,6 +61,9 @@ new_quantile_pred <- function(values = list(), quantile_levels = double()) { #' # Access the underlying information #' attr(v, "quantile_levels") #' unclass(v) +#' +#' # tidy format +#' as_tibble(v) quantile_pred <- function(values, quantile_levels = double()) { check_quantile_pred_inputs(values, quantile_levels) quantile_levels <- vctrs::vec_cast(quantile_levels, double()) @@ -163,3 +166,15 @@ restructure_rq_pred <- function(x, object) { tibble::tibble(.pred_quantile = quantile_pred(x, quantile_level)) } +#' @export +as_tibble.quantile_pred <- + function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) { + lvls <- attr(x, "quantile_levels") + n_samp <- length(x) + n_quant <- length(lvls) + tibble::tibble( + .pred_quantile = unlist(x), + .quantile_levels = rep(lvls, n_samp), + .row = rep(1:n_samp, each = n_quant) + ) + } diff --git a/man/quantile_pred.Rd b/man/quantile_pred.Rd index bf8f20c03..a89621716 100644 --- a/man/quantile_pred.Rd +++ b/man/quantile_pred.Rd @@ -23,5 +23,8 @@ v <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) # Access the underlying information attr(v, "quantile_levels") -vctrs::vec_data(v) +unclass(v) + +# tidy format +as_tibble(v) } diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index 9c3ac41d7..f6ab7e23b 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -32,6 +32,11 @@ test_that('linear quantile regression via quantreg - single quantile', { expect_true(length(one_quant_pred$.pred_quantile[[1]]) == 1L) expect_identical(attr(one_quant_pred$.pred_quantile, "quantile_levels"), .5) + one_quant_df <- as_tibble(one_quant_pred$.pred_quantile) + expect_s3_class(one_quant_df, c("tbl_df", "tbl", "data.frame")) + expect_named(one_quant_df, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(one_quant_df) == nrow(sac_test) * 1) + ### one_quant_one_row <- predict(one_quant, new_data = sac_test[1,]) @@ -45,6 +50,11 @@ test_that('linear quantile regression via quantreg - single quantile', { expect_identical(class(one_quant_one_row$.pred_quantile[[1]]), "numeric") expect_true(length(one_quant_one_row$.pred_quantile[[1]]) == 1L) expect_identical(attr(one_quant_pred$.pred_quantile, "quantile_levels"), .5) + + one_quant_one_row_df <- as_tibble(one_quant_one_row$.pred_quantile) + expect_s3_class(one_quant_one_row_df, c("tbl_df", "tbl", "data.frame")) + expect_named(one_quant_one_row_df, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(one_quant_one_row_df) == nrow(sac_test[1,]) * 1) }) test_that('linear quantile regression via quantreg - multiple quantiles', { @@ -81,6 +91,11 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(length(ten_quant_pred$.pred_quantile[[1]]) == 10L) expect_identical(attr(ten_quant_pred$.pred_quantile, "quantile_levels"), (0:9)/9) + ten_quant_df <- as_tibble(ten_quant_pred$.pred_quantile) + expect_s3_class(ten_quant_df, c("tbl_df", "tbl", "data.frame")) + expect_named(ten_quant_df, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10) + ### ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,]) @@ -97,6 +112,11 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { attr(ten_quant_one_row$.pred_quantile, "quantile_levels"), (0:9)/9 ) + + ten_quant_one_row_df <- as_tibble(ten_quant_one_row$.pred_quantile) + expect_s3_class(ten_quant_one_row_df, c("tbl_df", "tbl", "data.frame")) + expect_named(ten_quant_one_row_df, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(ten_quant_one_row_df) == nrow(sac_test[1,]) * 10) }) From 2748d06fbcf8c7b745d94e2f05cc91199bdf2622 Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 10 Sep 2024 20:18:50 -0400 Subject: [PATCH 20/31] updated NEWS file --- NEWS.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index c51afb0e7..c2c270811 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,9 @@ # parsnip (development version) - +* A new model mode, "quantile regression" was added. Including: + * A function to create a new vector class called `quantile_pred()` was added. + * A `linear_reg()` engine for `"quantreg"`. + * `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775). * Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083). From 5b09175a6d63fa2c4d3a6ca24ec6695da52e9be2 Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 10 Sep 2024 20:19:13 -0400 Subject: [PATCH 21/31] add PR number --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index c2c270811..21c1c85b4 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,7 +1,7 @@ # parsnip (development version) * A new model mode, "quantile regression" was added. Including: - * A function to create a new vector class called `quantile_pred()` was added. + * A function to create a new vector class called `quantile_pred()` was added (#1191). * A `linear_reg()` engine for `"quantreg"`. * `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775). From 30760de25df05ef353e820be61c25b8b5989af7d Mon Sep 17 00:00:00 2001 From: topepo Date: Tue, 10 Sep 2024 20:20:08 -0400 Subject: [PATCH 22/31] small new update --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 21c1c85b4..62f6ea8da 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,6 +1,6 @@ # parsnip (development version) -* A new model mode, "quantile regression" was added. Including: +* A new model mode (`"quantile regression"`) was added. Including: * A function to create a new vector class called `quantile_pred()` was added (#1191). * A `linear_reg()` engine for `"quantreg"`. From 926d587e711212c3314a2dcf5e4a247f9df97440 Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 12 Sep 2024 09:09:21 -0400 Subject: [PATCH 23/31] helper methods --- NAMESPACE | 2 ++ R/aaa_quantiles.R | 18 ++++++++++++++++++ man/quantile_pred.Rd | 6 ++++++ 3 files changed, 26 insertions(+) diff --git a/NAMESPACE b/NAMESPACE index 6633c1c60..9425a6db5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,7 @@ S3method(.censoring_weights_graf,default) S3method(.censoring_weights_graf,model_fit) +S3method(as.matrix,quantile_pred) S3method(as_tibble,quantile_pred) S3method(augment,model_fit) S3method(autoplot,glmnet) @@ -232,6 +233,7 @@ export(extract_fit_engine) export(extract_fit_time) export(extract_parameter_dials) export(extract_parameter_set_dials) +export(extract_quantile_levels) export(extract_spec_parsnip) export(find_engine_files) export(fit) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 7fb3790bb..38f5fecce 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -64,6 +64,9 @@ new_quantile_pred <- function(values = list(), quantile_levels = double()) { #' #' # tidy format #' as_tibble(v) +#' +#' # matrix format +#' as.matrix(v) quantile_pred <- function(values, quantile_levels = double()) { check_quantile_pred_inputs(values, quantile_levels) quantile_levels <- vctrs::vec_cast(quantile_levels, double()) @@ -178,3 +181,18 @@ as_tibble.quantile_pred <- .row = rep(1:n_samp, each = n_quant) ) } + +#' @export +as.matrix.quantile_pred <- function(x, ...) { + num_samp <- length(x) + matrix(unlist(x), nrow = num_samp) +} + +#' @export +#' @rdname quantile_pred +extract_quantile_levels <- function(x) { + if ( !inherits(x, "quantile_pred") ) { + cli::cli_abort("{.arg x} should have class {.val quantile_pred}.") + } + attr(x, "quantile_levels") +} diff --git a/man/quantile_pred.Rd b/man/quantile_pred.Rd index a89621716..26019ed26 100644 --- a/man/quantile_pred.Rd +++ b/man/quantile_pred.Rd @@ -2,9 +2,12 @@ % Please edit documentation in R/aaa_quantiles.R \name{quantile_pred} \alias{quantile_pred} +\alias{extract_quantile_levels} \title{Create a vector containing sets of quantiles} \usage{ quantile_pred(values, quantile_levels = double()) + +extract_quantile_levels(x) } \arguments{ \item{values}{A matrix of values. Each column should correspond to one of @@ -27,4 +30,7 @@ unclass(v) # tidy format as_tibble(v) + +# matrix format +as.matrix(v) } From b575c347069bbbf3f5ef4e79fa0c31d4fc0b68cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Thu, 12 Sep 2024 10:29:20 -0400 Subject: [PATCH 24/31] update docs --- NAMESPACE | 1 + R/aaa_quantiles.R | 54 +++++++++++++++-------- R/parsnip-package.R | 2 +- man/quantile_pred.Rd | 42 ++++++++++++++---- tests/testthat/helper-objects.R | 13 ++++++ tests/testthat/test-linear_reg_quantreg.R | 18 +------- 6 files changed, 85 insertions(+), 45 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 9425a6db5..55d0da025 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -414,6 +414,7 @@ importFrom(stats,as.formula) importFrom(stats,binomial) importFrom(stats,coef) importFrom(stats,delete.response) +importFrom(stats,median) importFrom(stats,model.frame) importFrom(stats,model.matrix) importFrom(stats,model.offset) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 38f5fecce..99c74ab58 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -45,28 +45,42 @@ new_quantile_pred <- function(values = list(), quantile_levels = double()) { ) } - #' Create a vector containing sets of quantiles #' +#' [quantile_pred()] is a special vector class used to efficiently store +#' predictions from a quantile regression model. It requires the same quantile +#' levels for each row being predicted. +#' #' @param values A matrix of values. Each column should correspond to one of #' the quantile levels. #' @param quantile_levels A vector of probabilities corresponding to `values`. +#' @param x An object produced by [quantile_pred()]. +#' @param .rows,.name_repair,rownames Arguments not used but required by the +#' original S3 method. +#' @param ... Not currently used. #' #' @export -#' @return A vector of values associated with the quantile levels. -#' +#' @return +#' * [quantile_pred()] returns a vector of values associated with the +#' quantile levels. +#' * [extract_quantile_levels()] returns a numeric vector of levels. +#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`, +#' `".quantile_levels"`, and `".row"`. +#' * [as.matrix()] returns an unnamed matrix with rows as sames, columns as +#' quantile levels, and entries are predictions. #' @examples -#' v <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +#' +#' unclass(.pred_quantile) #' #' # Access the underlying information -#' attr(v, "quantile_levels") -#' unclass(v) +#' extract_quantile_levels(.pred_quantile) #' -#' # tidy format -#' as_tibble(v) +#' # Matrix format +#' as.matrix(.pred_quantile) #' -#' # matrix format -#' as.matrix(v) +#' # Tidy format +#' tibble::as_tibble(.pred_quantile) quantile_pred <- function(values, quantile_levels = double()) { check_quantile_pred_inputs(values, quantile_levels) quantile_levels <- vctrs::vec_cast(quantile_levels, double()) @@ -170,6 +184,16 @@ restructure_rq_pred <- function(x, object) { } #' @export +#' @rdname quantile_pred +extract_quantile_levels <- function(x) { + if ( !inherits(x, "quantile_pred") ) { + cli::cli_abort("{.arg x} should have class {.val quantile_pred}.") + } + attr(x, "quantile_levels") +} + +#' @export +#' @rdname quantile_pred as_tibble.quantile_pred <- function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) { lvls <- attr(x, "quantile_levels") @@ -183,16 +207,8 @@ as_tibble.quantile_pred <- } #' @export +#' @rdname quantile_pred as.matrix.quantile_pred <- function(x, ...) { num_samp <- length(x) matrix(unlist(x), nrow = num_samp) } - -#' @export -#' @rdname quantile_pred -extract_quantile_levels <- function(x) { - if ( !inherits(x, "quantile_pred") ) { - cli::cli_abort("{.arg x} should have class {.val quantile_pred}.") - } - attr(x, "quantile_levels") -} diff --git a/R/parsnip-package.R b/R/parsnip-package.R index 01f1f42c1..c4dd3c81d 100644 --- a/R/parsnip-package.R +++ b/R/parsnip-package.R @@ -21,7 +21,7 @@ #' @importFrom stats .checkMFClasses .getXlevels as.formula binomial coef #' @importFrom stats delete.response model.frame model.matrix model.offset #' @importFrom stats model.response model.weights na.omit na.pass predict qnorm -#' @importFrom stats qt quantile setNames terms update +#' @importFrom stats qt quantile setNames terms update median #' @importFrom tibble as_tibble is_tibble tibble #' @importFrom tidyr gather #' @importFrom utils capture.output getFromNamespace globalVariables head diff --git a/man/quantile_pred.Rd b/man/quantile_pred.Rd index 26019ed26..c96971aae 100644 --- a/man/quantile_pred.Rd +++ b/man/quantile_pred.Rd @@ -3,34 +3,58 @@ \name{quantile_pred} \alias{quantile_pred} \alias{extract_quantile_levels} +\alias{as_tibble.quantile_pred} +\alias{as.matrix.quantile_pred} \title{Create a vector containing sets of quantiles} \usage{ quantile_pred(values, quantile_levels = double()) extract_quantile_levels(x) + +\method{as_tibble}{quantile_pred}(x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) + +\method{as.matrix}{quantile_pred}(x, ...) } \arguments{ \item{values}{A matrix of values. Each column should correspond to one of the quantile levels.} \item{quantile_levels}{A vector of probabilities corresponding to \code{values}.} + +\item{x}{An object produced by \code{\link[=quantile_pred]{quantile_pred()}}.} + +\item{...}{Not currently used.} + +\item{.rows, .name_repair, rownames}{Arguments not used but required by the +original S3 method.} } \value{ -A vector of values associated with the quantile levels. +\itemize{ +\item \code{\link[=quantile_pred]{quantile_pred()}} returns a vector of values associated with the +quantile levels. +\item \code{\link[=extract_quantile_levels]{extract_quantile_levels()}} returns a numeric vector of levels. +\item \code{\link[=as_tibble]{as_tibble()}} returns a tibble with rows \code{".pred_quantile"}, +\code{".quantile_levels"}, and \code{".row"}. +\item \code{\link[=as.matrix]{as.matrix()}} returns an unnamed matrix with rows as sames, columns as +quantile levels, and entries are predictions. +} } \description{ -Create a vector containing sets of quantiles +\code{\link[=quantile_pred]{quantile_pred()}} is a special vector class used to efficiently store +predictions from a quantile regression model. It requires the same quantile +levels for each row being predicted. } \examples{ -v <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) +.pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) -# Access the underlying information -attr(v, "quantile_levels") -unclass(v) +unclass(.pred_quantile) -# tidy format -as_tibble(v) +# Access the underlying information +extract_quantile_levels(.pred_quantile) # matrix format -as.matrix(v) +as.matrix(.pred_quantile) + +# tidy format +tibble::as_tibble(.pred_quantile) } diff --git a/tests/testthat/helper-objects.R b/tests/testthat/helper-objects.R index a9297a65a..38633ab4b 100644 --- a/tests/testthat/helper-objects.R +++ b/tests/testthat/helper-objects.R @@ -24,3 +24,16 @@ is_tf_ok <- function() { } res } + +# ------------------------------------------------------------------------------ +# for quantile regression tests + +data("Sacramento") + +Sacramento_small <- + Sacramento %>% + dplyr::mutate(price = log10(price)) %>% + dplyr::select(price, beds, baths, sqft, latitude, longitude) + +sac_train <- Sacramento_small[-(1:5), ] +sac_test <- Sacramento_small[ 1:5 , ] diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index f6ab7e23b..47a9f7c88 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -1,14 +1,7 @@ test_that('linear quantile regression via quantreg - single quantile', { skip_if_not_installed("quantreg") - data("Sacramento") - - Sacramento_small <- - Sacramento %>% - dplyr::select(price, beds, baths, sqft, latitude, longitude) - - sac_train <- Sacramento_small[-(1:5), ] - sac_test <- Sacramento_small[ 1:5 , ] + # data in `helper-objects.R` one_quant <- linear_reg() %>% @@ -60,14 +53,7 @@ test_that('linear quantile regression via quantreg - single quantile', { test_that('linear quantile regression via quantreg - multiple quantiles', { skip_if_not_installed("quantreg") - data("Sacramento") - - Sacramento_small <- - Sacramento %>% - dplyr::select(price, beds, baths, sqft, latitude, longitude) - - sac_train <- Sacramento_small[-(1:5), ] - sac_test <- Sacramento_small[ 1:5 , ] + # data in `helper-objects.R` ten_quant <- linear_reg() %>% From 83c744bccb666bd290bc5e97abaeff529545f0e8 Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 12 Sep 2024 18:52:09 -0400 Subject: [PATCH 25/31] re-enable quantiles prediction for #1203 --- R/predict_quantile.R | 15 ++++++++++----- man/other_predict.Rd | 9 ++++++++- 2 files changed, 18 insertions(+), 6 deletions(-) diff --git a/R/predict_quantile.R b/R/predict_quantile.R index f9154d6a9..efe0458f8 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -6,7 +6,12 @@ #' @method predict_quantile model_fit #' @export predict_quantile.model_fit #' @export -predict_quantile.model_fit <- function(object, new_data, ...) { +predict_quantile.model_fit <- function(object, + new_data, + quantile = (1:9)/10, + interval = "none", + level = 0.95, + ...) { check_spec_pred_type(object, "quantile") @@ -18,12 +23,11 @@ predict_quantile.model_fit <- function(object, new_data, ...) { new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$quantile$pre)) { + if (!is.null(object$spec$method$pred$quantile$pre)) new_data <- object$spec$method$pred$quantile$pre(new_data, object) - } # Pass some extra arguments to be used in post-processor - object$spec$method$pred$quantile$args$quantile_level <- object$quantile_level + object$spec$method$pred$quantile$args$p <- quantile pred_call <- make_pred_call(object$spec$method$pred$quantile) res <- eval_tidy(pred_call) @@ -40,5 +44,6 @@ predict_quantile.model_fit <- function(object, new_data, ...) { # @keywords internal # @rdname other_predict # @inheritParams predict.model_fit -predict_quantile <- function (object, ...) +predict_quantile <- function (object, ...) { UseMethod("predict_quantile") +} diff --git a/man/other_predict.Rd b/man/other_predict.Rd index bc1d104bf..6c997e28d 100644 --- a/man/other_predict.Rd +++ b/man/other_predict.Rd @@ -46,7 +46,14 @@ predict_linear_pred(object, ...) predict_numeric(object, ...) -\method{predict_quantile}{model_fit}(object, new_data, ...) +\method{predict_quantile}{model_fit}( + object, + new_data, + quantile = (1:9)/10, + interval = "none", + level = 0.95, + ... +) \method{predict_survival}{model_fit}( object, From 11dd169a55eadb481c35bf7b947184587276e50a Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 12 Sep 2024 18:52:27 -0400 Subject: [PATCH 26/31] update some tests --- R/aaa_quantiles.R | 12 ++++++++++-- man/quantile_pred.Rd | 4 ++-- tests/testthat/_snaps/quantile-reg-specs.md | 14 ++++---------- tests/testthat/helper-objects.R | 2 +- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 99c74ab58..679fe0671 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -9,9 +9,11 @@ check_quantile_level <- function(x, object, call) { {.arg quantile_level} must be specified for quantile regression models.") } } + if ( any(is.na(x)) ) { + cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", + call = call) + } x <- sort(unique(x)) - # TODO we need better vectorization here, otherwise we get things like: - # "Error during wrapup: i In index: 2." in the traceback. check_vector_probability(x, arg = "quantile_level", call = call) x } @@ -83,6 +85,7 @@ new_quantile_pred <- function(values = list(), quantile_levels = double()) { #' tibble::as_tibble(.pred_quantile) quantile_pred <- function(values, quantile_levels = double()) { check_quantile_pred_inputs(values, quantile_levels) + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) num_lvls <- length(quantile_levels) @@ -99,6 +102,11 @@ quantile_pred <- function(values, quantile_levels = double()) { } check_quantile_pred_inputs <- function(values, levels, call = caller_env()) { + if ( any(is.na(levels)) ) { + cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", + call = call) + } + if (!is.matrix(values)) { cli::cli_abort( "{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.", diff --git a/man/quantile_pred.Rd b/man/quantile_pred.Rd index c96971aae..abb34ca20 100644 --- a/man/quantile_pred.Rd +++ b/man/quantile_pred.Rd @@ -52,9 +52,9 @@ unclass(.pred_quantile) # Access the underlying information extract_quantile_levels(.pred_quantile) -# matrix format +# Matrix format as.matrix(.pred_quantile) -# tidy format +# Tidy format tibble::as_tibble(.pred_quantile) } diff --git a/tests/testthat/_snaps/quantile-reg-specs.md b/tests/testthat/_snaps/quantile-reg-specs.md index f7c24584c..627a5248b 100644 --- a/tests/testthat/_snaps/quantile-reg-specs.md +++ b/tests/testthat/_snaps/quantile-reg-specs.md @@ -20,9 +20,7 @@ linear_reg() %>% set_engine("quantreg") %>% set_mode("quantile regression", quantile_level = 2) Condition - Error in `purrr::map()`: - i In index: 1. - Caused by error in `set_mode()`: + Error in `set_mode()`: ! `quantile_level` must be a number between 0 and 1, not the number 2. --- @@ -31,9 +29,7 @@ linear_reg() %>% set_engine("quantreg") %>% set_mode("quantile regression", quantile_level = 1:2) Condition - Error in `purrr::map()`: - i In index: 2. - Caused by error in `set_mode()`: + Error in `set_mode()`: ! `quantile_level` must be a number between 0 and 1, not the number 2. --- @@ -42,8 +38,6 @@ linear_reg() %>% set_engine("quantreg") %>% set_mode("quantile regression", quantile_level = NA_real_) Condition - Error in `purrr::map()`: - i In index: 1. - Caused by error in `set_mode()`: - ! `quantile_level` must be a number, not a numeric `NA`. + Error in `set_mode()`: + ! Missing values are not allowed in `quantile_levels`. diff --git a/tests/testthat/helper-objects.R b/tests/testthat/helper-objects.R index 38633ab4b..14c3931fe 100644 --- a/tests/testthat/helper-objects.R +++ b/tests/testthat/helper-objects.R @@ -31,7 +31,7 @@ is_tf_ok <- function() { data("Sacramento") Sacramento_small <- - Sacramento %>% + modeldata::Sacramento %>% dplyr::mutate(price = log10(price)) %>% dplyr::select(price, beds, baths, sqft, latitude, longitude) From 9fa5bf0fa2ec5a11a50bc5a9b121e231b747f666 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Fri, 13 Sep 2024 08:38:43 -0400 Subject: [PATCH 27/31] no longer needed --- R/aaa_quantiles.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 679fe0671..75abc197c 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -186,7 +186,6 @@ restructure_rq_pred <- function(x, object) { } rownames(x) <- NULL n_pred_quantiles <- ncol(x) - # TODO check p = length(quantile_level) quantile_level <- object$spec$quantile_level tibble::tibble(.pred_quantile = quantile_pred(x, quantile_level)) } From 1e74bae0c4d4830a952d0b131f2b2051cc7d04ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Fri, 13 Sep 2024 08:53:37 -0400 Subject: [PATCH 28/31] use tibble::new_tibble --- R/aaa_quantiles.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 75abc197c..746fc46e3 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -187,7 +187,8 @@ restructure_rq_pred <- function(x, object) { rownames(x) <- NULL n_pred_quantiles <- ncol(x) quantile_level <- object$spec$quantile_level - tibble::tibble(.pred_quantile = quantile_pred(x, quantile_level)) + + tibble::new_tibble(x = list(.pred_quantile = quantile_pred(x, quantile_level))) } #' @export From 91220736edcfa93cab82e1c72b9ef4c04b9c5303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Fri, 13 Sep 2024 10:20:10 -0400 Subject: [PATCH 29/31] braces --- R/aaa_quantiles.R | 10 +++++----- R/predict_quantile.R | 3 ++- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/R/aaa_quantiles.R b/R/aaa_quantiles.R index 746fc46e3..2b62dfcf8 100644 --- a/R/aaa_quantiles.R +++ b/R/aaa_quantiles.R @@ -1,15 +1,15 @@ # Helpers for quantile regression models check_quantile_level <- function(x, object, call) { - if ( object$mode != "quantile regression" ) { + if (object$mode != "quantile regression") { return(invisible(TRUE)) } else { - if ( is.null(x) ) { + if (is.null(x)) { cli::cli_abort("In {.fn check_mode}, at least one value of {.arg quantile_level} must be specified for quantile regression models.") } } - if ( any(is.na(x)) ) { + if (any(is.na(x))) { cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", call = call) } @@ -102,7 +102,7 @@ quantile_pred <- function(values, quantile_levels = double()) { } check_quantile_pred_inputs <- function(values, levels, call = caller_env()) { - if ( any(is.na(levels)) ) { + if (any(is.na(levels))) { cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", call = call) } @@ -194,7 +194,7 @@ restructure_rq_pred <- function(x, object) { #' @export #' @rdname quantile_pred extract_quantile_levels <- function(x) { - if ( !inherits(x, "quantile_pred") ) { + if (!inherits(x, "quantile_pred")) { cli::cli_abort("{.arg x} should have class {.val quantile_pred}.") } attr(x, "quantile_levels") diff --git a/R/predict_quantile.R b/R/predict_quantile.R index efe0458f8..fc2d91b15 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -23,8 +23,9 @@ predict_quantile.model_fit <- function(object, new_data <- prepare_data(object, new_data) # preprocess data - if (!is.null(object$spec$method$pred$quantile$pre)) + if (!is.null(object$spec$method$pred$quantile$pre)) { new_data <- object$spec$method$pred$quantile$pre(new_data, object) + } # Pass some extra arguments to be used in post-processor object$spec$method$pred$quantile$args$p <- quantile From 9ee98e9e235dbdda95619f5dca2f5f96be90b124 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Fri, 13 Sep 2024 10:28:25 -0400 Subject: [PATCH 30/31] test as_tibble --- tests/testthat/_snaps/aaa_quantiles.md | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/testthat/_snaps/aaa_quantiles.md b/tests/testthat/_snaps/aaa_quantiles.md index b06554a58..0925f61df 100644 --- a/tests/testthat/_snaps/aaa_quantiles.md +++ b/tests/testthat/_snaps/aaa_quantiles.md @@ -33,7 +33,7 @@ # quantile_pred formatting Code - print(v) + v Output [1] [8.5] [9.5] [10.5] [11.5] [12.5] @@ -42,7 +42,7 @@ --- Code - print(quantile_pred(matrix(1:18, 9), c(1 / 3, 2 / 3))) + quantile_pred(matrix(1:18, 9), c(1 / 3, 2 / 3)) Output [1] [5.5] [6.5] [7.5] [8.5] [9.5] [10.5] [11.5] [12.5] [13.5] @@ -51,7 +51,7 @@ --- Code - print(quantile_pred(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(0.2, 0.8))) + quantile_pred(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(0.2, 0.8)) Output [1] [0.304] [0.5] [0.696] @@ -60,7 +60,7 @@ --- Code - print(tibble(qntls = v)) + tibble(qntls = v) Output # A tibble: 5 x 1 qntls @@ -74,7 +74,7 @@ --- Code - print(quantile_pred(m, 1:4 / 5)) + quantile_pred(m, 1:4 / 5) Output [1] [8.5] [9.5] [10.5] [11.5] [12.5] @@ -83,7 +83,7 @@ --- Code - print(one_quantile) + one_quantile Output [1] 1 2 3 4 5 @@ -92,7 +92,7 @@ --- Code - print(tibble(qntls = one_quantile)) + tibble(qntls = one_quantile) Output # A tibble: 5 x 1 qntls @@ -106,7 +106,7 @@ --- Code - print(quantile_pred(m, 5 / 9)) + quantile_pred(m, 5 / 9) Output [1] 1 NA 3 4 5 From 9ce72c0ff49efbd0201755041a916046f4816b25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Fri, 13 Sep 2024 10:28:33 -0400 Subject: [PATCH 31/31] remove print methods --- tests/testthat/test-aaa_quantiles.R | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/testthat/test-aaa_quantiles.R b/tests/testthat/test-aaa_quantiles.R index da19f0819..cdf71aa7d 100644 --- a/tests/testthat/test-aaa_quantiles.R +++ b/tests/testthat/test-aaa_quantiles.R @@ -30,22 +30,30 @@ test_that("quantile_pred outputs", { test_that("quantile_pred formatting", { # multiple quantiles v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) - expect_snapshot(print(v)) - expect_snapshot(print(quantile_pred(matrix(1:18, 9), c(1/3, 2/3)))) - expect_snapshot(print( + expect_snapshot(v) + expect_snapshot(quantile_pred(matrix(1:18, 9), c(1/3, 2/3))) + expect_snapshot( quantile_pred(matrix(seq(0.01, 1 - 0.01, length.out = 6), 3), c(.2, .8)) - )) - expect_snapshot(print(tibble(qntls = v))) + ) + expect_snapshot(tibble(qntls = v)) m <- matrix(1:20, 5) m[2, 3] <- NA m[4, 2] <- NA - expect_snapshot(print(quantile_pred(m, 1:4 / 5))) + expect_snapshot(quantile_pred(m, 1:4 / 5)) # single quantile m <- matrix(1:5) one_quantile <- quantile_pred(m, 5/9) - expect_snapshot(print(one_quantile)) - expect_snapshot(print(tibble(qntls = one_quantile))) + expect_snapshot(one_quantile) + expect_snapshot(tibble(qntls = one_quantile)) m[2] <- NA - expect_snapshot(print(quantile_pred(m, 5/9))) + expect_snapshot(quantile_pred(m, 5/9)) +}) + +test_that("as_tibble() for quantile_pred", { + v <- quantile_pred(matrix(1:20, 5), 1:4 / 5) + tbl <- as_tibble(v) + expect_s3_class(tbl, c("tbl_df", "tbl", "data.frame")) + expect_named(tbl, c(".pred_quantile", ".quantile_levels", ".row")) + expect_true(nrow(tbl) == 20) })