diff --git a/NEWS.md b/NEWS.md index 556bbc24c..2f17164a6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,9 @@ * A bug was fixed related to the column names generated by `multi_predict()`. The top-level tibble will always have a column named `.pred` and this list column contains tibbles across sub-models. The column names for these sub-model tibbles will have names consistent with `predict()` (which was previously incorrect). See [43c15db](https://github.com/tidymodels/parsnip/commit/43c15db377ea9ef27483ff209f6bd0e98cb830d2). +# [A bug](https://github.com/tidymodels/parsnip/issues/174) was fixed +standardizing the column names of `nnet` class probability predictions. + # parsnip 0.0.3.1 Test case update due to CRAN running extra tests [(#202)](https://github.com/tidymodels/parsnip/issues/202) diff --git a/R/mlp.R b/R/mlp.R index 44422ca9d..eb1fb6973 100644 --- a/R/mlp.R +++ b/R/mlp.R @@ -381,7 +381,7 @@ nnet_softmax <- function(results, object) { results <- apply(results, 1, function(x) exp(x)/sum(exp(x))) results <- t(results) - names(results) <- paste0(".pred_", object$lvl) + colnames(results) <- object$lvl results <- as_tibble(results) results } diff --git a/tests/testthat/test_mlp.R b/tests/testthat/test_mlp.R index 46006c600..944a2664f 100644 --- a/tests/testthat/test_mlp.R +++ b/tests/testthat/test_mlp.R @@ -157,3 +157,11 @@ test_that('bad input', { expect_error(translate(mlp(mode = "regression", formula = y ~ x) %>% set_engine())) }) +test_that("nnet_softmax", { + obj <- mlp(mode = 'classification') + obj$lvls <- c("a", "b") + res <- nnet_softmax(matrix(c(.8, .2)), obj) + expect_equal(names(res), obj$lvls) + expect_equal(res$b, 1 - res$a) +}) +