|
1 | 1 | # Helpers for quantile regression models
|
2 | 2 |
|
3 | 3 | check_quantile_level <- function(x, object, call) {
|
4 |
| - if ( object$mode != "quantile regression" ) { |
| 4 | + if (object$mode != "quantile regression") { |
5 | 5 | return(invisible(TRUE))
|
6 | 6 | } else {
|
7 |
| - if ( is.null(x) ) { |
| 7 | + if (is.null(x)) { |
8 | 8 | cli::cli_abort("In {.fn check_mode}, at least one value of
|
9 | 9 | {.arg quantile_level} must be specified for quantile regression models.")
|
10 | 10 | }
|
11 | 11 | }
|
| 12 | + if (any(is.na(x))) { |
| 13 | + cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", |
| 14 | + call = call) |
| 15 | + } |
12 | 16 | x <- sort(unique(x))
|
13 |
| - # TODO we need better vectorization here, otherwise we get things like: |
14 |
| - # "Error during wrapup: i In index: 2." in the traceback. |
15 |
| - res <- |
16 |
| - purrr::map(x, |
17 |
| - ~ check_number_decimal(.x, min = 0, max = 1, |
18 |
| - arg = "quantile_level", call = call, |
19 |
| - allow_infinite = FALSE) |
20 |
| - ) |
| 17 | + check_vector_probability(x, arg = "quantile_level", call = call) |
21 | 18 | x
|
22 | 19 | }
|
23 | 20 |
|
24 |
| -# Assumes the columns have the same order as quantile_level |
25 |
| -restructure_rq_pred <- function(x, object) { |
26 |
| - num_quantiles <- NCOL(x) |
27 |
| - if ( num_quantiles == 1L ){ |
28 |
| - x <- matrix(x, ncol = 1) |
| 21 | + |
| 22 | +# ------------------------------------------------------------------------- |
| 23 | +# A column vector of quantiles with an attribute |
| 24 | + |
| 25 | +#' @importFrom vctrs vec_ptype_abbr |
| 26 | +#' @export |
| 27 | +vctrs::vec_ptype_abbr |
| 28 | + |
| 29 | +#' @importFrom vctrs vec_ptype_full |
| 30 | +#' @export |
| 31 | +vctrs::vec_ptype_full |
| 32 | + |
| 33 | + |
| 34 | +#' @export |
| 35 | +vec_ptype_abbr.quantile_pred <- function(x, ...) { |
| 36 | + n_lvls <- length(attr(x, "quantile_levels")) |
| 37 | + cli::format_inline("qtl{?s}({n_lvls})") |
| 38 | +} |
| 39 | + |
| 40 | +#' @export |
| 41 | +vec_ptype_full.quantile_pred <- function(x, ...) "quantiles" |
| 42 | + |
| 43 | +new_quantile_pred <- function(values = list(), quantile_levels = double()) { |
| 44 | + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) |
| 45 | + vctrs::new_vctr( |
| 46 | + values, quantile_levels = quantile_levels, class = "quantile_pred" |
| 47 | + ) |
| 48 | +} |
| 49 | + |
| 50 | +#' Create a vector containing sets of quantiles |
| 51 | +#' |
| 52 | +#' [quantile_pred()] is a special vector class used to efficiently store |
| 53 | +#' predictions from a quantile regression model. It requires the same quantile |
| 54 | +#' levels for each row being predicted. |
| 55 | +#' |
| 56 | +#' @param values A matrix of values. Each column should correspond to one of |
| 57 | +#' the quantile levels. |
| 58 | +#' @param quantile_levels A vector of probabilities corresponding to `values`. |
| 59 | +#' @param x An object produced by [quantile_pred()]. |
| 60 | +#' @param .rows,.name_repair,rownames Arguments not used but required by the |
| 61 | +#' original S3 method. |
| 62 | +#' @param ... Not currently used. |
| 63 | +#' |
| 64 | +#' @export |
| 65 | +#' @return |
| 66 | +#' * [quantile_pred()] returns a vector of values associated with the |
| 67 | +#' quantile levels. |
| 68 | +#' * [extract_quantile_levels()] returns a numeric vector of levels. |
| 69 | +#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`, |
| 70 | +#' `".quantile_levels"`, and `".row"`. |
| 71 | +#' * [as.matrix()] returns an unnamed matrix with rows as sames, columns as |
| 72 | +#' quantile levels, and entries are predictions. |
| 73 | +#' @examples |
| 74 | +#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8)) |
| 75 | +#' |
| 76 | +#' unclass(.pred_quantile) |
| 77 | +#' |
| 78 | +#' # Access the underlying information |
| 79 | +#' extract_quantile_levels(.pred_quantile) |
| 80 | +#' |
| 81 | +#' # Matrix format |
| 82 | +#' as.matrix(.pred_quantile) |
| 83 | +#' |
| 84 | +#' # Tidy format |
| 85 | +#' tibble::as_tibble(.pred_quantile) |
| 86 | +quantile_pred <- function(values, quantile_levels = double()) { |
| 87 | + check_quantile_pred_inputs(values, quantile_levels) |
| 88 | + |
| 89 | + quantile_levels <- vctrs::vec_cast(quantile_levels, double()) |
| 90 | + num_lvls <- length(quantile_levels) |
| 91 | + |
| 92 | + if (ncol(values) != num_lvls) { |
| 93 | + cli::cli_abort( |
| 94 | + "The number of columns in {.arg values} must be equal to the length of |
| 95 | + {.arg quantile_levels}." |
| 96 | + ) |
| 97 | + } |
| 98 | + rownames(values) <- NULL |
| 99 | + colnames(values) <- NULL |
| 100 | + values <- lapply(vctrs::vec_chop(values), drop) |
| 101 | + new_quantile_pred(values, quantile_levels) |
| 102 | +} |
| 103 | + |
| 104 | +check_quantile_pred_inputs <- function(values, levels, call = caller_env()) { |
| 105 | + if (any(is.na(levels))) { |
| 106 | + cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.", |
| 107 | + call = call) |
29 | 108 | }
|
30 |
| - n <- nrow(x) |
31 | 109 |
|
| 110 | + if (!is.matrix(values)) { |
| 111 | + cli::cli_abort( |
| 112 | + "{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.", |
| 113 | + call = call |
| 114 | + ) |
| 115 | + } |
| 116 | + check_vector_probability(levels, arg = "quantile_levels", call = call) |
| 117 | + |
| 118 | + if (is.unsorted(levels)) { |
| 119 | + cli::cli_abort( |
| 120 | + "{.arg quantile_levels} must be sorted in increasing order.", |
| 121 | + call = call |
| 122 | + ) |
| 123 | + } |
| 124 | + invisible(NULL) |
| 125 | +} |
| 126 | + |
| 127 | +#' @export |
| 128 | +format.quantile_pred <- function(x, ...) { |
| 129 | + quantile_levels <- attr(x, "quantile_levels") |
| 130 | + if (length(quantile_levels) == 1L) { |
| 131 | + x <- unlist(x) |
| 132 | + out <- round(x, 3L) |
| 133 | + out[is.na(x)] <- NA_real_ |
| 134 | + } else { |
| 135 | + rng <- sapply(x, range, na.rm = TRUE) |
| 136 | + out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]") |
| 137 | + out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_ |
| 138 | + m <- median(x) |
| 139 | + out <- paste0("[", round(m, 3L), "]") |
| 140 | + } |
| 141 | + out |
| 142 | +} |
| 143 | + |
| 144 | +#' @importFrom vctrs obj_print_footer |
| 145 | +#' @export |
| 146 | +vctrs::obj_print_footer |
| 147 | + |
| 148 | +#' @export |
| 149 | +obj_print_footer.quantile_pred <- function(x, digits = 3, ...) { |
| 150 | + lvls <- attr(x, "quantile_levels") |
| 151 | + cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ") |
| 152 | +} |
| 153 | + |
| 154 | +check_vector_probability <- function(x, ..., |
| 155 | + allow_na = FALSE, |
| 156 | + allow_null = FALSE, |
| 157 | + arg = caller_arg(x), |
| 158 | + call = caller_env()) { |
| 159 | + for (d in x) { |
| 160 | + check_number_decimal( |
| 161 | + d, min = 0, max = 1, |
| 162 | + arg = arg, call = call, |
| 163 | + allow_na = allow_na, |
| 164 | + allow_null = allow_null, |
| 165 | + allow_infinite = FALSE |
| 166 | + ) |
| 167 | + } |
| 168 | +} |
| 169 | + |
| 170 | +#' @export |
| 171 | +median.quantile_pred <- function(x, ...) { |
| 172 | + lvls <- attr(x, "quantile_levels") |
| 173 | + loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps)) |
| 174 | + if (any(loc_median)) { |
| 175 | + return(map_dbl(x, ~ .x[min(which(loc_median))])) |
| 176 | + } |
| 177 | + if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) { |
| 178 | + return(rep(NA, vctrs::vec_size(x))) |
| 179 | + } |
| 180 | + map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y) |
| 181 | +} |
| 182 | + |
| 183 | +restructure_rq_pred <- function(x, object) { |
| 184 | + if (!is.matrix(x)) { |
| 185 | + x <- as.matrix(x) |
| 186 | + } |
| 187 | + rownames(x) <- NULL |
| 188 | + n_pred_quantiles <- ncol(x) |
32 | 189 | quantile_level <- object$spec$quantile_level
|
33 |
| - res <- |
34 |
| - tibble::tibble( |
35 |
| - .pred_quantile = as.vector(x), |
36 |
| - .quantile_level = rep(quantile_level, each = n), |
37 |
| - .row = rep(1:n, num_quantiles)) |
38 |
| - res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"]) |
39 |
| - res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val))) |
40 |
| - res$.row <- NULL |
41 |
| - res |
| 190 | + |
| 191 | + tibble::new_tibble(x = list(.pred_quantile = quantile_pred(x, quantile_level))) |
| 192 | +} |
| 193 | + |
| 194 | +#' @export |
| 195 | +#' @rdname quantile_pred |
| 196 | +extract_quantile_levels <- function(x) { |
| 197 | + if (!inherits(x, "quantile_pred")) { |
| 198 | + cli::cli_abort("{.arg x} should have class {.val quantile_pred}.") |
| 199 | + } |
| 200 | + attr(x, "quantile_levels") |
42 | 201 | }
|
43 | 202 |
|
| 203 | +#' @export |
| 204 | +#' @rdname quantile_pred |
| 205 | +as_tibble.quantile_pred <- |
| 206 | + function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) { |
| 207 | + lvls <- attr(x, "quantile_levels") |
| 208 | + n_samp <- length(x) |
| 209 | + n_quant <- length(lvls) |
| 210 | + tibble::tibble( |
| 211 | + .pred_quantile = unlist(x), |
| 212 | + .quantile_levels = rep(lvls, n_samp), |
| 213 | + .row = rep(1:n_samp, each = n_quant) |
| 214 | + ) |
| 215 | + } |
| 216 | + |
| 217 | +#' @export |
| 218 | +#' @rdname quantile_pred |
| 219 | +as.matrix.quantile_pred <- function(x, ...) { |
| 220 | + num_samp <- length(x) |
| 221 | + matrix(unlist(x), nrow = num_samp) |
| 222 | +} |
0 commit comments