-
Notifications
You must be signed in to change notification settings - Fork 106
Description
First of all, thank you all so much for the effort put into parsnip. It's a really great package!
As a first-time user, I found it difficult to track down what the original engine arguments were mapped to in parsnip. I knew very well what the parameter colsample_bytree of native xgboost meant for tuning, but I had to dig into the source code to found that it was translated into mtry in parsnip. This involved some trial-and-error on my side to translate some of my current models.
I think it would be beneficial if the mapping between original and standardized arguments was documented either (i) in the function itself or (ii) in a complete reference table, e.g.:
| model | engine | parsnip | original |
|---|---|---|---|
| boost_tree | xgboost | tree_depth | max_depth |
| boost_tree | xgboost | trees | nrounds |
| boost_tree | xgboost | learn_rate | eta |
| boost_tree | xgboost | mtry | colsample_bytree |
| boost_tree | xgboost | min_n | min_child_weight |
| boost_tree | xgboost | loss_reduction | gamma |
| boost_tree | xgboost | sample_size | subsample |
I am not sure how one would go about implementing this. I have thrown together a very hacky solution to list the translation of all current arguments below.
Expand code
library(tidyverse)
# Helper function
extract_params <- function(content) {
content %>%
str_replace_all("\\n", "") %>%
str_extract_all("set_model_arg.*?\\)") %>%
purrr::pluck(1) %>%
tibble(line = .) %>%
mutate(engine = str_extract(line, "(?<=eng = ).*?(?=,)"),
parsnip = str_extract(line, "(?<=parsnip = ).*?(?=,)"),
original = str_extract(line, "(?<=original = ).*?(?=,)")) %>%
select(-line) %>%
mutate_all(str_replace_all, '"', "")
}
##################################### #
# Download parsnip from GITHUB ----
##################################### #
download.file(url = "https://github.com/tidymodels/parsnip/archive/master.zip", "parsnip.zip")
unzip(file_name)
file.remove(file_name)
##################################### #
# List arguments of each model ----
##################################### #
params <- dir("parsnip-master/R", recursive = TRUE, pattern = "_data\\.R", full.names = TRUE) %>%
tibble(file_name = .) %>%
mutate(model = basename(file_name) %>% str_replace_all("_data.R", "")) %>%
filter(!model %in% c("nullmodel", "convert")) %>%
mutate(content = map_chr(file_name, read_file)) %>%
mutate(params = pbapply::pblapply(content, extract_params)) %>%
unnest(params) %>%
select(model, engine, parsnip, original)
Full mapping table
params %>%
knitr::kable()
| model | engine | parsnip | original |
|---|---|---|---|
| boost_tree | xgboost | tree_depth | max_depth |
| boost_tree | xgboost | trees | nrounds |
| boost_tree | xgboost | learn_rate | eta |
| boost_tree | xgboost | mtry | colsample_bytree |
| boost_tree | xgboost | min_n | min_child_weight |
| boost_tree | xgboost | loss_reduction | gamma |
| boost_tree | xgboost | sample_size | subsample |
| boost_tree | C5.0 | trees | trials |
| boost_tree | C5.0 | min_n | minCases |
| boost_tree | C5.0 | sample_size | sample |
| boost_tree | spark | tree_depth | max_depth |
| boost_tree | spark | trees | max_iter |
| boost_tree | spark | learn_rate | step_size |
| boost_tree | spark | mtry | feature_subset_strategy |
| boost_tree | spark | min_n | min_instances_per_node |
| boost_tree | spark | min_info_gain | gamma |
| boost_tree | spark | sample_size | subsampling_rate |
| decision_tree | rpart | tree_depth | maxdepth |
| decision_tree | rpart | min_n | minsplit |
| decision_tree | rpart | cost_complexity | cp |
| decision_tree | C5.0 | min_n | minCases |
| decision_tree | spark | tree_depth | max_depth |
| decision_tree | spark | min_n | min_instances_per_node |
| linear_reg | glmnet | penalty | lambda |
| linear_reg | glmnet | mixture | alpha |
| linear_reg | spark | penalty | reg_param |
| linear_reg | spark | mixture | elastic_net_param |
| logistic_reg | glmnet | penalty | lambda |
| logistic_reg | glmnet | mixture | alpha |
| logistic_reg | spark | penalty | reg_param |
| logistic_reg | spark | mixture | elastic_net_param |
| logistic_reg | keras | decay | decay |
| mars | earth | num_terms | nprune |
| mars | earth | prod_degree | degree |
| mars | earth | prune_method | pmethod |
| mlp | keras | hidden_units | hidden_units |
| mlp | keras | penalty | penalty |
| mlp | keras | dropout | dropout |
| mlp | keras | epochs | epochs |
| mlp | keras | activation | activation |
| mlp | nnet | hidden_units | size |
| mlp | nnet | penalty | decay |
| mlp | nnet | epochs | maxit |
| multinom_reg | glmnet | penalty | lambda |
| multinom_reg | glmnet | mixture | alpha |
| multinom_reg | spark | penalty | reg_param |
| multinom_reg | spark | mixture | elastic_net_param |
| multinom_reg | keras | decay | decay |
| nearest_neighbor | kknn | neighbors | ks |
| nearest_neighbor | kknn | weight_func | kernel |
| nearest_neighbor | kknn | dist_power | distance |
| rand_forest | ranger | mtry | mtry |
| rand_forest | ranger | trees | num.trees |
| rand_forest | ranger | min_n | min.node.size |
| rand_forest | randomForest | mtry | mtry |
| rand_forest | randomForest | trees | ntree |
| rand_forest | randomForest | min_n | nodesize |
| rand_forest | spark | mtry | feature_subset_strategy |
| rand_forest | spark | trees | num_trees |
| rand_forest | spark | min_n | min_instances_per_node |
| surv_reg | flexsurv | dist | dist |
| surv_reg | survival | dist | dist |
| svm_poly | kernlab | cost | C |
| svm_poly | kernlab | degree | degree |
| svm_poly | kernlab | scale_factor | scale |
| svm_poly | kernlab | margin | epsilon |
| svm_rbf | kernlab | cost | C |
| svm_rbf | kernlab | rbf_sigma | sigma |
| svm_rbf | kernlab | margin | epsilon |
Mapping table by model
params %>%
split(.$model) %>%
map(spread, engine, original) %>%
map(knitr::kable)
$boost_tree
| model | parsnip | C5.0 | spark | xgboost |
|---|---|---|---|---|
| boost_tree | learn_rate | NA | step_size | eta |
| boost_tree | loss_reduction | NA | NA | gamma |
| boost_tree | min_info_gain | NA | gamma | NA |
| boost_tree | min_n | minCases | min_instances_per_node | min_child_weight |
| boost_tree | mtry | NA | feature_subset_strategy | colsample_bytree |
| boost_tree | sample_size | sample | subsampling_rate | subsample |
| boost_tree | tree_depth | NA | max_depth | max_depth |
| boost_tree | trees | trials | max_iter | nrounds |
$decision_tree
| model | parsnip | C5.0 | rpart | spark |
|---|---|---|---|---|
| decision_tree | cost_complexity | NA | cp | NA |
| decision_tree | min_n | minCases | minsplit | min_instances_per_node |
| decision_tree | tree_depth | NA | maxdepth | max_depth |
$linear_reg
| model | parsnip | glmnet | spark |
|---|---|---|---|
| linear_reg | mixture | alpha | elastic_net_param |
| linear_reg | penalty | lambda | reg_param |
$logistic_reg
| model | parsnip | glmnet | keras | spark |
|---|---|---|---|---|
| logistic_reg | decay | NA | decay | NA |
| logistic_reg | mixture | alpha | NA | elastic_net_param |
| logistic_reg | penalty | lambda | NA | reg_param |
$mars
| model | parsnip | earth |
|---|---|---|
| mars | num_terms | nprune |
| mars | prod_degree | degree |
| mars | prune_method | pmethod |
$mlp
| model | parsnip | keras | nnet |
|---|---|---|---|
| mlp | activation | activation | NA |
| mlp | dropout | dropout | NA |
| mlp | epochs | epochs | maxit |
| mlp | hidden_units | hidden_units | size |
| mlp | penalty | penalty | decay |
$multinom_reg
| model | parsnip | glmnet | keras | spark |
|---|---|---|---|---|
| multinom_reg | decay | NA | decay | NA |
| multinom_reg | mixture | alpha | NA | elastic_net_param |
| multinom_reg | penalty | lambda | NA | reg_param |
$nearest_neighbor
| model | parsnip | kknn |
|---|---|---|
| nearest_neighbor | dist_power | distance |
| nearest_neighbor | neighbors | ks |
| nearest_neighbor | weight_func | kernel |
$rand_forest
| model | parsnip | randomForest | ranger | spark |
|---|---|---|---|---|
| rand_forest | min_n | nodesize | min.node.size | min_instances_per_node |
| rand_forest | mtry | mtry | mtry | feature_subset_strategy |
| rand_forest | trees | ntree | num.trees | num_trees |
$surv_reg
| model | parsnip | flexsurv | survival |
|---|---|---|---|
| surv_reg | dist | dist | dist |
$svm_poly
| model | parsnip | kernlab |
|---|---|---|
| svm_poly | cost | C |
| svm_poly | degree | degree |
| svm_poly | margin | epsilon |
| svm_poly | scale_factor | scale |
$svm_rbf
| model | parsnip | kernlab |
|---|---|---|
| svm_rbf | cost | C |
| svm_rbf | margin | epsilon |
| svm_rbf | rbf_sigma | sigma |