Skip to content

Commit 3bdb471

Browse files
dajmcdonsimonpcouchtopepo‘topepo’
authored
Quantile predictions output constructor (#1191)
* small change to predict checks * add vctrs for quantiles and test, refactor *_rq_preds * revise tests * Apply some of the suggestions from code review Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com> * rename tests on suggestion from code review * export missing funs from vctrs for formatting * convert errors to snapshot tests * pass call through input check * update snapshots for caller_env * rename to parsnip_quantiles, add format snapshot tests * Apply suggestions from @topepo Co-authored-by: Max Kuhn <mxkuhn@gmail.com> * rename parsnip_quantiles to quantile_pred * rename parsnip_quantiles to quantile_pred and add vector probability check * fix: two bugs introduced earlier * add formatting tests for single quantile * replace walk with a loop to avoid "Error in map()" * remove row/col names * adjust quantile_pred format * as_tibble method * updated NEWS file * add PR number * small new update * helper methods * update docs * re-enable quantiles prediction for #1203 * update some tests * no longer needed * use tibble::new_tibble * braces * test as_tibble * remove print methods --------- Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com> Co-authored-by: Max Kuhn <mxkuhn@gmail.com> Co-authored-by: ‘topepo’ <‘mxkuhn@gmail.com’>
1 parent 6168556 commit 3bdb471

13 files changed

+551
-70
lines changed

NAMESPACE

+16
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
S3method(.censoring_weights_graf,default)
44
S3method(.censoring_weights_graf,model_fit)
5+
S3method(as.matrix,quantile_pred)
6+
S3method(as_tibble,quantile_pred)
57
S3method(augment,model_fit)
68
S3method(autoplot,glmnet)
79
S3method(autoplot,model_fit)
@@ -36,10 +38,12 @@ S3method(extract_spec_parsnip,model_fit)
3638
S3method(fit,model_spec)
3739
S3method(fit_xy,gen_additive_mod)
3840
S3method(fit_xy,model_spec)
41+
S3method(format,quantile_pred)
3942
S3method(glance,model_fit)
4043
S3method(has_multi_predict,default)
4144
S3method(has_multi_predict,model_fit)
4245
S3method(has_multi_predict,workflow)
46+
S3method(median,quantile_pred)
4347
S3method(multi_predict,"_C5.0")
4448
S3method(multi_predict,"_earth")
4549
S3method(multi_predict,"_elnet")
@@ -54,6 +58,7 @@ S3method(multi_predict_args,default)
5458
S3method(multi_predict_args,model_fit)
5559
S3method(multi_predict_args,workflow)
5660
S3method(nullmodel,default)
61+
S3method(obj_print_footer,quantile_pred)
5762
S3method(predict,"_elnet")
5863
S3method(predict,"_glmnetfit")
5964
S3method(predict,"_lognet")
@@ -172,6 +177,8 @@ S3method(update,svm_rbf)
172177
S3method(varying_args,model_spec)
173178
S3method(varying_args,recipe)
174179
S3method(varying_args,step)
180+
S3method(vec_ptype_abbr,quantile_pred)
181+
S3method(vec_ptype_full,quantile_pred)
175182
export("%>%")
176183
export(.censoring_weights_graf)
177184
export(.check_glmnet_penalty_fit)
@@ -226,6 +233,7 @@ export(extract_fit_engine)
226233
export(extract_fit_time)
227234
export(extract_parameter_dials)
228235
export(extract_parameter_set_dials)
236+
export(extract_quantile_levels)
229237
export(extract_spec_parsnip)
230238
export(find_engine_files)
231239
export(fit)
@@ -280,6 +288,7 @@ export(new_model_spec)
280288
export(null_model)
281289
export(null_value)
282290
export(nullmodel)
291+
export(obj_print_footer)
283292
export(parsnip_addin)
284293
export(pls)
285294
export(poisson_reg)
@@ -307,6 +316,7 @@ export(prepare_data)
307316
export(print_model_spec)
308317
export(prompt_missing_implementation)
309318
export(proportional_hazards)
319+
export(quantile_pred)
310320
export(rand_forest)
311321
export(repair_call)
312322
export(req_pkgs)
@@ -350,6 +360,8 @@ export(update_model_info_file)
350360
export(update_spec)
351361
export(varying)
352362
export(varying_args)
363+
export(vec_ptype_abbr)
364+
export(vec_ptype_full)
353365
export(xgb_predict)
354366
export(xgb_train)
355367
import(rlang)
@@ -402,6 +414,7 @@ importFrom(stats,as.formula)
402414
importFrom(stats,binomial)
403415
importFrom(stats,coef)
404416
importFrom(stats,delete.response)
417+
importFrom(stats,median)
405418
importFrom(stats,model.frame)
406419
importFrom(stats,model.matrix)
407420
importFrom(stats,model.offset)
@@ -426,5 +439,8 @@ importFrom(utils,globalVariables)
426439
importFrom(utils,head)
427440
importFrom(utils,methods)
428441
importFrom(utils,stack)
442+
importFrom(vctrs,obj_print_footer)
443+
importFrom(vctrs,vec_ptype_abbr)
444+
importFrom(vctrs,vec_ptype_full)
429445
importFrom(vctrs,vec_size)
430446
importFrom(vctrs,vec_unique)

NEWS.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# parsnip (development version)
22

3-
3+
* A new model mode (`"quantile regression"`) was added. Including:
4+
* A function to create a new vector class called `quantile_pred()` was added (#1191).
5+
* A `linear_reg()` engine for `"quantreg"`.
6+
47
* `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775).
58

69
* Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083).

R/aaa_quantiles.R

+204-25
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,222 @@
11
# Helpers for quantile regression models
22

33
check_quantile_level <- function(x, object, call) {
4-
if ( object$mode != "quantile regression" ) {
4+
if (object$mode != "quantile regression") {
55
return(invisible(TRUE))
66
} else {
7-
if ( is.null(x) ) {
7+
if (is.null(x)) {
88
cli::cli_abort("In {.fn check_mode}, at least one value of
99
{.arg quantile_level} must be specified for quantile regression models.")
1010
}
1111
}
12+
if (any(is.na(x))) {
13+
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.",
14+
call = call)
15+
}
1216
x <- sort(unique(x))
13-
# TODO we need better vectorization here, otherwise we get things like:
14-
# "Error during wrapup: i In index: 2." in the traceback.
15-
res <-
16-
purrr::map(x,
17-
~ check_number_decimal(.x, min = 0, max = 1,
18-
arg = "quantile_level", call = call,
19-
allow_infinite = FALSE)
20-
)
17+
check_vector_probability(x, arg = "quantile_level", call = call)
2118
x
2219
}
2320

24-
# Assumes the columns have the same order as quantile_level
25-
restructure_rq_pred <- function(x, object) {
26-
num_quantiles <- NCOL(x)
27-
if ( num_quantiles == 1L ){
28-
x <- matrix(x, ncol = 1)
21+
22+
# -------------------------------------------------------------------------
23+
# A column vector of quantiles with an attribute
24+
25+
#' @importFrom vctrs vec_ptype_abbr
26+
#' @export
27+
vctrs::vec_ptype_abbr
28+
29+
#' @importFrom vctrs vec_ptype_full
30+
#' @export
31+
vctrs::vec_ptype_full
32+
33+
34+
#' @export
35+
vec_ptype_abbr.quantile_pred <- function(x, ...) {
36+
n_lvls <- length(attr(x, "quantile_levels"))
37+
cli::format_inline("qtl{?s}({n_lvls})")
38+
}
39+
40+
#' @export
41+
vec_ptype_full.quantile_pred <- function(x, ...) "quantiles"
42+
43+
new_quantile_pred <- function(values = list(), quantile_levels = double()) {
44+
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
45+
vctrs::new_vctr(
46+
values, quantile_levels = quantile_levels, class = "quantile_pred"
47+
)
48+
}
49+
50+
#' Create a vector containing sets of quantiles
51+
#'
52+
#' [quantile_pred()] is a special vector class used to efficiently store
53+
#' predictions from a quantile regression model. It requires the same quantile
54+
#' levels for each row being predicted.
55+
#'
56+
#' @param values A matrix of values. Each column should correspond to one of
57+
#' the quantile levels.
58+
#' @param quantile_levels A vector of probabilities corresponding to `values`.
59+
#' @param x An object produced by [quantile_pred()].
60+
#' @param .rows,.name_repair,rownames Arguments not used but required by the
61+
#' original S3 method.
62+
#' @param ... Not currently used.
63+
#'
64+
#' @export
65+
#' @return
66+
#' * [quantile_pred()] returns a vector of values associated with the
67+
#' quantile levels.
68+
#' * [extract_quantile_levels()] returns a numeric vector of levels.
69+
#' * [as_tibble()] returns a tibble with rows `".pred_quantile"`,
70+
#' `".quantile_levels"`, and `".row"`.
71+
#' * [as.matrix()] returns an unnamed matrix with rows as sames, columns as
72+
#' quantile levels, and entries are predictions.
73+
#' @examples
74+
#' .pred_quantile <- quantile_pred(matrix(rnorm(20), 5), c(.2, .4, .6, .8))
75+
#'
76+
#' unclass(.pred_quantile)
77+
#'
78+
#' # Access the underlying information
79+
#' extract_quantile_levels(.pred_quantile)
80+
#'
81+
#' # Matrix format
82+
#' as.matrix(.pred_quantile)
83+
#'
84+
#' # Tidy format
85+
#' tibble::as_tibble(.pred_quantile)
86+
quantile_pred <- function(values, quantile_levels = double()) {
87+
check_quantile_pred_inputs(values, quantile_levels)
88+
89+
quantile_levels <- vctrs::vec_cast(quantile_levels, double())
90+
num_lvls <- length(quantile_levels)
91+
92+
if (ncol(values) != num_lvls) {
93+
cli::cli_abort(
94+
"The number of columns in {.arg values} must be equal to the length of
95+
{.arg quantile_levels}."
96+
)
97+
}
98+
rownames(values) <- NULL
99+
colnames(values) <- NULL
100+
values <- lapply(vctrs::vec_chop(values), drop)
101+
new_quantile_pred(values, quantile_levels)
102+
}
103+
104+
check_quantile_pred_inputs <- function(values, levels, call = caller_env()) {
105+
if (any(is.na(levels))) {
106+
cli::cli_abort("Missing values are not allowed in {.arg quantile_levels}.",
107+
call = call)
29108
}
30-
n <- nrow(x)
31109

110+
if (!is.matrix(values)) {
111+
cli::cli_abort(
112+
"{.arg values} must be a {.cls matrix}, not {.obj_type_friendly {values}}.",
113+
call = call
114+
)
115+
}
116+
check_vector_probability(levels, arg = "quantile_levels", call = call)
117+
118+
if (is.unsorted(levels)) {
119+
cli::cli_abort(
120+
"{.arg quantile_levels} must be sorted in increasing order.",
121+
call = call
122+
)
123+
}
124+
invisible(NULL)
125+
}
126+
127+
#' @export
128+
format.quantile_pred <- function(x, ...) {
129+
quantile_levels <- attr(x, "quantile_levels")
130+
if (length(quantile_levels) == 1L) {
131+
x <- unlist(x)
132+
out <- round(x, 3L)
133+
out[is.na(x)] <- NA_real_
134+
} else {
135+
rng <- sapply(x, range, na.rm = TRUE)
136+
out <- paste0("[", round(rng[1, ], 3L), ", ", round(rng[2, ], 3L), "]")
137+
out[is.na(rng[1, ]) & is.na(rng[2, ])] <- NA_character_
138+
m <- median(x)
139+
out <- paste0("[", round(m, 3L), "]")
140+
}
141+
out
142+
}
143+
144+
#' @importFrom vctrs obj_print_footer
145+
#' @export
146+
vctrs::obj_print_footer
147+
148+
#' @export
149+
obj_print_footer.quantile_pred <- function(x, digits = 3, ...) {
150+
lvls <- attr(x, "quantile_levels")
151+
cat("# Quantile levels: ", format(lvls, digits = digits), "\n", sep = " ")
152+
}
153+
154+
check_vector_probability <- function(x, ...,
155+
allow_na = FALSE,
156+
allow_null = FALSE,
157+
arg = caller_arg(x),
158+
call = caller_env()) {
159+
for (d in x) {
160+
check_number_decimal(
161+
d, min = 0, max = 1,
162+
arg = arg, call = call,
163+
allow_na = allow_na,
164+
allow_null = allow_null,
165+
allow_infinite = FALSE
166+
)
167+
}
168+
}
169+
170+
#' @export
171+
median.quantile_pred <- function(x, ...) {
172+
lvls <- attr(x, "quantile_levels")
173+
loc_median <- (abs(lvls - 0.5) < sqrt(.Machine$double.eps))
174+
if (any(loc_median)) {
175+
return(map_dbl(x, ~ .x[min(which(loc_median))]))
176+
}
177+
if (length(lvls) < 2 || min(lvls) > 0.5 || max(lvls) < 0.5) {
178+
return(rep(NA, vctrs::vec_size(x)))
179+
}
180+
map_dbl(x, ~ stats::approx(lvls, .x, xout = 0.5)$y)
181+
}
182+
183+
restructure_rq_pred <- function(x, object) {
184+
if (!is.matrix(x)) {
185+
x <- as.matrix(x)
186+
}
187+
rownames(x) <- NULL
188+
n_pred_quantiles <- ncol(x)
32189
quantile_level <- object$spec$quantile_level
33-
res <-
34-
tibble::tibble(
35-
.pred_quantile = as.vector(x),
36-
.quantile_level = rep(quantile_level, each = n),
37-
.row = rep(1:n, num_quantiles))
38-
res <- vctrs::vec_split(x = res[,1:2], by = res[, ".row"])
39-
res <- vctrs::vec_cbind(res$key, tibble::new_tibble(list(.pred_quantile = res$val)))
40-
res$.row <- NULL
41-
res
190+
191+
tibble::new_tibble(x = list(.pred_quantile = quantile_pred(x, quantile_level)))
192+
}
193+
194+
#' @export
195+
#' @rdname quantile_pred
196+
extract_quantile_levels <- function(x) {
197+
if (!inherits(x, "quantile_pred")) {
198+
cli::cli_abort("{.arg x} should have class {.val quantile_pred}.")
199+
}
200+
attr(x, "quantile_levels")
42201
}
43202

203+
#' @export
204+
#' @rdname quantile_pred
205+
as_tibble.quantile_pred <-
206+
function (x, ..., .rows = NULL, .name_repair = "minimal", rownames = NULL) {
207+
lvls <- attr(x, "quantile_levels")
208+
n_samp <- length(x)
209+
n_quant <- length(lvls)
210+
tibble::tibble(
211+
.pred_quantile = unlist(x),
212+
.quantile_levels = rep(lvls, n_samp),
213+
.row = rep(1:n_samp, each = n_quant)
214+
)
215+
}
216+
217+
#' @export
218+
#' @rdname quantile_pred
219+
as.matrix.quantile_pred <- function(x, ...) {
220+
num_samp <- length(x)
221+
matrix(unlist(x), nrow = num_samp)
222+
}

R/parsnip-package.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#' @importFrom stats .checkMFClasses .getXlevels as.formula binomial coef
2222
#' @importFrom stats delete.response model.frame model.matrix model.offset
2323
#' @importFrom stats model.response model.weights na.omit na.pass predict qnorm
24-
#' @importFrom stats qt quantile setNames terms update
24+
#' @importFrom stats qt quantile setNames terms update median
2525
#' @importFrom tibble as_tibble is_tibble tibble
2626
#' @importFrom tidyr gather
2727
#' @importFrom utils capture.output getFromNamespace globalVariables head

R/predict_quantile.R

+9-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
#' @method predict_quantile model_fit
77
#' @export predict_quantile.model_fit
88
#' @export
9-
predict_quantile.model_fit <- function(object, new_data, ...) {
9+
predict_quantile.model_fit <- function(object,
10+
new_data,
11+
quantile = (1:9)/10,
12+
interval = "none",
13+
level = 0.95,
14+
...) {
1015

1116
check_spec_pred_type(object, "quantile")
1217

@@ -23,7 +28,7 @@ predict_quantile.model_fit <- function(object, new_data, ...) {
2328
}
2429

2530
# Pass some extra arguments to be used in post-processor
26-
object$spec$method$pred$quantile$args$quantile_level <- object$quantile_level
31+
object$spec$method$pred$quantile$args$p <- quantile
2732
pred_call <- make_pred_call(object$spec$method$pred$quantile)
2833

2934
res <- eval_tidy(pred_call)
@@ -40,5 +45,6 @@ predict_quantile.model_fit <- function(object, new_data, ...) {
4045
# @keywords internal
4146
# @rdname other_predict
4247
# @inheritParams predict.model_fit
43-
predict_quantile <- function (object, ...)
48+
predict_quantile <- function (object, ...) {
4449
UseMethod("predict_quantile")
50+
}

0 commit comments

Comments
 (0)