Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

prophet_xgboost - prophet_boost #251

Closed
Tracked by #249
spsanderson opened this issue Apr 18, 2022 · 0 comments · Fixed by #262
Closed
Tracked by #249

prophet_xgboost - prophet_boost #251

spsanderson opened this issue Apr 18, 2022 · 0 comments · Fixed by #262
Assignees
Labels

Comments

@spsanderson
Copy link
Owner

spsanderson commented Apr 18, 2022

Function:

library(healthyverse)
library(dplyr)
library(recipes)
library(timetk)
library(rsample)
library(dials)
library(modeltime)

get_recipe_call <- function(.rec_call){
  cl <- .rec_call
  cl$tune <- NULL
  cl$verbose <- NULL
  cl$colors <- NULL
  cl$prefix <- NULL
  rec_cl <- cl
  rec_cl[[1]] <- rlang::expr(recipe)
  rec_cl
}

assign_value <- function(name, value, cr = TRUE) {
  value <- rlang::enexpr(value)
  value <- rlang::expr_text(value, width = 74L)
  chr_assign(name, value, cr)
}

chr_assign <- function(name, value, cr = TRUE) {
  name <- paste(name, "<-")
  if (cr) {
    res <- c(name, paste0("\n  ", value))
  } else {
    res <- paste(name, value)
  }
  res
}

ts_auto_prophet_boost <- function(.data, .date_col, .value_col, .formula, .rsamp_obj,
                           .prefix = "ts_prophet_boost", .tune = TRUE, .grid_size = 10,
                           .num_cores = 1, .cv_assess = 12, .cv_skip = 3, 
                           .cv_slice_limit = 6, .best_metric = "rmse", 
                           .bootstrap_final = FALSE){
  
  # Tidyeval ----
  date_col_var_expr <- rlang::enquo(.date_col)
  value_col_var_expr <- rlang::enquo(.value_col)
  sampling_object <- .rsamp_obj
  # Cross Validation
  cv_assess = as.numeric(.cv_assess)
  cv_skip   = as.numeric(.cv_skip)
  cv_slice  = as.numeric(.cv_slice_limit)
  # Tuning Grid
  grid_size <- as.numeric(.grid_size)
  num_cores <- as.numeric(.num_cores)
  best_metric <- as.character(.best_metric)
  # Data and splits
  splits <- .rsamp_obj
  data_tbl <- dplyr::as_tibble(.data)
  
  # Checks ----
  if (rlang::quo_is_missing(date_col_var_expr)){
    rlang::abort(
      message = "'.date_col' must be supplied.",
      use_cli_format = TRUE
    )
  }
  
  if (rlang::quo_is_missing(value_col_var_expr)){
    rlang::abort(
      message = "'.value_col' must be supplied.",
      use_cli_format = TRUE
    )
  }
  
  if (!inherits(x = splits, what = "rsplit")){
    rlang::abort(
      message = "'.rsamp_obj' must be have class rsplit, use the rsample package.",
      use_cli_format = TRUE
    )
  }
  
  # Recipe ----
  # Get the initial recipe call
  recipe_call <- get_recipe_call(match.call())
  
  rec_syntax <- paste0(.prefix, "_recipe") %>%
    assign_value(!!recipe_call)
  
  rec_obj <- recipes::recipe(formula = .formula, data = data_tbl)
  
  rec_obj <- rec_obj %>%
    timetk::step_timeseries_signature({{date_col_var_expr}}) %>%
    timetk::step_holiday_signature({{date_col_var_expr}}) %>%
    recipes::step_novel(recipes::all_nominal_predictors()) %>%
    recipes::step_mutate_at(tidyselect::vars_select_helpers$where(is.character)
                            , fn = ~ as.factor(.)) %>%
    #recipes::step_rm({{date_col_var_expr}}) %>%
    recipes::step_dummy(recipes::all_nominal(), one_hot = TRUE) %>%
    recipes::step_zv(recipes::all_predictors(), -date_col_index.num) %>%
    recipes::step_normalize(recipes::all_numeric_predictors())
  
  # Tune/Spec ----
  if (.tune){
    model_spec <- modeltime::prophet_boost(
      changepoint_num      = tune::tune()
      , changepoint_range  = tune::tune()
      , seasonality_yearly = FALSE
      , seasonality_weekly = FALSE
      , seasonality_daily  = FALSE
      , prior_scale_changepoints = tune::tune()
      , prior_scale_seasonality  = tune::tune()
      , prior_scale_holidays     = tune::tune()
      , trees                    = tune::tune()
      , min_n                    = tune::tune()
      , tree_depth               = tune::tune()
      , learn_rate               = tune::tune()
      , loss_reduction           = tune::tune()
      , stop_iter                = tune::tune()
    )
  } else {
    model_spec <- modeltime::prophet_reg()
  }
  
  model_spec <- model_spec %>%
    parsnip::set_mode(mode = "regression") %>%
    parsnip::set_engine("prophet_xgboost")
  
  # Workflow ----
  wflw <- workflows::workflow() %>%
    workflows::add_recipe(rec_obj) %>%
    workflows::add_model(model_spec) 
  
  # Tuning Grid ----
  if (.tune){
    
    # Start parallel backend
    modeltime::parallel_start(num_cores)
    
    tuning_grid_spec <- dials::grid_latin_hypercube(
      hardhat::extract_parameter_set_dials(model_spec),
      size = grid_size
    ) 
    
    # Make TS CV ----
    tscv <- timetk::time_series_cv(
      data        = rsample::training(splits),
      date_var    = {{date_col_var_expr}},
      cumulative  = TRUE,
      assess      = cv_assess,
      skip        = cv_skip,
      slice_limit = cv_slice
    )
    
    # Tune the workflow
    tuned_results <- wflw %>%
      tune::tune_grid(
        resamples = tscv,
        grid      = tuning_grid_spec,
        metrics   = modeltime::default_forecast_accuracy_metric_set()
      )
    
    # Get the best result set by a specified metric
    best_result_set <- tuned_results %>%
      tune::show_best(metric = best_metric, n = 1)
    
    # Plot results
    tune_results_plt <- tuned_results %>%
      tune::autoplot() +
      ggplot2::theme_minimal() + 
      ggplot2::geom_smooth(se = FALSE)
    
    # Make final workflow
    wflw_fit <- wflw %>%
      tune::finalize_workflow(
        tuned_results %>%
          tune::show_best(metric = best_metric, n = Inf) %>%
          dplyr::slice(1)
      ) %>%
      parsnip::fit(rsample::training(splits))
    
    # Stop parallel backend
    modeltime::parallel_stop()
    
  } else {
    wflw_fit <- wflw %>%
      parsnip::fit(rsample::training(splits))
  }
  
  # Calibrate and Plot ----
  cap <- healthyR.ts::calibrate_and_plot(
    wflw_fit,
    .splits_obj  = splits,
    .data        = data_tbl,
    .interactive = TRUE,
    .print_info  = FALSE
  )
  
  # Return ----
  output <- list(
    recipe_info = list(
      recipe_call   = recipe_call,
      recipe_syntax = rec_syntax,
      rec_obj       = rec_obj
    ),
    model_info = list(
      model_spec  = model_spec,
      wflw        = wflw,
      fitted_wflw = wflw_fit,
      was_tuned   = ifelse(.tune, "tuned", "not_tuned")
    ),
    model_calibration = list(
      plot = cap$plot,
      calibration_tbl = cap$calibration_tbl,
      model_accuracy = cap$model_accuracy
    )
  )
  
  if (.tune){
    output$tuned_info = list(
      tuning_grid      = tuning_grid_spec,
      tscv             = tscv,
      tuned_results    = tuned_results,
      grid_size        = grid_size,
      best_metric      = best_metric,
      best_result_set  = best_result_set,
      tuning_grid_plot = tune_results_plt,
      plotly_grid_plot = plotly::ggplotly(tune_results_plt)
    )
  }
  
  return(invisible(output))
}

Example:

data <- AirPassengers %>%
  ts_to_tbl() %>%
  select(-index)

splits <- time_series_split(
  data
  , date_col
  , assess = 12
  , skip = 3
  , cumulative = TRUE
)

tst_pboost <- ts_auto_prophet_boost(
  .data = data, 
  .num_cores = 5,
  .date_col = date_col, 
  .value_col = value, 
  .rsamp_obj = splits,
  .formula = value ~ .,
  .grid_size = 5
)

image

@spsanderson spsanderson self-assigned this Apr 18, 2022
@spsanderson spsanderson added this to the healthyR.ts 0.1.9 milestone Apr 18, 2022
@spsanderson spsanderson moved this from Todo to In Progress in @spsanderson's Repository Issue Overview Apr 18, 2022
spsanderson added a commit that referenced this issue Apr 19, 2022
Fixes #249

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
Fixes #251
@spsanderson spsanderson mentioned this issue Apr 19, 2022
Repository owner moved this from In Progress to Done in @spsanderson's Repository Issue Overview Apr 19, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Development

Successfully merging a pull request may close this issue.

1 participant