We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nnetar
Function:
library(healthyverse) library(dplyr) library(recipes) library(timetk) library(rsample) library(dials) library(modeltime) get_recipe_call <- function(.rec_call){ cl <- .rec_call cl$tune <- NULL cl$verbose <- NULL cl$colors <- NULL cl$prefix <- NULL rec_cl <- cl rec_cl[[1]] <- rlang::expr(recipe) rec_cl } assign_value <- function(name, value, cr = TRUE) { value <- rlang::enexpr(value) value <- rlang::expr_text(value, width = 74L) chr_assign(name, value, cr) } chr_assign <- function(name, value, cr = TRUE) { name <- paste(name, "<-") if (cr) { res <- c(name, paste0("\n ", value)) } else { res <- paste(name, value) } res } ts_auto_nnetar <- function(.data, .date_col, .value_col, .formula, .rsamp_obj, .prefix = "ts_nnetar", .tune = TRUE, .grid_size = 10, .num_cores = 1, .cv_assess = 12, .cv_skip = 3, .cv_slice_limit = 6, .best_metric = "rmse", .bootstrap_final = FALSE){ # Tidyeval ---- date_col_var_expr <- rlang::enquo(.date_col) value_col_var_expr <- rlang::enquo(.value_col) sampling_object <- .rsamp_obj # Cross Validation cv_assess = as.numeric(.cv_assess) cv_skip = as.numeric(.cv_skip) cv_slice = as.numeric(.cv_slice_limit) # Tuning Grid grid_size <- as.numeric(.grid_size) num_cores <- as.numeric(.num_cores) best_metric <- as.character(.best_metric) # Data and splits splits <- .rsamp_obj data_tbl <- dplyr::as_tibble(.data) # Checks ---- if (rlang::quo_is_missing(date_col_var_expr)){ rlang::abort( message = "'.date_col' must be supplied.", use_cli_format = TRUE ) } if (rlang::quo_is_missing(value_col_var_expr)){ rlang::abort( message = "'.value_col' must be supplied.", use_cli_format = TRUE ) } if (!inherits(x = splits, what = "rsplit")){ rlang::abort( message = "'.rsamp_obj' must be have class rsplit, use the rsample package.", use_cli_format = TRUE ) } # Recipe ---- # Get the initial recipe call recipe_call <- get_recipe_call(match.call()) rec_syntax <- paste0(.prefix, "_recipe") %>% assign_value(!!recipe_call) rec_obj <- recipes::recipe(formula = .formula, data = data_tbl) rec_obj <- rec_obj %>% timetk::step_timeseries_signature({{date_col_var_expr}}) %>% timetk::step_holiday_signature({{date_col_var_expr}}) %>% recipes::step_novel(recipes::all_nominal_predictors()) %>% recipes::step_mutate_at(tidyselect::vars_select_helpers$where(is.character) , fn = ~ as.factor(.)) %>% #recipes::step_rm({{date_col_var_expr}}) %>% recipes::step_dummy(recipes::all_nominal(), one_hot = TRUE) %>% recipes::step_zv(recipes::all_predictors(), -date_col_index.num) %>% recipes::step_normalize(recipes::all_numeric_predictors()) # Tune/Spec ---- if (.tune){ model_spec <- modeltime::nnetar_reg( seasonal_period = "auto" , non_seasonal_ar = tune::tune() , seasonal_ar = tune::tune() , hidden_units = tune::tune() , num_networks = tune::tune() , penalty = tune::tune() , epochs = tune::tune() ) } else { model_spec <- modeltime::nnetar_reg() } model_spec <- model_spec %>% parsnip::set_mode(mode = "regression") %>% parsnip::set_engine("nnetar") # Workflow ---- wflw <- workflows::workflow() %>% workflows::add_recipe(rec_obj) %>% workflows::add_model(model_spec) # Tuning Grid ---- if (.tune){ # Start parallel backend modeltime::parallel_start(num_cores) tuning_grid_spec <- dials::grid_latin_hypercube( hardhat::extract_parameter_set_dials(model_spec), size = grid_size ) # Make TS CV ---- tscv <- timetk::time_series_cv( data = rsample::training(splits), date_var = {{date_col_var_expr}}, cumulative = TRUE, assess = cv_assess, skip = cv_skip, slice_limit = cv_slice ) # Tune the workflow tuned_results <- wflw %>% tune::tune_grid( resamples = tscv, grid = tuning_grid_spec, metrics = modeltime::default_forecast_accuracy_metric_set() ) # Get the best result set by a specified metric best_result_set <- tuned_results %>% tune::show_best(metric = best_metric, n = 1) # Plot results tune_results_plt <- tuned_results %>% tune::autoplot() + ggplot2::theme_minimal() + ggplot2::geom_smooth(se = FALSE) # Make final workflow wflw_fit <- wflw %>% tune::finalize_workflow( tuned_results %>% tune::show_best(metric = best_metric, n = Inf) %>% dplyr::slice(1) ) %>% parsnip::fit(rsample::training(splits)) # Stop parallel backend modeltime::parallel_stop() } else { wflw_fit <- wflw %>% parsnip::fit(rsample::training(splits)) } # Calibrate and Plot ---- cap <- healthyR.ts::calibrate_and_plot( wflw_fit, .splits_obj = splits, .data = data_tbl, .interactive = TRUE, .print_info = FALSE ) # Return ---- output <- list( recipe_info = list( recipe_call = recipe_call, recipe_syntax = rec_syntax, rec_obj = rec_obj ), model_info = list( model_spec = model_spec, wflw = wflw, fitted_wflw = wflw_fit, was_tuned = ifelse(.tune, "tuned", "not_tuned") ), model_calibration = list( plot = cap$plot, calibration_tbl = cap$calibration_tbl, model_accuracy = cap$model_accuracy ) ) if (.tune){ output$tuned_info = list( tuning_grid = tuning_grid_spec, tscv = tscv, tuned_results = tuned_results, grid_size = grid_size, best_metric = best_metric, best_result_set = best_result_set, tuning_grid_plot = tune_results_plt, plotly_grid_plot = plotly::ggplotly(tune_results_plt) ) } return(invisible(output)) }
Examples:
data <- AirPassengers %>% ts_to_tbl() %>% select(-index) splits <- time_series_split( data , date_col , assess = 12 , skip = 3 , cumulative = TRUE ) tst_nnetar <- ts_auto_nnetar( .data = data, .num_cores = 5, .date_col = date_col, .value_col = value, .rsamp_obj = splits, .formula = value ~ ., .grid_size = 10 )
The text was updated successfully, but these errors were encountered:
f8e4fde
Merge pull request #260 from spsanderson/development
6e7fe91
Fixes #248
spsanderson
No branches or pull requests
Function:
Examples:
The text was updated successfully, but these errors were encountered: