diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index f8d1095a493d..234b208166b5 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -720,8 +720,9 @@ setMethod("predict", signature(object = "MultilayerPerceptronClassificationModel # Returns the summary of a Multilayer Perceptron Classification Model produced by \code{spark.mlp} #' @param object a Multilayer Perceptron Classification Model fitted by \code{spark.mlp} -#' @return \code{summary} returns a list containing \code{layers}, the label distribution, and -#' \code{tables}, conditional probabilities given the target label. +#' @return \code{summary} returns a list containing \code{labelCount}, \code{layers}, and +#' \code{weights}. For \code{weights}, it is a numeric vector with length equal to +#' the expected given the architecture (i.e., for 8-10-2 network, 100 connection weights). #' @rdname spark.mlp #' @export #' @aliases summary,MultilayerPerceptronClassificationModel-method @@ -732,7 +733,6 @@ setMethod("summary", signature(object = "MultilayerPerceptronClassificationModel labelCount <- callJMethod(jobj, "labelCount") layers <- unlist(callJMethod(jobj, "layers")) weights <- callJMethod(jobj, "weights") - weights <- matrix(weights, nrow = length(weights)) list(labelCount = labelCount, layers = layers, weights = weights) }) diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index ac896cfbcfff..5b1404c621bd 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -369,6 +369,8 @@ test_that("spark.mlp", { expect_equal(summary$labelCount, 3) expect_equal(summary$layers, c(4, 5, 4, 3)) expect_equal(length(summary$weights), 64) + expect_equal(head(summary$weights, 5), list(-0.878743, 0.2154151, -1.16304, -0.6583214, 1.009825), + tolerance = 1e-6) # Test predict method mlpTestDF <- df