Skip to content

Commit 861a64a

Browse files
topepo‘topepo’
and
‘topepo’
authored
Change to quantile argument to quantile levels (#1208)
* quantile -> quantile_levels for #1203 * defer test until censored updates in new PR * update docs for quantile_levels * update test * disable quantile predictions for surv_reg --------- Co-authored-by: ‘topepo’ <‘mxkuhn@gmail.com’>
1 parent bef131b commit 861a64a

9 files changed

+43
-61
lines changed

Diff for: NEWS.md

+4
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,10 @@
1111

1212
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
1313

14+
## Breaking Change
15+
16+
* For quantile prediction, the `predict()` argument has been changed from `quantile` to `quantile_levels` for consistency. This does not affect models with mode `"quantile regression"`.
17+
* The quantile regression prediction type was disabled for the deprecated `surv_reg()` model.
1418

1519
# parsnip 1.2.1
1620

Diff for: R/predict.R

+1-1
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())
344344

345345
# ----------------------------------------------------------------------------
346346

347-
other_args <- c("interval", "level", "std_error", "quantile",
347+
other_args <- c("interval", "level", "std_error", "quantile_levels",
348348
"time", "eval_time", "increasing")
349349
is_pred_arg <- names(the_dots) %in% other_args
350350
if (any(!is_pred_arg)) {

Diff for: R/predict_quantile.R

+18-5
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
#' @keywords internal
22
#' @rdname other_predict
3-
#' @param quantile A vector of numbers between 0 and 1 for the quantile being
4-
#' predicted.
3+
#' @param quantile_levels A vector of values between zero and one for the
4+
#' quantile to be predicted. If the model has a `"censored regression"` mode,
5+
#' this value should be `NULL`. For other modes, the default is `(1:9)/10`.
56
#' @inheritParams predict.model_fit
67
#' @method predict_quantile model_fit
78
#' @export predict_quantile.model_fit
89
#' @export
910
predict_quantile.model_fit <- function(object,
1011
new_data,
11-
quantile = (1:9)/10,
12+
quantile_levels = NULL,
1213
interval = "none",
1314
level = 0.95,
1415
...) {
@@ -20,15 +21,27 @@ predict_quantile.model_fit <- function(object,
2021
return(NULL)
2122
}
2223

24+
if (object$spec$mode == "quantile regression") {
25+
if (!is.null(quantile_levels)) {
26+
cli::cli_abort("When the mode is {.val quantile regression},
27+
{.arg quantile_levels} are specified by {.fn set_mode}.")
28+
}
29+
} else {
30+
if (is.null(quantile_levels)) {
31+
quantile_levels <- (1:9)/10
32+
}
33+
hardhat::check_quantile_levels(quantile_levels)
34+
# Pass some extra arguments to be used in post-processor
35+
object$quantile_levels <- quantile_levels
36+
}
37+
2338
new_data <- prepare_data(object, new_data)
2439

2540
# preprocess data
2641
if (!is.null(object$spec$method$pred$quantile$pre)) {
2742
new_data <- object$spec$method$pred$quantile$pre(new_data, object)
2843
}
2944

30-
# Pass some extra arguments to be used in post-processor
31-
object$spec$method$pred$quantile$args$p <- quantile
3245
pred_call <- make_pred_call(object$spec$method$pred$quantile)
3346

3447
res <- eval_tidy(pred_call)

Diff for: R/surv_reg_data.R

-38
Original file line numberDiff line numberDiff line change
@@ -59,25 +59,6 @@ set_pred(
5959
)
6060
)
6161

62-
set_pred(
63-
model = "surv_reg",
64-
eng = "flexsurv",
65-
mode = "regression",
66-
type = "quantile",
67-
value = list(
68-
pre = NULL,
69-
post = flexsurv_quant,
70-
func = c(fun = "summary"),
71-
args =
72-
list(
73-
object = expr(object$fit),
74-
newdata = expr(new_data),
75-
type = "quantile",
76-
quantiles = expr(quantile)
77-
)
78-
)
79-
)
80-
8162
# ------------------------------------------------------------------------------
8263

8364
set_model_engine("surv_reg", mode = "regression", eng = "survival")
@@ -133,22 +114,3 @@ set_pred(
133114
)
134115
)
135116
)
136-
137-
set_pred(
138-
model = "surv_reg",
139-
eng = "survival",
140-
mode = "regression",
141-
type = "quantile",
142-
value = list(
143-
pre = NULL,
144-
post = survreg_quant,
145-
func = c(fun = "predict"),
146-
args =
147-
list(
148-
object = expr(object$fit),
149-
newdata = expr(new_data),
150-
type = "quantile",
151-
p = expr(quantile)
152-
)
153-
)
154-
)

Diff for: man/other_predict.Rd

+4-3
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: man/set_args.Rd

+1-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Diff for: tests/testthat/_snaps/linear_reg_quantreg.md

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# linear quantile regression via quantreg - multiple quantiles
2+
3+
Code
4+
ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0:
5+
9) / 9)
6+
Condition
7+
Error in `predict_quantile()`:
8+
! When the mode is "quantile regression", `quantile_levels` are specified by `set_mode()`.
9+

Diff for: tests/testthat/test-linear_reg_quantreg.R

+5
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,11 @@ test_that('linear quantile regression via quantreg - multiple quantiles', {
8383
expect_named(ten_quant_df, c(".pred_quantile", ".quantile_levels", ".row"))
8484
expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10)
8585

86+
expect_snapshot(
87+
ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0:9)/9),
88+
error = TRUE
89+
)
90+
8691
###
8792

8893
ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,])

Diff for: tests/testthat/test-surv_reg_survreg.R

+1-13
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ complete_form <- survival::Surv(time) ~ group
1010
# ------------------------------------------------------------------------------
1111

1212
test_that('survival execution', {
13-
skip_on_travis()
14-
1513
rlang::local_options(lifecycle_verbosity = "quiet")
1614
surv_basic <- surv_reg() %>% set_engine("survival")
1715
surv_lnorm <- surv_reg(dist = "lognormal") %>% set_engine("survival")
@@ -46,7 +44,7 @@ test_that('survival execution', {
4644
})
4745

4846
test_that('survival prediction', {
49-
skip_on_travis()
47+
skip_if_not_installed("survival")
5048

5149
rlang::local_options(lifecycle_verbosity = "quiet")
5250
surv_basic <- surv_reg() %>% set_engine("survival")
@@ -61,16 +59,6 @@ test_that('survival prediction', {
6159
exp_pred <- predict(extract_fit_engine(res), head(lung))
6260
exp_pred <- tibble(.pred = unname(exp_pred))
6361
expect_equal(exp_pred, predict(res, head(lung)))
64-
65-
exp_quant <- predict(extract_fit_engine(res), head(lung), p = (2:4)/5, type = "quantile")
66-
exp_quant <-
67-
apply(exp_quant, 1, function(x)
68-
tibble(.pred = x, .quantile = (2:4) / 5))
69-
exp_quant <- tibble(.pred = exp_quant)
70-
obs_quant <- predict(res, head(lung), type = "quantile", quantile = (2:4)/5)
71-
72-
expect_equal(as.data.frame(exp_quant), as.data.frame(obs_quant))
73-
7462
})
7563

7664

0 commit comments

Comments
 (0)