Skip to content

Commit decbf6c

Browse files
committed
add test for rand seed
1 parent c9d3dc4 commit decbf6c

File tree

1 file changed

+10
-0
lines changed

1 file changed

+10
-0
lines changed

R/pkg/inst/tests/testthat/test_mllib.R

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,16 @@ test_that("spark.mlp", {
392392
expect_error(spark.mlp(df, layers = NULL), "layers must be a integer vector with length > 1.")
393393
expect_error(spark.mlp(df, layers = c()), "layers must be a integer vector with length > 1.")
394394
expect_error(spark.mlp(df, layers = c(3)), "layers must be a integer vector with length > 1.")
395+
396+
# Test random seed
397+
# default seed
398+
model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10)
399+
mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
400+
expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 2, 0, 1))
401+
# seed equals 10
402+
model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10)
403+
mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction"))
404+
expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1))
395405
})
396406

397407
test_that("spark.naiveBayes", {

0 commit comments

Comments
 (0)