diff --git a/.github/workflows/CRAN-R-CMD-check.yaml b/.github/workflows/CRAN-R-CMD-check.yaml index 64c38eb6..9f16274a 100644 --- a/.github/workflows/CRAN-R-CMD-check.yaml +++ b/.github/workflows/CRAN-R-CMD-check.yaml @@ -1,7 +1,13 @@ on: + push: + branches: + - master + pull_request: + branches: + - master schedule: - # runs tests every day at 1am - - cron: '0 1 * * *' + # runs tests every day at 1am EST + - cron: '0 5 * * *' name: CRAN-R-CMD-check @@ -15,7 +21,7 @@ jobs: fail-fast: false matrix: config: - - {os: macOS-latest, r: 'devel'} + - {os: ubuntu-16.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest"} - {os: windows-latest, r: 'release'} env: @@ -32,6 +38,12 @@ jobs: - uses: r-lib/actions/setup-pandoc@master + - uses: actions/setup-java@v1 + with: + java-version: '8' # The JDK version to make available on the path. + java-package: jdk # (jre, jdk, or jdk+fx) - defaults to jdk + architecture: x64 # (x64 or x86) - defaults to x64 + - name: Query dependencies run: | install.packages('remotes') @@ -52,9 +64,11 @@ jobs: env: RHUB_PLATFORM: linux-x86_64-ubuntu-gcc run: | - Rscript -e "remotes::install_github('r-hub/sysreqs')" - sysreqs=$(Rscript -e "cat(sysreqs::sysreq_commands('DESCRIPTION'))") - sudo -s eval "$sysreqs" + while read -r cmd + do + eval sudo $cmd + done < <(Rscript -e 'cat(remotes::system_requirements("ubuntu", "16.04"), sep = "\n")') + - name: Install dependencies run: | @@ -82,7 +96,7 @@ jobs: - name: Upload check results if: failure() - uses: actions/upload-artifact@master + uses: actions/upload-artifact@main with: name: ${{ runner.os }}-r${{ matrix.config.r }}-results path: check diff --git a/.github/workflows/GH-R-CMD-check.yaml b/.github/workflows/GH-R-CMD-check.yaml index ee2fa1a3..2128d19d 100644 --- a/.github/workflows/GH-R-CMD-check.yaml +++ b/.github/workflows/GH-R-CMD-check.yaml @@ -1,7 +1,13 @@ on: + push: + branches: + - master + pull_request: + branches: + - master schedule: - # runs tests every day at 1am - - cron: '0 1 * * *' + # runs tests every day at 1am EST + - cron: '0 5 * * *' name: GH-R-CMD-check @@ -15,7 +21,7 @@ jobs: fail-fast: false matrix: config: - - {os: macOS-latest, r: 'devel'} + - {os: ubuntu-16.04, r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/xenial/latest"} - {os: windows-latest, r: 'release'} env: @@ -32,6 +38,12 @@ jobs: - uses: r-lib/actions/setup-pandoc@master + - uses: actions/setup-java@v1 + with: + java-version: '8' # The JDK version to make available on the path. + java-package: jdk # (jre, jdk, or jdk+fx) - defaults to jdk + architecture: x64 # (x64 or x86) - defaults to x64 + - name: Query dependencies run: | install.packages('remotes') @@ -52,9 +64,11 @@ jobs: env: RHUB_PLATFORM: linux-x86_64-ubuntu-gcc run: | - Rscript -e "remotes::install_github('r-hub/sysreqs')" - sysreqs=$(Rscript -e "cat(sysreqs::sysreq_commands('DESCRIPTION'))") - sudo -s eval "$sysreqs" + while read -r cmd + do + eval sudo $cmd + done < <(Rscript -e 'cat(remotes::system_requirements("ubuntu", "16.04"), sep = "\n")') + - name: Install dependencies run: | @@ -76,7 +90,6 @@ jobs: try(remotes::install_github("tidymodels/modeldata"), silent = TRUE) shell: Rscript {0} - - name: Session info run: | options(width = 100) @@ -97,7 +110,7 @@ jobs: - name: Upload check results if: failure() - uses: actions/upload-artifact@master + uses: actions/upload-artifact@main with: name: ${{ runner.os }}-r${{ matrix.config.r }}-results path: check diff --git a/.github/workflows/pr-commands.yaml b/.github/workflows/pr-commands.yaml index 0d3cb716..1ae5d594 100644 --- a/.github/workflows/pr-commands.yaml +++ b/.github/workflows/pr-commands.yaml @@ -21,6 +21,8 @@ jobs: run: Rscript -e 'roxygen2::roxygenise()' - name: commit run: | + git config --local user.email "actions@github.com" + git config --local user.name "GitHub Actions" git add man/\* NAMESPACE git commit -m 'Document' - uses: r-lib/actions/pr-push@master @@ -44,6 +46,8 @@ jobs: run: Rscript -e 'styler::style_pkg()' - name: commit run: | + git config --local user.email "actions@github.com" + git config --local user.name "GitHub Actions" git add \*.R git commit -m 'Style' - uses: r-lib/actions/pr-push@master diff --git a/DESCRIPTION b/DESCRIPTION index fb89491b..16495701 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -21,9 +21,17 @@ Suggests: ranger, randomForest, C50, - earth + earth, + sparklyr, + dplyr, + glmnet, + rstanarm, + modeldata, + parsnip, + purrr, + tidyr, + tibble, + rlang Language: en-US Depends: - tidymodels, - modeldata, - parsnip + tidymodels diff --git a/inst/WORDLIST b/inst/WORDLIST index 76f574e8..5addb065 100644 --- a/inst/WORDLIST +++ b/inst/WORDLIST @@ -1,3 +1,4 @@ Lifecycle tidymodels cron +GH diff --git a/tests/testthat.R b/tests/testthat.R index 80ce43d8..8d6fba89 100644 --- a/tests/testthat.R +++ b/tests/testthat.R @@ -1,4 +1,48 @@ library(testthat) library(extratests) -test_check("extratests") +## ----------------------------------------------------------------------------- + +spark_install_winutils <- function(version) { + hadoop_version <- if (version < "2.0.0") "2.6" else "2.7" + spark_dir <- paste("spark-", version, "-bin-hadoop", hadoop_version, sep = "") + winutils_dir <- file.path(Sys.getenv("LOCALAPPDATA"), "spark", spark_dir, "tmp", "hadoop", "bin", fsep = "\\") + + if (!dir.exists(winutils_dir)) { + message("Installing winutils...") + + dir.create(winutils_dir, recursive = TRUE) + winutils_path <- file.path(winutils_dir, "winutils.exe", fsep = "\\") + + download.file( + "https://github.com/steveloughran/winutils/raw/master/hadoop-2.6.0/bin/winutils.exe", + winutils_path, + mode = "wb" + ) + + message("Installed winutils in ", winutils_path) + } +} + +## ----------------------------------------------------------------------------- + +library(sparklyr) + +if (.Platform$OS.type == "windows") { + # Right now, this does not seem to be working on windows; working on fixing it. + # Leaving it as-is as a place to start. + spark_install_winutils("2.4") + sparklyr::spark_install(verbose = TRUE, version = "2.4", hadoop_version = "2.7") +} else { + sparklyr::spark_install(verbose = TRUE, version = "2.4") +} + +sc <- try(sparklyr::spark_connect(master = "local"), silent = TRUE) + +if(inherits(sc, "try-error")) { + print(sc) +} + +## ----------------------------------------------------------------------------- + +test_check("extratests", reporter = "summary") diff --git a/tests/testthat/parsnip-helper-objects.R b/tests/testthat/parsnip-helper-objects.R new file mode 100644 index 00000000..a66ee072 --- /dev/null +++ b/tests/testthat/parsnip-helper-objects.R @@ -0,0 +1,16 @@ +library(modeldata) +library(parsnip) + +## ----------------------------------------------------------------------------- + +data("wa_churn") +data("lending_club") +data("hpc_data") + +# ------------------------------------------------------------------------------ + +ctrl <- control_parsnip(verbosity = 1, catch = FALSE) +caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE) +quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) + +run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0 diff --git a/tests/testthat/test-aaa-spark-install.R b/tests/testthat/test-aaa-spark-install.R new file mode 100644 index 00000000..9b06a25c --- /dev/null +++ b/tests/testthat/test-aaa-spark-install.R @@ -0,0 +1,5 @@ +context("spark installation") + +test_that('is spark installed?', { + expect_true(sparklyr::spark_install_find()$installed) +}) diff --git a/tests/testthat/test-data-descriptors-spark.R b/tests/testthat/test-data-descriptors-spark.R new file mode 100644 index 00000000..67668915 --- /dev/null +++ b/tests/testthat/test-data-descriptors-spark.R @@ -0,0 +1,84 @@ +library(testthat) +library(parsnip) + +context("spark descriptors") + +source(test_path("parsnip-helper-objects.R")) +hpc <- hpc_data[1:150, c(2:5, 8)] %>% as.data.frame() + +# ------------------------------------------------------------------------------ + +context("descriptor variables") + +# ------------------------------------------------------------------------------ + +template <- function(col, pred, ob, lev, fact, dat, x, y) { + lst <- list(.cols = col, .preds = pred, .obs = ob, + .lvls = lev, .facts = fact, .dat = dat, + .x = x, .y = y) + + Filter(Negate(is.null), lst) +} + +eval_descrs <- function(descrs, not = NULL) { + + if (!is.null(not)) { + for (descr in not) { + descrs[[descr]] <- NULL + } + } + + lapply(descrs, do.call, list()) +} + +class_tab <- table(hpc$class, dnn = NULL) + +# ------------------------------------------------------------------------------ + + + +test_that("spark descriptor", { + + skip_if_not_installed("sparklyr") + + library(sparklyr) + library(dplyr) + + sc <- try(spark_connect(master = "local"), silent = TRUE) + + skip_if(inherits(sc, "try-error")) + + npk_descr <- copy_to(sc, npk[, 1:4], "npk_descr", overwrite = TRUE) + hpc_descr <- copy_to(sc, hpc, "hpc_descr", overwrite = TRUE) + + # spark does not allow .x, .y, .dat; spark handles factors differently + template2 <- purrr::partial(template, x = NULL, y = NULL, dat = NULL) + eval_descrs2 <- purrr::partial(eval_descrs, not = c(".x", ".y", ".dat")) + class_tab2 <- table(as.character(hpc$class), dnn = NULL) + + expect_equal( + template2(6, 4, 150, NA, 1), + eval_descrs2(parsnip:::get_descr_form(compounds ~ ., data = hpc_descr)) + ) + expect_equal( + template2(3, 1, 150, NA, 1), + eval_descrs2(parsnip:::get_descr_form(compounds ~ class, data = hpc_descr)) + ) + expect_equal( + template2(1, 1, 150, NA, 0), + eval_descrs2(parsnip:::get_descr_form(compounds ~ input_fields, data = hpc_descr)) + ) + expect_equivalent( + template2(4, 4, 150, class_tab2, 0), + eval_descrs2(parsnip:::get_descr_form(class ~ ., data = hpc_descr)) + ) + expect_equal( + template2(1, 1, 150, class_tab2, 0), + eval_descrs2(parsnip:::get_descr_form(class ~ input_fields, data = hpc_descr)) + ) + expect_equivalent( + template2(7, 3, 24, rev(table(npk$K, dnn = NULL)), 3), + eval_descrs2(parsnip:::get_descr_form(K ~ ., data = npk_descr)) + ) + +}) diff --git a/tests/testthat/test-encodings-glmnet.R b/tests/testthat/test-encodings-glmnet.R new file mode 100644 index 00000000..358e4574 --- /dev/null +++ b/tests/testthat/test-encodings-glmnet.R @@ -0,0 +1,71 @@ +context("encodings - glmnet ") + +library(tidymodels) +data(ames, package = "modeldata") + +## ----------------------------------------------------------------------------- + +parsnip_mod <- + linear_reg(penalty = .1) %>% + set_engine("glmnet") + +## ----------------------------------------------------------------------------- + +test_that('parsnip models with formula interface', { + skip_if(utils::packageVersion("parsnip") <= "0.1.2") + + parsnip_form_fit <- + parsnip_mod %>% + fit(Sale_Price ~ Year_Built + Alley, data = ames) + + parsnip_form_names <- tidy(parsnip_form_fit)$term + + expect_true(sum(grepl("(Intercept)", parsnip_form_names)) == 1) + expect_true(sum(grepl("^Alley", parsnip_form_names)) == 2) +}) + +test_that('parsnip models with xy interface', { + skip_if(utils::packageVersion("parsnip") <= "0.1.2") + + expect_warning( + expect_error( + parsnip_mod %>% + fit_xy(x = ames[, c("Year_Built", "Alley")], y = ames$Sale_Price) + ) + ) + + parsnip_xy_fit <- + parsnip_mod %>% + fit_xy(x = ames[, c("Year_Built", "Longitude")], y = ames$Sale_Price) + + parsnip_xy_names <- tidy(parsnip_xy_fit)$term + + expect_true(sum(grepl("(Intercept)", parsnip_xy_names)) == 1) +}) + +## ----------------------------------------------------------------------------- + +test_that('workflows', { + skip_if(utils::packageVersion("parsnip") <= "0.1.2") + + wflow <- + workflow() %>% + add_model(parsnip_mod) %>% + add_formula(Sale_Price ~ Year_Built + Alley) + + parsnip_wflow_fit <- + wflow %>% + fit(data = ames) + + parsnip_wflow_names <- + parsnip_wflow_fit %>% + pull_workflow_fit() %>% + tidy() %>% + pull(term) + + expect_true(sum(grepl("(Intercept)", parsnip_wflow_names)) == 1) + expect_true(sum(grepl("^Alley", parsnip_wflow_names)) == 2) +}) + + + diff --git a/tests/testthat/test-encodings-randomForest.R b/tests/testthat/test-encodings-randomForest.R index 9474f535..45d7973f 100644 --- a/tests/testthat/test-encodings-randomForest.R +++ b/tests/testthat/test-encodings-randomForest.R @@ -1,4 +1,4 @@ -context("randomForest encodings") +context("encodings - randomForest") library(tidymodels) data(scat, package = "modeldata") @@ -57,7 +57,7 @@ test_that('workflows', { pluck("importance") %>% rownames() - # expect_equal(sum(grepl("Location", parsnip_wflow_names)) == 1) + expect_true(sum(grepl("Location", parsnip_wflow_names)) == 1) }) diff --git a/tests/testthat/test-encodings-ranger.R b/tests/testthat/test-encodings-ranger.R index ff8a56d2..0703a947 100644 --- a/tests/testthat/test-encodings-ranger.R +++ b/tests/testthat/test-encodings-ranger.R @@ -1,4 +1,4 @@ -context("ranger encodings") +context("encodings - ranger") library(tidymodels) data(scat, package = "modeldata") @@ -57,7 +57,7 @@ test_that('workflows', { pluck("variable.importance") %>% names() - # expect_equal(sum(grepl("Location", parsnip_wflow_names)) == 1) + expect_true(sum(grepl("Location", parsnip_wflow_names)) == 1) }) diff --git a/tests/testthat/test-engine-parameters-c5.R b/tests/testthat/test-engine-parameters-c5.R index 9c75f854..01aef925 100644 --- a/tests/testthat/test-engine-parameters-c5.R +++ b/tests/testthat/test-engine-parameters-c5.R @@ -1,4 +1,4 @@ -context("engine-specific parameters with C50") +context("engine-specific parameters - C50") library(tidymodels) data(two_class_dat, package = "modeldata") diff --git a/tests/testthat/test-engine-parameters-earth.R b/tests/testthat/test-engine-parameters-earth.R index 96cfdf0a..2702af47 100644 --- a/tests/testthat/test-engine-parameters-earth.R +++ b/tests/testthat/test-engine-parameters-earth.R @@ -1,4 +1,4 @@ -context("engine-specific parameters with earth") +context("engine-specific parameters - earth") library(tidymodels) diff --git a/tests/testthat/test-engine-parameters-randomForest.R b/tests/testthat/test-engine-parameters-randomForest.R index b0f70994..e5f6e2b5 100644 --- a/tests/testthat/test-engine-parameters-randomForest.R +++ b/tests/testthat/test-engine-parameters-randomForest.R @@ -1,4 +1,4 @@ -context("engine-specific parameters with randomForest") +context("engine-specific parameters - randomForest") library(tidymodels) diff --git a/tests/testthat/test-engine-parameters-ranger.R b/tests/testthat/test-engine-parameters-ranger.R index cf399db0..e498ecf6 100644 --- a/tests/testthat/test-engine-parameters-ranger.R +++ b/tests/testthat/test-engine-parameters-ranger.R @@ -1,4 +1,4 @@ -context("engine-specific parameters with ranger") +context("engine-specific parameters - ranger") library(tidymodels) diff --git a/tests/testthat/test-glmnet-linear.R b/tests/testthat/test-glmnet-linear.R new file mode 100644 index 00000000..3ef35679 --- /dev/null +++ b/tests/testthat/test-glmnet-linear.R @@ -0,0 +1,312 @@ +library(testthat) +library(parsnip) +library(rlang) +library(tidyr) +library(modeldata) + +# ------------------------------------------------------------------------------ + +context("engine - glmnet - linear regression") + +# ------------------------------------------------------------------------------ + +ctrl <- control_parsnip(verbosity = 1, catch = FALSE) +caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE) +quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) + +run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0 + +## ----------------------------------------------------------------------------- + +data("hpc_data") + +hpc <- hpc_data[1:150, c(2:5, 8)] +num_pred <- c("compounds", "iterations", "num_pending") +hpc_bad_form <- as.formula(class ~ term) +hpc_basic <- linear_reg(penalty = .1, mixture = .3) %>% + set_engine("glmnet", nlambda = 15) +no_lambda <- linear_reg(mixture = .3) %>% + set_engine("glmnet") + +# ------------------------------------------------------------------------------ + +test_that('glmnet execution', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + expect_error( + res <- fit_xy( + hpc_basic, + control = ctrl, + x = hpc[, num_pred], + y = hpc$input_fields + ), + regexp = NA + ) + + expect_true(has_multi_predict(res)) + expect_equal(multi_predict_args(res), "penalty") + + expect_error( + fit( + hpc_basic, + hpc_bad_form, + data = hpc, + control = ctrl + ) + ) + + glmnet_xy_catch <- fit_xy( + hpc_basic, + x = hpc[, num_pred], + y = factor(hpc$input_fields), + control = caught_ctrl + ) + expect_true(inherits(glmnet_xy_catch$fit, "try-error")) + +}) + +test_that('glmnet prediction, single lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + res_xy <- fit_xy( + hpc_basic, + control = ctrl, + x = hpc[, num_pred], + y = hpc$input_fields + ) + + # glmn_mod <- glmnet::glmnet(x = as.matrix(hpc[, num_pred]), y = hpc$input_fields, + # alpha = .3, nlambda = 15) + + uni_pred <- c(640.599944271351, 196.646976529848, 186.279646400216, 194.673852228774, + 198.126819755653) + + expect_equal(uni_pred, predict(res_xy, hpc[1:5, num_pred])$.pred, tolerance = 0.0001) + + res_form <- fit( + hpc_basic, + input_fields ~ log(compounds) + class, + data = hpc, + control = ctrl + ) + + form_pred <- c(570.504089227118, 162.413061474088, 167.022896537861, 157.609071878082, + 165.887783741483) + + expect_equal(form_pred, predict(res_form, hpc[1:5,])$.pred, tolerance = 0.0001) +}) + + +test_that('glmnet prediction, multiple lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + lams <- c(.01, 0.1) + + hpc_mult <- linear_reg(penalty = lams, mixture = .3) %>% + set_engine("glmnet") + + res_xy <- fit_xy( + hpc_mult, + control = ctrl, + x = hpc[, num_pred], + y = hpc$input_fields + ) + + # mult_pred <- + # predict(res_xy$fit, + # newx = as.matrix(hpc[1:5, num_pred]), + # s = lams) + # mult_pred <- stack(as.data.frame(mult_pred)) + # mult_pred$penalty <- rep(lams, each = 5) + # mult_pred$rows <- rep(1:5, 2) + # mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ] + # mult_pred <- mult_pred[, c("penalty", "values")] + # names(mult_pred) <- c("penalty", ".pred") + # mult_pred <- tibble::as_tibble(mult_pred) + mult_pred <- + tibble::tribble( + ~penalty, ~.pred, + 0.01, 639.672880668187, + 0.1, 639.672880668187, + 0.01, 197.744613311359, + 0.1, 197.744613311359, + 0.01, 187.737940787615, + 0.1, 187.737940787615, + 0.01, 195.780487678662, + 0.1, 195.780487678662, + 0.01, 199.217707535882, + 0.1, 199.217707535882 + ) + + expect_equal( + as.data.frame(mult_pred), + multi_predict(res_xy, new_data = hpc[1:5, num_pred], lambda = lams) %>% + unnest(cols = c(.pred)) %>% + as.data.frame(), + tolerance = 0.0001 + ) + + res_form <- fit( + hpc_mult, + input_fields ~ log(compounds) + class, + data = hpc, + control = ctrl + ) + + # form_mat <- model.matrix(input_fields ~ log(compounds) + class, data = hpc) + # form_mat <- form_mat[1:5, -1] + # + # form_pred <- + # predict(res_form$fit, + # newx = form_mat, + # s = lams) + # form_pred <- stack(as.data.frame(form_pred)) + # form_pred$penalty <- rep(lams, each = 5) + # form_pred$rows <- rep(1:5, 2) + # form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ] + # form_pred <- form_pred[, c("penalty", "values")] + # names(form_pred) <- c("penalty", ".pred") + # form_pred <- tibble::as_tibble(form_pred) + + form_pred <- + tibble::tribble( + ~penalty, ~.pred, + 0.01, 570.474473760044, + 0.1, 570.474473760044, + 0.01, 164.040104978709, + 0.1, 164.040104978709, + 0.01, 168.709676954287, + 0.1, 168.709676954287, + 0.01, 159.173862504055, + 0.1, 159.173862504055, + 0.01, 167.559854709074, + 0.1, 167.559854709074 + ) + + expect_equal( + as.data.frame(form_pred), + multi_predict(res_form, new_data = hpc[1:5, ], lambda = lams) %>% + unnest(cols = c(.pred)) %>% + as.data.frame(), + tolerance = 0.0001 + ) +}) + +test_that('glmnet prediction, all lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + hpc_all <- linear_reg(mixture = .3) %>% + set_engine("glmnet", nlambda = 7) + + res_xy <- fit_xy( + hpc_all, + control = ctrl, + x = hpc[, num_pred], + y = hpc$input_fields + ) + + all_pred <- predict(res_xy$fit, newx = as.matrix(hpc[1:5, num_pred])) + all_pred <- stack(as.data.frame(all_pred)) + all_pred$penalty <- rep(res_xy$fit$lambda, each = 5) + all_pred$rows <- rep(1:5, length(res_xy$fit$lambda)) + all_pred <- all_pred[order(all_pred$rows, all_pred$penalty), ] + all_pred <- all_pred[, c("penalty", "values")] + names(all_pred) <- c("penalty", ".pred") + all_pred <- tibble::as_tibble(all_pred) + + expect_equal(all_pred, multi_predict(res_xy, new_data = hpc[1:5,num_pred ]) %>% unnest(cols = c(.pred))) + + res_form <- fit( + hpc_all, + input_fields ~ log(compounds) + class, + data = hpc, + control = ctrl + ) + + form_mat <- model.matrix(input_fields ~ log(compounds) + class, data = hpc) + form_mat <- form_mat[1:5, -1] + + form_pred <- predict(res_form$fit, newx = form_mat) + form_pred <- stack(as.data.frame(form_pred)) + form_pred$penalty <- rep(res_form$fit$lambda, each = 5) + form_pred$rows <- rep(1:5, length(res_form$fit$lambda)) + form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ] + form_pred <- form_pred[, c("penalty", "values")] + names(form_pred) <- c("penalty", ".pred") + form_pred <- tibble::as_tibble(form_pred) + + expect_equal(form_pred, multi_predict(res_form, hpc[1:5, c("compounds", "class")]) %>% unnest(cols = c(.pred))) +}) + + +test_that('submodel prediction', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + reg_fit <- + linear_reg() %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) + + pred_glmn <- predict(reg_fit$fit, as.matrix(mtcars[1:4, -1]), s = .1) + + mp_res <- multi_predict(reg_fit, new_data = mtcars[1:4, -1], penalty = .1) + mp_res <- do.call("rbind", mp_res$.pred) + expect_equal(mp_res[[".pred"]], unname(pred_glmn[,1])) + + expect_error( + multi_predict(reg_fit, newdata = mtcars[1:4, -1], penalty = .1), + "Did you mean" + ) + + reg_fit <- + linear_reg() %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) + + + pred_glmn_all <- + predict(reg_fit$fit, as.matrix(mtcars[1:2, -1])) %>% + as.data.frame() %>% + stack() %>% + dplyr::arrange(ind) + + + mp_res_all <- + multi_predict(reg_fit, new_data = mtcars[1:2, -1]) %>% + tidyr::unnest(cols = c(.pred)) + + expect_equal(sort(mp_res_all$.pred), sort(pred_glmn_all$values)) + +}) + + +test_that('error traps', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + expect_error( + linear_reg() %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) %>% + predict(mtcars[-(1:4), ], penalty = 0:1) + ) + expect_error( + linear_reg() %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars[-(1:4), ]) %>% + predict(mtcars[-(1:4), ]) + ) + +}) + diff --git a/tests/testthat/test-glmnet-logistic.R b/tests/testthat/test-glmnet-logistic.R new file mode 100644 index 00000000..5af80c9d --- /dev/null +++ b/tests/testthat/test-glmnet-logistic.R @@ -0,0 +1,436 @@ +library(testthat) +library(parsnip) +library(rlang) +library(tibble) +library(tidyr) +library(modeldata) + +# ------------------------------------------------------------------------------ + +context("engine - glmnet - logistic regressiont") + +# ------------------------------------------------------------------------------ + +ctrl <- control_parsnip(verbosity = 1, catch = FALSE) +caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE) +quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) + +run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0 + +## ----------------------------------------------------------------------------- + +data(lending_club) +data(wa_churn) +lending_club <- head(lending_club, 200) +lc_form <- as.formula(Class ~ log(funded_amnt) + int_rate) +num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") +lc_bad_form <- as.formula(funded_amnt ~ term) +lc_basic <- logistic_reg() %>% set_engine("glmnet") + +# ------------------------------------------------------------------------------ + +test_that('glmnet execution', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + expect_error( + res <- fit_xy( + lc_basic, + control = ctrl, + x = lending_club[, num_pred], + y = lending_club$Class + ), + regexp = NA + ) + + expect_true(has_multi_predict(res)) + expect_equal(multi_predict_args(res), "penalty") + + expect_error( + glmnet_xy_catch <- fit_xy( + lc_basic, + x = lending_club[, num_pred], + y = lending_club$total_bal_il, + control = caught_ctrl + ) + ) +}) + +test_that('glmnet prediction, one lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + xy_fit <- fit_xy( + logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), + control = ctrl, + x = lending_club[, num_pred], + y = lending_club$Class + ) + + uni_pred <- + predict(xy_fit$fit, + newx = as.matrix(lending_club[1:7, num_pred]), + s = 0.1, type = "response")[,1] + uni_pred <- ifelse(uni_pred >= 0.5, "good", "bad") + uni_pred <- factor(uni_pred, levels = levels(lending_club$Class)) + uni_pred <- unname(uni_pred) + + expect_equal(uni_pred, predict(xy_fit, lending_club[1:7, num_pred])$.pred_class) + + res_form <- fit( + logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), + Class ~ log(funded_amnt) + int_rate, + data = lending_club, + control = ctrl + ) + + form_mat <- model.matrix(Class ~ log(funded_amnt) + int_rate, data = lending_club) + form_mat <- form_mat[1:7, -1] + + form_pred <- + predict(res_form$fit, + newx = form_mat, + s = 0.1, type = "response")[,1] + form_pred <- ifelse(form_pred >= 0.5, "good", "bad") + form_pred <- factor(form_pred, levels = levels(lending_club$Class)) + form_pred <- unname(form_pred) + + expect_equal( + form_pred, + predict(res_form, lending_club[1:7, c("funded_amnt", "int_rate")], type = "class")$.pred_class + ) + +}) + + +test_that('glmnet prediction, mulitiple lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + lams <- c(0.01, 0.1) + + xy_fit <- fit_xy( + logistic_reg(penalty = lams) %>% set_engine("glmnet"), + control = ctrl, + x = lending_club[, num_pred], + y = lending_club$Class + ) + + mult_pred <- + predict(xy_fit$fit, + newx = as.matrix(lending_club[1:7, num_pred]), + s = lams, type = "response") + mult_pred <- stack(as.data.frame(mult_pred)) + mult_pred$values <- ifelse(mult_pred$values >= 0.5, "good", "bad") + mult_pred$values <- factor(mult_pred$values, levels = levels(lending_club$Class)) + mult_pred$penalty <- rep(lams, each = 7) + mult_pred$rows <- rep(1:7, 2) + mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ] + mult_pred <- mult_pred[, c("penalty", "values")] + names(mult_pred) <- c("penalty", ".pred_class") + mult_pred <- tibble::as_tibble(mult_pred) + + expect_equal( + mult_pred, + multi_predict(xy_fit, lending_club[1:7, num_pred], type = "class") %>% unnest(cols = c(.pred)) + ) + + res_form <- fit( + logistic_reg(penalty = lams) %>% set_engine("glmnet"), + Class ~ log(funded_amnt) + int_rate, + data = lending_club, + control = ctrl + ) + + form_mat <- model.matrix(Class ~ log(funded_amnt) + int_rate, data = lending_club) + form_mat <- form_mat[1:7, -1] + + form_pred <- + predict(res_form$fit, + newx = form_mat, + s = lams) + form_pred <- stack(as.data.frame(form_pred)) + form_pred$values <- ifelse(form_pred$values >= 0.5, "good", "bad") + form_pred$values <- factor(form_pred$values, levels = levels(lending_club$Class)) + form_pred$penalty <- rep(lams, each = 7) + form_pred$rows <- rep(1:7, 2) + form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ] + form_pred <- form_pred[, c("penalty", "values")] + names(form_pred) <- c("penalty", ".pred_class") + form_pred <- tibble::as_tibble(form_pred) + + expect_equal( + form_pred, + multi_predict(res_form, lending_club[1:7, c("funded_amnt", "int_rate")]) %>% unnest(cols = c(.pred)) + ) + +}) + +test_that('glmnet prediction, no lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + xy_fit <- fit_xy( + logistic_reg() %>% set_engine("glmnet", nlambda = 11), + control = ctrl, + x = lending_club[, num_pred], + y = lending_club$Class + ) + + mult_pred <- + predict(xy_fit$fit, + newx = as.matrix(lending_club[1:7, num_pred]), + s = xy_fit$fit$lambda, type = "response") + mult_pred <- stack(as.data.frame(mult_pred)) + mult_pred$values <- ifelse(mult_pred$values >= 0.5, "good", "bad") + mult_pred$values <- factor(mult_pred$values, levels = levels(lending_club$Class)) + mult_pred$penalty <- rep(xy_fit$fit$lambda, each = 7) + mult_pred$rows <- rep(1:7, 2) + mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ] + mult_pred <- mult_pred[, c("penalty", "values")] + names(mult_pred) <- c("penalty", ".pred_class") + mult_pred <- tibble::as_tibble(mult_pred) + + expect_equal(mult_pred, multi_predict(xy_fit, lending_club[1:7, num_pred]) %>% unnest(cols = c(.pred))) + + res_form <- fit( + logistic_reg() %>% set_engine("glmnet", nlambda = 11), + Class ~ log(funded_amnt) + int_rate, + data = lending_club, + control = ctrl + ) + + form_mat <- model.matrix(Class ~ log(funded_amnt) + int_rate, data = lending_club) + form_mat <- form_mat[1:7, -1] + + form_pred <- + predict(res_form$fit, + newx = form_mat, + type = "response") + form_pred <- stack(as.data.frame(form_pred)) + form_pred$values <- ifelse(form_pred$values >= 0.5, "good", "bad") + form_pred$values <- factor(form_pred$values, levels = levels(lending_club$Class)) + form_pred$penalty <- rep(res_form$fit$lambda, each = 7) + form_pred$rows <- rep(1:7, 2) + form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ] + form_pred <- form_pred[, c("penalty", "values")] + names(form_pred) <- c("penalty", ".pred_class") + form_pred <- tibble::as_tibble(form_pred) + + expect_equal( + form_pred, + multi_predict(res_form, lending_club[1:7, c("funded_amnt", "int_rate")]) %>% unnest(cols = c(.pred)) + ) + +}) + + +test_that('glmnet probabilities, one lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + xy_fit <- fit_xy( + logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), + control = ctrl, + x = lending_club[, num_pred], + y = lending_club$Class + ) + + uni_pred <- + predict(xy_fit$fit, + newx = as.matrix(lending_club[1:7, num_pred]), + s = 0.1, type = "response")[,1] + uni_pred <- tibble(.pred_bad = 1 - uni_pred, .pred_good = uni_pred) + + expect_equal( + uni_pred, + predict(xy_fit, lending_club[1:7, num_pred], type = "prob") + ) + + res_form <- fit( + logistic_reg(penalty = 0.1) %>% set_engine("glmnet"), + Class ~ log(funded_amnt) + int_rate, + data = lending_club, + control = ctrl + ) + + form_mat <- model.matrix(Class ~ log(funded_amnt) + int_rate, data = lending_club) + form_mat <- form_mat[1:7, -1] + + form_pred <- + unname(predict(res_form$fit, + newx = form_mat, + s = 0.1, type = "response")[, 1]) + form_pred <- tibble(.pred_bad = 1 - form_pred, .pred_good = form_pred) + + expect_equal( + form_pred, + predict(res_form, lending_club[1:7, c("funded_amnt", "int_rate")], type = "prob") + ) + + one_row <- predict(res_form, lending_club[1, c("funded_amnt", "int_rate")], type = "prob") + expect_equivalent(form_pred[1,], one_row) + +}) + +test_that('glmnet probabilities, mulitiple lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + lams <- c(0.01, 0.1) + + xy_fit <- fit_xy( + logistic_reg(penalty = lams) %>% set_engine("glmnet"), + control = ctrl, + x = lending_club[, num_pred], + y = lending_club$Class + ) + + mult_pred <- + predict(xy_fit$fit, + newx = as.matrix(lending_club[1:7, num_pred]), + s = lams, type = "response") + mult_pred <- stack(as.data.frame(mult_pred)) + mult_pred$penalty <- rep(lams, each = 7) + mult_pred$rows <- rep(1:7, 2) + mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ] + mult_pred$.pred_bad <- 1 - mult_pred$values + mult_pred <- mult_pred[, c("penalty", ".pred_bad", "values")] + names(mult_pred) <- c("penalty", ".pred_bad", ".pred_good") + mult_pred <- tibble::as_tibble(mult_pred) + + expect_equal( + mult_pred, + multi_predict(xy_fit, lending_club[1:7, num_pred], lambda = lams, type = "prob") %>% + unnest(cols = c(.pred)) + ) + + res_form <- fit( + logistic_reg(penalty = lams) %>% set_engine("glmnet"), + Class ~ log(funded_amnt) + int_rate, + data = lending_club, + control = ctrl + ) + + form_mat <- model.matrix(Class ~ log(funded_amnt) + int_rate, data = lending_club) + form_mat <- form_mat[1:7, -1] + + form_pred <- + predict(res_form$fit, + newx = form_mat, + s = lams, type = "response") + form_pred <- stack(as.data.frame(form_pred)) + form_pred$penalty <- rep(lams, each = 7) + form_pred$rows <- rep(1:7, 2) + form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ] + form_pred$.pred_bad <- 1 - form_pred$values + form_pred <- form_pred[, c("penalty", ".pred_bad", "values")] + names(form_pred) <- c("penalty", ".pred_bad", ".pred_good") + form_pred <- tibble::as_tibble(form_pred) + + expect_equal( + form_pred, + multi_predict(res_form, lending_club[1:7, c("funded_amnt", "int_rate")], type = "prob") %>% + unnest(cols = c(.pred)) + ) + +}) + + +test_that('glmnet probabilities, no lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + xy_fit <- fit_xy( + logistic_reg() %>% set_engine("glmnet"), + control = ctrl, + x = lending_club[, num_pred], + y = lending_club$Class + ) + + mult_pred <- + predict(xy_fit$fit, + newx = as.matrix(lending_club[1:7, num_pred]), + type = "response") + mult_pred <- stack(as.data.frame(mult_pred)) + mult_pred$penalty <- rep(xy_fit$fit$lambda, each = 7) + mult_pred$rows <- rep(1:7, length(xy_fit$fit$lambda)) + mult_pred <- mult_pred[order(mult_pred$rows, mult_pred$penalty), ] + mult_pred$.pred_bad <- 1 - mult_pred$values + mult_pred <- mult_pred[, c("penalty", ".pred_bad", "values")] + names(mult_pred) <- c("penalty", ".pred_bad", ".pred_good") + mult_pred <- tibble::as_tibble(mult_pred) + + expect_equal( + mult_pred, + multi_predict(xy_fit, lending_club[1:7, num_pred], type = "prob") %>% + unnest(cols = c(.pred)) + ) + + res_form <- fit( + logistic_reg() %>% set_engine("glmnet"), + Class ~ log(funded_amnt) + int_rate, + data = lending_club, + control = ctrl + ) + + form_mat <- model.matrix(Class ~ log(funded_amnt) + int_rate, data = lending_club) + form_mat <- form_mat[1:7, -1] + + form_pred <- + predict(res_form$fit, + newx = form_mat, + type = "response") + form_pred <- stack(as.data.frame(form_pred)) + form_pred$penalty <- rep(res_form$fit$lambda, each = 7) + form_pred$rows <- rep(1:7, length(res_form$fit$lambda)) + form_pred <- form_pred[order(form_pred$rows, form_pred$penalty), ] + form_pred$.pred_bad <- 1 - form_pred$values + form_pred <- form_pred[, c("penalty", ".pred_bad", "values")] + names(form_pred) <- c("penalty", ".pred_bad", ".pred_good") + form_pred <- tibble::as_tibble(form_pred) + + expect_equal( + form_pred, + multi_predict(res_form, lending_club[1:7, c("funded_amnt", "int_rate")], type = "prob") %>% unnest(cols = c(.pred)) + ) + +}) + + +test_that('submodel prediction', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + vars <- c("female", "tenure", "total_charges", "phone_service", "monthly_charges") + class_fit <- + logistic_reg() %>% + set_engine("glmnet") %>% + fit(churn ~ ., data = wa_churn[-(1:4), c("churn", vars)]) + + pred_glmn <- predict(class_fit$fit, as.matrix(wa_churn[1:4, vars]), s = .1, type = "response") + + mp_res <- multi_predict(class_fit, new_data = wa_churn[1:4, vars], penalty = .1, type = "prob") + mp_res <- do.call("rbind", mp_res$.pred) + expect_equal(mp_res[[".pred_No"]], unname(pred_glmn[,1])) + + expect_error( + multi_predict(class_fit, newdata = wa_churn[1:4, vars], penalty = .1, type = "prob"), + "Did you mean" + ) + + # Can predict using default penalty. See #108 + expect_error( + multi_predict(class_fit, new_data = wa_churn[1:4, vars]), + NA + ) + +}) diff --git a/tests/testthat/test-glmnet-multinom.R b/tests/testthat/test-glmnet-multinom.R new file mode 100644 index 00000000..6623a739 --- /dev/null +++ b/tests/testthat/test-glmnet-multinom.R @@ -0,0 +1,182 @@ +library(testthat) +library(parsnip) +library(rlang) +library(tibble) +library(dplyr) + +# ------------------------------------------------------------------------------ + +context("engine - glmnet - multinom regression") + +# ------------------------------------------------------------------------------ + +ctrl <- control_parsnip(verbosity = 1, catch = FALSE) +caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE) +quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) + +run_glmnet <- utils::compareVersion('3.6.0', as.character(getRversion())) > 0 + +data("hpc_data") +hpc <- hpc_data[, c(2:5, 8)] +rows <- c(1, 51, 101) + +# ------------------------------------------------------------------------------ + +test_that('glmnet execution', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + expect_error( + res <- fit_xy( + multinom_reg() %>% set_engine("glmnet"), + control = ctrl, + x = hpc[, 1:4], + y = hpc$class + ), + regexp = NA + ) + + expect_true(has_multi_predict(res)) + expect_equal(multi_predict_args(res), "penalty") + + expect_error( + glmnet_xy_catch <- fit_xy( + multinom_reg() %>% set_engine("glmnet"), + x = hpc[, 2:5], + y = hpc$compounds, + control = caught_ctrl + ) + ) + +}) + +test_that('glmnet prediction, one lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + xy_fit <- fit_xy( + multinom_reg(penalty = 0.1) %>% set_engine("glmnet"), + control = ctrl, + x = hpc[, 1:4], + y = hpc$class + ) + + uni_pred <- + predict(xy_fit$fit, + newx = as.matrix(hpc[rows, 1:4]), + s = xy_fit$spec$args$penalty, type = "class") + uni_pred <- factor(uni_pred[,1], levels = levels(hpc$class)) + uni_pred <- unname(uni_pred) + + expect_equal(uni_pred, predict(xy_fit, hpc[rows, 1:4], type = "class")$.pred_class) + + res_form <- fit( + multinom_reg(penalty = 0.1) %>% set_engine("glmnet"), + class ~ log(compounds) + input_fields, + data = hpc, + control = ctrl + ) + + form_mat <- model.matrix(class ~ log(compounds) + input_fields, data = hpc) + form_mat <- form_mat[rows, -1] + + form_pred <- + predict(res_form$fit, + newx = form_mat, + s = res_form$spec$args$penalty, + type = "class") + form_pred <- factor(form_pred[,1], levels = levels(hpc$class)) + expect_equal(form_pred, parsnip:::predict_class.model_fit(res_form, hpc[rows, c("compounds", "input_fields")])) + expect_equal(form_pred, predict(res_form, hpc[rows, c("compounds", "input_fields")], type = "class")$.pred_class) + +}) + + +test_that('glmnet probabilities, mulitiple lambda', { + + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + lams <- c(0.01, 0.1) + + xy_fit <- fit_xy( + multinom_reg(penalty = lams) %>% set_engine("glmnet"), + control = ctrl, + x = hpc[, 1:4], + y = hpc$class + ) + + expect_error(predict(xy_fit, hpc[rows, 1:4], type = "class")) + expect_error(predict(xy_fit, hpc[rows, 1:4], type = "prob")) + + mult_pred <- + predict(xy_fit$fit, + newx = as.matrix(hpc[rows, 1:4]), + s = lams, type = "response") + mult_pred <- apply(mult_pred, 3, as_tibble) + mult_pred <- dplyr:::bind_rows(mult_pred) + mult_probs <- mult_pred + names(mult_pred) <- paste0(".pred_", names(mult_pred)) + mult_pred$penalty <- rep(lams, each = 3) + mult_pred$row <- rep(1:3, 2) + mult_pred <- mult_pred[order(mult_pred$row, mult_pred$penalty),] + mult_pred <- split(mult_pred[, -5], mult_pred$row) + names(mult_pred) <- NULL + mult_pred <- tibble(.pred = mult_pred) + + multi_pred_res <- multi_predict(xy_fit, hpc[rows, 1:4], penalty = lams, type = "prob") + + for (i in seq_along(multi_pred_res$.pred)) { + expect_equal( + mult_pred %>% dplyr::slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred")), + multi_pred_res %>% dplyr::slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred")) + ) + } + + mult_class <- factor(names(mult_probs)[apply(mult_probs, 1, which.max)], + levels = xy_fit$lvl) + mult_class <- tibble( + .pred_class = mult_class, + penalty = rep(lams, each = 3), + row = rep(1:3, 2) + ) + mult_class <- mult_class[order(mult_class$row, mult_class$penalty),] + mult_class <- split(mult_class[, -3], mult_class$row) + names(mult_class) <- NULL + mult_class <- tibble(.pred = mult_class) + + mult_class_res <- multi_predict(xy_fit, hpc[rows, 1:4], penalty = lams) + + for (i in seq_along(mult_class_res$.pred)) { + expect_equal( + mult_class %>% dplyr::slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred")), + mult_class_res %>% dplyr::slice(i) %>% pull(.pred) %>% purrr::pluck(1) %>% dplyr::select(starts_with(".pred")) + ) + } + + expect_error( + multi_predict(xy_fit, newdata = hpc[rows, 1:4], penalty = lams), + "Did you mean" + ) + + # Can predict probs with default penalty. See #108 + expect_error( + multi_predict(xy_fit, new_data = hpc[rows, 1:4], type = "prob"), + NA + ) + +}) + +test_that("class predictions are factors with all levels", { + skip_if_not_installed("glmnet") + skip_if(run_glmnet) + + basic <- multinom_reg() %>% set_engine("glmnet") %>% fit(class ~ ., data = hpc) + nd <- hpc[hpc$class == "VF", ] + yhat <- predict(basic, new_data = nd, penalty = .1) + yhat_multi <- multi_predict(basic, new_data = nd, penalty = .1)$.pred + expect_is(yhat_multi[[1]]$.pred_class, "factor") + expect_equal(levels(yhat_multi[[1]]$.pred_class), levels(hpc$class)) +}) diff --git a/tests/testthat/test-glmnet-tidy.R b/tests/testthat/test-glmnet-tidy.R new file mode 100644 index 00000000..38184b14 --- /dev/null +++ b/tests/testthat/test-glmnet-tidy.R @@ -0,0 +1,63 @@ +context("engine - glmnet - tidy method") + +## ----------------------------------------------------------------------------- + +test_that('linear regression', { + skip_if_not_installed("glmnet") + skip_if(utils::packageVersion("parsnip") <= "0.1.2") + + ps_mod <- + linear_reg(penalty = .1) %>% + set_engine("glmnet") %>% + fit(mpg ~ ., data = mtcars) + + ps_coefs <- tidy(ps_mod) + gn_coefs <- as.matrix(coef(ps_mod$fit, s = .1)) + for(i in ps_coefs$term) { + expect_equal(ps_coefs$estimate[ps_coefs$term == i], gn_coefs[i,1]) + } +}) + +test_that('logistic regression', { + skip_if_not_installed("glmnet") + skip_if(utils::packageVersion("parsnip") <= "0.1.2") + + data(two_class_dat, package = "modeldata") + + ps_mod <- + logistic_reg(penalty = .1) %>% + set_engine("glmnet") %>% + fit(Class ~ ., data = two_class_dat) + + ps_coefs <- tidy(ps_mod) + gn_coefs <- as.matrix(coef(ps_mod$fit, s = .1)) + for(i in ps_coefs$term) { + expect_equal(ps_coefs$estimate[ps_coefs$term == i], gn_coefs[i,1]) + } +}) + +test_that('multinomial regression', { + skip_if_not_installed("glmnet") + skip_if(utils::packageVersion("parsnip") <= "0.1.2") + + data(penguins, package = "modeldata") + + ps_mod <- + multinom_reg(penalty = .01) %>% + set_engine("glmnet") %>% + fit(species ~ ., data = penguins) + + ps_coefs <- tidy(ps_mod) + gn_coefs <- coef(ps_mod$fit, s = .01) + gn_coefs <- purrr::map(gn_coefs, as.matrix) + for(i in unique(ps_coefs$term)) { + for(j in unique(ps_coefs$class)) { + expect_equal( + ps_coefs$estimate[ps_coefs$term == i & ps_coefs$class == j], + gn_coefs[[j]][i,1] + ) + } + } +}) + + diff --git a/tests/testthat/test-parsnip-boost-tree-spark.R b/tests/testthat/test-parsnip-boost-tree-spark.R new file mode 100644 index 00000000..4edeab4e --- /dev/null +++ b/tests/testthat/test-parsnip-boost-tree-spark.R @@ -0,0 +1,182 @@ +library(testthat) +library(parsnip) +library(dplyr) + +# ------------------------------------------------------------------------------ + +context("boosted tree execution with spark") +source(test_path("parsnip-helper-objects.R")) +hpc <- hpc_data[1:150, c(2:5, 8)] + +# ------------------------------------------------------------------------------ + +test_that('spark execution', { + + skip_if_not_installed("sparklyr") + library(sparklyr) + + sc <- try(spark_connect(master = "local"), silent = TRUE) + + skip_if(inherits(sc, "try-error")) + + hpc_bt_tr <- copy_to(sc, hpc[-(1:4), ], "hpc_bt_tr", overwrite = TRUE) + hpc_bt_te <- copy_to(sc, hpc[ 1:4 , -1], "hpc_bt_te", overwrite = TRUE) + + # ---------------------------------------------------------------------------- + + expect_error( + spark_reg_fit <- + fit( + boost_tree(trees = 5, mode = "regression") %>% + set_engine("spark", seed = 12), + control = ctrl, + compounds ~ ., + data = hpc_bt_tr + ), + regexp = NA + ) + + # check for reproducibility and passing extra arguments + expect_error( + spark_reg_fit_dup <- + fit( + boost_tree(trees = 5, mode = "regression") %>% + set_engine("spark", seed = 12), + control = ctrl, + compounds ~ ., + data = hpc_bt_tr + ), + regexp = NA + ) + + expect_error( + spark_reg_pred <- predict(spark_reg_fit, hpc_bt_te), + regexp = NA + ) + + expect_error( + spark_reg_pred_num <- parsnip:::predict_numeric.model_fit(spark_reg_fit, hpc_bt_te), + regexp = NA + ) + + expect_error( + spark_reg_dup <- predict(spark_reg_fit_dup, hpc_bt_te), + regexp = NA + ) + + expect_error( + spark_reg_num_dup <- parsnip:::predict_numeric.model_fit(spark_reg_fit_dup, hpc_bt_te), + regexp = NA + ) + + expect_equal(colnames(spark_reg_pred), "pred") + + expect_equal( + as.data.frame(spark_reg_pred)$pred, + as.data.frame(spark_reg_dup)$pred + ) + expect_equal( + as.data.frame(spark_reg_pred_num)$pred, + as.data.frame(spark_reg_num_dup)$pred + ) + + + # ---------------------------------------------------------------------------- + + # same for classification + + churn_bt_tr <- copy_to(sc, wa_churn[ 5:100, ], "churn_bt_tr", overwrite = TRUE) + churn_bt_te <- copy_to(sc, wa_churn[ 1:4, -1], "churn_bt_te", overwrite = TRUE) + + # ---------------------------------------------------------------------------- + + expect_error( + spark_class_fit <- + fit( + boost_tree(trees = 5, mode = "classification") %>% + set_engine("spark", seed = 12), + control = ctrl, + churn ~ ., + data = churn_bt_tr + ), + regexp = NA + ) + + # check for reproducibility and passing extra arguments + expect_error( + spark_class_fit_dup <- + fit( + boost_tree(trees = 5, mode = "classification") %>% + set_engine("spark", seed = 12), + control = ctrl, + churn ~ ., + data = churn_bt_tr + ), + regexp = NA + ) + + expect_error( + spark_class_pred <- predict(spark_class_fit, churn_bt_te), + regexp = NA + ) + + expect_error( + spark_class_pred_class <- parsnip:::predict_class.model_fit(spark_class_fit, churn_bt_te), + regexp = NA + ) + + expect_error( + spark_class_dup <- predict(spark_class_fit_dup, churn_bt_te), + regexp = NA + ) + + expect_error( + spark_class_dup_class <- parsnip:::predict_class.model_fit(spark_class_fit_dup, churn_bt_te), + regexp = NA + ) + + expect_equal(colnames(spark_class_pred), "pred_class") + + expect_equal( + as.data.frame(spark_class_pred)$pred_class, + as.data.frame(spark_class_dup)$pred_class + ) + expect_equal( + as.data.frame(spark_class_pred_class)$pred_class, + as.data.frame(spark_class_dup_class)$pred_class + ) + + + expect_error( + spark_class_prob <- predict(spark_class_fit, churn_bt_te, type = "prob"), + regexp = NA + ) + + expect_error( + spark_class_prob_classprob <- parsnip:::predict_classprob.model_fit(spark_class_fit, churn_bt_te), + regexp = NA + ) + + expect_error( + spark_class_dup <- predict(spark_class_fit_dup, churn_bt_te, type = "prob"), + regexp = NA + ) + + expect_error( + spark_class_dup_classprob <- parsnip:::predict_classprob.model_fit(spark_class_fit_dup, churn_bt_te), + regexp = NA + ) + + expect_equal(colnames(spark_class_prob), c("pred_No", "pred_Yes")) + + expect_equivalent( + as.data.frame(spark_class_prob), + as.data.frame(spark_class_dup) + ) + expect_equal( + as.data.frame(spark_class_prob_classprob), + as.data.frame(spark_class_dup_classprob) + ) + +}) + diff --git a/tests/testthat/test-parsnip-linear-reg-spark.R b/tests/testthat/test-parsnip-linear-reg-spark.R new file mode 100644 index 00000000..39560545 --- /dev/null +++ b/tests/testthat/test-parsnip-linear-reg-spark.R @@ -0,0 +1,56 @@ +library(testthat) +library(parsnip) +library(dplyr) + +# ------------------------------------------------------------------------------ + +context("linear regression execution with spark") +source(test_path("parsnip-helper-objects.R")) +hpc <- hpc_data[1:150, c(2:5, 8)] + +# ------------------------------------------------------------------------------ + +test_that('spark execution', { + + skip_if_not_installed("sparklyr") + + library(sparklyr) + + sc <- try(spark_connect(master = "local"), silent = TRUE) + + skip_if(inherits(sc, "try-error")) + + hpc_linreg_tr <- copy_to(sc, hpc[-(1:4), ], "hpc_linreg_tr", overwrite = TRUE) + hpc_linreg_te <- copy_to(sc, hpc[ 1:4 , -1], "hpc_linreg_te", overwrite = TRUE) + + expect_error( + spark_fit <- + fit( + linear_reg() %>% set_engine("spark"), + control = ctrl, + compounds ~ ., + data = hpc_linreg_tr + ), + regexp = NA + ) + + expect_false(has_multi_predict(spark_fit)) + expect_equal(multi_predict_args(spark_fit), NA_character_) + + expect_error( + spark_pred <- predict(spark_fit, hpc_linreg_te), + regexp = NA + ) + + expect_error( + spark_pred_num <- predict(spark_fit, hpc_linreg_te), + regexp = NA + ) + + lm_fit <- lm(compounds ~ ., data = hpc[-(1:4), ]) + lm_pred <- unname(predict(lm_fit, hpc[ 1:4 , -1])) + + expect_equal(as.data.frame(spark_pred)$pred, lm_pred) + expect_equal(as.data.frame(spark_pred_num)$pred, lm_pred) +}) + diff --git a/tests/testthat/test-parsnip-logistic-reg-spark.R b/tests/testthat/test-parsnip-logistic-reg-spark.R new file mode 100644 index 00000000..8b6e593e --- /dev/null +++ b/tests/testthat/test-parsnip-logistic-reg-spark.R @@ -0,0 +1,88 @@ +library(testthat) +library(parsnip) +library(dplyr) + +# ------------------------------------------------------------------------------ + +context("logistic regression execution with spark") +source(test_path("parsnip-helper-objects.R")) +hpc <- hpc_data[1:150, c(2:5, 8)] + +# ------------------------------------------------------------------------------ + +test_that('spark execution', { + + skip_if_not_installed("sparklyr") + + library(sparklyr) + + sc <- try(spark_connect(master = "local"), silent = TRUE) + + skip_if(inherits(sc, "try-error")) + + churn_logit_tr <- copy_to(sc, wa_churn[ 5:100, ], "churn_logit_tr", overwrite = TRUE) + churn_logit_te <- copy_to(sc, wa_churn[ 1:4, -1], "churn_logit_te", overwrite = TRUE) + + # ---------------------------------------------------------------------------- + + expect_error( + spark_class_fit <- + fit( + logistic_reg() %>% set_engine("spark"), + control = ctrl, + churn ~ ., + data = churn_logit_tr + ), + regexp = NA + ) + + # check for reproducibility and passing extra arguments + expect_error( + spark_class_fit_dup <- + fit( + logistic_reg() %>% set_engine("spark"), + control = ctrl, + churn ~ ., + data = churn_logit_tr + ), + regexp = NA + ) + + expect_error( + spark_class_pred <- predict(spark_class_fit, churn_logit_te), + regexp = NA + ) + + expect_error( + spark_class_pred_class <- predict(spark_class_fit, churn_logit_te), + regexp = NA + ) + + expect_equal(colnames(spark_class_pred), "pred_class") + + expect_equal( + as.data.frame(spark_class_pred)$pred_class, + as.data.frame(spark_class_pred_class)$pred_class + ) + + expect_error( + spark_class_prob <- predict(spark_class_fit, churn_logit_te, type = "prob"), + regexp = NA + ) + + expect_error( + spark_class_prob_classprob <- predict(spark_class_fit, churn_logit_te, type = "prob"), + regexp = NA + ) + + expect_equal(colnames(spark_class_prob), c("pred_No", "pred_Yes")) + + expect_equivalent( + as.data.frame(spark_class_prob), + as.data.frame(spark_class_prob_classprob) + ) + + +}) + + diff --git a/tests/testthat/test-parsnip-multinom-reg-spark.R b/tests/testthat/test-parsnip-multinom-reg-spark.R new file mode 100644 index 00000000..232b864e --- /dev/null +++ b/tests/testthat/test-parsnip-multinom-reg-spark.R @@ -0,0 +1,77 @@ +library(testthat) +library(parsnip) +library(dplyr) + +# ------------------------------------------------------------------------------ + +context("multinomial regression execution with spark") +source(test_path("parsnip-helper-objects.R")) +hpc <- hpc_data[1:150, c(2:5, 8)] + +# ------------------------------------------------------------------------------ + +test_that('spark execution', { + + skip_if_not_installed("sparklyr") + + library(sparklyr) + + sc <- try(spark_connect(master = "local"), silent = TRUE) + + skip_if(inherits(sc, "try-error")) + + hpc_rows <- c(1, 51, 101) + hpc_tr <- copy_to(sc, hpc[-hpc_rows, ], "hpc_tr", overwrite = TRUE) + hpc_te <- copy_to(sc, hpc[ hpc_rows, -5], "hpc_te", overwrite = TRUE) + + # ---------------------------------------------------------------------------- + + expect_error( + spark_class_fit <- + fit( + multinom_reg() %>% set_engine("spark"), + control = ctrl, + class ~ ., + data = hpc_tr + ), + regexp = NA + ) + + expect_error( + spark_class_pred <- predict(spark_class_fit, hpc_te), + regexp = NA + ) + + expect_error( + spark_class_pred_class <- predict(spark_class_fit, hpc_te), + regexp = NA + ) + + expect_equal(colnames(spark_class_pred), "pred_class") + + expect_equal( + as.data.frame(spark_class_pred)$pred_class, + as.data.frame(spark_class_pred_class)$pred_class + ) + + expect_error( + spark_class_prob <- predict(spark_class_fit, hpc_te, type = "prob"), + regexp = NA + ) + + expect_error( + spark_class_prob_classprob <- predict(spark_class_fit, hpc_te, type = "prob"), + regexp = NA + ) + + expect_equal( + colnames(spark_class_prob), + c("pred_VF", "pred_F", "pred_L", "pred_M") + ) + + expect_equivalent( + as.data.frame(spark_class_prob), + as.data.frame(spark_class_prob_classprob) + ) +}) + diff --git a/tests/testthat/test-parsnip-rand-forest-spark.R b/tests/testthat/test-parsnip-rand-forest-spark.R new file mode 100644 index 00000000..2f769065 --- /dev/null +++ b/tests/testthat/test-parsnip-rand-forest-spark.R @@ -0,0 +1,182 @@ +library(testthat) +library(parsnip) +library(dplyr) + +# ------------------------------------------------------------------------------ + +context("random forest execution with spark") +source(test_path("parsnip-helper-objects.R")) +hpc <- hpc_data[1:150, c(2:5, 8)] + +# ------------------------------------------------------------------------------ + +test_that('spark execution', { + + skip_if_not_installed("sparklyr") + + library(sparklyr) + + sc <- try(spark_connect(master = "local"), silent = TRUE) + + skip_if(inherits(sc, "try-error")) + + hpc_rf_tr <- copy_to(sc, hpc[-(1:4), ], "hpc_rf_tr", overwrite = TRUE) + hpc_rf_te <- copy_to(sc, hpc[ 1:4 , -1], "hpc_rf_te", overwrite = TRUE) + + # ---------------------------------------------------------------------------- + + expect_error( + spark_reg_fit <- + fit( + rand_forest(trees = 5, mode = "regression") %>% + set_engine("spark", seed = 12), + control = ctrl, + compounds ~ ., + data = hpc_rf_tr + ), + regexp = NA + ) + + # check for reproducibility and passing extra arguments + expect_error( + spark_reg_fit_dup <- + fit( + rand_forest(trees = 5, mode = "regression") %>% + set_engine("spark", seed = 12), + control = ctrl, + compounds ~ ., + data = hpc_rf_tr + ), + regexp = NA + ) + + expect_error( + spark_reg_pred <- predict(spark_reg_fit, hpc_rf_te), + regexp = NA + ) + + expect_error( + spark_reg_pred_num <- predict(spark_reg_fit, hpc_rf_te), + regexp = NA + ) + + expect_error( + spark_reg_dup <- predict(spark_reg_fit_dup, hpc_rf_te), + regexp = NA + ) + + expect_error( + spark_reg_num_dup <- predict(spark_reg_fit_dup, hpc_rf_te), + regexp = NA + ) + + expect_equal(colnames(spark_reg_pred), "pred") + + expect_equal( + as.data.frame(spark_reg_pred)$pred, + as.data.frame(spark_reg_dup)$pred + ) + expect_equal( + as.data.frame(spark_reg_pred_num)$pred, + as.data.frame(spark_reg_num_dup)$pred + ) + + + # ---------------------------------------------------------------------------- + + # same for classification + + churn_rf_tr <- copy_to(sc, wa_churn[ 5:100, ], "churn_rf_tr", overwrite = TRUE) + churn_rf_te <- copy_to(sc, wa_churn[ 1:4, -1], "churn_rf_te", overwrite = TRUE) + + # ---------------------------------------------------------------------------- + + expect_error( + spark_class_fit <- + fit( + rand_forest(trees = 5, mode = "classification") %>% + set_engine("spark", seed = 12), + control = ctrl, + churn ~ ., + data = churn_rf_tr + ), + regexp = NA + ) + + # check for reproducibility and passing extra arguments + expect_error( + spark_class_fit_dup <- + fit( + rand_forest(trees = 5, mode = "classification") %>% + set_engine("spark", seed = 12), + control = ctrl, + churn ~ ., + data = churn_rf_tr + ), + regexp = NA + ) + + expect_error( + spark_class_pred <- predict(spark_class_fit, churn_rf_te), + regexp = NA + ) + + expect_error( + spark_class_pred_class <- predict(spark_class_fit, churn_rf_te), + regexp = NA + ) + + expect_error( + spark_class_dup <- predict(spark_class_fit_dup, churn_rf_te), + regexp = NA + ) + + expect_error( + spark_class_dup_class <- predict(spark_class_fit_dup, churn_rf_te), + regexp = NA + ) + + expect_equal(colnames(spark_class_pred), "pred_class") + + expect_equal( + as.data.frame(spark_class_pred)$pred_class, + as.data.frame(spark_class_dup)$pred_class + ) + expect_equal( + as.data.frame(spark_class_pred_class)$pred_class, + as.data.frame(spark_class_dup_class)$pred_class + ) + + + expect_error( + spark_class_prob <- predict(spark_class_fit, churn_rf_te, type = "prob"), + regexp = NA + ) + + expect_error( + spark_class_dup <- predict(spark_class_fit_dup, churn_rf_te, type = "prob"), + regexp = NA + ) + + expect_error( + spark_class_dup_classprob <- predict(spark_class_fit_dup, churn_rf_te, type = "prob"), + regexp = NA + ) + expect_error( + spark_class_prob_classprob <- predict(spark_class_fit, churn_rf_te, type = "prob"), + regexp = NA + ) + + expect_equal(colnames(spark_class_prob), c("pred_No", "pred_Yes")) + + expect_equivalent( + as.data.frame(spark_class_prob), + as.data.frame(spark_class_dup) + ) + expect_equal( + as.data.frame(spark_class_prob_classprob), + as.data.frame(spark_class_dup_classprob) + ) + +}) + diff --git a/tests/testthat/test-stan-linear.R b/tests/testthat/test-stan-linear.R new file mode 100644 index 00000000..0ea839fa --- /dev/null +++ b/tests/testthat/test-stan-linear.R @@ -0,0 +1,149 @@ +library(testthat) +library(parsnip) +library(rlang) +library(modeldata) + +context("engine - stan - linear regression") + +## ----------------------------------------------------------------------------- + +ctrl <- control_parsnip(verbosity = 1, catch = FALSE) +caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE) +quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) + +## ----------------------------------------------------------------------------- + +data("hpc_data") +hpc <- hpc_data[, c(2:5, 8)] + +# ------------------------------------------------------------------------------ + + +num_pred <- c("compounds", "iterations", "num_pending") +hpc_bad_form <- as.formula(class ~ term) +hpc_basic <- linear_reg() %>% + set_engine("stan", seed = 10, chains = 1) + +ctrl <- control_parsnip(verbosity = 0L, catch = FALSE) +caught_ctrl <- control_parsnip(verbosity = 0L, catch = TRUE) +quiet_ctrl <- control_parsnip(verbosity = 0L, catch = TRUE) + +# ------------------------------------------------------------------------------ + +test_that('stan_glm execution', { + skip_if_not_installed("rstanarm") + skip_on_cran() + + expect_error( + res <- fit( + hpc_basic, + compounds ~ log(input_fields) + class, + data = hpc, + control = ctrl + ), + regexp = NA + ) + expect_error( + res <- fit_xy( + hpc_basic, + x = hpc[, num_pred], + y = hpc$input_fields, + control = ctrl + ), + regexp = NA + ) + + expect_false(has_multi_predict(res)) + expect_equal(multi_predict_args(res), NA_character_) + + expect_error( + res <- fit( + hpc_basic, + class ~ term, + data = hpc, + control = ctrl + ) + ) + +}) + + +test_that('stan prediction', { + skip_if_not_installed("rstanarm") + skip_on_cran() + + uni_pred <- c(1691.46306020449, 1494.27323520418, 1522.36011539284, 1493.39683598195, + 1494.93053462084) + inl_pred <- c(429.164145548939, 256.32488428038, 254.949927688403, 255.007333947447, + 255.336665165556) + + res_xy <- fit_xy( + linear_reg() %>% + set_engine("stan", seed = 10, chains = 1), + x = hpc[, num_pred], + y = hpc$input_fields, + control = quiet_ctrl + ) + + set.seed(383) + expect_equal(uni_pred, predict(res_xy, hpc[1:5, num_pred])$.pred, tolerance = 0.1) + + res_form <- fit( + hpc_basic, + compounds ~ log(input_fields) + class, + data = hpc, + control = quiet_ctrl + ) + expect_equal(inl_pred, predict(res_form, hpc[1:5, ])$.pred, tolerance = 0.1) +}) + + +test_that('stan intervals', { + skip_if_not_installed("rstanarm") + skip_on_cran() + + res_xy <- fit_xy( + linear_reg() %>% + set_engine("stan", seed = 1333, chains = 10, iter = 1000), + x = hpc[, num_pred], + y = hpc$input_fields, + control = quiet_ctrl + ) + + set.seed(1231) + confidence_parsnip <- + predict(res_xy, + new_data = hpc[1:5,], + type = "conf_int", + level = 0.93) + + set.seed(1231) + prediction_parsnip <- + predict(res_xy, + new_data = hpc[1:5,], + type = "pred_int", + level = 0.93) + + ci_lower <- c(1577.25718753727, 1382.58210286254, 1399.96490471468, 1381.56774986889, + 1383.25519963864) + ci_upper <- c(1809.28331613624, 1609.11912475981, 1646.44852457781, 1608.3327281785, + 1609.4796390366) + + pi_lower <- c(-4960.33135373564, -5123.82860109357, -5063.60881734505, -5341.21637448872, + -5184.63627366821) + pi_upper <- c(8345.56815544477, 7954.98392035813, 7890.10036321417, 7970.64062851536, + 8247.10241974192) + + expect_equivalent(confidence_parsnip$.pred_lower, ci_lower, tolerance = 1e-2) + expect_equivalent(confidence_parsnip$.pred_upper, ci_upper, tolerance = 1e-2) + + expect_equivalent(prediction_parsnip$.pred_lower, + pi_lower, + tolerance = 1e-2) + expect_equivalent(prediction_parsnip$.pred_upper, + pi_upper, + tolerance = 1e-2) +}) + + + diff --git a/tests/testthat/test-stan-logistic.R b/tests/testthat/test-stan-logistic.R new file mode 100644 index 00000000..5b463c16 --- /dev/null +++ b/tests/testthat/test-stan-logistic.R @@ -0,0 +1,206 @@ +library(testthat) +library(parsnip) +library(rlang) +library(tibble) +library(modeldata) + +context("engine - stan - logistic regression") + +## ----------------------------------------------------------------------------- + +ctrl <- control_parsnip(verbosity = 1, catch = FALSE) +caught_ctrl <- control_parsnip(verbosity = 1, catch = TRUE) +quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) + +## ----------------------------------------------------------------------------- + +data("hpc_data") +hpc <- hpc_data[1:150, c(2:5, 8)] + +data("lending_club") +lending_club <- head(lending_club, 200) +lc_form <- as.formula(Class ~ log(funded_amnt) + int_rate) +num_pred <- c("funded_amnt", "annual_inc", "num_il_tl") +lc_basic <- + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1) + +ctrl <- control_parsnip(verbosity = 0, catch = FALSE) +caught_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) +quiet_ctrl <- control_parsnip(verbosity = 0, catch = TRUE) + +# ------------------------------------------------------------------------------ + +test_that('stan_glm execution', { + skip_if_not_installed("rstanarm") + skip_on_cran() + + expect_error( + res <- fit( + lc_basic, + funded_amnt ~ term, + data = lending_club, + control = ctrl + ) + ) + + expect_error( + fit_xy( + lc_basic, + control = caught_ctrl, + x = lending_club[, num_pred], + y = lending_club$total_bal_il + ) + ) + +}) + + +test_that('stan_glm prediction', { + skip_if_not_installed("rstanarm") + skip_on_cran() + + xy_fit <- fit_xy( + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1), + control = ctrl, + x = lending_club[, num_pred], + y = lending_club$Class + ) + + xy_pred <- structure(c(2L, 2L, 2L, 2L, 2L, 2L, 2L), .Label = c("bad", "good"), class = "factor") + + expect_equal(xy_pred, parsnip:::predict_class.model_fit(xy_fit, lending_club[1:7, num_pred])) + + res_form <- fit( + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1), + Class ~ log(funded_amnt) + int_rate, + data = lending_club, + control = ctrl + ) + + form_pred <- structure(c(2L, 2L, 2L, 2L, 2L, 2L, 2L), + .Label = c("bad", "good"), + class = "factor") + + expect_equal(form_pred, parsnip:::predict_class.model_fit(res_form, lending_club[1:7, c("funded_amnt", "int_rate")])) + +}) + + + +test_that('stan_glm probability', { + skip_if_not_installed("rstanarm") + skip_on_cran() + + xy_fit <- fit_xy( + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1), + control = ctrl, + x = lending_club[, num_pred], + y = lending_club$Class + ) + + xy_pred <- + tibble::tribble( + ~bad, ~good, + 0.0173511241321764, 0.982648875867824, + 0.0550090130462705, 0.94499098695373, + 0.0292445716644468, 0.970755428335553, + 0.0516116810109397, 0.94838831898906, + 0.0142530690940691, 0.985746930905931, + 0.0184806465081366, 0.981519353491863, + 0.0253642111906806, 0.974635788809319 + ) + + expect_equivalent( + xy_pred %>% as.data.frame(), + parsnip:::predict_classprob.model_fit(xy_fit, lending_club[1:7, num_pred]) %>% as.data.frame(), + tolerance = 0.1 + ) + + res_form <- fit( + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1), + Class ~ log(funded_amnt) + int_rate, + data = lending_club, + control = ctrl + ) + + form_pred <- + tibble::tribble( + ~bad, ~good, + 0.0451516541621074, 0.954848345837893, + 0.0663232780491584, 0.933676721950842, + 0.0425128897715562, 0.957487110228444, + 0.0442197030195933, 0.955780296980407, + 0.00135166763321781, 0.998648332366782, + 0.013776487556396, 0.986223512443604, + 0.00359938202445076, 0.996400617975549 + ) + expect_equivalent( + form_pred %>% as.data.frame(), + parsnip:::predict_classprob.model_fit(res_form, lending_club[1:7, c("funded_amnt", "int_rate")]) %>% + as.data.frame(), + tolerance = 0.1 + ) +}) + + +test_that('stan intervals', { + skip_if_not_installed("rstanarm") + skip_on_cran() + + res_form <- fit( + logistic_reg() %>% + set_engine("stan", seed = 1333, chains = 1), + Class ~ log(funded_amnt) + int_rate, + data = lending_club, + control = ctrl + ) + + set.seed(555) + confidence_parsnip <- + predict(res_form, + new_data = lending_club[1:5,], + type = "conf_int", + level = 0.93, + std_error = TRUE) + + set.seed(555) + prediction_parsnip <- + predict(res_form, + new_data = lending_club[1:5,], + type = "pred_int", + level = 0.93, + std_error = TRUE) + + stan_lower <- + c(`1` = 0.913925483690233, `2` = 0.841801274737206, `3` = 0.91056642931229, + `4` = 0.913619668586545, `5` = 0.987780279394871) + stan_upper <- + c(`1` = 0.978674663115785, `2` = 0.975178762720162, `3` = 0.984417491942267, + `4` = 0.979606072215269, `5` = 0.9999049778978) + stan_std <- + c(`1` = 0.0181025303127182, `2` = 0.0388665155739319, `3` = 0.0205886091162274, + `4` = 0.0181715224502082, `5` = 0.00405145389896896) + + expect_equivalent(confidence_parsnip$.pred_lower_good, stan_lower, tolerance = 0.01) + expect_equivalent(confidence_parsnip$.pred_upper_good, stan_upper, tolerance = 0.01) + expect_equivalent(confidence_parsnip$.pred_lower_bad, 1 - stan_upper, tolerance = 0.01) + expect_equivalent(confidence_parsnip$.pred_upper_bad, 1 - stan_lower, tolerance = 0.01) + expect_equivalent(confidence_parsnip$.std_error, stan_std, tolerance = 0.001) + + stan_pred_lower <- c(`1` = 0, `2` = 0, `3` = 0, `4` = 0, `5` = 1) + stan_pred_upper <- c(`1` = 1, `2` = 1, `3` = 1, `4` = 1, `5` = 1) + stan_pred_std <- + c(`1` = 0.211744742168102, `2` = 0.265130711714607, `3` = 0.209589904165081, + `4` = 0.198389410902796, `5` = 0.0446989708829856) + expect_equivalent(prediction_parsnip$.pred_lower_good, stan_pred_lower) + expect_equivalent(prediction_parsnip$.pred_upper_good, stan_pred_upper) + expect_equivalent(prediction_parsnip$.std_error, stan_pred_std, tolerance = 0.1) +}) + + +