|
1 | 1 | set_new_model("linear_reg") |
2 | 2 |
|
3 | 3 | set_model_mode("linear_reg", "regression") |
| 4 | +set_model_mode("linear_reg", "quantile regression") |
4 | 5 |
|
5 | 6 | # ------------------------------------------------------------------------------ |
6 | 7 |
|
@@ -582,3 +583,74 @@ set_pred( |
582 | 583 | ) |
583 | 584 | ) |
584 | 585 |
|
| 586 | +# ------------------------------------------------------------------------------ |
| 587 | + |
| 588 | +set_model_engine("linear_reg", "quantile regression", "quantreg") |
| 589 | +set_dependency("linear_reg", "quantreg", "quantreg") |
| 590 | + |
| 591 | +set_fit( |
| 592 | + model = "linear_reg", |
| 593 | + eng = "quantreg", |
| 594 | + mode = "quantile regression", |
| 595 | + value = list( |
| 596 | + interface = "formula", |
| 597 | + protect = c("formula", "data", "weights"), |
| 598 | + func = c(pkg = "quantreg", fun = "rq"), |
| 599 | + defaults = list() |
| 600 | + ) |
| 601 | +) |
| 602 | + |
| 603 | +set_encoding( |
| 604 | + model = "linear_reg", |
| 605 | + eng = "quantreg", |
| 606 | + mode = "quantile regression", |
| 607 | + options = list( |
| 608 | + predictor_indicators = "traditional", |
| 609 | + compute_intercept = TRUE, |
| 610 | + remove_intercept = TRUE, |
| 611 | + allow_sparse_x = FALSE |
| 612 | + ) |
| 613 | +) |
| 614 | + |
| 615 | +set_pred( |
| 616 | + model = "linear_reg", |
| 617 | + eng = "quantreg", |
| 618 | + mode = "quantile regression", |
| 619 | + type = "numeric", |
| 620 | + value = list( |
| 621 | + pre = NULL, |
| 622 | + post = NULL, |
| 623 | + func = c(fun = "predict"), |
| 624 | + args = |
| 625 | + list( |
| 626 | + object = expr(object$fit), |
| 627 | + newdata = expr(new_data), |
| 628 | + type = "response", |
| 629 | + rankdeficient = "simple" |
| 630 | + ) |
| 631 | + ) |
| 632 | +) |
| 633 | + |
| 634 | +set_pred( |
| 635 | + model = "linear_reg", |
| 636 | + eng = "quantreg", |
| 637 | + mode = "quantile regression", |
| 638 | + type = "conf_int", |
| 639 | + value = list( |
| 640 | + pre = NULL, |
| 641 | + post = function(results, object) { |
| 642 | + tibble::as_tibble(results) %>% |
| 643 | + dplyr::select(-fit) %>% |
| 644 | + setNames(c(".pred_lower", ".pred_upper")) |
| 645 | + }, |
| 646 | + func = c(fun = "predict"), |
| 647 | + args = |
| 648 | + list( |
| 649 | + object = expr(object$fit), |
| 650 | + newdata = expr(new_data), |
| 651 | + interval = "confidence", |
| 652 | + level = expr(level) |
| 653 | + ) |
| 654 | + ) |
| 655 | +) |
| 656 | + |
0 commit comments