diff --git a/DESCRIPTION b/DESCRIPTION index 3e8db4117..4da92afe6 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: epipredict Title: Basic epidemiology forecasting methods -Version: 0.0.18 +Version: 0.0.19 Authors@R: c( person("Daniel", "McDonald", , "daniel@stat.ubc.ca", role = c("aut", "cre")), person("Ryan", "Tibshirani", , "ryantibs@cmu.edu", role = "aut"), @@ -36,10 +36,8 @@ Imports: glue, hardhat (>= 1.3.0), magrittr, - quantreg, recipes (>= 1.0.4), rlang (>= 1.0.0), - smoothqr, stats, tibble, tidyr, @@ -52,13 +50,16 @@ Suggests: data.table, epidatr (>= 1.0.0), fs, + grf, knitr, lubridate, poissonreg, purrr, + quantreg, ranger, RcppRoll, rmarkdown, + smoothqr, testthat (>= 3.0.0), usethis, xgboost diff --git a/NAMESPACE b/NAMESPACE index 5d045ec8f..118b12aff 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -240,7 +240,6 @@ importFrom(ggplot2,autoplot) importFrom(hardhat,refresh_blueprint) importFrom(hardhat,run_mold) importFrom(magrittr,"%>%") -importFrom(quantreg,rq) importFrom(recipes,bake) importFrom(recipes,prep) importFrom(rlang,"!!!") @@ -253,13 +252,13 @@ importFrom(rlang,as_function) importFrom(rlang,caller_env) importFrom(rlang,enquo) importFrom(rlang,enquos) +importFrom(rlang,expr) importFrom(rlang,global_env) importFrom(rlang,inject) importFrom(rlang,is_logical) importFrom(rlang,is_null) importFrom(rlang,is_true) importFrom(rlang,set_names) -importFrom(smoothqr,smooth_qr) importFrom(stats,as.formula) importFrom(stats,family) importFrom(stats,lm) diff --git a/NEWS.md b/NEWS.md index 12780f208..bb9aad743 100644 --- a/NEWS.md +++ b/NEWS.md @@ -52,4 +52,5 @@ Pre-1.0.0 numbering scheme: 0.x will indicate releases, while 0.0.x will indicat `...` args intended for `predict.model_fit()` - `bake.epi_recipe()` will now re-infer the geo and time type in case baking the steps has changed the appropriate values -- Add `step_epi_slide` to produce generic sliding computations over an `epi_df` \ No newline at end of file +- Add `step_epi_slide` to produce generic sliding computations over an `epi_df` +- Add quantile random forests (via `{grf}`) as a parsnip engine diff --git a/R/epipredict-package.R b/R/epipredict-package.R index 554374533..6ca349570 100644 --- a/R/epipredict-package.R +++ b/R/epipredict-package.R @@ -1,7 +1,7 @@ ## usethis namespace: start #' @importFrom tibble tibble #' @importFrom rlang := !! %||% as_function global_env set_names !!! -#' @importFrom rlang is_logical is_true inject enquo enquos +#' @importFrom rlang is_logical is_true inject enquo enquos expr #' @importFrom stats poly predict lm residuals quantile #' @importFrom cli cli_abort #' @importFrom checkmate assert assert_character assert_int assert_scalar diff --git a/R/layer_quantile_distn.R b/R/layer_quantile_distn.R index 734ccec9e..ea96969da 100644 --- a/R/layer_quantile_distn.R +++ b/R/layer_quantile_distn.R @@ -1,9 +1,14 @@ #' Returns predictive quantiles #' #' This function calculates quantiles when the prediction was _distributional_. -#' Currently, the only distributional engine is `quantile_reg()`. -#' If this engine is used, then this layer will grab out estimated (or extrapolated) -#' quantiles at the requested quantile values. +#' +#' Currently, the only distributional modes/engines are +#' * `quantile_reg()` +#' * `smooth_quantile_reg()` +#' * `rand_forest(mode = "regression") %>% set_engine("grf_quantiles")` +#' +#' If these engines were used, then this layer will grab out estimated +#' (or extrapolated) quantiles at the requested quantile values. #' #' @param frosting a `frosting` postprocessor #' @param ... Unused, include for consistency with other layers. diff --git a/R/make_grf_quantiles.R b/R/make_grf_quantiles.R new file mode 100644 index 000000000..253ea1ac7 --- /dev/null +++ b/R/make_grf_quantiles.R @@ -0,0 +1,193 @@ +#' Random quantile forests via grf +#' +#' [grf::quantile_forest()] fits random forests in a way that makes it easy +#' to calculate _quantile_ forests. Currently, this is the only engine +#' provided here, since quantile regression is the typical use-case. +#' +#' @section Tuning Parameters: +#' +#' This model has 3 tuning parameters: +#' +#' - `mtry`: # Randomly Selected Predictors (type: integer, default: see below) +#' - `trees`: # Trees (type: integer, default: 2000L) +#' - `min_n`: Minimal Node Size (type: integer, default: 5) +#' +#' `mtry` depends on the number of columns in the design matrix. +#' The default in [grf::quantile_forest()] is `min(ceiling(sqrt(ncol(X)) + 20), ncol(X))`. +#' +#' For categorical predictors, a one-hot encoding is always used. This makes +#' splitting efficient, but has implications for the `mtry` choice. A factor +#' with many levels will become a large number of columns in the design matrix +#' which means that some of these may be selected frequently for potential splits. +#' This is different than in other implementations of random forest. For more +#' details, see [the `grf` discussion](https://grf-labs.github.io/grf/articles/categorical_inputs.html). +#' +#' @section Translation from parsnip to the original package: +#' +#' ```{r, translate-engine} +#' rand_forest( +#' mode = "regression", # you must specify the `mode = regression` +#' mtry = integer(1), +#' trees = integer(1), +#' min_n = integer(1) +#' ) %>% +#' set_engine("grf_quantiles") %>% +#' translate() +#' ``` +#' +#' @section Case weights: +#' +#' Case weights are not supported. +#' +#' @examples +#' library(grf) +#' tib <- data.frame( +#' y = rnorm(100), x = rnorm(100), z = rnorm(100), +#' f = factor(sample(letters[1:3], 100, replace = TRUE)) +#' ) +#' spec <- rand_forest(engine = "grf_quantiles", mode = "regression") +#' out <- fit(spec, formula = y ~ x + z, data = tib) +#' predict(out, new_data = tib[1:5, ]) %>% +#' pivot_quantiles_wider(.pred) +#' +#' # -- adjusting the desired quantiles +#' +#' spec <- rand_forest(mode = "regression") %>% +#' set_engine(engine = "grf_quantiles", quantiles = c(1:9 / 10)) +#' out <- fit(spec, formula = y ~ x + z, data = tib) +#' predict(out, new_data = tib[1:5, ]) %>% +#' pivot_quantiles_wider(.pred) +#' +#' # -- a more complicated task +#' +#' library(dplyr) +#' dat <- case_death_rate_subset %>% +#' filter(time_value > as.Date("2021-10-01")) +#' rec <- epi_recipe(dat) %>% +#' step_epi_lag(case_rate, death_rate, lag = c(0, 7, 14)) %>% +#' step_epi_ahead(death_rate, ahead = 7) %>% +#' step_epi_naomit() +#' frost <- frosting() %>% +#' layer_predict() %>% +#' layer_threshold(.pred) +#' spec <- rand_forest(mode = "regression") %>% +#' set_engine(engine = "grf_quantiles", quantiles = c(.25, .5, .75)) +#' +#' ewf <- epi_workflow(rec, spec, frost) %>% +#' fit(dat) %>% +#' forecast() +#' ewf %>% +#' rename(forecast_date = time_value) %>% +#' mutate(target_date = forecast_date + 7L) %>% +#' pivot_quantiles_wider(.pred) +#' +#' @name grf_quantiles +NULL + + + +make_grf_quantiles <- function() { + parsnip::set_model_engine( + model = "rand_forest", mode = "regression", eng = "grf_quantiles" + ) + parsnip::set_dependency( + model = "rand_forest", eng = "grf_quantiles", pkg = "grf", + mode = "regression" + ) + + + # These are the arguments to the parsnip::rand_forest() that must be + # translated from grf::quantile_forest + parsnip::set_model_arg( + model = "rand_forest", + eng = "grf_quantiles", + parsnip = "mtry", + original = "mtry", + func = list(pkg = "dials", fun = "mtry"), + has_submodel = FALSE + ) + parsnip::set_model_arg( + model = "rand_forest", + eng = "grf_quantiles", + parsnip = "trees", + original = "num.trees", + func = list(pkg = "dials", fun = "trees"), + has_submodel = FALSE + ) + parsnip::set_model_arg( + model = "rand_forest", + eng = "grf_quantiles", + parsnip = "min_n", + original = "min.node.size", + func = list(pkg = "dials", fun = "min_n"), + has_submodel = FALSE + ) + + # the `value` list describes how grf::quantile_forest expects to receive + # arguments. In particular, it needs X and Y to be passed in as a matrices. + # But the matrix interface in parsnip calls these x and y. So the data + # slot translates them + # + # protect - prevents the user from passing X and Y arguments themselves + # defaults - engine specific arguments (not model specific) that we allow + # the user to change + parsnip::set_fit( + model = "rand_forest", + eng = "grf_quantiles", + mode = "regression", + value = list( + interface = "matrix", + protect = c("X", "Y"), + data = c(x = "X", y = "Y"), + func = c(pkg = "grf", fun = "quantile_forest"), + defaults = list( + quantiles = c(0.1, 0.5, 0.9), + num.threads = 1L, + seed = rlang::expr(stats::runif(1, 0, .Machine$integer.max)) + ) + ) + ) + + parsnip::set_encoding( + model = "rand_forest", + eng = "grf_quantiles", + mode = "regression", + options = list( + # one hot is the closest to typical factor handling in randomForest + # (1 vs all splitting), though since we aren't bagging, + # factors with many levels could be visited frequently + predictor_indicators = "one_hot", + compute_intercept = FALSE, + remove_intercept = FALSE, + allow_sparse_x = FALSE + ) + ) + + # turn the predictions into a tibble with a dist_quantiles column + process_qrf_preds <- function(x, object) { + quantile_levels <- parsnip::extract_fit_engine(object)$quantiles.orig + x <- x$predictions + out <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x))) + out <- dist_quantiles(out, list(quantile_levels)) + return(dplyr::tibble(.pred = out)) + } + + parsnip::set_pred( + model = "rand_forest", + eng = "grf_quantiles", + mode = "regression", + type = "numeric", + value = list( + pre = NULL, + post = process_qrf_preds, + func = c(fun = "predict"), + # map between parsnip::predict args and grf::quantile_forest args + args = list( + object = quote(object$fit), + newdata = quote(new_data), + seed = rlang::expr(sample.int(10^5, 1)), + verbose = FALSE + ) + ) + ) +} diff --git a/R/make_quantile_reg.R b/R/make_quantile_reg.R index 832ef50f8..865e169e9 100644 --- a/R/make_quantile_reg.R +++ b/R/make_quantile_reg.R @@ -3,12 +3,14 @@ #' @description #' `quantile_reg()` generates a quantile regression model _specification_ for #' the [tidymodels](https://www.tidymodels.org/) framework. Currently, the -#' only supported engine is "rq" which uses [quantreg::rq()]. +#' only supported engines are "rq", which uses [quantreg::rq()]. +#' Quantile regression is also possible by combining [parsnip::rand_forest()] +#' with the `grf` engine. See [grf_quantiles]. #' #' @param mode A single character string for the type of model. #' The only possible value for this model is "regression". #' @param engine Character string naming the fitting function. Currently, only -#' "rq" is supported. +#' "rq" and "grf" are supported. #' @param quantile_levels A scalar or vector of values in (0, 1) to determine which #' quantiles to estimate (default is 0.5). #' @@ -16,8 +18,9 @@ #' #' @seealso [fit.model_spec()], [set_engine()] #' -#' @importFrom quantreg rq +#' #' @examples +#' library(quantreg) #' tib <- data.frame(y = rnorm(100), x1 = rnorm(100), x2 = rnorm(100)) #' rq_spec <- quantile_reg(quantile_levels = c(.2, .8)) %>% set_engine("rq") #' ff <- rq_spec %>% fit(y ~ ., data = tib) @@ -106,7 +109,7 @@ make_quantile_reg <- function() { out <- switch(type, rq = dist_quantiles(unname(as.list(x)), object$quantile_levels), # one quantile rqs = { - x <- lapply(unname(split(x, seq(nrow(x)))), function(q) sort(q)) + x <- lapply(vctrs::vec_chop(x), function(x) sort(drop(x))) dist_quantiles(x, list(object$tau)) }, cli_abort(c( @@ -114,7 +117,7 @@ make_quantile_reg <- function() { i = "See {.fun quantreg::rq}." )) ) - return(data.frame(.pred = out)) + return(dplyr::tibble(.pred = out)) } diff --git a/R/make_smooth_quantile_reg.R b/R/make_smooth_quantile_reg.R index 9ab3a366b..448ee0fa5 100644 --- a/R/make_smooth_quantile_reg.R +++ b/R/make_smooth_quantile_reg.R @@ -21,8 +21,8 @@ #' #' @seealso [fit.model_spec()], [set_engine()] #' -#' @importFrom smoothqr smooth_qr #' @examples +#' library(smoothqr) #' tib <- data.frame( #' y1 = rnorm(100), y2 = rnorm(100), y3 = rnorm(100), #' y4 = rnorm(100), y5 = rnorm(100), y6 = rnorm(100), @@ -62,17 +62,16 @@ #' lines(pl$x, pl$`0.8`, col = "blue") #' lines(pl$x, pl$`0.5`, col = "red") #' -#' if (require("ggplot2")) { -#' ggplot(data.frame(x = x, y = y), aes(x)) + -#' geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") + -#' geom_point(aes(y = y), colour = "grey") + # observed data -#' geom_function(fun = sin, colour = "black") + # truth -#' geom_vline(xintercept = fd, linetype = "dashed") + # end of training data -#' geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction -#' theme_bw() + -#' coord_cartesian(xlim = c(0, NA)) + -#' ylab("y") -#' } +#' library(ggplot2) +#' ggplot(data.frame(x = x, y = y), aes(x)) + +#' geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") + +#' geom_point(aes(y = y), colour = "grey") + # observed data +#' geom_function(fun = sin, colour = "black") + # truth +#' geom_vline(xintercept = fd, linetype = "dashed") + # end of training data +#' geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction +#' theme_bw() + +#' coord_cartesian(xlim = c(0, NA)) + +#' ylab("y") smooth_quantile_reg <- function( mode = "regression", engine = "smoothqr", diff --git a/R/zzz.R b/R/zzz.R index bb7cff9bf..7e335b67d 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -8,4 +8,5 @@ make_flatline_reg() make_quantile_reg() make_smooth_quantile_reg() + make_grf_quantiles() } diff --git a/man/grf_quantiles.Rd b/man/grf_quantiles.Rd new file mode 100644 index 000000000..e6852a55b --- /dev/null +++ b/man/grf_quantiles.Rd @@ -0,0 +1,108 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/make_grf_quantiles.R +\name{grf_quantiles} +\alias{grf_quantiles} +\title{Random quantile forests via grf} +\description{ +\code{\link[grf:quantile_forest]{grf::quantile_forest()}} fits random forests in a way that makes it easy +to calculate \emph{quantile} forests. Currently, this is the only engine +provided here, since quantile regression is the typical use-case. +} +\section{Tuning Parameters}{ + + +This model has 3 tuning parameters: +\itemize{ +\item \code{mtry}: # Randomly Selected Predictors (type: integer, default: see below) +\item \code{trees}: # Trees (type: integer, default: 2000L) +\item \code{min_n}: Minimal Node Size (type: integer, default: 5) +} + +\code{mtry} depends on the number of columns in the design matrix. +The default in \code{\link[grf:quantile_forest]{grf::quantile_forest()}} is \code{min(ceiling(sqrt(ncol(X)) + 20), ncol(X))}. + +For categorical predictors, a one-hot encoding is always used. This makes +splitting efficient, but has implications for the \code{mtry} choice. A factor +with many levels will become a large number of columns in the design matrix +which means that some of these may be selected frequently for potential splits. +This is different than in other implementations of random forest. For more +details, see \href{https://grf-labs.github.io/grf/articles/categorical_inputs.html}{the \code{grf} discussion}. +} + +\section{Translation from parsnip to the original package}{ + + +\if{html}{\out{
}}\preformatted{rand_forest( + mode = "regression", # you must specify the `mode = regression` + mtry = integer(1), + trees = integer(1), + min_n = integer(1) +) \%>\% + set_engine("grf_quantiles") \%>\% + translate() +#> Random Forest Model Specification (regression) +#> +#> Main Arguments: +#> mtry = integer(1) +#> trees = integer(1) +#> min_n = integer(1) +#> +#> Computational engine: grf_quantiles +#> +#> Model fit template: +#> grf::quantile_forest(X = missing_arg(), Y = missing_arg(), mtry = min_cols(~integer(1), +#> x), num.trees = integer(1), min.node.size = min_rows(~integer(1), +#> x), quantiles = c(0.1, 0.5, 0.9), num.threads = 1L, seed = stats::runif(1, +#> 0, .Machine$integer.max)) +}\if{html}{\out{
}} +} + +\section{Case weights}{ + + +Case weights are not supported. +} + +\examples{ +library(grf) +tib <- data.frame( + y = rnorm(100), x = rnorm(100), z = rnorm(100), + f = factor(sample(letters[1:3], 100, replace = TRUE)) +) +spec <- rand_forest(engine = "grf_quantiles", mode = "regression") +out <- fit(spec, formula = y ~ x + z, data = tib) +predict(out, new_data = tib[1:5, ]) \%>\% + pivot_quantiles_wider(.pred) + +# -- adjusting the desired quantiles + +spec <- rand_forest(mode = "regression") \%>\% + set_engine(engine = "grf_quantiles", quantiles = c(1:9 / 10)) +out <- fit(spec, formula = y ~ x + z, data = tib) +predict(out, new_data = tib[1:5, ]) \%>\% + pivot_quantiles_wider(.pred) + +# -- a more complicated task + +library(dplyr) +dat <- case_death_rate_subset \%>\% + filter(time_value > as.Date("2021-10-01")) +rec <- epi_recipe(dat) \%>\% + step_epi_lag(case_rate, death_rate, lag = c(0, 7, 14)) \%>\% + step_epi_ahead(death_rate, ahead = 7) \%>\% + step_epi_naomit() +frost <- frosting() \%>\% + layer_predict() \%>\% + layer_threshold(.pred) +spec <- rand_forest(mode = "regression") \%>\% + set_engine(engine = "grf_quantiles", quantiles = c(.25, .5, .75)) + +ewf <- epi_workflow(rec, spec, frost) \%>\% + fit(dat) \%>\% + forecast() +ewf \%>\% + rename(forecast_date = time_value) \%>\% + mutate(target_date = forecast_date + 7L) \%>\% + pivot_quantiles_wider(.pred) + +} diff --git a/man/layer_quantile_distn.Rd b/man/layer_quantile_distn.Rd index 695a1d12d..f5de4aa19 100644 --- a/man/layer_quantile_distn.Rd +++ b/man/layer_quantile_distn.Rd @@ -32,9 +32,17 @@ quantiles will be added to the predictions. } \description{ This function calculates quantiles when the prediction was \emph{distributional}. -Currently, the only distributional engine is \code{quantile_reg()}. -If this engine is used, then this layer will grab out estimated (or extrapolated) -quantiles at the requested quantile values. +} +\details{ +Currently, the only distributional modes/engines are +\itemize{ +\item \code{quantile_reg()} +\item \code{smooth_quantile_reg()} +\item \code{rand_forest(mode = "regression") \%>\% set_engine("grf_quantiles")} +} + +If these engines were used, then this layer will grab out estimated +(or extrapolated) quantiles at the requested quantile values. } \examples{ jhu <- case_death_rate_subset \%>\% diff --git a/man/quantile_reg.Rd b/man/quantile_reg.Rd index 8e576ac84..981918fbe 100644 --- a/man/quantile_reg.Rd +++ b/man/quantile_reg.Rd @@ -11,7 +11,7 @@ quantile_reg(mode = "regression", engine = "rq", quantile_levels = 0.5) The only possible value for this model is "regression".} \item{engine}{Character string naming the fitting function. Currently, only -"rq" is supported.} +"rq" and "grf" are supported.} \item{quantile_levels}{A scalar or vector of values in (0, 1) to determine which quantiles to estimate (default is 0.5).} @@ -19,9 +19,12 @@ quantiles to estimate (default is 0.5).} \description{ \code{quantile_reg()} generates a quantile regression model \emph{specification} for the \href{https://www.tidymodels.org/}{tidymodels} framework. Currently, the -only supported engine is "rq" which uses \code{\link[quantreg:rq]{quantreg::rq()}}. +only supported engines are "rq", which uses \code{\link[quantreg:rq]{quantreg::rq()}}. +Quantile regression is also possible by combining \code{\link[parsnip:rand_forest]{parsnip::rand_forest()}} +with the \code{grf} engine. See \link{grf_quantiles}. } \examples{ +library(quantreg) tib <- data.frame(y = rnorm(100), x1 = rnorm(100), x2 = rnorm(100)) rq_spec <- quantile_reg(quantile_levels = c(.2, .8)) \%>\% set_engine("rq") ff <- rq_spec \%>\% fit(y ~ ., data = tib) diff --git a/man/smooth_quantile_reg.Rd b/man/smooth_quantile_reg.Rd index bd8c012f2..c6b17dd86 100644 --- a/man/smooth_quantile_reg.Rd +++ b/man/smooth_quantile_reg.Rd @@ -36,6 +36,7 @@ the \href{https://www.tidymodels.org/}{tidymodels} framework. Currently, the only supported engine is \code{\link[smoothqr:smooth_qr]{smoothqr::smooth_qr()}}. } \examples{ +library(smoothqr) tib <- data.frame( y1 = rnorm(100), y2 = rnorm(100), y3 = rnorm(100), y4 = rnorm(100), y5 = rnorm(100), y6 = rnorm(100), @@ -75,17 +76,16 @@ lines(pl$x, pl$`0.2`, col = "blue") lines(pl$x, pl$`0.8`, col = "blue") lines(pl$x, pl$`0.5`, col = "red") -if (require("ggplot2")) { - ggplot(data.frame(x = x, y = y), aes(x)) + - geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") + - geom_point(aes(y = y), colour = "grey") + # observed data - geom_function(fun = sin, colour = "black") + # truth - geom_vline(xintercept = fd, linetype = "dashed") + # end of training data - geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction - theme_bw() + - coord_cartesian(xlim = c(0, NA)) + - ylab("y") -} +library(ggplot2) +ggplot(data.frame(x = x, y = y), aes(x)) + + geom_ribbon(data = pl, aes(ymin = `0.2`, ymax = `0.8`), fill = "lightblue") + + geom_point(aes(y = y), colour = "grey") + # observed data + geom_function(fun = sin, colour = "black") + # truth + geom_vline(xintercept = fd, linetype = "dashed") + # end of training data + geom_line(data = pl, aes(y = `0.5`), colour = "red") + # median prediction + theme_bw() + + coord_cartesian(xlim = c(0, NA)) + + ylab("y") } \seealso{ \code{\link[=fit.model_spec]{fit.model_spec()}}, \code{\link[=set_engine]{set_engine()}} diff --git a/tests/testthat/test-grf_quantiles.R b/tests/testthat/test-grf_quantiles.R new file mode 100644 index 000000000..2570c247d --- /dev/null +++ b/tests/testthat/test-grf_quantiles.R @@ -0,0 +1,52 @@ +set.seed(12345) +library(grf) +tib <- tibble( + y = rnorm(100), x = rnorm(100), z = rnorm(100), + f = factor(sample(letters[1:3], 100, replace = TRUE)) +) + +test_that("quantile_rand_forest defaults work", { + spec <- rand_forest(engine = "grf_quantiles", mode = "regression") + expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib)) + pars <- parsnip::extract_fit_engine(out) + manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, quantiles = c(0.1, 0.5, 0.9)) + expect_identical(pars$quantiles.orig, manual$quantiles) + expect_identical(pars$`_num_trees`, manual$`_num_trees`) + + fseed <- 12345 + spec_seed <- rand_forest(mode = "regression", mtry = 2L, min_n = 10) %>% + set_engine("grf_quantiles", seed = fseed) + out <- fit(spec_seed, formula = y ~ x + z - 1, data = tib) + manual <- quantile_forest( + as.matrix(tib[, 2:3]), tib$y, + quantiles = c(0.1, 0.5, 0.9), seed = fseed, + mtry = 2L, min.node.size = 10 + ) + p_pars <- predict(out, new_data = tib[1:5, ]) %>% + pivot_quantiles_wider(.pred) + p_manual <- predict(manual, newdata = as.matrix(tib[1:5, 2:3]))$predictions + colnames(p_manual) <- c("0.1", "0.5", "0.9") + p_manual <- tibble::as_tibble(p_manual) + # not equal despite the seed, etc + # expect_equal(p_pars, p_manual) +}) + +test_that("quantile_rand_forest handles alternative quantiles", { + spec <- rand_forest(mode = "regression") %>% + set_engine("grf_quantiles", quantiles = c(.2, .5, .8)) + expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib)) + pars <- parsnip::extract_fit_engine(out) + manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, quantiles = c(.2, .5, .8)) + expect_identical(pars$quantiles.orig, manual$quantiles.orig) + expect_identical(pars$`_num_trees`, manual$`_num_trees`) +}) + + +test_that("quantile_rand_forest handles allows setting the trees and mtry", { + spec <- rand_forest(mode = "regression", mtry = 2, trees = 100, engine = "grf_quantiles") + expect_silent(out <- fit(spec, formula = y ~ x + z, data = tib)) + pars <- parsnip::extract_fit_engine(out) + manual <- quantile_forest(as.matrix(tib[, 2:3]), tib$y, mtry = 2, num.trees = 100) + expect_identical(pars$quantiles.orig, manual$quantiles.orig) + expect_identical(pars$`_num_trees`, manual$`_num_trees`) +})