diff --git a/.github/workflows/master.yml b/.github/workflows/master.yml
index 5eb00c4aba0f..d2b7dca3684f 100644
--- a/.github/workflows/master.yml
+++ b/.github/workflows/master.yml
@@ -50,7 +50,7 @@ jobs:
lint:
runs-on: ubuntu-latest
- name: Linters
+ name: Linters (Java/Scala/Python), licenses, dependencies
steps:
- uses: actions/checkout@master
- uses: actions/setup-java@v1
@@ -72,3 +72,26 @@ jobs:
run: ./dev/check-license
- name: Dependencies
run: ./dev/test-dependencies.sh
+
+ lintr:
+ runs-on: ubuntu-latest
+ name: Linter (R)
+ steps:
+ - uses: actions/checkout@master
+ - uses: actions/setup-java@v1
+ with:
+ java-version: '11'
+ - name: install R
+ run: |
+ echo 'deb https://cloud.r-project.org/bin/linux/ubuntu bionic-cran35/' | sudo tee -a /etc/apt/sources.list
+ sudo apt-key adv --keyserver keyserver.ubuntu.com --recv-keys E298A3A825C0D65DFD57CBB651716619E084DAB9
+ sudo apt-get update
+ sudo apt-get install -y r-base r-base-dev libcurl4-openssl-dev
+ - name: install R packages
+ run: |
+ sudo Rscript -e "install.packages(c('curl', 'xml2', 'httr', 'devtools', 'testthat', 'knitr', 'rmarkdown', 'roxygen2', 'e1071', 'survival'), repos='https://cloud.r-project.org/')"
+ sudo Rscript -e "devtools::install_github('jimhester/lintr@v2.0.0')"
+ - name: package and install SparkR
+ run: ./R/install-dev.sh
+ - name: lint-r
+ run: ./dev/lint-r
diff --git a/R/pkg/.lintr b/R/pkg/.lintr
index c83ad2adfe0e..67dc1218ea55 100644
--- a/R/pkg/.lintr
+++ b/R/pkg/.lintr
@@ -1,2 +1,2 @@
-linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, object_name_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE))
+linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, object_name_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE), object_usage_linter = NULL, cyclocomp_linter = NULL)
exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R")
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 6f3c7c120ba3..593d3ca16220 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -2252,7 +2252,7 @@ setMethod("mutate",
# The last column of the same name in the specific columns takes effect
deDupCols <- list()
- for (i in 1:length(cols)) {
+ for (i in seq_len(length(cols))) {
deDupCols[[ns[[i]]]] <- alias(cols[[i]], ns[[i]])
}
@@ -2416,7 +2416,7 @@ setMethod("arrange",
# builds a list of columns of type Column
# example: [[1]] Column Species ASC
# [[2]] Column Petal_Length DESC
- jcols <- lapply(seq_len(length(decreasing)), function(i){
+ jcols <- lapply(seq_len(length(decreasing)), function(i) {
if (decreasing[[i]]) {
desc(getColumn(x, by[[i]]))
} else {
@@ -2749,7 +2749,7 @@ genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) {
col <- getColumn(x, colName)
if (colName %in% intersectedColNames) {
newJoin <- paste(colName, suffix, sep = "")
- if (newJoin %in% allColNames){
+ if (newJoin %in% allColNames) {
stop("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.",
"Please use different suffixes for the intersected columns.")
}
@@ -3475,7 +3475,7 @@ setMethod("str",
cat(paste0("'", class(object), "': ", length(names), " variables:\n"))
if (nrow(localDF) > 0) {
- for (i in 1 : ncol(localDF)) {
+ for (i in seq_len(ncol(localDF))) {
# Get the first elements for each column
firstElements <- if (types[i] == "character") {
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index f27ef4ee28f1..f48a334ed676 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -166,9 +166,9 @@ writeToFileInArrow <- function(fileName, rdf, numPartitions) {
for (rdf_slice in rdf_slices) {
batch <- arrow::record_batch(rdf_slice)
if (is.null(stream_writer)) {
- stream <- arrow::FileOutputStream(fileName)
+ stream <- arrow::FileOutputStream$create(fileName)
schema <- batch$schema
- stream_writer <- arrow::RecordBatchStreamWriter(stream, schema)
+ stream_writer <- arrow::RecordBatchStreamWriter$create(stream, schema)
}
stream_writer$write_batch(batch)
@@ -197,7 +197,7 @@ getSchema <- function(schema, firstRow = NULL, rdd = NULL) {
as.list(schema)
}
if (is.null(names)) {
- names <- lapply(1:length(firstRow), function(x) {
+ names <- lapply(seq_len(length(firstRow)), function(x) {
paste0("_", as.character(x))
})
}
@@ -213,7 +213,7 @@ getSchema <- function(schema, firstRow = NULL, rdd = NULL) {
})
types <- lapply(firstRow, infer_type)
- fields <- lapply(1:length(firstRow), function(i) {
+ fields <- lapply(seq_len(length(firstRow)), function(i) {
structField(names[[i]], types[[i]], TRUE)
})
schema <- do.call(structType, fields)
diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R
index 93ba1307043a..d96a287f818a 100644
--- a/R/pkg/R/context.R
+++ b/R/pkg/R/context.R
@@ -416,7 +416,7 @@ spark.getSparkFiles <- function(fileName) {
#' @examples
#'\dontrun{
#' sparkR.session()
-#' doubled <- spark.lapply(1:10, function(x){2 * x})
+#' doubled <- spark.lapply(1:10, function(x) {2 * x})
#'}
#' @note spark.lapply since 2.0.0
spark.lapply <- function(list, func) {
diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R
index a6febb1cbd13..ca4a6e342d77 100644
--- a/R/pkg/R/deserialize.R
+++ b/R/pkg/R/deserialize.R
@@ -242,7 +242,7 @@ readDeserializeInArrow <- function(inputCon) {
# for now.
dataLen <- readInt(inputCon)
arrowData <- readBin(inputCon, raw(), as.integer(dataLen), endian = "big")
- batches <- arrow::RecordBatchStreamReader(arrowData)$batches()
+ batches <- arrow::RecordBatchStreamReader$create(arrowData)$batches()
if (useAsTibble) {
as_tibble <- get("as_tibble", envir = asNamespace("arrow"))
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 6e8f4dc3a790..2b7995e1e37f 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -162,7 +162,7 @@ methods <- c("avg", "max", "mean", "min", "sum")
#' @note pivot since 2.0.0
setMethod("pivot",
signature(x = "GroupedData", colname = "character"),
- function(x, colname, values = list()){
+ function(x, colname, values = list()) {
stopifnot(length(colname) == 1)
if (length(values) == 0) {
result <- callJMethod(x@sgd, "pivot", colname)
diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R
index c3501977e64b..a8c1ddb3dd20 100644
--- a/R/pkg/R/utils.R
+++ b/R/pkg/R/utils.R
@@ -131,7 +131,7 @@ hashCode <- function(key) {
} else {
asciiVals <- sapply(charToRaw(key), function(x) { strtoi(x, 16L) })
hashC <- 0
- for (k in 1:length(asciiVals)) {
+ for (k in seq_len(length(asciiVals))) {
hashC <- mult31AndAdd(hashC, asciiVals[k])
}
as.integer(hashC)
@@ -543,10 +543,14 @@ processClosure <- function(node, oldEnv, defVars, checkedFuncs, newEnv) {
funcList <- mget(nodeChar, envir = checkedFuncs, inherits = F,
ifnotfound = list(list(NULL)))[[1]]
found <- sapply(funcList, function(func) {
- ifelse(identical(func, obj), TRUE, FALSE)
+ ifelse(
+ identical(func, obj) &&
+ # Also check if the parent environment is identical to current parent
+ identical(parent.env(environment(func)), func.env),
+ TRUE, FALSE)
})
if (sum(found) > 0) {
- # If function has been examined, ignore.
+ # If function has been examined ignore
break
}
# Function has not been examined, record it and recursively clean its closure.
@@ -724,7 +728,7 @@ assignNewEnv <- function(data) {
stopifnot(length(cols) > 0)
env <- new.env()
- for (i in 1:length(cols)) {
+ for (i in seq_len(length(cols))) {
assign(x = cols[i], value = data[, cols[i], drop = F], envir = env)
}
env
@@ -750,7 +754,7 @@ launchScript <- function(script, combinedArgs, wait = FALSE, stdout = "", stderr
if (.Platform$OS.type == "windows") {
scriptWithArgs <- paste(script, combinedArgs, sep = " ")
# on Windows, intern = F seems to mean output to the console. (documentation on this is missing)
- shell(scriptWithArgs, translate = TRUE, wait = wait, intern = wait) # nolint
+ shell(scriptWithArgs, translate = TRUE, wait = wait, intern = wait)
} else {
# http://stat.ethz.ch/R-manual/R-devel/library/base/html/system2.html
# stdout = F means discard output
diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R
index dfe69b7f4f1f..1ef05ea621e8 100644
--- a/R/pkg/inst/worker/worker.R
+++ b/R/pkg/inst/worker/worker.R
@@ -194,7 +194,7 @@ if (isEmpty != 0) {
} else {
# gapply mode
outputs <- list()
- for (i in 1:length(data)) {
+ for (i in seq_len(length(data))) {
# Timing reading input data for execution
inputElap <- elapsedSecs()
output <- compute(mode, partition, serializer, deserializer, keys[[i]],
diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R
index c2b2458ec064..cb47353d600d 100644
--- a/R/pkg/tests/fulltests/test_sparkSQL.R
+++ b/R/pkg/tests/fulltests/test_sparkSQL.R
@@ -172,7 +172,7 @@ test_that("structField type strings", {
typeList <- c(primitiveTypes, complexTypes)
typeStrings <- names(typeList)
- for (i in seq_along(typeStrings)){
+ for (i in seq_along(typeStrings)) {
typeString <- typeStrings[i]
expected <- typeList[[i]]
testField <- structField("_col", typeString)
@@ -203,7 +203,7 @@ test_that("structField type strings", {
errorList <- c(primitiveErrors, complexErrors)
typeStrings <- names(errorList)
- for (i in seq_along(typeStrings)){
+ for (i in seq_along(typeStrings)) {
typeString <- typeStrings[i]
expected <- paste0("Unsupported type for SparkDataframe: ", errorList[[i]])
expect_error(structField("_col", typeString), expected)
diff --git a/R/pkg/tests/fulltests/test_utils.R b/R/pkg/tests/fulltests/test_utils.R
index b2b6f34aaa08..c4fcbecee18e 100644
--- a/R/pkg/tests/fulltests/test_utils.R
+++ b/R/pkg/tests/fulltests/test_utils.R
@@ -110,6 +110,15 @@ test_that("cleanClosure on R functions", {
actual <- get("y", envir = env, inherits = FALSE)
expect_equal(actual, y)
+ # Test for combination for nested and sequenctial functions in a closure
+ f1 <- function(x) x + 1
+ f2 <- function(x) f1(x) + 2
+ userFunc <- function(x) { f1(x); f2(x) }
+ cUserFuncEnv <- environment(cleanClosure(userFunc))
+ expect_equal(length(cUserFuncEnv), 2)
+ innerCUserFuncEnv <- environment(cUserFuncEnv$f2)
+ expect_equal(length(innerCUserFuncEnv), 1)
+
# Test for function (and variable) definitions.
f <- function(x) {
g <- function(y) { y * 2 }
diff --git a/R/run-tests.sh b/R/run-tests.sh
index 86bd8aad5f11..51ca7d600caf 100755
--- a/R/run-tests.sh
+++ b/R/run-tests.sh
@@ -23,7 +23,7 @@ FAILED=0
LOGFILE=$FWDIR/unit-tests.out
rm -f $LOGFILE
-SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
+SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" --conf spark.driver.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true" --conf spark.executor.extraJavaOptions="-Dio.netty.tryReflectionSetAccessible=true" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE
FAILED=$((PIPESTATUS[0]||$FAILED))
NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)"
diff --git a/appveyor.yml b/appveyor.yml
index b36175a787ae..00c688ba18eb 100644
--- a/appveyor.yml
+++ b/appveyor.yml
@@ -42,10 +42,7 @@ install:
# Install maven and dependencies
- ps: .\dev\appveyor-install-dependencies.ps1
# Required package for R unit tests
- - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'e1071', 'survival'), repos='https://cloud.r-project.org/')"
- # Use Arrow R 0.14.1 for now. 0.15.0 seems not working for now. See SPARK-29378.
- - cmd: R -e "install.packages(c('assertthat', 'bit64', 'fs', 'purrr', 'R6', 'tidyselect'), repos='https://cloud.r-project.org/')"
- - cmd: R -e "install.packages('https://cran.r-project.org/src/contrib/Archive/arrow/arrow_0.14.1.tar.gz', repos=NULL, type='source')"
+ - cmd: R -e "install.packages(c('knitr', 'rmarkdown', 'e1071', 'survival', 'arrow'), repos='https://cloud.r-project.org/')"
# Here, we use the fixed version of testthat. For more details, please see SPARK-22817.
# As of devtools 2.1.0, it requires testthat higher then 2.1.1 as a dependency. SparkR test requires testthat 1.0.2.
# Therefore, we don't use devtools but installs it directly from the archive including its dependencies.
diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
index 6397f26c02f3..01bf7eb2438a 100644
--- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
+++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CalendarIntervalSuite.java
@@ -46,36 +46,6 @@ public void equalsTest() {
assertEquals(i1, i6);
}
- @Test
- public void toStringTest() {
- CalendarInterval i;
-
- i = new CalendarInterval(0, 0, 0);
- assertEquals("0 seconds", i.toString());
-
- i = new CalendarInterval(34, 0, 0);
- assertEquals("2 years 10 months", i.toString());
-
- i = new CalendarInterval(-34, 0, 0);
- assertEquals("-2 years -10 months", i.toString());
-
- i = new CalendarInterval(0, 31, 0);
- assertEquals("31 days", i.toString());
-
- i = new CalendarInterval(0, -31, 0);
- assertEquals("-31 days", i.toString());
-
- i = new CalendarInterval(0, 0, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123);
- assertEquals("3 hours 13 minutes 0.000123 seconds", i.toString());
-
- i = new CalendarInterval(0, 0, -3 * MICROS_PER_HOUR - 13 * MICROS_PER_MINUTE - 123);
- assertEquals("-3 hours -13 minutes -0.000123 seconds", i.toString());
-
- i = new CalendarInterval(34, 31, 3 * MICROS_PER_HOUR + 13 * MICROS_PER_MINUTE + 123);
- assertEquals("2 years 10 months 31 days 3 hours 13 minutes 0.000123 seconds",
- i.toString());
- }
-
@Test
public void periodAndDurationTest() {
CalendarInterval interval = new CalendarInterval(120, -40, 123456);
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
index a13037b5e24d..77564f48015f 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -89,7 +89,12 @@ private[ui] class ExecutorThreadDumpPage(
Thread ID |
Thread Name |
Thread State |
- Thread Locks |
+
+
+ Thread Locks
+
+ |
{dumpRows}
diff --git a/dev/lint-r b/dev/lint-r
index bfda0bca15eb..b08f5efecd5d 100755
--- a/dev/lint-r
+++ b/dev/lint-r
@@ -17,6 +17,9 @@
# limitations under the License.
#
+set -o pipefail
+set -e
+
SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
SPARK_ROOT_DIR="$(dirname $SCRIPT_DIR)"
LINT_R_REPORT_FILE_NAME="$SPARK_ROOT_DIR/dev/lint-r-report.log"
@@ -24,7 +27,7 @@ LINT_R_REPORT_FILE_NAME="$SPARK_ROOT_DIR/dev/lint-r-report.log"
if ! type "Rscript" > /dev/null; then
echo "ERROR: You should install R"
- exit
+ exit 1
fi
`which Rscript` --vanilla "$SPARK_ROOT_DIR/dev/lint-r.R" "$SPARK_ROOT_DIR" | tee "$LINT_R_REPORT_FILE_NAME"
diff --git a/dev/lint-r.R b/dev/lint-r.R
index a4261d266bbc..7e165319e316 100644
--- a/dev/lint-r.R
+++ b/dev/lint-r.R
@@ -27,7 +27,7 @@ if (! library(SparkR, lib.loc = LOCAL_LIB_LOC, logical.return = TRUE)) {
# Installs lintr from Github in a local directory.
# NOTE: The CRAN's version is too old to adapt to our rules.
if ("lintr" %in% row.names(installed.packages()) == FALSE) {
- devtools::install_github("jimhester/lintr@5431140")
+ devtools::install_github("jimhester/lintr@v2.0.0")
}
library(lintr)
diff --git a/dev/run-tests.py b/dev/run-tests.py
index 82277720bb52..fc8b7251a85f 100755
--- a/dev/run-tests.py
+++ b/dev/run-tests.py
@@ -43,15 +43,20 @@ def determine_modules_for_files(filenames):
"""
Given a list of filenames, return the set of modules that contain those files.
If a file is not associated with a more specific submodule, then this method will consider that
- file to belong to the 'root' module.
+ file to belong to the 'root' module. GitHub Action and Appveyor files are ignored.
>>> sorted(x.name for x in determine_modules_for_files(["python/pyspark/a.py", "sql/core/foo"]))
['pyspark-core', 'sql']
>>> [x.name for x in determine_modules_for_files(["file_not_matched_by_any_subproject"])]
['root']
+ >>> [x.name for x in determine_modules_for_files( \
+ [".github/workflows/master.yml", "appveyor.yml"])]
+ []
"""
changed_modules = set()
for filename in filenames:
+ if filename in (".github/workflows/master.yml", "appveyor.yml"):
+ continue
matched_at_least_one_module = False
for module in modules.all_modules:
if module.contains_file(filename):
diff --git a/docs/configuration.md b/docs/configuration.md
index 97ea1fb4ba04..0c7cc6022eb0 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1857,6 +1857,51 @@ Apart from these, the following properties are also available, and may be useful
driver using more memory.
+
+ spark.scheduler.listenerbus.eventqueue.shared.capacity |
+ spark.scheduler.listenerbus.eventqueue.capacity |
+
+ Capacity for shared event queue in Spark listener bus, which hold events for external listener(s)
+ that register to the listener bus. Consider increasing value, if the listener events corresponding
+ to shared queue are dropped. Increasing this value may result in the driver using more memory.
+ |
+
+
+ spark.scheduler.listenerbus.eventqueue.appStatus.capacity |
+ spark.scheduler.listenerbus.eventqueue.capacity |
+
+ Capacity for appStatus event queue, which hold events for internal application status listeners.
+ Consider increasing value, if the listener events corresponding to appStatus queue are dropped.
+ Increasing this value may result in the driver using more memory.
+ |
+
+
+ spark.scheduler.listenerbus.eventqueue.executorManagement.capacity |
+ spark.scheduler.listenerbus.eventqueue.capacity |
+
+ Capacity for executorManagement event queue in Spark listener bus, which hold events for internal
+ executor management listeners. Consider increasing value if the listener events corresponding to
+ executorManagement queue are dropped. Increasing this value may result in the driver using more memory.
+ |
+
+
+ spark.scheduler.listenerbus.eventqueue.eventLog.capacity |
+ spark.scheduler.listenerbus.eventqueue.capacity |
+
+ Capacity for eventLog queue in Spark listener bus, which hold events for Event logging listeners
+ that write events to eventLogs. Consider increasing value if the listener events corresponding to eventLog queue
+ are dropped. Increasing this value may result in the driver using more memory.
+ |
+
+
+ spark.scheduler.listenerbus.eventqueue.streams.capacity |
+ spark.scheduler.listenerbus.eventqueue.capacity |
+
+ Capacity for streams queue in Spark listener bus, which hold events for internal streaming listener.
+ Consider increasing value if the listener events corresponding to streams queue are dropped. Increasing
+ this value may result in the driver using more memory.
+ |
+
spark.scheduler.blacklist.unschedulableTaskSetTimeout |
120s |
diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md
index b83b4ba08a5f..d8c7d8a72962 100644
--- a/docs/ml-classification-regression.md
+++ b/docs/ml-classification-regression.md
@@ -478,15 +478,16 @@ it computes the conditional probability distribution of each feature given each
For prediction, it applies Bayes' theorem to compute the conditional probability distribution
of each label given an observation.
-MLlib supports both [multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes)
-and [Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html).
+MLlib supports [Multinomial naive Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes),
+[Bernoulli naive Bayes](http://nlp.stanford.edu/IR-book/html/htmledition/the-bernoulli-model-1.html)
+and [Gaussian naive Bayes](https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Gaussian_naive_Bayes).
*Input data*:
-These models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
+These Multinomial and Bernoulli models are typically used for [document classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
Within that context, each observation is a document and each feature represents a term.
A feature's value is the frequency of the term (in multinomial Naive Bayes) or
a zero or one indicating whether the term was found in the document (in Bernoulli Naive Bayes).
-Feature values must be *non-negative*. The model type is selected with an optional parameter
+Feature values for Multinomial and Bernoulli models must be *non-negative*. The model type is selected with an optional parameter
"multinomial" or "bernoulli" with "multinomial" as the default.
For document classification, the input feature vectors should usually be sparse vectors.
Since the training data is only used once, it is not necessary to cache it.
diff --git a/docs/sql-keywords.md b/docs/sql-keywords.md
index 81d7ce37af17..79bc13459623 100644
--- a/docs/sql-keywords.md
+++ b/docs/sql-keywords.md
@@ -19,15 +19,15 @@ license: |
limitations under the License.
---
-When `spark.sql.ansi.enabled` is true, Spark SQL has two kinds of keywords:
+When `spark.sql.dialect.spark.ansi.enabled` is true, Spark SQL has two kinds of keywords:
* Reserved keywords: Keywords that are reserved and can't be used as identifiers for table, view, column, function, alias, etc.
* Non-reserved keywords: Keywords that have a special meaning only in particular contexts and can be used as identifiers in other contexts. For example, `SELECT 1 WEEK` is an interval literal, but WEEK can be used as identifiers in other places.
-When `spark.sql.ansi.enabled` is false, Spark SQL has two kinds of keywords:
-* Non-reserved keywords: Same definition as the one when `spark.sql.ansi.enabled=true`.
+When `spark.sql.dialect.spark.ansi.enabled` is false, Spark SQL has two kinds of keywords:
+* Non-reserved keywords: Same definition as the one when `spark.sql.dialect.spark.ansi.enabled=true`.
* Strict-non-reserved keywords: A strict version of non-reserved keywords, which can not be used as table alias.
-By default `spark.sql.ansi.enabled` is false.
+By default `spark.sql.dialect.spark.ansi.enabled` is false.
Below is a list of all the keywords in Spark SQL.
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
index f1cd3343b792..efd7ca74c796 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MsSqlServerIntegrationSuite.scala
@@ -59,7 +59,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
"""
|INSERT INTO numbers VALUES (
|0,
- |127, 32767, 2147483647, 9223372036854775807,
+ |255, 32767, 2147483647, 9223372036854775807,
|123456789012345.123456789012345, 123456789012345.123456789012345,
|123456789012345.123456789012345,
|123, 12345.12,
@@ -119,7 +119,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
val types = row.toSeq.map(x => x.getClass.toString)
assert(types.length == 12)
assert(types(0).equals("class java.lang.Boolean"))
- assert(types(1).equals("class java.lang.Byte"))
+ assert(types(1).equals("class java.lang.Integer"))
assert(types(2).equals("class java.lang.Short"))
assert(types(3).equals("class java.lang.Integer"))
assert(types(4).equals("class java.lang.Long"))
@@ -131,7 +131,7 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(types(10).equals("class java.math.BigDecimal"))
assert(types(11).equals("class java.math.BigDecimal"))
assert(row.getBoolean(0) == false)
- assert(row.getByte(1) == 127)
+ assert(row.getInt(1) == 255)
assert(row.getShort(2) == 32767)
assert(row.getInt(3) == 2147483647)
assert(row.getLong(4) == 9223372036854775807L)
@@ -202,46 +202,4 @@ class MsSqlServerIntegrationSuite extends DockerJDBCIntegrationSuite {
df2.write.jdbc(jdbcUrl, "datescopy", new Properties)
df3.write.jdbc(jdbcUrl, "stringscopy", new Properties)
}
-
- test("SPARK-29644: Write tables with ShortType") {
- import testImplicits._
- val df = Seq(-32768.toShort, 0.toShort, 1.toShort, 38.toShort, 32768.toShort).toDF("a")
- val tablename = "shorttable"
- df.write
- .format("jdbc")
- .mode("overwrite")
- .option("url", jdbcUrl)
- .option("dbtable", tablename)
- .save()
- val df2 = spark.read
- .format("jdbc")
- .option("url", jdbcUrl)
- .option("dbtable", tablename)
- .load()
- assert(df.count == df2.count)
- val rows = df2.collect()
- val colType = rows(0).toSeq.map(x => x.getClass.toString)
- assert(colType(0) == "class java.lang.Short")
- }
-
- test("SPARK-29644: Write tables with ByteType") {
- import testImplicits._
- val df = Seq(-127.toByte, 0.toByte, 1.toByte, 38.toByte, 128.toByte).toDF("a")
- val tablename = "bytetable"
- df.write
- .format("jdbc")
- .mode("overwrite")
- .option("url", jdbcUrl)
- .option("dbtable", tablename)
- .save()
- val df2 = spark.read
- .format("jdbc")
- .option("url", jdbcUrl)
- .option("dbtable", tablename)
- .load()
- assert(df.count == df2.count)
- val rows = df2.collect()
- val colType = rows(0).toSeq.map(x => x.getClass.toString)
- assert(colType(0) == "class java.lang.Byte")
- }
}
diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
index 8401b0a8a752..bba1b5275269 100644
--- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
+++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/MySQLIntegrationSuite.scala
@@ -84,7 +84,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(types.length == 9)
assert(types(0).equals("class java.lang.Boolean"))
assert(types(1).equals("class java.lang.Long"))
- assert(types(2).equals("class java.lang.Short"))
+ assert(types(2).equals("class java.lang.Integer"))
assert(types(3).equals("class java.lang.Integer"))
assert(types(4).equals("class java.lang.Integer"))
assert(types(5).equals("class java.lang.Long"))
@@ -93,7 +93,7 @@ class MySQLIntegrationSuite extends DockerJDBCIntegrationSuite {
assert(types(8).equals("class java.lang.Double"))
assert(rows(0).getBoolean(0) == false)
assert(rows(0).getLong(1) == 0x225)
- assert(rows(0).getShort(2) == 17)
+ assert(rows(0).getInt(2) == 17)
assert(rows(0).getInt(3) == 77777)
assert(rows(0).getInt(4) == 123456789)
assert(rows(0).getLong(5) == 123456789012345L)
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala
index 8e29e38b2a64..56c0fdd7c35b 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaBatchWrite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010
import java.{util => ju}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{BatchWrite, DataWriter, DataWriterFactory, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
import org.apache.spark.sql.types.StructType
@@ -40,7 +40,7 @@ private[kafka010] class KafkaBatchWrite(
validateQuery(schema.toAttributes, producerParams, topic)
- override def createBatchWriterFactory(): KafkaBatchWriterFactory =
+ override def createBatchWriterFactory(info: PhysicalWriteInfo): KafkaBatchWriterFactory =
KafkaBatchWriterFactory(topic, producerParams, schema)
override def commit(messages: Array[WriterCommitMessage]): Unit = {}
diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala
index 2b50b771e694..bcf9e3416f84 100644
--- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala
+++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.kafka010
import java.{util => ju}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.connector.write.{DataWriter, WriterCommitMessage}
+import org.apache.spark.sql.connector.write.{DataWriter, PhysicalWriteInfo, WriterCommitMessage}
import org.apache.spark.sql.connector.write.streaming.{StreamingDataWriterFactory, StreamingWrite}
import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery
import org.apache.spark.sql.types.StructType
@@ -41,7 +41,8 @@ private[kafka010] class KafkaStreamingWrite(
validateQuery(schema.toAttributes, producerParams, topic)
- override def createStreamingWriterFactory(): KafkaStreamWriterFactory =
+ override def createStreamingWriterFactory(
+ info: PhysicalWriteInfo): KafkaStreamWriterFactory =
KafkaStreamWriterFactory(topic, producerParams, schema)
override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index bcca40d159c9..806287079441 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -18,18 +18,22 @@
package org.apache.spark.ml.classification
import org.apache.hadoop.fs.Path
+import org.json4s.DefaultFormats
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.annotation.Since
import org.apache.spark.ml.PredictorParams
-import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.{DoubleParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.HasWeightCol
+import org.apache.spark.ml.stat.Summarizer
import org.apache.spark.ml.util._
import org.apache.spark.ml.util.Instrumentation.instrumented
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.VersionUtils
/**
* Params for Naive Bayes Classifiers.
@@ -49,12 +53,13 @@ private[classification] trait NaiveBayesParams extends PredictorParams with HasW
/**
* The model type which is a string (case-sensitive).
- * Supported options: "multinomial" and "bernoulli".
+ * Supported options: "multinomial", "bernoulli", "gaussian".
* (default = multinomial)
* @group param
*/
final val modelType: Param[String] = new Param[String](this, "modelType", "The model type " +
- "which is a string (case-sensitive). Supported options: multinomial (default) and bernoulli.",
+ "which is a string (case-sensitive). Supported options: multinomial (default), bernoulli" +
+ " and gaussian.",
ParamValidators.inArray[String](NaiveBayes.supportedModelTypes.toArray))
/** @group getParam */
@@ -72,7 +77,11 @@ private[classification] trait NaiveBayesParams extends PredictorParams with HasW
* binary (0/1) data, it can also be used as Bernoulli NB
* (see
* here).
- * The input feature values must be nonnegative.
+ * The input feature values for Multinomial NB and Bernoulli NB must be nonnegative.
+ * Since 3.0.0, it also supports Gaussian NB
+ * (see
+ * here)
+ * which can handle continuous data.
*/
// scalastyle:on line.size.limit
@Since("1.5.0")
@@ -103,7 +112,7 @@ class NaiveBayes @Since("1.5.0") (
*/
@Since("1.5.0")
def setModelType(value: String): this.type = set(modelType, value)
- setDefault(modelType -> NaiveBayes.Multinomial)
+ setDefault(modelType -> Multinomial)
/**
* Sets the value of param [[weightCol]].
@@ -130,6 +139,9 @@ class NaiveBayes @Since("1.5.0") (
positiveLabel: Boolean): NaiveBayesModel = instrumented { instr =>
instr.logPipelineStage(this)
instr.logDataset(dataset)
+ instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
+ probabilityCol, modelType, smoothing, thresholds)
+
if (positiveLabel && isDefined(thresholds)) {
val numClasses = getNumClasses(dataset)
instr.logNumClasses(numClasses)
@@ -138,44 +150,55 @@ class NaiveBayes @Since("1.5.0") (
s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
}
- val validateInstance = $(modelType) match {
- case Multinomial =>
- (instance: Instance) => requireNonnegativeValues(instance.features)
- case Bernoulli =>
- (instance: Instance) => requireZeroOneBernoulliValues(instance.features)
+ $(modelType) match {
+ case Bernoulli | Multinomial =>
+ trainDiscreteImpl(dataset, instr)
+ case Gaussian =>
+ trainGaussianImpl(dataset, instr)
case _ =>
// This should never happen.
throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
+ }
- instr.logParams(this, labelCol, featuresCol, weightCol, predictionCol, rawPredictionCol,
- probabilityCol, modelType, smoothing, thresholds)
+ private def trainDiscreteImpl(
+ dataset: Dataset[_],
+ instr: Instrumentation): NaiveBayesModel = {
+ val spark = dataset.sparkSession
+ import spark.implicits._
- val numFeatures = dataset.select(col($(featuresCol))).head().getAs[Vector](0).size
- instr.logNumFeatures(numFeatures)
+ val validateUDF = $(modelType) match {
+ case Multinomial =>
+ udf { vector: Vector => requireNonnegativeValues(vector); vector }
+ case Bernoulli =>
+ udf { vector: Vector => requireZeroOneBernoulliValues(vector); vector }
+ }
+
+ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
+ col($(weightCol)).cast(DoubleType)
+ } else {
+ lit(1.0)
+ }
// Aggregates term frequencies per label.
- // TODO: Calling aggregateByKey and collect creates two stages, we can implement something
- // TODO: similar to reduceByKeyLocally to save one stage.
- val aggregated = extractInstances(dataset, validateInstance).map { instance =>
- (instance.label, (instance.weight, instance.features))
- }.aggregateByKey[(Double, DenseVector, Long)]((0.0, Vectors.zeros(numFeatures).toDense, 0L))(
- seqOp = {
- case ((weightSum, featureSum, count), (weight, features)) =>
- BLAS.axpy(weight, features, featureSum)
- (weightSum + weight, featureSum, count + 1)
- },
- combOp = {
- case ((weightSum1, featureSum1, count1), (weightSum2, featureSum2, count2)) =>
- BLAS.axpy(1.0, featureSum2, featureSum1)
- (weightSum1 + weightSum2, featureSum1, count1 + count2)
- }).collect().sortBy(_._1)
-
- val numSamples = aggregated.map(_._2._3).sum
+ // TODO: Summarizer directly returns sum vector.
+ val aggregated = dataset.groupBy(col($(labelCol)))
+ .agg(sum(w).as("weightSum"), Summarizer.metrics("mean", "count")
+ .summary(validateUDF(col($(featuresCol))), w).as("summary"))
+ .select($(labelCol), "weightSum", "summary.mean", "summary.count")
+ .as[(Double, Double, Vector, Long)]
+ .map { case (label, weightSum, mean, count) =>
+ BLAS.scal(weightSum, mean)
+ (label, weightSum, mean, count)
+ }.collect().sortBy(_._1)
+
+ val numFeatures = aggregated.head._3.size
+ instr.logNumFeatures(numFeatures)
+ val numSamples = aggregated.map(_._4).sum
instr.logNumExamples(numSamples)
val numLabels = aggregated.length
instr.logNumClasses(numLabels)
- val numDocuments = aggregated.map(_._2._1).sum
+ val numDocuments = aggregated.map(_._2).sum
val labelArray = new Array[Double](numLabels)
val piArray = new Array[Double](numLabels)
@@ -184,19 +207,17 @@ class NaiveBayes @Since("1.5.0") (
val lambda = $(smoothing)
val piLogDenom = math.log(numDocuments + numLabels * lambda)
var i = 0
- aggregated.foreach { case (label, (n, sumTermFreqs, _)) =>
+ aggregated.foreach { case (label, n, sumTermFreqs, _) =>
labelArray(i) = label
piArray(i) = math.log(n + lambda) - piLogDenom
val thetaLogDenom = $(modelType) match {
- case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
+ case Multinomial => math.log(sumTermFreqs.toArray.sum + numFeatures * lambda)
case Bernoulli => math.log(n + 2.0 * lambda)
- case _ =>
- // This should never happen.
- throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
}
var j = 0
+ val offset = i * numFeatures
while (j < numFeatures) {
- thetaArray(i * numFeatures + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
+ thetaArray(offset + j) = math.log(sumTermFreqs(j) + lambda) - thetaLogDenom
j += 1
}
i += 1
@@ -204,7 +225,86 @@ class NaiveBayes @Since("1.5.0") (
val pi = Vectors.dense(piArray)
val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
- new NaiveBayesModel(uid, pi, theta).setOldLabels(labelArray)
+ new NaiveBayesModel(uid, pi.compressed, theta.compressed, null)
+ .setOldLabels(labelArray)
+ }
+
+ private def trainGaussianImpl(
+ dataset: Dataset[_],
+ instr: Instrumentation): NaiveBayesModel = {
+ val spark = dataset.sparkSession
+ import spark.implicits._
+
+ val w = if (isDefined(weightCol) && $(weightCol).nonEmpty) {
+ col($(weightCol)).cast(DoubleType)
+ } else {
+ lit(1.0)
+ }
+
+ // Aggregates mean vector and square-sum vector per label.
+ // TODO: Summarizer directly returns square-sum vector.
+ val aggregated = dataset.groupBy(col($(labelCol)))
+ .agg(sum(w).as("weightSum"), Summarizer.metrics("mean", "normL2")
+ .summary(col($(featuresCol)), w).as("summary"))
+ .select($(labelCol), "weightSum", "summary.mean", "summary.normL2")
+ .as[(Double, Double, Vector, Vector)]
+ .map { case (label, weightSum, mean, normL2) =>
+ (label, weightSum, mean, Vectors.dense(normL2.toArray.map(v => v * v)))
+ }.collect().sortBy(_._1)
+
+ val numFeatures = aggregated.head._3.size
+ instr.logNumFeatures(numFeatures)
+
+ val numLabels = aggregated.length
+ instr.logNumClasses(numLabels)
+
+ val numInstances = aggregated.map(_._2).sum
+
+ // If the ratio of data variance between dimensions is too small, it
+ // will cause numerical errors. To address this, we artificially
+ // boost the variance by epsilon, a small fraction of the standard
+ // deviation of the largest dimension.
+ // Refer to scikit-learn's implementation
+ // [https://github.com/scikit-learn/scikit-learn/blob/0.21.X/sklearn/naive_bayes.py#L348]
+ // and discussion [https://github.com/scikit-learn/scikit-learn/pull/5349] for detail.
+ val epsilon = Iterator.range(0, numFeatures).map { j =>
+ var globalSum = 0.0
+ var globalSqrSum = 0.0
+ aggregated.foreach { case (_, weightSum, mean, squareSum) =>
+ globalSum += mean(j) * weightSum
+ globalSqrSum += squareSum(j)
+ }
+ globalSqrSum / numInstances -
+ globalSum * globalSum / numInstances / numInstances
+ }.max * 1e-9
+
+ val piArray = new Array[Double](numLabels)
+
+ // thetaArray in Gaussian NB store the means of features per label
+ val thetaArray = new Array[Double](numLabels * numFeatures)
+
+ // thetaArray in Gaussian NB store the variances of features per label
+ val sigmaArray = new Array[Double](numLabels * numFeatures)
+
+ var i = 0
+ val logNumInstances = math.log(numInstances)
+ aggregated.foreach { case (_, weightSum, mean, squareSum) =>
+ piArray(i) = math.log(weightSum) - logNumInstances
+ var j = 0
+ val offset = i * numFeatures
+ while (j < numFeatures) {
+ val m = mean(j)
+ thetaArray(offset + j) = m
+ sigmaArray(offset + j) = epsilon + squareSum(j) / weightSum - m * m
+ j += 1
+ }
+ i += 1
+ }
+
+ val pi = Vectors.dense(piArray)
+ val theta = new DenseMatrix(numLabels, numFeatures, thetaArray, true)
+ val sigma = new DenseMatrix(numLabels, numFeatures, sigmaArray, true)
+ new NaiveBayesModel(uid, pi.compressed, theta.compressed, sigma.compressed)
}
@Since("1.5.0")
@@ -219,8 +319,11 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
/** String name for Bernoulli model type. */
private[classification] val Bernoulli: String = "bernoulli"
+ /** String name for Gaussian model type. */
+ private[classification] val Gaussian: String = "gaussian"
+
/* Set of modelTypes that NaiveBayes supports */
- private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
+ private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli, Gaussian)
private[NaiveBayes] def requireNonnegativeValues(v: Vector): Unit = {
val values = v match {
@@ -248,19 +351,24 @@ object NaiveBayes extends DefaultParamsReadable[NaiveBayes] {
/**
* Model produced by [[NaiveBayes]]
- * @param pi log of class priors, whose dimension is C (number of classes)
+ *
+ * @param pi log of class priors, whose dimension is C (number of classes)
* @param theta log of class conditional probabilities, whose dimension is C (number of classes)
* by D (number of features)
+ * @param sigma variance of each feature, whose dimension is C (number of classes)
+ * by D (number of features). This matrix is only available when modelType
+ * is set Gaussian.
*/
@Since("1.5.0")
class NaiveBayesModel private[ml] (
@Since("1.5.0") override val uid: String,
@Since("2.0.0") val pi: Vector,
- @Since("2.0.0") val theta: Matrix)
+ @Since("2.0.0") val theta: Matrix,
+ @Since("3.0.0") val sigma: Matrix)
extends ProbabilisticClassificationModel[Vector, NaiveBayesModel]
with NaiveBayesParams with MLWritable {
- import NaiveBayes.{Bernoulli, Multinomial}
+ import NaiveBayes.{Bernoulli, Multinomial, Gaussian}
/**
* mllib NaiveBayes is a wrapper of ml implementation currently.
@@ -280,18 +388,36 @@ class NaiveBayesModel private[ml] (
* This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
* application of this condition (in predict function).
*/
- private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match {
- case Multinomial => (None, None)
+ @transient private lazy val (thetaMinusNegTheta, negThetaSum) = $(modelType) match {
case Bernoulli =>
val negTheta = theta.map(value => math.log1p(-math.exp(value)))
val ones = new DenseVector(Array.fill(theta.numCols) {1.0})
val thetaMinusNegTheta = theta.map { value =>
value - math.log1p(-math.exp(value))
}
- (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
+ (thetaMinusNegTheta, negTheta.multiply(ones))
+ case _ =>
+ // This should never happen.
+ throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}. " +
+ "Variables thetaMinusNegTheta and negThetaSum should only be precomputed in Bernoulli NB.")
+ }
+
+ /**
+ * Gaussian scoring requires sum of log(Variance).
+ * This precomputes sum of log(Variance) which are used for the linear algebra
+ * application of this condition (in predict function).
+ */
+ @transient private lazy val logVarSum = $(modelType) match {
+ case Gaussian =>
+ Array.tabulate(numClasses) { i =>
+ Iterator.range(0, numFeatures).map { j =>
+ math.log(sigma(i, j))
+ }.sum
+ }
case _ =>
// This should never happen.
- throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
+ throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}. " +
+ "Variables logVarSum should only be precomputed in Gaussian NB.")
}
@Since("1.6.0")
@@ -311,24 +437,42 @@ class NaiveBayesModel private[ml] (
require(value == 0.0 || value == 1.0,
s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.")
)
- val prob = thetaMinusNegTheta.get.multiply(features)
+ val prob = thetaMinusNegTheta.multiply(features)
BLAS.axpy(1.0, pi, prob)
- BLAS.axpy(1.0, negThetaSum.get, prob)
+ BLAS.axpy(1.0, negThetaSum, prob)
prob
}
- override protected def predictRaw(features: Vector): Vector = {
+ private def gaussianCalculation(features: Vector) = {
+ val prob = Array.ofDim[Double](numClasses)
+ var i = 0
+ while (i < numClasses) {
+ var s = 0.0
+ var j = 0
+ while (j < numFeatures) {
+ val d = features(j) - theta(i, j)
+ s += d * d / sigma(i, j)
+ j += 1
+ }
+ prob(i) = pi(i) - (s + logVarSum(i)) / 2
+ i += 1
+ }
+ Vectors.dense(prob)
+ }
+
+ @transient private lazy val predictRawFunc = {
$(modelType) match {
case Multinomial =>
- multinomialCalculation(features)
+ features: Vector => multinomialCalculation(features)
case Bernoulli =>
- bernoulliCalculation(features)
- case _ =>
- // This should never happen.
- throw new IllegalArgumentException(s"Invalid modelType: ${$(modelType)}.")
+ features: Vector => bernoulliCalculation(features)
+ case Gaussian =>
+ features: Vector => gaussianCalculation(features)
}
}
+ override protected def predictRaw(features: Vector): Vector = predictRawFunc(features)
+
override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
rawPrediction match {
case dv: DenseVector =>
@@ -354,7 +498,7 @@ class NaiveBayesModel private[ml] (
@Since("1.5.0")
override def copy(extra: ParamMap): NaiveBayesModel = {
- copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
+ copyValues(new NaiveBayesModel(uid, pi, theta, sigma).setParent(this.parent), extra)
}
@Since("1.5.0")
@@ -378,34 +522,61 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
/** [[MLWriter]] instance for [[NaiveBayesModel]] */
private[NaiveBayesModel] class NaiveBayesModelWriter(instance: NaiveBayesModel) extends MLWriter {
+ import NaiveBayes._
private case class Data(pi: Vector, theta: Matrix)
+ private case class GaussianData(pi: Vector, theta: Matrix, sigma: Matrix)
override protected def saveImpl(path: String): Unit = {
// Save metadata and Params
DefaultParamsWriter.saveMetadata(instance, path, sc)
- // Save model data: pi, theta
- val data = Data(instance.pi, instance.theta)
val dataPath = new Path(path, "data").toString
- sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+
+ instance.getModelType match {
+ case Multinomial | Bernoulli =>
+ // Save model data: pi, theta
+ require(instance.sigma == null)
+ val data = Data(instance.pi, instance.theta)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+
+ case Gaussian =>
+ require(instance.sigma != null)
+ val data = GaussianData(instance.pi, instance.theta, instance.sigma)
+ sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+ }
}
}
private class NaiveBayesModelReader extends MLReader[NaiveBayesModel] {
+ import NaiveBayes._
/** Checked against metadata when loading model */
private val className = classOf[NaiveBayesModel].getName
override def load(path: String): NaiveBayesModel = {
+ implicit val format = DefaultFormats
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
+ val modelTypeJson = metadata.getParamValue("modelType")
+ val modelType = Param.jsonDecode[String](compact(render(modelTypeJson)))
val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath)
val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
- val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
- .select("pi", "theta")
- .head()
- val model = new NaiveBayesModel(metadata.uid, pi, theta)
+
+ val model = if (major.toInt < 3 || modelType != Gaussian) {
+ val Row(pi: Vector, theta: Matrix) =
+ MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
+ .select("pi", "theta")
+ .head()
+ new NaiveBayesModel(metadata.uid, pi, theta, null)
+ } else {
+ val Row(pi: Vector, theta: Matrix, sigma: Matrix) =
+ MLUtils.convertMatrixColumnsToML(vecConverted, "theta", "sigma")
+ .select("pi", "theta", "sigma")
+ .head()
+ new NaiveBayesModel(metadata.uid, pi, theta, sigma)
+ }
metadata.getAndSetParams(model)
model
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 9100ef1db6e1..9e4844ff8907 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -22,15 +22,15 @@ import scala.util.Random
import breeze.linalg.{DenseVector => BDV, Vector => BV}
import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis}
-import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial}
+import org.apache.spark.SparkException
+import org.apache.spark.ml.classification.NaiveBayes._
import org.apache.spark.ml.classification.NaiveBayesSuite._
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{Dataset, Row}
class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
@@ -38,6 +38,8 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
@transient var dataset: Dataset[_] = _
@transient var bernoulliDataset: Dataset[_] = _
+ @transient var gaussianDataset: Dataset[_] = _
+ @transient var gaussianDataset2: Dataset[_] = _
private val seed = 42
@@ -53,6 +55,23 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
dataset = generateNaiveBayesInput(pi, theta, 100, seed).toDF()
bernoulliDataset = generateNaiveBayesInput(pi, theta, 100, seed, "bernoulli").toDF()
+
+ // theta for gaussian nb
+ val theta2 = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0: mean
+ Array(0.10, 0.70, 0.10, 0.10), // label 1: mean
+ Array(0.10, 0.10, 0.70, 0.10) // label 2: mean
+ )
+
+ // sigma for gaussian nb
+ val sigma = Array(
+ Array(0.10, 0.10, 0.50, 0.10), // label 0: variance
+ Array(0.50, 0.10, 0.10, 0.10), // label 1: variance
+ Array(0.10, 0.10, 0.10, 0.50) // label 2: variance
+ )
+ gaussianDataset = generateGaussianNaiveBayesInput(pi, theta2, sigma, 1000, seed).toDF()
+ gaussianDataset2 = spark.read.format("libsvm")
+ .load("../data/mllib/sample_multiclass_classification_data.txt")
}
def validatePrediction(predictionAndLabels: Seq[Row]): Unit = {
@@ -67,10 +86,17 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
def validateModelFit(
piData: Vector,
thetaData: Matrix,
+ sigmaData: Matrix,
model: NaiveBayesModel): Unit = {
assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~==
Vectors.dense(piData.toArray.map(math.exp)) absTol 0.05, "pi mismatch")
assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
+ if (sigmaData == null) {
+ assert(model.sigma == null, "sigma mismatch")
+ } else {
+ assert(model.sigma.map(math.exp) ~== sigmaData.map(math.exp) absTol 0.05,
+ "sigma mismatch")
+ }
}
def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
@@ -90,6 +116,19 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
Vectors.dense(classProbs.map(_ / classProbsSum))
}
+ def expectedGaussianProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+ val pi = model.pi.toArray.map(math.exp)
+ val classProbs = pi.indices.map { i =>
+ feature.toArray.zipWithIndex.map { case (v, j) =>
+ val mean = model.theta(i, j)
+ val variance = model.sigma(i, j)
+ math.exp(- (v - mean) * (v - mean) / variance / 2) / math.sqrt(variance * math.Pi * 2)
+ }.product * pi(i)
+ }.toArray
+ val classProbsSum = classProbs.sum
+ Vectors.dense(classProbs.map(_ / classProbsSum))
+ }
+
def validateProbabilities(
featureAndProbabilities: Seq[Row],
model: NaiveBayesModel,
@@ -102,6 +141,8 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
expectedMultinomialProbabilities(model, features)
case Bernoulli =>
expectedBernoulliProbabilities(model, features)
+ case Gaussian =>
+ expectedGaussianProbabilities(model, features)
case _ =>
throw new IllegalArgumentException(s"Invalid modelType: $modelType.")
}
@@ -112,12 +153,14 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
test("model types") {
assert(Multinomial === "multinomial")
assert(Bernoulli === "bernoulli")
+ assert(Gaussian === "gaussian")
}
test("params") {
ParamsSuite.checkParams(new NaiveBayes)
val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
- theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)))
+ theta = new DenseMatrix(2, 3, Array(0.1, 0.2, 0.3, 0.4, 0.6, 0.4)),
+ sigma = null)
ParamsSuite.checkParams(model)
}
@@ -146,7 +189,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("multinomial")
val model = nb.fit(testDataset)
- validateModelFit(pi, theta, model)
+ validateModelFit(pi, theta, null, model)
assert(model.hasParent)
MLTestingUtils.checkCopyAndUids(nb, model)
@@ -192,12 +235,17 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
test("Naive Bayes with weighted samples") {
val numClasses = 3
def modelEquals(m1: NaiveBayesModel, m2: NaiveBayesModel): Unit = {
+ assert(m1.getModelType === m2.getModelType)
assert(m1.pi ~== m2.pi relTol 0.01)
assert(m1.theta ~== m2.theta relTol 0.01)
+ if (m1.getModelType == Gaussian) {
+ assert(m1.sigma ~== m2.sigma relTol 0.01)
+ }
}
val testParams = Seq[(String, Dataset[_])](
("bernoulli", bernoulliDataset),
- ("multinomial", dataset)
+ ("multinomial", dataset),
+ ("gaussian", gaussianDataset)
)
testParams.foreach { case (family, dataset) =>
// NaiveBayes is sensitive to constant scaling of the weights unless smoothing is set to 0
@@ -228,7 +276,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
val nb = new NaiveBayes().setSmoothing(1.0).setModelType("bernoulli")
val model = nb.fit(testDataset)
- validateModelFit(pi, theta, model)
+ validateModelFit(pi, theta, null, model)
assert(model.hasParent)
val validationDataset =
@@ -308,14 +356,112 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
}
}
+ test("Naive Bayes Gaussian") {
+ val piArray = Array(0.5, 0.1, 0.4).map(math.log)
+
+ val thetaArray = Array(
+ Array(0.70, 0.10, 0.10, 0.10), // label 0: mean
+ Array(0.10, 0.70, 0.10, 0.10), // label 1: mean
+ Array(0.10, 0.10, 0.70, 0.10) // label 2: mean
+ )
+
+ val sigmaArray = Array(
+ Array(0.10, 0.10, 0.50, 0.10), // label 0: variance
+ Array(0.50, 0.10, 0.10, 0.10), // label 1: variance
+ Array(0.10, 0.10, 0.10, 0.50) // label 2: variance
+ )
+
+ val pi = Vectors.dense(piArray)
+ val theta = new DenseMatrix(3, 4, thetaArray.flatten, true)
+ val sigma = new DenseMatrix(3, 4, sigmaArray.flatten, true)
+
+ val nPoints = 10000
+ val testDataset =
+ generateGaussianNaiveBayesInput(piArray, thetaArray, sigmaArray, nPoints, 42).toDF()
+ val gnb = new NaiveBayes().setModelType("gaussian")
+ val model = gnb.fit(testDataset)
+
+ validateModelFit(pi, theta, sigma, model)
+ assert(model.hasParent)
+
+ val validationDataset =
+ generateGaussianNaiveBayesInput(piArray, thetaArray, sigmaArray, nPoints, 17).toDF()
+
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+ validatePrediction(predictionAndLabels.collect())
+
+ val featureAndProbabilities = model.transform(validationDataset)
+ .select("features", "probability")
+ validateProbabilities(featureAndProbabilities.collect(), model, "gaussian")
+ }
+
+ test("Naive Bayes Gaussian - Model Coefficients") {
+ /*
+ Using the following Python code to verify the correctness.
+
+ import numpy as np
+ from sklearn.naive_bayes import GaussianNB
+ from sklearn.datasets import load_svmlight_file
+
+ path = "./data/mllib/sample_multiclass_classification_data.txt"
+ X, y = load_svmlight_file(path)
+ X = X.toarray()
+ clf = GaussianNB()
+ clf.fit(X, y)
+
+ >>> clf.class_prior_
+ array([0.33333333, 0.33333333, 0.33333333])
+ >>> clf.theta_
+ array([[ 0.27111101, -0.18833335, 0.54305072, 0.60500005],
+ [-0.60777778, 0.18166667, -0.84271174, -0.88000014],
+ [-0.09111114, -0.35833336, 0.10508474, 0.0216667 ]])
+ >>> clf.sigma_
+ array([[0.12230125, 0.07078052, 0.03430001, 0.05133607],
+ [0.03758145, 0.0988028 , 0.0033903 , 0.00782224],
+ [0.08058764, 0.06701387, 0.02486641, 0.02661392]])
+ */
+
+ val gnb = new NaiveBayes().setModelType(Gaussian)
+ val model = gnb.fit(gaussianDataset2)
+ assert(Vectors.dense(model.pi.toArray.map(math.exp)) ~=
+ Vectors.dense(0.33333333, 0.33333333, 0.33333333) relTol 1E-5)
+
+ val thetaRows = model.theta.rowIter.toArray
+ assert(thetaRows(0) ~=
+ Vectors.dense(0.27111101, -0.18833335, 0.54305072, 0.60500005)relTol 1E-5)
+ assert(thetaRows(1) ~=
+ Vectors.dense(-0.60777778, 0.18166667, -0.84271174, -0.88000014)relTol 1E-5)
+ assert(thetaRows(2) ~=
+ Vectors.dense(-0.09111114, -0.35833336, 0.10508474, 0.0216667)relTol 1E-5)
+
+ val sigmaRows = model.sigma.rowIter.toArray
+ assert(sigmaRows(0) ~=
+ Vectors.dense(0.12230125, 0.07078052, 0.03430001, 0.05133607)relTol 1E-5)
+ assert(sigmaRows(1) ~=
+ Vectors.dense(0.03758145, 0.0988028, 0.0033903, 0.00782224)relTol 1E-5)
+ assert(sigmaRows(2) ~=
+ Vectors.dense(0.08058764, 0.06701387, 0.02486641, 0.02661392)relTol 1E-5)
+ }
+
test("read/write") {
def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = {
+ assert(model.getModelType === model2.getModelType)
assert(model.pi === model2.pi)
assert(model.theta === model2.theta)
+ if (model.getModelType == "gaussian") {
+ assert(model.sigma === model2.sigma)
+ } else {
+ assert(model.sigma === null && model2.sigma === null)
+ }
}
val nb = new NaiveBayes()
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings,
NaiveBayesSuite.allParamSettings, checkModelData)
+
+ val gnb = new NaiveBayes().setModelType("gaussian")
+ testEstimatorAndModelReadWrite(gnb, gaussianDataset,
+ NaiveBayesSuite.allParamSettingsForGaussian,
+ NaiveBayesSuite.allParamSettingsForGaussian, checkModelData)
}
test("should support all NumericType labels and weights, and not support other types") {
@@ -324,6 +470,7 @@ class NaiveBayesSuite extends MLTest with DefaultReadWriteTest {
nb, spark) { (expected, actual) =>
assert(expected.pi === actual.pi)
assert(expected.theta === actual.theta)
+ assert(expected.sigma === null && actual.sigma === null)
}
}
}
@@ -340,6 +487,16 @@ object NaiveBayesSuite {
"smoothing" -> 0.1
)
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allParamSettingsForGaussian: Map[String, Any] = Map(
+ "predictionCol" -> "myPrediction",
+ "modelType" -> "gaussian"
+ )
+
private def calcLabel(p: Double, pi: Array[Double]): Int = {
var sum = 0.0
for (j <- 0 until pi.length) {
@@ -384,4 +541,26 @@ object NaiveBayesSuite {
LabeledPoint(y, Vectors.dense(xi))
}
}
+
+ // Generate input
+ def generateGaussianNaiveBayesInput(
+ pi: Array[Double], // 1XC
+ theta: Array[Array[Double]], // CXD
+ sigma: Array[Array[Double]], // CXD
+ nPoints: Int,
+ seed: Int): Seq[LabeledPoint] = {
+ val D = theta(0).length
+ val rnd = new Random(seed)
+ val _pi = pi.map(math.exp)
+
+ for (i <- 0 until nPoints) yield {
+ val y = calcLabel(rnd.nextDouble(), _pi)
+ val xi = Array.tabulate[Double] (D) { j =>
+ val mean = theta(y)(j)
+ val variance = sigma(y)(j)
+ mean + rnd.nextGaussian() * math.sqrt(variance)
+ }
+ LabeledPoint(y, Vectors.dense(xi))
+ }
+ }
}
diff --git a/pom.xml b/pom.xml
index a6a82b3339d0..44593b78c9a0 100644
--- a/pom.xml
+++ b/pom.xml
@@ -2326,7 +2326,7 @@
**/*Suite.java
${project.build.directory}/surefire-reports
- -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize}
+ -ea -Xmx4g -Xss4m -XX:ReservedCodeCacheSize=${CodeCacheSize} -Dio.netty.tryReflectionSetAccessible=true
- -da -Xmx4g -XX:ReservedCodeCacheSize=${CodeCacheSize}
+ -da -Xmx4g -XX:ReservedCodeCacheSize=${CodeCacheSize} -Dio.netty.tryReflectionSetAccessible=true