diff --git a/LICENSE b/LICENSE index 7950dd6ceb6d..c21032a1fd27 100644 --- a/LICENSE +++ b/LICENSE @@ -297,3 +297,4 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (MIT License) RowsGroup (http://datatables.net/license/mit) (MIT License) jsonFormatter (http://www.jqueryscript.net/other/jQuery-Plugin-For-Pretty-JSON-Formatting-jsonFormatter.html) (MIT License) modernizr (https://github.com/Modernizr/Modernizr/blob/master/LICENSE) + (MIT License) machinist (https://github.com/typelevel/machinist) diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 879c1f80f2c5..cfa49b94c952 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 2.2.0 +Version: 2.2.1 Title: R Frontend for Apache Spark Description: The SparkR package provides an R Frontend for Apache Spark. Authors@R: c(person("Shivaram", "Venkataraman", role = c("aut", "cre"), diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ca45c6f9b0a9..44e39c4abb47 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -122,6 +122,7 @@ exportMethods("arrange", "group_by", "groupBy", "head", + "hint", "insertInto", "intersect", "isLocal", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 88a138fd8eb1..a7b1e3b6ae32 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3642,3 +3642,33 @@ setMethod("checkpoint", df <- callJMethod(x@sdf, "checkpoint", as.logical(eager)) dataFrame(df) }) + +#' hint +#' +#' Specifies execution plan hint and return a new SparkDataFrame. +#' +#' @param x a SparkDataFrame. +#' @param name a name of the hint. +#' @param ... optional parameters for the hint. +#' @return A SparkDataFrame. +#' @family SparkDataFrame functions +#' @aliases hint,SparkDataFrame,character-method +#' @rdname hint +#' @name hint +#' @export +#' @examples +#' \dontrun{ +#' df <- createDataFrame(mtcars) +#' avg_mpg <- mean(groupBy(createDataFrame(mtcars), "cyl"), "mpg") +#' +#' head(join(df, hint(avg_mpg, "broadcast"), df$cyl == avg_mpg$cyl)) +#' } +#' @note hint since 2.2.0 +setMethod("hint", + signature(x = "SparkDataFrame", name = "character"), + function(x, name, ...) { + parameters <- list(...) + stopifnot(all(sapply(parameters, is.character))) + jdf <- callJMethod(x@sdf, "hint", name, parameters) + dataFrame(jdf) + }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 945676c7f10b..f8ae5526bc72 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -572,6 +572,10 @@ setGeneric("group_by", function(x, ...) { standardGeneric("group_by") }) #' @export setGeneric("groupBy", function(x, ...) { standardGeneric("groupBy") }) +#' @rdname hint +#' @export +setGeneric("hint", function(x, name, ...) { standardGeneric("hint") }) + #' @rdname insertInto #' @export setGeneric("insertInto", function(x, tableName, ...) { standardGeneric("insertInto") }) @@ -1469,7 +1473,7 @@ setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") #' @rdname awaitTermination #' @export -setGeneric("awaitTermination", function(x, timeout) { standardGeneric("awaitTermination") }) +setGeneric("awaitTermination", function(x, timeout = NULL) { standardGeneric("awaitTermination") }) #' @rdname isActive #' @export diff --git a/R/pkg/R/streaming.R b/R/pkg/R/streaming.R index e353d2dd07c3..8390bd5e6de7 100644 --- a/R/pkg/R/streaming.R +++ b/R/pkg/R/streaming.R @@ -169,8 +169,10 @@ setMethod("isActive", #' immediately. #' #' @param x a StreamingQuery. -#' @param timeout time to wait in milliseconds -#' @return TRUE if query has terminated within the timeout period. +#' @param timeout time to wait in milliseconds, if omitted, wait indefinitely until \code{stopQuery} +#' is called or an error has occured. +#' @return TRUE if query has terminated within the timeout period; nothing if timeout is not +#' specified. #' @rdname awaitTermination #' @name awaitTermination #' @aliases awaitTermination,StreamingQuery-method @@ -182,8 +184,12 @@ setMethod("isActive", #' @note experimental setMethod("awaitTermination", signature(x = "StreamingQuery"), - function(x, timeout) { - handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + function(x, timeout = NULL) { + if (is.null(timeout)) { + invisible(handledCallJMethod(x@ssq, "awaitTermination")) + } else { + handledCallJMethod(x@ssq, "awaitTermination", as.integer(timeout)) + } }) #' stopQuery diff --git a/R/pkg/inst/tests/testthat/test_Serde.R b/R/pkg/inst/tests/testthat/test_Serde.R index b5f6f1b54fa8..518fb7bd9404 100644 --- a/R/pkg/inst/tests/testthat/test_Serde.R +++ b/R/pkg/inst/tests/testthat/test_Serde.R @@ -20,6 +20,8 @@ context("SerDe functionality") sparkSession <- sparkR.session(enableHiveSupport = FALSE) test_that("SerDe of primitive types", { + skip_on_cran() + x <- callJStatic("SparkRHandler", "echo", 1L) expect_equal(x, 1L) expect_equal(class(x), "integer") @@ -38,6 +40,8 @@ test_that("SerDe of primitive types", { }) test_that("SerDe of list of primitive types", { + skip_on_cran() + x <- list(1L, 2L, 3L) y <- callJStatic("SparkRHandler", "echo", x) expect_equal(x, y) @@ -65,6 +69,8 @@ test_that("SerDe of list of primitive types", { }) test_that("SerDe of list of lists", { + skip_on_cran() + x <- list(list(1L, 2L, 3L), list(1, 2, 3), list(TRUE, FALSE), list("a", "b", "c")) y <- callJStatic("SparkRHandler", "echo", x) diff --git a/R/pkg/inst/tests/testthat/test_Windows.R b/R/pkg/inst/tests/testthat/test_Windows.R index 1d777ddb286d..919b063bf069 100644 --- a/R/pkg/inst/tests/testthat/test_Windows.R +++ b/R/pkg/inst/tests/testthat/test_Windows.R @@ -17,6 +17,8 @@ context("Windows-specific tests") test_that("sparkJars tag in SparkContext", { + skip_on_cran() + if (.Platform$OS.type != "windows") { skip("This test is only for Windows, skipped") } diff --git a/R/pkg/inst/tests/testthat/test_binaryFile.R b/R/pkg/inst/tests/testthat/test_binaryFile.R index b5c279e3156e..63f54e1af02b 100644 --- a/R/pkg/inst/tests/testthat/test_binaryFile.R +++ b/R/pkg/inst/tests/testthat/test_binaryFile.R @@ -24,6 +24,8 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("saveAsObjectFile()/objectFile() following textFile() works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -38,6 +40,8 @@ test_that("saveAsObjectFile()/objectFile() following textFile() works", { }) test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) @@ -50,6 +54,8 @@ test_that("saveAsObjectFile()/objectFile() works on a parallelized list", { }) test_that("saveAsObjectFile()/objectFile() following RDD transformations works", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -74,6 +80,8 @@ test_that("saveAsObjectFile()/objectFile() following RDD transformations works", }) test_that("saveAsObjectFile()/objectFile() works with multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") diff --git a/R/pkg/inst/tests/testthat/test_binary_function.R b/R/pkg/inst/tests/testthat/test_binary_function.R index 59cb2e620440..25bb2b84266d 100644 --- a/R/pkg/inst/tests/testthat/test_binary_function.R +++ b/R/pkg/inst/tests/testthat/test_binary_function.R @@ -29,6 +29,8 @@ rdd <- parallelize(sc, nums, 2L) mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("union on two RDDs", { + skip_on_cran() + actual <- collectRDD(unionRDD(rdd, rdd)) expect_equal(actual, as.list(rep(nums, 2))) @@ -51,6 +53,8 @@ test_that("union on two RDDs", { }) test_that("cogroup on two RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) cogroup.rdd <- cogroup(rdd1, rdd2, numPartitions = 2L) @@ -69,6 +73,8 @@ test_that("cogroup on two RDDs", { }) test_that("zipPartitions() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 diff --git a/R/pkg/inst/tests/testthat/test_broadcast.R b/R/pkg/inst/tests/testthat/test_broadcast.R index 65f204d096f4..504ded4fc862 100644 --- a/R/pkg/inst/tests/testthat/test_broadcast.R +++ b/R/pkg/inst/tests/testthat/test_broadcast.R @@ -26,6 +26,8 @@ nums <- 1:2 rrdd <- parallelize(sc, nums, 2L) test_that("using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) randomMatBr <- broadcast(sc, randomMat) @@ -38,6 +40,8 @@ test_that("using broadcast variable", { }) test_that("without using broadcast variable", { + skip_on_cran() + randomMat <- matrix(nrow = 10, ncol = 10, data = rnorm(100)) useBroadcast <- function(x) { diff --git a/R/pkg/inst/tests/testthat/test_client.R b/R/pkg/inst/tests/testthat/test_client.R index 0cf25fe1dbf3..3d53bebab630 100644 --- a/R/pkg/inst/tests/testthat/test_client.R +++ b/R/pkg/inst/tests/testthat/test_client.R @@ -18,6 +18,8 @@ context("functions in client.R") test_that("adding spark-testing-base as a package works", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "holdenk:spark-testing-base:1.3.0_0.0.5") expect_equal(gsub("[[:space:]]", "", args), @@ -26,16 +28,22 @@ test_that("adding spark-testing-base as a package works", { }) test_that("no package specified doesn't add packages flag", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", "", "", "") expect_equal(gsub("[[:space:]]", "", args), "") }) test_that("multiple packages don't produce a warning", { + skip_on_cran() + expect_warning(generateSparkSubmitArgs("", "", "", "", c("A", "B")), NA) }) test_that("sparkJars sparkPackages as character vectors", { + skip_on_cran() + args <- generateSparkSubmitArgs("", "", c("one.jar", "two.jar", "three.jar"), "", c("com.databricks:spark-avro_2.10:2.0.1")) expect_match(args, "--jars one.jar,two.jar,three.jar") diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index c84711349111..9ec79ade5610 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -18,6 +18,8 @@ context("test functions in sparkR.R") test_that("Check masked functions", { + skip_on_cran() + # Check that we are not masking any new function from base, stats, testthat unexpectedly # NOTE: We should avoid adding entries to *namesOfMaskedCompletely* as masked functions make it # hard for users to use base R functions. Please check when in doubt. @@ -55,6 +57,8 @@ test_that("Check masked functions", { }) test_that("repeatedly starting and stopping SparkR", { + skip_on_cran() + for (i in 1:4) { sc <- suppressWarnings(sparkR.init()) rdd <- parallelize(sc, 1:20, 2L) @@ -73,6 +77,8 @@ test_that("repeatedly starting and stopping SparkSession", { }) test_that("rdd GC across sparkR.stop", { + skip_on_cran() + sc <- sparkR.sparkContext() # sc should get id 0 rdd1 <- parallelize(sc, 1:20, 2L) # rdd1 should get id 1 rdd2 <- parallelize(sc, 1:10, 2L) # rdd2 should get id 2 @@ -96,6 +102,8 @@ test_that("rdd GC across sparkR.stop", { }) test_that("job group functions can be called", { + skip_on_cran() + sc <- sparkR.sparkContext() setJobGroup("groupId", "job description", TRUE) cancelJobGroup("groupId") @@ -108,12 +116,16 @@ test_that("job group functions can be called", { }) test_that("utility function can be called", { + skip_on_cran() + sparkR.sparkContext() setLogLevel("ERROR") sparkR.session.stop() }) test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whitelist", { + skip_on_cran() + e <- new.env() e[["spark.driver.memory"]] <- "512m" ops <- getClientModeSparkSubmitOpts("sparkrmain", e) @@ -141,6 +153,8 @@ test_that("getClientModeSparkSubmitOpts() returns spark-submit args from whiteli }) test_that("sparkJars sparkPackages as comma-separated strings", { + skip_on_cran() + expect_warning(processSparkJars(" a, b ")) jars <- suppressWarnings(processSparkJars(" a, b ")) expect_equal(lapply(jars, basename), list("a", "b")) @@ -168,6 +182,8 @@ test_that("spark.lapply should perform simple transforms", { }) test_that("add and get file to be downloaded with Spark job on every node", { + skip_on_cran() + sparkR.sparkContext() # Test add file. path <- tempfile(pattern = "hello", fileext = ".txt") diff --git a/R/pkg/inst/tests/testthat/test_includePackage.R b/R/pkg/inst/tests/testthat/test_includePackage.R index 563ea298c2dd..f823ad8e9c98 100644 --- a/R/pkg/inst/tests/testthat/test_includePackage.R +++ b/R/pkg/inst/tests/testthat/test_includePackage.R @@ -26,6 +26,8 @@ nums <- 1:2 rdd <- parallelize(sc, nums, 2L) test_that("include inside function", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) @@ -42,6 +44,8 @@ test_that("include inside function", { }) test_that("use include package", { + skip_on_cran() + # Only run the test if plyr is installed. if ("plyr" %in% rownames(installed.packages())) { suppressPackageStartupMessages(library(plyr)) diff --git a/R/pkg/inst/tests/testthat/test_mllib_classification.R b/R/pkg/inst/tests/testthat/test_mllib_classification.R index 459254d271a5..cbc708718286 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_classification.R +++ b/R/pkg/inst/tests/testthat/test_mllib_classification.R @@ -284,22 +284,11 @@ test_that("spark.mlp", { c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # test initialWeights - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = + model <- spark.mlp(df, label ~ features, layers = c(4, 3), initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2, initialWeights = - c(0.0, 0.0, 0.0, 0.0, 0.0, 5.0, 5.0, 5.0, 5.0, 5.0, 9.0, 9.0, 9.0, 9.0, 9.0)) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "2.0", "1.0", "2.0", "2.0", "1.0", "0.0")) - - model <- spark.mlp(df, label ~ features, layers = c(4, 3), maxIter = 2) - mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) - expect_equal(head(mlpPredictions$prediction, 10), - c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "0.0", "2.0", "1.0", "0.0")) + c("1.0", "1.0", "1.0", "1.0", "0.0", "1.0", "2.0", "2.0", "1.0", "0.0")) # Test formula works well df <- suppressWarnings(createDataFrame(iris)) @@ -310,8 +299,6 @@ test_that("spark.mlp", { expect_equal(summary$numOfOutputs, 3) expect_equal(summary$layers, c(4, 3)) expect_equal(length(summary$weights), 15) - expect_equal(head(summary$weights, 5), list(-1.1957257, -5.2693685, 7.4489734, -6.3751413, - -10.2376130), tolerance = 1e-6) }) test_that("spark.naiveBayes", { diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R index 1661e987b730..478012e8828c 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R +++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R @@ -255,6 +255,8 @@ test_that("spark.lda with libsvm", { }) test_that("spark.lda with text input", { + skip_on_cran() + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, optimizer = "online", features = "value") @@ -297,6 +299,8 @@ test_that("spark.lda with text input", { }) test_that("spark.posterior and spark.perplexity", { + skip_on_cran() + text <- read.text(absoluteSparkPath("data/mllib/sample_lda_data.txt")) model <- spark.lda(text, features = "value", k = 3) diff --git a/R/pkg/inst/tests/testthat/test_mllib_regression.R b/R/pkg/inst/tests/testthat/test_mllib_regression.R index 3e9ad7719807..58924f952c6b 100644 --- a/R/pkg/inst/tests/testthat/test_mllib_regression.R +++ b/R/pkg/inst/tests/testthat/test_mllib_regression.R @@ -23,6 +23,8 @@ context("MLlib regression algorithms, except for tree-based algorithms") sparkSession <- sparkR.session(enableHiveSupport = FALSE) test_that("formula of spark.glm", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # directly calling the spark API # dot minus and intercept vs native glm @@ -195,6 +197,8 @@ test_that("spark.glm summary", { }) test_that("spark.glm save/load", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) m <- spark.glm(training, Sepal_Width ~ Sepal_Length + Species) s <- summary(m) @@ -222,6 +226,8 @@ test_that("spark.glm save/load", { }) test_that("formula of glm", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # dot minus and intercept vs native glm model <- glm(Sepal_Width ~ . - Species + 0, data = training) @@ -248,6 +254,8 @@ test_that("formula of glm", { }) test_that("glm and predict", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) # gaussian family model <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) @@ -292,6 +300,8 @@ test_that("glm and predict", { }) test_that("glm summary", { + skip_on_cran() + # gaussian family training <- suppressWarnings(createDataFrame(iris)) stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) @@ -341,6 +351,8 @@ test_that("glm summary", { }) test_that("glm save/load", { + skip_on_cran() + training <- suppressWarnings(createDataFrame(iris)) m <- glm(Sepal_Width ~ Sepal_Length + Species, data = training) s <- summary(m) diff --git a/R/pkg/inst/tests/testthat/test_parallelize_collect.R b/R/pkg/inst/tests/testthat/test_parallelize_collect.R index 55972e1ba469..1f7f387de08c 100644 --- a/R/pkg/inst/tests/testthat/test_parallelize_collect.R +++ b/R/pkg/inst/tests/testthat/test_parallelize_collect.R @@ -39,6 +39,8 @@ jsc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", # Tests test_that("parallelize() on simple vectors and lists returns an RDD", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 1) numVectorRDD2 <- parallelize(jsc, numVector, 10) numListRDD <- parallelize(jsc, numList, 1) @@ -66,6 +68,8 @@ test_that("parallelize() on simple vectors and lists returns an RDD", { }) test_that("collect(), following a parallelize(), gives back the original collections", { + skip_on_cran() + numVectorRDD <- parallelize(jsc, numVector, 10) expect_equal(collectRDD(numVectorRDD), as.list(numVector)) @@ -86,6 +90,8 @@ test_that("collect(), following a parallelize(), gives back the original collect }) test_that("regression: collect() following a parallelize() does not drop elements", { + skip_on_cran() + # 10 %/% 6 = 1, ceiling(10 / 6) = 2 collLen <- 10 numPart <- 6 @@ -95,6 +101,8 @@ test_that("regression: collect() following a parallelize() does not drop element }) test_that("parallelize() and collect() work for lists of pairs (pairwise data)", { + skip_on_cran() + # use the pairwise logical to indicate pairwise data numPairsRDDD1 <- parallelize(jsc, numPairs, 1) numPairsRDDD2 <- parallelize(jsc, numPairs, 2) diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index b72c801dd958..a3b1631e1d11 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -29,22 +29,30 @@ intPairs <- list(list(1L, -1), list(2L, 100), list(2L, 1), list(1L, 200)) intRdd <- parallelize(sc, intPairs, 2L) test_that("get number of partitions in RDD", { + skip_on_cran() + expect_equal(getNumPartitionsRDD(rdd), 2) expect_equal(getNumPartitionsRDD(intRdd), 2) }) test_that("first on RDD", { + skip_on_cran() + expect_equal(firstRDD(rdd), 1) newrdd <- lapply(rdd, function(x) x + 1) expect_equal(firstRDD(newrdd), 2) }) test_that("count and length on RDD", { - expect_equal(countRDD(rdd), 10) - expect_equal(lengthRDD(rdd), 10) + skip_on_cran() + + expect_equal(countRDD(rdd), 10) + expect_equal(lengthRDD(rdd), 10) }) test_that("count by values and keys", { + skip_on_cran() + mods <- lapply(rdd, function(x) { x %% 3 }) actual <- countByValue(mods) expected <- list(list(0, 3L), list(1, 4L), list(2, 3L)) @@ -56,30 +64,40 @@ test_that("count by values and keys", { }) test_that("lapply on RDD", { + skip_on_cran() + multiples <- lapply(rdd, function(x) { 2 * x }) actual <- collectRDD(multiples) expect_equal(actual, as.list(nums * 2)) }) test_that("lapplyPartition on RDD", { + skip_on_cran() + sums <- lapplyPartition(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("mapPartitions on RDD", { + skip_on_cran() + sums <- mapPartitions(rdd, function(part) { sum(unlist(part)) }) actual <- collectRDD(sums) expect_equal(actual, list(15, 40)) }) test_that("flatMap() on RDDs", { + skip_on_cran() + flat <- flatMap(intRdd, function(x) { list(x, x) }) actual <- collectRDD(flat) expect_equal(actual, rep(intPairs, each = 2)) }) test_that("filterRDD on RDD", { + skip_on_cran() + filtered.rdd <- filterRDD(rdd, function(x) { x %% 2 == 0 }) actual <- collectRDD(filtered.rdd) expect_equal(actual, list(2, 4, 6, 8, 10)) @@ -95,6 +113,8 @@ test_that("filterRDD on RDD", { }) test_that("lookup on RDD", { + skip_on_cran() + vals <- lookup(intRdd, 1L) expect_equal(vals, list(-1, 200)) @@ -103,6 +123,8 @@ test_that("lookup on RDD", { }) test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { + skip_on_cran() + rdd2 <- rdd for (i in 1:12) rdd2 <- lapplyPartitionsWithIndex( @@ -117,6 +139,8 @@ test_that("several transformations on RDD (a benchmark on PipelinedRDD)", { }) test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkpoint()", { + skip_on_cran() + # RDD rdd2 <- rdd # PipelinedRDD @@ -158,6 +182,8 @@ test_that("PipelinedRDD support actions: cache(), persist(), unpersist(), checkp }) test_that("reduce on RDD", { + skip_on_cran() + sum <- reduce(rdd, "+") expect_equal(sum, 55) @@ -167,6 +193,8 @@ test_that("reduce on RDD", { }) test_that("lapply with dependency", { + skip_on_cran() + fa <- 5 multiples <- lapply(rdd, function(x) { fa * x }) actual <- collectRDD(multiples) @@ -175,6 +203,8 @@ test_that("lapply with dependency", { }) test_that("lapplyPartitionsWithIndex on RDDs", { + skip_on_cran() + func <- function(partIndex, part) { list(partIndex, Reduce("+", part)) } actual <- collectRDD(lapplyPartitionsWithIndex(rdd, func), flatten = FALSE) expect_equal(actual, list(list(0, 15), list(1, 40))) @@ -191,10 +221,14 @@ test_that("lapplyPartitionsWithIndex on RDDs", { }) test_that("sampleRDD() on RDDs", { + skip_on_cran() + expect_equal(unlist(collectRDD(sampleRDD(rdd, FALSE, 1.0, 2014L))), nums) }) test_that("takeSample() on RDDs", { + skip_on_cran() + # ported from RDDSuite.scala, modified seeds data <- parallelize(sc, 1:100, 2L) for (seed in 4:5) { @@ -237,6 +271,8 @@ test_that("takeSample() on RDDs", { }) test_that("mapValues() on pairwise RDDs", { + skip_on_cran() + multiples <- mapValues(intRdd, function(x) { x * 2 }) actual <- collectRDD(multiples) expected <- lapply(intPairs, function(x) { @@ -246,6 +282,8 @@ test_that("mapValues() on pairwise RDDs", { }) test_that("flatMapValues() on pairwise RDDs", { + skip_on_cran() + l <- parallelize(sc, list(list(1, c(1, 2)), list(2, c(3, 4)))) actual <- collectRDD(flatMapValues(l, function(x) { x })) expect_equal(actual, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -258,6 +296,8 @@ test_that("flatMapValues() on pairwise RDDs", { }) test_that("reduceByKeyLocally() on PairwiseRDDs", { + skip_on_cran() + pairs <- parallelize(sc, list(list(1, 2), list(1.1, 3), list(1, 4)), 2L) actual <- reduceByKeyLocally(pairs, "+") expect_equal(sortKeyValueList(actual), @@ -271,6 +311,8 @@ test_that("reduceByKeyLocally() on PairwiseRDDs", { }) test_that("distinct() on RDDs", { + skip_on_cran() + nums.rep2 <- rep(1:10, 2) rdd.rep2 <- parallelize(sc, nums.rep2, 2L) uniques <- distinctRDD(rdd.rep2) @@ -279,21 +321,29 @@ test_that("distinct() on RDDs", { }) test_that("maximum() on RDDs", { + skip_on_cran() + max <- maximum(rdd) expect_equal(max, 10) }) test_that("minimum() on RDDs", { + skip_on_cran() + min <- minimum(rdd) expect_equal(min, 1) }) test_that("sumRDD() on RDDs", { + skip_on_cran() + sum <- sumRDD(rdd) expect_equal(sum, 55) }) test_that("keyBy on RDDs", { + skip_on_cran() + func <- function(x) { x * x } keys <- keyBy(rdd, func) actual <- collectRDD(keys) @@ -301,6 +351,8 @@ test_that("keyBy on RDDs", { }) test_that("repartition/coalesce on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:20, 4L) # each partition contains 5 elements # repartition @@ -322,6 +374,8 @@ test_that("repartition/coalesce on RDDs", { }) test_that("sortBy() on RDDs", { + skip_on_cran() + sortedRdd <- sortBy(rdd, function(x) { x * x }, ascending = FALSE) actual <- collectRDD(sortedRdd) expect_equal(actual, as.list(sort(nums, decreasing = TRUE))) @@ -333,6 +387,8 @@ test_that("sortBy() on RDDs", { }) test_that("takeOrdered() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- takeOrdered(rdd, 6L) @@ -345,6 +401,8 @@ test_that("takeOrdered() on RDDs", { }) test_that("top() on RDDs", { + skip_on_cran() + l <- list(10, 1, 2, 9, 3, 4, 5, 6, 7) rdd <- parallelize(sc, l) actual <- top(rdd, 6L) @@ -357,6 +415,8 @@ test_that("top() on RDDs", { }) test_that("fold() on RDDs", { + skip_on_cran() + actual <- fold(rdd, 0, "+") expect_equal(actual, Reduce("+", nums, 0)) @@ -366,6 +426,8 @@ test_that("fold() on RDDs", { }) test_that("aggregateRDD() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list(1, 2, 3, 4)) zeroValue <- list(0, 0) seqOp <- function(x, y) { list(x[[1]] + y, x[[2]] + 1) } @@ -379,6 +441,8 @@ test_that("aggregateRDD() on RDDs", { }) test_that("zipWithUniqueId() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithUniqueId(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 4), @@ -393,6 +457,8 @@ test_that("zipWithUniqueId() on RDDs", { }) test_that("zipWithIndex() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) actual <- collectRDD(zipWithIndex(rdd)) expected <- list(list("a", 0), list("b", 1), list("c", 2), @@ -407,24 +473,32 @@ test_that("zipWithIndex() on RDDs", { }) test_that("glom() on RDD", { + skip_on_cran() + rdd <- parallelize(sc, as.list(1:4), 2L) actual <- collectRDD(glom(rdd)) expect_equal(actual, list(list(1, 2), list(3, 4))) }) test_that("keys() on RDDs", { + skip_on_cran() + keys <- keys(intRdd) actual <- collectRDD(keys) expect_equal(actual, lapply(intPairs, function(x) { x[[1]] })) }) test_that("values() on RDDs", { + skip_on_cran() + values <- values(intRdd) actual <- collectRDD(values) expect_equal(actual, lapply(intPairs, function(x) { x[[2]] })) }) test_that("pipeRDD() on RDDs", { + skip_on_cran() + actual <- collectRDD(pipeRDD(rdd, "more")) expected <- as.list(as.character(1:10)) expect_equal(actual, expected) @@ -442,6 +516,8 @@ test_that("pipeRDD() on RDDs", { }) test_that("zipRDD() on RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, 0:4, 2) rdd2 <- parallelize(sc, 1000:1004, 2) actual <- collectRDD(zipRDD(rdd1, rdd2)) @@ -471,6 +547,8 @@ test_that("zipRDD() on RDDs", { }) test_that("cartesian() on RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:3) actual <- collectRDD(cartesian(rdd, rdd)) expect_equal(sortKeyValueList(actual), @@ -514,6 +592,8 @@ test_that("cartesian() on RDDs", { }) test_that("subtract() on RDDs", { + skip_on_cran() + l <- list(1, 1, 2, 2, 3, 4) rdd1 <- parallelize(sc, l) @@ -541,6 +621,8 @@ test_that("subtract() on RDDs", { }) test_that("subtractByKey() on pairwise RDDs", { + skip_on_cran() + l <- list(list("a", 1), list("b", 4), list("b", 5), list("a", 2)) rdd1 <- parallelize(sc, l) @@ -570,6 +652,8 @@ test_that("subtractByKey() on pairwise RDDs", { }) test_that("intersection() on RDDs", { + skip_on_cran() + # intersection with self actual <- collectRDD(intersection(rdd, rdd)) expect_equal(sort(as.integer(actual)), nums) @@ -586,6 +670,8 @@ test_that("intersection() on RDDs", { }) test_that("join() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(joinRDD(rdd1, rdd2, 2L)) @@ -610,6 +696,8 @@ test_that("join() on pairwise RDDs", { }) test_that("leftOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) actual <- collectRDD(leftOuterJoin(rdd1, rdd2, 2L)) @@ -640,6 +728,8 @@ test_that("leftOuterJoin() on pairwise RDDs", { }) test_that("rightOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(rightOuterJoin(rdd1, rdd2, 2L)) @@ -667,6 +757,8 @@ test_that("rightOuterJoin() on pairwise RDDs", { }) test_that("fullOuterJoin() on pairwise RDDs", { + skip_on_cran() + rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) rdd2 <- parallelize(sc, list(list(1, 1), list(2, 4))) actual <- collectRDD(fullOuterJoin(rdd1, rdd2, 2L)) @@ -698,6 +790,8 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { + skip_on_cran() + numPairsRdd <- map(rdd, function(x) { list (x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) @@ -747,6 +841,8 @@ test_that("sortByKey() on pairwise RDDs", { }) test_that("collectAsMap() on a pairwise RDD", { + skip_on_cran() + rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) vals <- collectAsMap(rdd) expect_equal(vals, list(`1` = 2, `3` = 4)) @@ -765,11 +861,15 @@ test_that("collectAsMap() on a pairwise RDD", { }) test_that("show()", { + skip_on_cran() + rdd <- parallelize(sc, list(1:10)) expect_output(showRDD(rdd), "ParallelCollectionRDD\\[\\d+\\] at parallelize at RRDD\\.scala:\\d+") }) test_that("sampleByKey() on pairwise RDDs", { + skip_on_cran() + rdd <- parallelize(sc, 1:2000) pairsRDD <- lapply(rdd, function(x) { if (x %% 2 == 0) list("a", x) else list("b", x) }) fractions <- list(a = 0.2, b = 0.1) @@ -794,6 +894,8 @@ test_that("sampleByKey() on pairwise RDDs", { }) test_that("Test correct concurrency of RRDD.compute()", { + skip_on_cran() + rdd <- parallelize(sc, 1:1000, 100) jrdd <- getJRDD(lapply(rdd, function(x) { x }), "row") zrdd <- callJMethod(jrdd, "zip", jrdd) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index d38efab0fd1d..cedf4f100c6c 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -37,6 +37,8 @@ strList <- list("Dexter Morgan: Blood. Sometimes it sets my teeth on edge and ", strListRDD <- parallelize(sc, strList, 4) test_that("groupByKey for integers", { + skip_on_cran() + grouped <- groupByKey(intRdd, 2L) actual <- collectRDD(grouped) @@ -46,6 +48,8 @@ test_that("groupByKey for integers", { }) test_that("groupByKey for doubles", { + skip_on_cran() + grouped <- groupByKey(doubleRdd, 2L) actual <- collectRDD(grouped) @@ -55,6 +59,8 @@ test_that("groupByKey for doubles", { }) test_that("reduceByKey for ints", { + skip_on_cran() + reduced <- reduceByKey(intRdd, "+", 2L) actual <- collectRDD(reduced) @@ -64,6 +70,8 @@ test_that("reduceByKey for ints", { }) test_that("reduceByKey for doubles", { + skip_on_cran() + reduced <- reduceByKey(doubleRdd, "+", 2L) actual <- collectRDD(reduced) @@ -72,6 +80,8 @@ test_that("reduceByKey for doubles", { }) test_that("combineByKey for ints", { + skip_on_cran() + reduced <- combineByKey(intRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -81,6 +91,8 @@ test_that("combineByKey for ints", { }) test_that("combineByKey for doubles", { + skip_on_cran() + reduced <- combineByKey(doubleRdd, function(x) { x }, "+", "+", 2L) actual <- collectRDD(reduced) @@ -89,6 +101,8 @@ test_that("combineByKey for doubles", { }) test_that("combineByKey for characters", { + skip_on_cran() + stringKeyRDD <- parallelize(sc, list(list("max", 1L), list("min", 2L), list("other", 3L), list("max", 4L)), 2L) @@ -101,6 +115,8 @@ test_that("combineByKey for characters", { }) test_that("aggregateByKey", { + skip_on_cran() + # test aggregateByKey for int keys rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -129,6 +145,8 @@ test_that("aggregateByKey", { }) test_that("foldByKey", { + skip_on_cran() + # test foldByKey for int keys folded <- foldByKey(intRdd, 0, "+", 2L) @@ -172,6 +190,8 @@ test_that("foldByKey", { }) test_that("partitionBy() partitions data correctly", { + skip_on_cran() + # Partition by magnitude partitionByMagnitude <- function(key) { if (key >= 3) 1 else 0 } @@ -187,6 +207,8 @@ test_that("partitionBy() partitions data correctly", { }) test_that("partitionBy works with dependencies", { + skip_on_cran() + kOne <- 1 partitionByParity <- function(key) { if (key %% 2 == kOne) 7 else 4 } @@ -205,6 +227,8 @@ test_that("partitionBy works with dependencies", { }) test_that("test partitionBy with string keys", { + skip_on_cran() + words <- flatMap(strListRDD, function(line) { strsplit(line, " ")[[1]] }) wordCount <- lapply(words, function(word) { list(word, 1L) }) diff --git a/R/pkg/inst/tests/testthat/test_sparkR.R b/R/pkg/inst/tests/testthat/test_sparkR.R index f73fc6baecce..a40981c188f7 100644 --- a/R/pkg/inst/tests/testthat/test_sparkR.R +++ b/R/pkg/inst/tests/testthat/test_sparkR.R @@ -18,6 +18,8 @@ context("functions in sparkR.R") test_that("sparkCheckInstall", { + skip_on_cran() + # "local, yarn-client, mesos-client" mode, SPARK_HOME was set correctly, # and the SparkR job was submitted by "spark-submit" sparkHome <- paste0(tempdir(), "/", "sparkHome") diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 6a6c9a809ab1..de36d5cc5b08 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -97,15 +97,21 @@ mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesMapType, mapTypeJsonPath) test_that("calling sparkRSQL.init returns existing SQL context", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) }) test_that("calling sparkRSQL.init returns existing SparkSession", { + skip_on_cran() + expect_equal(suppressWarnings(sparkRSQL.init(sc)), sparkSession) }) test_that("calling sparkR.session returns existing SparkSession", { + skip_on_cran() + expect_equal(sparkR.session(), sparkSession) }) @@ -194,6 +200,8 @@ test_that("structField type strings", { }) test_that("create DataFrame from RDD", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- createDataFrame(rdd, list("a", "b")) dfAsDF <- as.DataFrame(rdd, list("a", "b")) @@ -291,6 +299,8 @@ test_that("create DataFrame from RDD", { }) test_that("createDataFrame uses files for large objects", { + skip_on_cran() + # To simulate a large file scenario, we set spark.r.maxAllocationLimit to a smaller value conf <- callJMethod(sparkSession, "conf") callJMethod(conf, "set", "spark.r.maxAllocationLimit", "100") @@ -351,6 +361,8 @@ test_that("read/write csv as DataFrame", { }) test_that("Support other types for options", { + skip_on_cran() + csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv") mockLinesCsv <- c("year,make,model,comment,blank", "\"2012\",\"Tesla\",\"S\",\"No comment\",", @@ -405,6 +417,8 @@ test_that("convert NAs to null type in DataFrames", { }) test_that("toDF", { + skip_on_cran() + rdd <- lapply(parallelize(sc, 1:10), function(x) { list(x, as.character(x)) }) df <- toDF(rdd, list("a", "b")) expect_is(df, "SparkDataFrame") @@ -516,6 +530,8 @@ test_that("create DataFrame with complex types", { }) test_that("create DataFrame from a data.frame with complex types", { + skip_on_cran() + ldf <- data.frame(row.names = 1:2) ldf$a_list <- list(list(1, 2), list(3, 4)) ldf$an_envir <- c(as.environment(list(a = 1, b = 2)), as.environment(list(c = 3))) @@ -528,6 +544,8 @@ test_that("create DataFrame from a data.frame with complex types", { }) test_that("Collect DataFrame with complex types", { + skip_on_cran() + # ArrayType df <- read.json(complexTypeJsonPath) ldf <- collect(df) @@ -615,6 +633,8 @@ test_that("read/write json files", { }) test_that("read/write json files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "json") jsonPath <- tempfile(pattern = "jsonPath", fileext = ".json") @@ -628,6 +648,8 @@ test_that("read/write json files - compression option", { }) test_that("jsonRDD() on a RDD with json string", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) rdd <- parallelize(sc, mockLines) expect_equal(countRDD(rdd), 3) @@ -684,6 +706,8 @@ test_that( }) test_that("test cache, uncache and clearCache", { + skip_on_cran() + df <- read.json(jsonPath) createOrReplaceTempView(df, "table1") cacheTable("table1") @@ -737,6 +761,8 @@ test_that("tableToDF() returns a new DataFrame", { }) test_that("toRDD() returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- toRDD(df) expect_is(testRDD, "RDD") @@ -744,6 +770,8 @@ test_that("toRDD() returns an RRDD", { }) test_that("union on two RDDs created from DataFrames returns an RRDD", { + skip_on_cran() + df <- read.json(jsonPath) RDD1 <- toRDD(df) RDD2 <- toRDD(df) @@ -754,6 +782,8 @@ test_that("union on two RDDs created from DataFrames returns an RRDD", { }) test_that("union on mixed serialization types correctly returns a byte RRDD", { + skip_on_cran() + # Byte RDD nums <- 1:10 rdd <- parallelize(sc, nums, 2L) @@ -783,6 +813,8 @@ test_that("union on mixed serialization types correctly returns a byte RRDD", { }) test_that("objectFile() works with row serialization", { + skip_on_cran() + objectPath <- tempfile(pattern = "spark-test", fileext = ".tmp") df <- read.json(jsonPath) dfRDD <- toRDD(df) @@ -795,6 +827,8 @@ test_that("objectFile() works with row serialization", { }) test_that("lapply() on a DataFrame returns an RDD with the correct columns", { + skip_on_cran() + df <- read.json(jsonPath) testRDD <- lapply(df, function(row) { row$newCol <- row$age + 5 @@ -863,6 +897,8 @@ test_that("collect() support Unicode characters", { }) test_that("multiple pipeline transformations result in an RDD with the correct values", { + skip_on_cran() + df <- read.json(jsonPath) first <- lapply(df, function(row) { row$age <- row$age + 5 @@ -1890,6 +1926,18 @@ test_that("join(), crossJoin() and merge() on a DataFrame", { unlink(jsonPath2) unlink(jsonPath3) + + # Join with broadcast hint + df1 <- sql("SELECT * FROM range(10e10)") + df2 <- sql("SELECT * FROM range(10e10)") + + execution_plan <- capture.output(explain(join(df1, df2, df1$id == df2$id))) + expect_false(any(grepl("BroadcastHashJoin", execution_plan))) + + execution_plan_hint <- capture.output( + explain(join(df1, hint(df2, "broadcast"), df1$id == df2$id)) + ) + expect_true(any(grepl("BroadcastHashJoin", execution_plan_hint))) }) test_that("toJSON() on DataFrame", { @@ -2049,6 +2097,8 @@ test_that("mutate(), transform(), rename() and names()", { }) test_that("read/write ORC files", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2070,6 +2120,8 @@ test_that("read/write ORC files", { }) test_that("read/write ORC files - compression option", { + skip_on_cran() + setHiveContext(sc) df <- read.df(jsonPath, "json") @@ -2116,6 +2168,8 @@ test_that("read/write Parquet files", { }) test_that("read/write Parquet files - compression option/mode", { + skip_on_cran() + df <- read.df(jsonPath, "json") tempPath <- tempfile(pattern = "tempPath", fileext = ".parquet") @@ -2133,6 +2187,8 @@ test_that("read/write Parquet files - compression option/mode", { }) test_that("read/write text files", { + skip_on_cran() + # Test write.df and read.df df <- read.df(jsonPath, "text") expect_is(df, "SparkDataFrame") @@ -2154,6 +2210,8 @@ test_that("read/write text files", { }) test_that("read/write text files - compression option", { + skip_on_cran() + df <- read.df(jsonPath, "text") textPath <- tempfile(pattern = "textPath", fileext = ".txt") @@ -2387,6 +2445,8 @@ test_that("approxQuantile() on a DataFrame", { }) test_that("SQL error message is returned from JVM", { + skip_on_cran() + retError <- tryCatch(sql("select * from blah"), error = function(e) e) expect_equal(grepl("Table or view not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) @@ -2395,6 +2455,8 @@ test_that("SQL error message is returned from JVM", { irisDF <- suppressWarnings(createDataFrame(iris)) test_that("Method as.data.frame as a synonym for collect()", { + skip_on_cran() + expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) @@ -2812,6 +2874,8 @@ test_that("Window functions on a DataFrame", { }) test_that("createDataFrame sqlContext parameter backward compatibility", { + skip_on_cran() + sqlContext <- suppressWarnings(sparkRSQL.init(sc)) a <- 1:3 b <- c("a", "b", "c") @@ -2891,6 +2955,8 @@ test_that("Setting and getting config on SparkSession, sparkR.conf(), sparkR.uiW }) test_that("enableHiveSupport on SparkSession", { + skip_on_cran() + setHiveContext(sc) unsetHiveContext() # if we are still here, it must be built with hive @@ -2906,6 +2972,8 @@ test_that("Spark version from SparkSession", { }) test_that("Call DataFrameWriter.save() API in Java without path and check argument types", { + skip_on_cran() + df <- read.df(jsonPath, "json") # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in write.df API and then it calls @@ -2932,6 +3000,8 @@ test_that("Call DataFrameWriter.save() API in Java without path and check argume }) test_that("Call DataFrameWriter.load() API in Java without path and check argument types", { + skip_on_cran() + # This tests if the exception is thrown from JVM not from SparkR side. # It makes sure that we can omit path argument in read.df API and then it calls # DataFrameWriter.load() without path. @@ -3056,6 +3126,8 @@ compare_list <- function(list1, list2) { # This should always be the **very last test** in this test file. test_that("No extra files are created in SPARK_HOME by starting session and making calls", { + skip_on_cran() + # Check that it is not creating any extra file. # Does not check the tempdir which would be cleaned up after. filesAfter <- list.files(path = sparkRDir, all.files = TRUE) diff --git a/R/pkg/inst/tests/testthat/test_streaming.R b/R/pkg/inst/tests/testthat/test_streaming.R index 03b1bd3dc1f4..91df7ac6f984 100644 --- a/R/pkg/inst/tests/testthat/test_streaming.R +++ b/R/pkg/inst/tests/testthat/test_streaming.R @@ -47,29 +47,37 @@ schema <- structType(structField("name", "string"), structField("count", "double")) test_that("read.stream, write.stream, awaitTermination, stopQuery", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) q <- write.stream(counts, "memory", queryName = "people", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 3) writeLines(mockLinesNa, jsonPathNa) awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people"))[[1]], 6) stopQuery(q) expect_true(awaitTermination(q, 1)) + expect_error(awaitTermination(q), NA) }) test_that("print from explain, lastProgress, status, isActive", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema) expect_true(isStreaming(df)) counts <- count(group_by(df, "name")) q <- write.stream(counts, "memory", queryName = "people2", outputMode = "complete") awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") expect_equal(capture.output(explain(q))[[1]], "== Physical Plan ==") expect_true(any(grepl("\"description\" : \"MemorySink\"", capture.output(lastProgress(q))))) @@ -82,6 +90,8 @@ test_that("print from explain, lastProgress, status, isActive", { }) test_that("Stream other format", { + skip_on_cran() + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") df <- read.df(jsonPath, "json", schema) write.df(df, parquetPath, "parquet", "overwrite") @@ -92,6 +102,7 @@ test_that("Stream other format", { q <- write.stream(counts, "memory", queryName = "people3", outputMode = "complete") expect_false(awaitTermination(q, 5 * 1000)) + callJMethod(q@ssq, "processAllAvailable") expect_equal(head(sql("SELECT count(*) FROM people3"))[[1]], 3) expect_equal(queryName(q), "people3") @@ -107,6 +118,8 @@ test_that("Stream other format", { }) test_that("Non-streaming DataFrame", { + skip_on_cran() + c <- as.DataFrame(cars) expect_false(isStreaming(c)) @@ -116,6 +129,8 @@ test_that("Non-streaming DataFrame", { }) test_that("Unsupported operation", { + skip_on_cran() + # memory sink without aggregation df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) expect_error(write.stream(df, "memory", queryName = "people", outputMode = "complete"), @@ -124,6 +139,8 @@ test_that("Unsupported operation", { }) test_that("Terminated by error", { + skip_on_cran() + df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = -1) counts <- count(group_by(df, "name")) # This would not fail before returning with a StreamingQuery, @@ -131,7 +148,7 @@ test_that("Terminated by error", { expect_error(q <- write.stream(counts, "memory", queryName = "people4", outputMode = "complete"), NA) - expect_error(awaitTermination(q, 1), + expect_error(awaitTermination(q, 5 * 1000), paste0(".*(awaitTermination : streaming query error - Invalid value '-1' for option", " 'maxFilesPerTrigger', must be a positive integer).*")) diff --git a/R/pkg/inst/tests/testthat/test_take.R b/R/pkg/inst/tests/testthat/test_take.R index aaa532856c3d..e2130eaac78d 100644 --- a/R/pkg/inst/tests/testthat/test_take.R +++ b/R/pkg/inst/tests/testthat/test_take.R @@ -34,6 +34,8 @@ sparkSession <- sparkR.session(enableHiveSupport = FALSE) sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) test_that("take() gives back the original elements in correct count and order", { + skip_on_cran() + numVectorRDD <- parallelize(sc, numVector, 10) # case: number of elements to take is less than the size of the first partition expect_equal(takeRDD(numVectorRDD, 1), as.list(head(numVector, n = 1))) diff --git a/R/pkg/inst/tests/testthat/test_textFile.R b/R/pkg/inst/tests/testthat/test_textFile.R index 3b466066e939..28b7e8e3183f 100644 --- a/R/pkg/inst/tests/testthat/test_textFile.R +++ b/R/pkg/inst/tests/testthat/test_textFile.R @@ -24,6 +24,8 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", mockFile <- c("Spark is pretty.", "Spark is awesome.") test_that("textFile() on a local file returns an RDD", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -36,6 +38,8 @@ test_that("textFile() on a local file returns an RDD", { }) test_that("textFile() followed by a collect() returns the same content", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -46,6 +50,8 @@ test_that("textFile() followed by a collect() returns the same content", { }) test_that("textFile() word count works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -64,6 +70,8 @@ test_that("textFile() word count works as expected", { }) test_that("several transformations on RDD created by textFile()", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) @@ -78,6 +86,8 @@ test_that("several transformations on RDD created by textFile()", { }) test_that("textFile() followed by a saveAsTextFile() returns the same content", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -92,6 +102,8 @@ test_that("textFile() followed by a saveAsTextFile() returns the same content", }) test_that("saveAsTextFile() on a parallelized list works as expected", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") l <- list(1, 2, 3) rdd <- parallelize(sc, l, 1L) @@ -103,6 +115,8 @@ test_that("saveAsTextFile() on a parallelized list works as expected", { }) test_that("textFile() and saveAsTextFile() word count works as expected", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName1) @@ -128,6 +142,8 @@ test_that("textFile() and saveAsTextFile() word count works as expected", { }) test_that("textFile() on multiple paths", { + skip_on_cran() + fileName1 <- tempfile(pattern = "spark-test", fileext = ".tmp") fileName2 <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines("Spark is pretty.", fileName1) @@ -141,6 +157,8 @@ test_that("textFile() on multiple paths", { }) test_that("Pipelined operations on RDDs created using textFile", { + skip_on_cran() + fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") writeLines(mockFile, fileName) diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 6d006eccf665..bda479214e9c 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -23,6 +23,7 @@ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", test_that("convertJListToRList() gives back (deserializes) the original JLists of strings and integers", { + skip_on_cran() # It's hard to manually create a Java List using rJava, since it does not # support generics well. Instead, we rely on collectRDD() returning a # JList. @@ -40,6 +41,7 @@ test_that("convertJListToRList() gives back (deserializes) the original JLists }) test_that("serializeToBytes on RDD", { + skip_on_cran() # File content mockFile <- c("Spark is pretty.", "Spark is awesome.") fileName <- tempfile(pattern = "spark-test", fileext = ".tmp") @@ -167,6 +169,7 @@ test_that("convertToJSaveMode", { }) test_that("captureJVMException", { + skip_on_cran() method <- "getSQLDataType" expect_error(tryCatch(callJStatic("org.apache.spark.sql.api.r.SQLUtils", method, "unknown"), @@ -177,6 +180,8 @@ test_that("captureJVMException", { }) test_that("hashCode", { + skip_on_cran() + expect_error(hashCode("bc53d3605e8a5b7de1e8e271c2317645"), NA) }) diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index a6ff650c33fe..b933c59a8456 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -182,7 +182,7 @@ head(df) ``` ### Data Sources -SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL programming guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. +SparkR supports operating on a variety of data sources through the `SparkDataFrame` interface. You can check the Spark SQL Programming Guide for more [specific options](https://spark.apache.org/docs/latest/sql-programming-guide.html#manually-specifying-options) that are available for the built-in data sources. The general method for creating `SparkDataFrame` from data sources is `read.df`. This method takes in the path for the file to load and the type of data source, and the currently active Spark Session will be used automatically. SparkR supports reading CSV, JSON and Parquet files natively and through Spark Packages you can find data source connectors for popular file formats like Avro. These packages can be added with `sparkPackages` parameter when initializing SparkSession using `sparkR.session`. @@ -232,7 +232,7 @@ write.df(people, path = "people.parquet", source = "parquet", mode = "overwrite" ``` ### Hive Tables -You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL programming guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). +You can also create SparkDataFrames from Hive tables. To do this we will need to create a SparkSession with Hive support which can access tables in the Hive MetaStore. Note that Spark should have been built with Hive support and more details can be found in the [SQL Programming Guide](https://spark.apache.org/docs/latest/sql-programming-guide.html). In SparkR, by default it will attempt to create a SparkSession with Hive support enabled (`enableHiveSupport = TRUE`). ```{r, eval=FALSE} sql("CREATE TABLE IF NOT EXISTS src (key INT, value STRING)") @@ -505,6 +505,10 @@ SparkR supports the following machine learning models and algorithms. * Alternating Least Squares (ALS) +#### Frequent Pattern Mining + +* FP-growth + #### Statistics * Kolmogorov-Smirnov Test @@ -653,6 +657,7 @@ head(select(naiveBayesPrediction, "Class", "Sex", "Age", "Survived", "prediction Survival analysis studies the expected duration of time until an event happens, and often the relationship with risk factors or treatment taken on the subject. In contrast to standard regression analysis, survival modeling has to deal with special characteristics in the data including non-negative survival time and censoring. Accelerated Failure Time (AFT) model is a parametric survival model for censored data that assumes the effect of a covariate is to accelerate or decelerate the life course of an event by some constant. For more information, refer to the Wikipedia page [AFT Model](https://en.wikipedia.org/wiki/Accelerated_failure_time_model) and the references there. Different from a [Proportional Hazards Model](https://en.wikipedia.org/wiki/Proportional_hazards_model) designed for the same purpose, the AFT model is easier to parallelize because each instance contributes to the objective function independently. + ```{r, warning=FALSE} library(survival) ovarianDF <- createDataFrame(ovarian) @@ -707,7 +712,7 @@ summary(tweedieGLM1) ``` We can try other distributions in the tweedie family, for example, a compound Poisson distribution with a log link: ```{r} -tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", +tweedieGLM2 <- spark.glm(carsDF, mpg ~ wt + hp, family = "tweedie", var.power = 1.2, link.power = 0.0) summary(tweedieGLM2) ``` @@ -883,7 +888,7 @@ perplexity There are multiple options that can be configured in `spark.als`, including `rank`, `reg`, `nonnegative`. For a complete list, refer to the help file. -```{r} +```{r, eval=FALSE} ratings <- list(list(0, 0, 4.0), list(0, 1, 2.0), list(1, 1, 3.0), list(1, 2, 4.0), list(2, 1, 1.0), list(2, 2, 5.0)) df <- createDataFrame(ratings, c("user", "item", "rating")) @@ -891,7 +896,7 @@ model <- spark.als(df, "rating", "user", "item", rank = 10, reg = 0.1, nonnegati ``` Extract latent factors. -```{r} +```{r, eval=FALSE} stats <- summary(model) userFactors <- stats$userFactors itemFactors <- stats$itemFactors @@ -901,11 +906,42 @@ head(itemFactors) Make predictions. -```{r} +```{r, eval=FALSE} predicted <- predict(model, df) head(predicted) ``` +#### FP-growth + +`spark.fpGrowth` executes FP-growth algorithm to mine frequent itemsets on a `SparkDataFrame`. `itemsCol` should be an array of values. + +```{r} +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "T,R,U", "T,S", "V,R", "R,U,T,V", "R,S", "V,S,U", "U,R", "S,T", "V,R", "V,U,S", + "T,V,U", "R,V", "T,S", "T,S", "S,T", "S,U", "T,R", "V,R", "S,V", "T,S,U" +))), "split(rawItems, ',') AS items") + +fpm <- spark.fpGrowth(df, minSupport = 0.2, minConfidence = 0.5) +``` + +`spark.freqItemsets` method can be used to retrieve a `SparkDataFrame` with the frequent itemsets. + +```{r} +head(spark.freqItemsets(fpm)) +``` + +`spark.associationRules` returns a `SparkDataFrame` with the association rules. + +```{r} +head(spark.associationRules(fpm)) +``` + +We can make predictions based on the `antecedent`. + +```{r} +head(predict(fpm, df)) +``` + #### Kolmogorov-Smirnov Test `spark.kstest` runs a two-sided, one-sample [Kolmogorov-Smirnov (KS) test](https://en.wikipedia.org/wiki/Kolmogorov%E2%80%93Smirnov_test). @@ -952,6 +988,72 @@ unlink(modelPath) ``` +## Structured Streaming + +SparkR supports the Structured Streaming API (experimental). + +You can check the Structured Streaming Programming Guide for [an introduction](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#programming-model) to its programming model and basic concepts. + +### Simple Source and Sink + +Spark has a few built-in input sources. As an example, to test with a socket source reading text into words and displaying the computed word counts: + +```{r, eval=FALSE} +# Create DataFrame representing the stream of input lines from connection +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") +``` + +### Kafka Source + +It is simple to read data from Kafka. For more information, see [Input Sources](https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html#input-sources) supported by Structured Streaming. + +```{r, eval=FALSE} +topic <- read.stream("kafka", + kafka.bootstrap.servers = "host1:port1,host2:port2", + subscribe = "topic1") +keyvalue <- selectExpr(topic, "CAST(key AS STRING)", "CAST(value AS STRING)") +``` + +### Operations and Sinks + +Most of the common operations on `SparkDataFrame` are supported for streaming, including selection, projection, and aggregation. Once you have defined the final result, to start the streaming computation, you will call the `write.stream` method setting a sink and `outputMode`. + +A streaming `SparkDataFrame` can be written for debugging to the console, to a temporary in-memory table, or for further processing in a fault-tolerant manner to a File Sink in different formats. + +```{r, eval=FALSE} +noAggDF <- select(where(deviceDataStreamingDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# Aggregate +aggDF <- count(groupBy(noAggDF, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +head(sql("select * from aggregates")) +``` + + ## Advanced Topics ### SparkR Object Classes diff --git a/R/run-tests.sh b/R/run-tests.sh index 742a2c5ed76d..29764f48bd15 100755 --- a/R/run-tests.sh +++ b/R/run-tests.sh @@ -23,7 +23,7 @@ FAILED=0 LOGFILE=$FWDIR/unit-tests.out rm -f $LOGFILE -SPARK_TESTING=1 $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE +SPARK_TESTING=1 NOT_CRAN=true $FWDIR/../bin/spark-submit --driver-java-options "-Dlog4j.configuration=file:$FWDIR/log4j.properties" --conf spark.hadoop.fs.defaultFS="file:///" $FWDIR/pkg/tests/run-all.R 2>&1 | tee -a $LOGFILE FAILED=$((PIPESTATUS[0]||$FAILED)) NUM_TEST_WARNING="$(grep -c -e 'Warnings ----------------' $LOGFILE)" diff --git a/assembly/pom.xml b/assembly/pom.xml index 9d8607d9137c..da7b0c9d1b93 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 8657af744c06..7577253dd039 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 24c10fb1ddb9..558864ae4faa 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java index 6daf9609d76d..c0f1da50f5e6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandler.java @@ -21,7 +21,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; -import java.util.List; +import java.util.Iterator; import java.util.Map; import com.codahale.metrics.Gauge; @@ -30,7 +30,6 @@ import com.codahale.metrics.MetricSet; import com.codahale.metrics.Timer; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.Lists; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -93,14 +92,25 @@ protected void handleMessage( OpenBlocks msg = (OpenBlocks) msgObj; checkAuth(client, msg.appId); - List blocks = Lists.newArrayList(); - long totalBlockSize = 0; - for (String blockId : msg.blockIds) { - final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, blockId); - totalBlockSize += block != null ? block.size() : 0; - blocks.add(block); - } - long streamId = streamManager.registerStream(client.getClientId(), blocks.iterator()); + Iterator iter = new Iterator() { + private int index = 0; + + @Override + public boolean hasNext() { + return index < msg.blockIds.length; + } + + @Override + public ManagedBuffer next() { + final ManagedBuffer block = blockManager.getBlockData(msg.appId, msg.execId, + msg.blockIds[index]); + index++; + metrics.blockTransferRateBytes.mark(block != null ? block.size() : 0); + return block; + } + }; + + long streamId = streamManager.registerStream(client.getClientId(), iter); if (logger.isTraceEnabled()) { logger.trace("Registered streamId {} with {} buffers for client {} from host {}", streamId, @@ -109,7 +119,6 @@ protected void handleMessage( getRemoteAddress(client.getChannel())); } callback.onSuccess(new StreamHandle(streamId, msg.blockIds.length).toByteBuffer()); - metrics.blockTransferRateBytes.mark(totalBlockSize); } finally { responseDelayContext.stop(); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java index e47a72c9d16c..4d48b1897038 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockHandlerSuite.java @@ -88,8 +88,6 @@ public void testOpenShuffleBlocks() { ByteBuffer openBlocks = new OpenBlocks("app0", "exec1", new String[] { "b0", "b1" }) .toByteBuffer(); handler.receive(client, openBlocks, callback); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); - verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); ArgumentCaptor response = ArgumentCaptor.forClass(ByteBuffer.class); verify(callback, times(1)).onSuccess(response.capture()); @@ -107,6 +105,8 @@ public void testOpenShuffleBlocks() { assertEquals(block0Marker, buffers.next()); assertEquals(block1Marker, buffers.next()); assertFalse(buffers.hasNext()); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b0"); + verify(blockResolver, times(1)).getBlockData("app0", "exec1", "b1"); // Verify open block request latency metrics Timer openBlockRequestLatencyMillis = (Timer) ((ExternalShuffleBlockHandler) handler) diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index b8ae04eefb97..7a33b6821792 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -216,9 +216,8 @@ public void testFetchWrongExecutor() throws Exception { registerExecutor("exec-0", dataContext0.createExecutorInfo(SORT_MANAGER)); FetchResult execFetch = fetchBlocks("exec-0", new String[] { "shuffle_0_0_0" /* right */, "shuffle_1_0_0" /* wrong */ }); - // Both still fail, as we start by checking for all block. - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_0_0_0", "shuffle_1_0_0"), execFetch.failedBlocks); + assertEquals(Sets.newHashSet("shuffle_0_0_0"), execFetch.successBlocks); + assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); } @Test diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 5e5a80bd4446..70fed65b0255 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index c7620d0fe128..fd50e3a4bfb9 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -21,7 +21,6 @@ import java.io.IOException; import java.nio.charset.StandardCharsets; import java.nio.ByteBuffer; -import java.nio.file.Files; import java.util.List; import java.util.Map; @@ -340,9 +339,9 @@ protected Path getRecoveryPath(String fileName) { * when it previously was not. If YARN NM recovery is enabled it uses that path, otherwise * it will uses a YARN local dir. */ - protected File initRecoveryDb(String dbFileName) { + protected File initRecoveryDb(String dbName) { if (_recoveryPath != null) { - File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbFileName); + File recoveryFile = new File(_recoveryPath.toUri().getPath(), dbName); if (recoveryFile.exists()) { return recoveryFile; } @@ -350,7 +349,7 @@ protected File initRecoveryDb(String dbFileName) { // db doesn't exist in recovery path go check local dirs for it String[] localDirs = _conf.getTrimmedStrings("yarn.nodemanager.local-dirs"); for (String dir : localDirs) { - File f = new File(new Path(dir).toUri().getPath(), dbFileName); + File f = new File(new Path(dir).toUri().getPath(), dbName); if (f.exists()) { if (_recoveryPath == null) { // If NM recovery is not enabled, we should specify the recovery path using NM local @@ -363,17 +362,21 @@ protected File initRecoveryDb(String dbFileName) { // make sure to move all DBs to the recovery path from the old NM local dirs. // If another DB was initialized first just make sure all the DBs are in the same // location. - File newLoc = new File(_recoveryPath.toUri().getPath(), dbFileName); - if (!newLoc.equals(f)) { + Path newLoc = new Path(_recoveryPath, dbName); + Path copyFrom = new Path(f.toURI()); + if (!newLoc.equals(copyFrom)) { + logger.info("Moving " + copyFrom + " to: " + newLoc); try { - Files.move(f.toPath(), newLoc.toPath()); + // The move here needs to handle moving non-empty directories across NFS mounts + FileSystem fs = FileSystem.getLocal(_conf); + fs.rename(copyFrom, newLoc); } catch (Exception e) { // Fail to move recovery file to new path, just continue on with new DB location logger.error("Failed to move recovery file {} to the path {}", - dbFileName, _recoveryPath.toString(), e); + dbName, _recoveryPath.toString(), e); } } - return newLoc; + return new File(newLoc.toUri().getPath()); } } } @@ -381,7 +384,7 @@ protected File initRecoveryDb(String dbFileName) { _recoveryPath = new Path(localDirs[0]); } - return new File(_recoveryPath.toUri().getPath(), dbFileName); + return new File(_recoveryPath.toUri().getPath(), dbName); } /** diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index 1356c4723b66..076d98af834d 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index 9345dc8f0cc4..e74d84a5b3b9 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index f03a4da5e715..76783abe36a2 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java index 1321b8318115..4ab5b6889c21 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/Platform.java @@ -48,7 +48,8 @@ public final class Platform { boolean _unaligned; String arch = System.getProperty("os.arch", ""); if (arch.equals("ppc64le") || arch.equals("ppc64")) { - // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but ppc64 and ppc64le support it + // Since java.nio.Bits.unaligned() doesn't return true on ppc (See JDK-8165231), but + // ppc64 and ppc64le support it _unaligned = true; } else { try { diff --git a/core/pom.xml b/core/pom.xml index 24ce36deeb16..254a9b9ac318 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java index aa0b37323132..5f9141174916 100644 --- a/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java +++ b/core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java @@ -155,7 +155,8 @@ public long acquireExecutionMemory(long required, MemoryConsumer consumer) { for (MemoryConsumer c: consumers) { if (c != consumer && c.getUsed() > 0 && c.getMode() == mode) { long key = c.getUsed(); - List list = sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); + List list = + sortedConsumers.computeIfAbsent(key, k -> new ArrayList<>(1)); list.add(c); } } diff --git a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js index 930a0698928d..cb9922d23c44 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/executorspage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/executorspage.js @@ -253,10 +253,14 @@ $(document).ready(function () { var deadTotalBlacklisted = 0; response.forEach(function (exec) { - exec.onHeapMemoryUsed = exec.hasOwnProperty('onHeapMemoryUsed') ? exec.onHeapMemoryUsed : 0; - exec.maxOnHeapMemory = exec.hasOwnProperty('maxOnHeapMemory') ? exec.maxOnHeapMemory : 0; - exec.offHeapMemoryUsed = exec.hasOwnProperty('offHeapMemoryUsed') ? exec.offHeapMemoryUsed : 0; - exec.maxOffHeapMemory = exec.hasOwnProperty('maxOffHeapMemory') ? exec.maxOffHeapMemory : 0; + var memoryMetrics = { + usedOnHeapStorageMemory: 0, + usedOffHeapStorageMemory: 0, + totalOnHeapStorageMemory: 0, + totalOffHeapStorageMemory: 0 + }; + + exec.memoryMetrics = exec.hasOwnProperty('memoryMetrics') ? exec.memoryMetrics : memoryMetrics; }); response.forEach(function (exec) { @@ -264,10 +268,10 @@ $(document).ready(function () { allRDDBlocks += exec.rddBlocks; allMemoryUsed += exec.memoryUsed; allMaxMemory += exec.maxMemory; - allOnHeapMemoryUsed += exec.onHeapMemoryUsed; - allOnHeapMaxMemory += exec.maxOnHeapMemory; - allOffHeapMemoryUsed += exec.offHeapMemoryUsed; - allOffHeapMaxMemory += exec.maxOffHeapMemory; + allOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + allOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + allOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + allOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; allDiskUsed += exec.diskUsed; allTotalCores += exec.totalCores; allMaxTasks += exec.maxTasks; @@ -286,10 +290,10 @@ $(document).ready(function () { activeRDDBlocks += exec.rddBlocks; activeMemoryUsed += exec.memoryUsed; activeMaxMemory += exec.maxMemory; - activeOnHeapMemoryUsed += exec.onHeapMemoryUsed; - activeOnHeapMaxMemory += exec.maxOnHeapMemory; - activeOffHeapMemoryUsed += exec.offHeapMemoryUsed; - activeOffHeapMaxMemory += exec.maxOffHeapMemory; + activeOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + activeOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + activeOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + activeOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; activeDiskUsed += exec.diskUsed; activeTotalCores += exec.totalCores; activeMaxTasks += exec.maxTasks; @@ -308,10 +312,10 @@ $(document).ready(function () { deadRDDBlocks += exec.rddBlocks; deadMemoryUsed += exec.memoryUsed; deadMaxMemory += exec.maxMemory; - deadOnHeapMemoryUsed += exec.onHeapMemoryUsed; - deadOnHeapMaxMemory += exec.maxOnHeapMemory; - deadOffHeapMemoryUsed += exec.offHeapMemoryUsed; - deadOffHeapMaxMemory += exec.maxOffHeapMemory; + deadOnHeapMemoryUsed += exec.memoryMetrics.usedOnHeapStorageMemory; + deadOnHeapMaxMemory += exec.memoryMetrics.totalOnHeapStorageMemory; + deadOffHeapMemoryUsed += exec.memoryMetrics.usedOffHeapStorageMemory; + deadOffHeapMaxMemory += exec.memoryMetrics.totalOffHeapStorageMemory; deadDiskUsed += exec.diskUsed; deadTotalCores += exec.totalCores; deadMaxTasks += exec.maxTasks; @@ -431,10 +435,10 @@ $(document).ready(function () { { data: function (row, type) { if (type !== 'display') - return row.onHeapMemoryUsed; + return row.memoryMetrics.usedOnHeapStorageMemory; else - return (formatBytes(row.onHeapMemoryUsed, type) + ' / ' + - formatBytes(row.maxOnHeapMemory, type)); + return (formatBytes(row.memoryMetrics.usedOnHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOnHeapStorageMemory, type)); }, "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { $(nTd).addClass('on_heap_memory') @@ -443,10 +447,10 @@ $(document).ready(function () { { data: function (row, type) { if (type !== 'display') - return row.offHeapMemoryUsed; + return row.memoryMetrics.usedOffHeapStorageMemory; else - return (formatBytes(row.offHeapMemoryUsed, type) + ' / ' + - formatBytes(row.maxOffHeapMemory, type)); + return (formatBytes(row.memoryMetrics.usedOffHeapStorageMemory, type) + ' / ' + + formatBytes(row.memoryMetrics.totalOffHeapStorageMemory, type)); }, "fnCreatedCell": function (nTd, sData, oData, iRow, iCol) { $(nTd).addClass('off_heap_memory') diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html index 42e2d9abdeb5..6ba3b092dc65 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage-template.html @@ -77,7 +77,7 @@ {{duration}} {{sparkUser}} {{lastUpdated}} - Download + Download {{/attempts}} {{/applications}} diff --git a/core/src/main/resources/org/apache/spark/ui/static/historypage.js b/core/src/main/resources/org/apache/spark/ui/static/historypage.js index 54810edaf146..1f89306403cd 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/historypage.js +++ b/core/src/main/resources/org/apache/spark/ui/static/historypage.js @@ -120,6 +120,9 @@ $(document).ready(function() { attempt["startTime"] = formatDate(attempt["startTime"]); attempt["endTime"] = formatDate(attempt["endTime"]); attempt["lastUpdated"] = formatDate(attempt["lastUpdated"]); + attempt["log"] = uiRoot + "/api/v1/applications/" + id + "/" + + (attempt.hasOwnProperty("attemptId") ? attempt["attemptId"] + "/" : "") + "logs"; + var app_clone = {"id" : id, "name" : name, "num" : num, "attempts" : [attempt]}; array.push(app_clone); } diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 261b3329a7b9..fcc72ff49276 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -331,7 +331,7 @@ private[spark] class ExecutorAllocationManager( val delta = addExecutors(maxNeeded) logDebug(s"Starting timer to add more executors (to " + s"expire in $sustainedSchedulerBacklogTimeoutS seconds)") - addTime += sustainedSchedulerBacklogTimeoutS * 1000 + addTime = now + (sustainedSchedulerBacklogTimeoutS * 1000) delta } else { 0 diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 99efc4893fda..7dbceb9c5c1a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -1350,7 +1350,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulator[T](initialValue: T, name: String)(implicit param: AccumulatorParam[T]) : Accumulator[T] = { - val acc = new Accumulator(initialValue, param, Some(name)) + val acc = new Accumulator(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1379,7 +1379,7 @@ class SparkContext(config: SparkConf) extends Logging { @deprecated("use AccumulatorV2", "2.0.0") def accumulable[R, T](initialValue: R, name: String)(implicit param: AccumulableParam[R, T]) : Accumulable[R, T] = { - val acc = new Accumulable(initialValue, param, Some(name)) + val acc = new Accumulable(initialValue, param, Option(name)) cleaner.foreach(_.registerAccumulatorForCleanup(acc.newAcc)) acc } @@ -1414,7 +1414,7 @@ class SparkContext(config: SparkConf) extends Logging { * @note Accumulators must be registered before use, or it will throw exception. */ def register(acc: AccumulatorV2[_, _], name: String): Unit = { - acc.register(this, name = Some(name)) + acc.register(this, name = Option(name)) } /** @@ -1734,6 +1734,7 @@ class SparkContext(config: SparkConf) extends Logging { * Return information about blocks stored in all of the slaves */ @DeveloperApi + @deprecated("This method may change or be removed in a future release.", "2.2.0") def getExecutorStorageStatus: Array[StorageStatus] = { assertNotStopped() env.blockManager.master.getStorageStatus @@ -1938,6 +1939,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this + // `SparkContext` is stopped. + localProperties.remove() // Unset YARN mode system env variable, to allow switching between cluster types. System.clearProperty("SPARK_YARN_MODE") SparkContext.clearActiveContext() diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index bae7a3f307f5..9cc321af4bde 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -28,6 +28,7 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.{Token, TokenIdentifier} @@ -353,6 +354,28 @@ class SparkHadoopUtil extends Logging { } buffer.toString } + + private[spark] def checkAccessPermission(status: FileStatus, mode: FsAction): Boolean = { + val perm = status.getPermission + val ugi = UserGroupInformation.getCurrentUser + + if (ugi.getShortUserName == status.getOwner) { + if (perm.getUserAction.implies(mode)) { + return true + } + } else if (ugi.getGroupNames.contains(status.getGroup)) { + if (perm.getGroupAction.implies(mode)) { + return true + } + } else if (perm.getOtherAction.implies(mode)) { + return true + } + + logDebug(s"Permission denied: user=${ugi.getShortUserName}, " + + s"path=${status.getPath}:${status.getOwner}:${status.getGroup}" + + s"${if (status.isDirectory) "d" else "-"}$perm") + false + } } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index d7d82800b8b5..6d8758a3d3b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -86,7 +86,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return Count of application event logs that are currently under process */ def getEventLogsUnderProcess(): Int = { - return 0; + 0 } /** @@ -95,7 +95,7 @@ private[history] abstract class ApplicationHistoryProvider { * @return 0 if this is undefined or unsupported, otherwise the last updated time in millis */ def getLastUpdatedTime(): Long = { - return 0; + 0 } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 9012736bc274..f4235df24512 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -27,7 +27,8 @@ import scala.xml.Node import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.hdfs.DistributedFileSystem import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException @@ -318,21 +319,14 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // scan for modified applications, replay and merge them val logInfos: Seq[FileStatus] = statusList .filter { entry => - try { - val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) - !entry.isDirectory() && - // FsHistoryProvider generates a hidden file which can't be read. Accidentally - // reading a garbage file is safe, but we would log an error which can be scary to - // the end-user. - !entry.getPath().getName().startsWith(".") && - prevFileSize < entry.getLen() - } catch { - case e: AccessControlException => - // Do not use "logInfo" since these messages can get pretty noisy if printed on - // every poll. - logDebug(s"No permission to read $entry, ignoring.") - false - } + val prevFileSize = fileToAppInfo.get(entry.getPath()).map{_.fileSize}.getOrElse(0L) + !entry.isDirectory() && + // FsHistoryProvider generates a hidden file which can't be read. Accidentally + // reading a garbage file is safe, but we would log an error which can be scary to + // the end-user. + !entry.getPath().getName().startsWith(".") && + prevFileSize < entry.getLen() && + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) } .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => @@ -445,7 +439,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) /** * Replay the log files in the list and merge the list of old applications with new ones */ - private def mergeApplicationListing(fileStatus: FileStatus): Unit = { + protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { val newAttempts = try { val eventsFilter: ReplayEventsFilter = { eventString => eventString.startsWith(APPL_START_EVENT_PREFIX) || diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 54f39f7620e5..d9c8fda99ef9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -301,6 +301,14 @@ object HistoryServer extends Logging { logDebug(s"Clearing ${SecurityManager.SPARK_AUTH_CONF}") config.set(SecurityManager.SPARK_AUTH_CONF, "false") } + + if (config.getBoolean("spark.acls.enable", config.getBoolean("spark.ui.acls.enable", false))) { + logInfo("Either spark.acls.enable or spark.ui.acls.enable is configured, clearing it and " + + "only using spark.history.ui.acl.enable") + config.set("spark.acls.enable", "false") + config.set("spark.ui.acls.enable", "false") + } + new SecurityManager(config) } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 946a92882141..a8d721f3e0d4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -83,7 +83,7 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") Executor Memory: {Utils.megabytesToString(app.desc.memoryPerExecutorMB)} -
  • Submit Date: {app.submitDate}
  • +
  • Submit Date: {UIUtils.formatDate(app.submitDate)}
  • State: {app.state}
  • { if (!app.isFinished) { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index e722a24d4a89..9351c72094e3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -252,7 +252,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { } {driver.id} {killLink} - {driver.submitDate} + {UIUtils.formatDate(driver.submitDate)} {driver.worker.map(w => if (w.isAlive()) { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 83469c5ff060..51b6c373c4da 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -23,13 +23,15 @@ import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer import java.util.Properties -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent._ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, Map} import scala.util.control.NonFatal +import com.google.common.util.concurrent.ThreadFactoryBuilder + import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -84,7 +86,20 @@ private[spark] class Executor( } // Start worker thread pool - private val threadPool = ThreadUtils.newDaemonCachedThreadPool("Executor task launch worker") + private val threadPool = { + val threadFactory = new ThreadFactoryBuilder() + .setDaemon(true) + .setNameFormat("Executor task launch worker-%d") + .setThreadFactory(new ThreadFactory { + override def newThread(r: Runnable): Thread = + // Use UninterruptibleThread to run tasks so that we can allow running codes without being + // interrupted by `Thread.interrupt()`. Some issues, such as KAFKA-1894, HADOOP-10622, + // will hang forever if some methods are interrupted. + new UninterruptibleThread(r, "unused") // thread name will be set by ThreadFactoryBuilder + }) + .build() + Executors.newCachedThreadPool(threadFactory).asInstanceOf[ThreadPoolExecutor] + } private val executorSource = new ExecutorSource(threadPool, executorId) // Pool used for threads that supervise task killing / cancellation private val taskReaperPool = ThreadUtils.newDaemonCachedThreadPool("Task reaper") @@ -432,7 +447,8 @@ private[spark] class Executor( setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled(t.reason))) - case NonFatal(_) if task != null && task.reasonIfKilled.isDefined => + case _: InterruptedException | NonFatal(_) if + task != null && task.reasonIfKilled.isDefined => val killReason = task.reasonIfKilled.getOrElse("unknown reason") logInfo(s"Executor interrupted and killed $taskName (TID $taskId), reason: $killReason") setTaskFinishedAndClearInterruptStatus() diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index dfd2f818acda..a3ce3d1ccc5e 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -251,13 +251,10 @@ class TaskMetrics private[spark] () extends Serializable { private[spark] def accumulators(): Seq[AccumulatorV2[_, _]] = internalAccums ++ externalAccums - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - accumulators.find { acc => - acc.name.isDefined && acc.name.get == name - } + private[spark] def nonZeroInternalAccums(): Seq[AccumulatorV2[_, _]] = { + // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its + // value will be updated at driver side. + internalAccums.filter(a => !a.isZero || a == _resultSize) } } @@ -308,16 +305,16 @@ private[spark] object TaskMetrics extends Logging { */ def fromAccumulators(accums: Seq[AccumulatorV2[_, _]]): TaskMetrics = { val tm = new TaskMetrics - val (internalAccums, externalAccums) = - accums.partition(a => a.name.isDefined && tm.nameToAccums.contains(a.name.get)) - - internalAccums.foreach { acc => - val tmAcc = tm.nameToAccums(acc.name.get).asInstanceOf[AccumulatorV2[Any, Any]] - tmAcc.metadata = acc.metadata - tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + for (acc <- accums) { + val name = acc.name + if (name.isDefined && tm.nameToAccums.contains(name.get)) { + val tmAcc = tm.nameToAccums(name.get).asInstanceOf[AccumulatorV2[Any, Any]] + tmAcc.metadata = acc.metadata + tmAcc.merge(acc.asInstanceOf[AccumulatorV2[Any, Any]]) + } else { + tm.externalAccums += acc + } } - - tm.externalAccums ++= externalAccums tm } } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 89aeea493908..7f7921d56f49 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -244,8 +244,8 @@ package object config { ConfigBuilder("spark.redaction.regex") .doc("Regex to decide which Spark configuration properties and environment variables in " + "driver and executor environments contain sensitive information. When this regex matches " + - "a property, its value is redacted from the environment UI and various logs like YARN " + - "and event logs.") + "a property key or value, the value is redacted from the environment UI and various logs " + + "like YARN and event logs.") .regexConf .createWithDefault("(?i)secret|password".r) @@ -272,4 +272,10 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val CHECKPOINT_COMPRESS = + ConfigBuilder("spark.checkpoint.compress") + .doc("Whether to compress RDD checkpoints. Generally a good idea. Compression will use " + + "spark.io.compression.codec.") + .booleanConf + .createWithDefault(false) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala index 2ed8a00df702..305fd9a6de10 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala @@ -56,11 +56,12 @@ class NettyBlockRpcServer( message match { case openBlocks: OpenBlocks => - val blocks: Seq[ManagedBuffer] = - openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData) + val blocksNum = openBlocks.blockIds.length + val blocks = for (i <- (0 until blocksNum).view) + yield blockManager.getBlockData(BlockId.apply(openBlocks.blockIds(i))) val streamId = streamManager.registerStream(appId, blocks.iterator.asJava) - logTrace(s"Registered streamId $streamId with ${blocks.size} buffers") - responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteBuffer) + logTrace(s"Registered streamId $streamId with $blocksNum buffers") + responseContext.onSuccess(new StreamHandle(streamId, blocksNum).toByteBuffer) case uploadBlock: UploadBlock => // StorageLevel and ClassTag are serialized as bytes using our JavaSerializer. diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e524675332d1..63a87e7f09d8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.{RDDBlockId, StorageLevel} import org.apache.spark.util.{BoundedPriorityQueue, Utils} -import org.apache.spark.util.collection.OpenHashMap +import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils} import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag]( val mapRDDs = mapPartitions { items => // Priority keeps the largest elements, so let's reverse the ordering. val queue = new BoundedPriorityQueue[T](num)(ord.reverse) - queue ++= util.collection.Utils.takeOrdered(items, num)(ord) + queue ++= collectionUtils.takeOrdered(items, num)(ord) Iterator.single(queue) } if (mapRDDs.partitions.length == 0) { diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index e0a29b48314f..37c67cee55f9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import java.io.{FileNotFoundException, IOException} +import java.util.concurrent.TimeUnit import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -27,6 +28,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CHECKPOINT_COMPRESS +import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{SerializableConfiguration, Utils} /** @@ -119,6 +122,7 @@ private[spark] object ReliableCheckpointRDD extends Logging { originalRDD: RDD[T], checkpointDir: String, blockSize: Int = -1): ReliableCheckpointRDD[T] = { + val checkpointStartTimeNs = System.nanoTime() val sc = originalRDD.sparkContext @@ -140,6 +144,10 @@ private[spark] object ReliableCheckpointRDD extends Logging { writePartitionerToCheckpointDir(sc, originalRDD.partitioner.get, checkpointDirPath) } + val checkpointDurationMs = + TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - checkpointStartTimeNs) + logInfo(s"Checkpointing took $checkpointDurationMs ms.") + val newRDD = new ReliableCheckpointRDD[T]( sc, checkpointDirPath.toString, originalRDD.partitioner) if (newRDD.partitions.length != originalRDD.partitions.length) { @@ -169,7 +177,12 @@ private[spark] object ReliableCheckpointRDD extends Logging { val bufferSize = env.conf.getInt("spark.buffer.size", 65536) val fileOutputStream = if (blockSize < 0) { - fs.create(tempOutputPath, false, bufferSize) + val fileStream = fs.create(tempOutputPath, false, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedOutputStream(fileStream) + } else { + fileStream + } } else { // This is mainly for testing purpose fs.create(tempOutputPath, false, bufferSize, @@ -273,7 +286,14 @@ private[spark] object ReliableCheckpointRDD extends Logging { val env = SparkEnv.get val fs = path.getFileSystem(broadcastedConf.value.value) val bufferSize = env.conf.getInt("spark.buffer.size", 65536) - val fileInputStream = fs.open(path, bufferSize) + val fileInputStream = { + val fileStream = fs.open(path, bufferSize) + if (env.conf.get(CHECKPOINT_COMPRESS)) { + CompressionCodec.createCodec(env.conf).compressedInputStream(fileStream) + } else { + fileStream + } + } val serializer = env.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala similarity index 97% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala rename to core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala index 145dc22b7428..ab72addb2466 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/rdd/util/PeriodicRDDCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.rdd.util import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index aecb3a980e7c..a7dbf87915b2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -252,11 +252,17 @@ private[spark] class EventLoggingListener( private[spark] def redactEvent( event: SparkListenerEnvironmentUpdate): SparkListenerEnvironmentUpdate = { - // "Spark Properties" entry will always exist because the map is always populated with it. - val redactedProps = Utils.redact(sparkConf, event.environmentDetails("Spark Properties")) - val redactedEnvironmentDetails = event.environmentDetails + - ("Spark Properties" -> redactedProps) - SparkListenerEnvironmentUpdate(redactedEnvironmentDetails) + // environmentDetails maps a string descriptor to a set of properties + // Similar to: + // "JVM Information" -> jvmInformation, + // "Spark Properties" -> sparkProperties, + // ... + // where jvmInformation, sparkProperties, etc. are sequence of tuples. + // We go through the various of properties and redact sensitive information from them. + val redactedProps = event.environmentDetails.map{ case (name, props) => + name -> Utils.redact(sparkConf, props) + } + SparkListenerEnvironmentUpdate(redactedProps) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7fd2918960cd..5c337b992c84 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -182,14 +182,11 @@ private[spark] abstract class Task[T]( */ def collectAccumulatorUpdates(taskFailed: Boolean = false): Seq[AccumulatorV2[_, _]] = { if (context != null) { - context.taskMetrics.internalAccums.filter { a => - // RESULT_SIZE accumulator is always zero at executor, we need to send it back as its - // value will be updated at driver side. - // Note: internal accumulators representing task metrics always count failed values - !a.isZero || a.name == Some(InternalAccumulator.RESULT_SIZE) - // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not filter - // them out. - } ++ context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) + // Note: internal accumulators representing task metrics always count failed values + context.taskMetrics.nonZeroInternalAccums() ++ + // zero value external accumulators may still be useful, e.g. SQLMetrics, we should not + // filter them out. + context.taskMetrics.externalAccums.filter(a => !taskFailed || a.countFailedValues) } else { Seq.empty } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4eedaaea6119..dc82bb770472 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -69,6 +69,10 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // `CoarseGrainedSchedulerBackend.this`. private val executorDataMap = new HashMap[String, ExecutorData] + // Number of executors requested by the cluster manager, [[ExecutorAllocationManager]] + @GuardedBy("CoarseGrainedSchedulerBackend.this") + private var requestedTotalExecutors = 0 + // Number of executors requested from the cluster manager that have not registered yet @GuardedBy("CoarseGrainedSchedulerBackend.this") private var numPendingExecutors = 0 @@ -413,6 +417,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * */ protected def reset(): Unit = { val executors = synchronized { + requestedTotalExecutors = 0 numPendingExecutors = 0 executorsPendingToRemove.clear() Set() ++ executorDataMap.keys @@ -487,12 +492,21 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager") val response = synchronized { + requestedTotalExecutors += numAdditionalExecutors numPendingExecutors += numAdditionalExecutors logDebug(s"Number of pending executors is now $numPendingExecutors") + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""requestExecutors($numAdditionalExecutors): Executor request doesn't match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } // Account for executors pending to be added or removed - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + doRequestTotalExecutors(requestedTotalExecutors) } defaultAskTimeout.awaitResult(response) @@ -524,6 +538,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp } val response = synchronized { + this.requestedTotalExecutors = numExecutors this.localityAwareTasks = localityAwareTasks this.hostToLocalTaskCount = hostToLocalTaskCount @@ -589,8 +604,17 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp // take into account executors that are pending to be added or removed. val adjustTotalExecutors = if (!replace) { - doRequestTotalExecutors( - numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size) + requestedTotalExecutors = math.max(requestedTotalExecutors - executorsToKill.size, 0) + if (requestedTotalExecutors != + (numExistingExecutors + numPendingExecutors - executorsPendingToRemove.size)) { + logDebug( + s"""killExecutors($executorIds, $replace, $force): Executor counts do not match: + |requestedTotalExecutors = $requestedTotalExecutors + |numExistingExecutors = $numExistingExecutors + |numPendingExecutors = $numPendingExecutors + |executorsPendingToRemove = ${executorsPendingToRemove.size}""".stripMargin) + } + doRequestTotalExecutors(requestedTotalExecutors) } else { numPendingExecutors += knownExecutors.size Future.successful(true) diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index 00f918c09c66..f17b63775482 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -184,14 +184,27 @@ private[v1] class ApiRootResource extends ApiRequestContext { @Path("applications/{appId}/logs") def getEventLogs( @PathParam("appId") appId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, None) + try { + // withSparkUI will throw NotFoundException if attemptId exists for this application. + // So we need to try again with attempt id "1". + withSparkUI(appId, None) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } catch { + case _: NotFoundException => + withSparkUI(appId, Some("1")) { _ => + new EventLogDownloadResource(uiRoot, appId, None) + } + } } @Path("applications/{appId}/{attemptId}/logs") def getEventLogs( @PathParam("appId") appId: String, @PathParam("attemptId") attemptId: String): EventLogDownloadResource = { - new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + withSparkUI(appId, Some(attemptId)) { _ => + new EventLogDownloadResource(uiRoot, appId, Some(attemptId)) + } } @Path("version") @@ -291,7 +304,6 @@ private[v1] trait ApiRequestContext { case None => throw new NotFoundException("no such app: " + appId) } } - } private[v1] class ForbiddenException(msg: String) extends WebApplicationException( diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index d159b9450ef5..56d8e51732ff 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -76,10 +76,13 @@ class ExecutorSummary private[spark]( val isBlacklisted: Boolean, val maxMemory: Long, val executorLogs: Map[String, String], - val onHeapMemoryUsed: Option[Long], - val offHeapMemoryUsed: Option[Long], - val maxOnHeapMemory: Option[Long], - val maxOffHeapMemory: Option[Long]) + val memoryMetrics: Option[MemoryMetrics]) + +class MemoryMetrics private[spark]( + val usedOnHeapStorageMemory: Long, + val usedOffHeapStorageMemory: Long, + val totalOnHeapStorageMemory: Long, + val totalOffHeapStorageMemory: Long) class JobData private[spark]( val jobId: Int, diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 467c3e0e6b51..6f85b9e4d6c7 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -497,11 +497,17 @@ private[spark] class BlockManagerInfo( updateLastSeenMs() - if (_blocks.containsKey(blockId)) { + val blockExists = _blocks.containsKey(blockId) + var originalMemSize: Long = 0 + var originalDiskSize: Long = 0 + var originalLevel: StorageLevel = StorageLevel.NONE + + if (blockExists) { // The block exists on the slave already. val blockStatus: BlockStatus = _blocks.get(blockId) - val originalLevel: StorageLevel = blockStatus.storageLevel - val originalMemSize: Long = blockStatus.memSize + originalLevel = blockStatus.storageLevel + originalMemSize = blockStatus.memSize + originalDiskSize = blockStatus.diskSize if (originalLevel.useMemory) { _remainingMem += originalMemSize @@ -520,32 +526,44 @@ private[spark] class BlockManagerInfo( blockStatus = BlockStatus(storageLevel, memSize = memSize, diskSize = 0) _blocks.put(blockId, blockStatus) _remainingMem -= memSize - logInfo("Added %s in memory on %s (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(memSize), - Utils.bytesToString(_remainingMem))) + if (blockExists) { + logInfo(s"Updated $blockId in memory on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(memSize)}," + + s" original size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } else { + logInfo(s"Added $blockId in memory on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(memSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") + } } if (storageLevel.useDisk) { blockStatus = BlockStatus(storageLevel, memSize = 0, diskSize = diskSize) _blocks.put(blockId, blockStatus) - logInfo("Added %s on disk on %s (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(diskSize))) + if (blockExists) { + logInfo(s"Updated $blockId on disk on ${blockManagerId.hostPort}" + + s" (current size: ${Utils.bytesToString(diskSize)}," + + s" original size: ${Utils.bytesToString(originalDiskSize)})") + } else { + logInfo(s"Added $blockId on disk on ${blockManagerId.hostPort}" + + s" (size: ${Utils.bytesToString(diskSize)})") + } } if (!blockId.isBroadcast && blockStatus.isCached) { _cachedBlocks += blockId } - } else if (_blocks.containsKey(blockId)) { + } else if (blockExists) { // If isValid is not true, drop the block. - val blockStatus: BlockStatus = _blocks.get(blockId) _blocks.remove(blockId) _cachedBlocks -= blockId - if (blockStatus.storageLevel.useMemory) { - logInfo("Removed %s on %s in memory (size: %s, free: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.memSize), - Utils.bytesToString(_remainingMem))) + if (originalLevel.useMemory) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} in memory" + + s" (size: ${Utils.bytesToString(originalMemSize)}," + + s" free: ${Utils.bytesToString(_remainingMem)})") } - if (blockStatus.storageLevel.useDisk) { - logInfo("Removed %s on %s on disk (size: %s)".format( - blockId, blockManagerId.hostPort, Utils.bytesToString(blockStatus.diskSize))) + if (originalLevel.useDisk) { + logInfo(s"Removed $blockId on ${blockManagerId.hostPort} on disk" + + s" (size: ${Utils.bytesToString(originalDiskSize)})") } } } diff --git a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala index 1b30d4fa93bc..ac60f795915a 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageStatusListener.scala @@ -30,6 +30,7 @@ import org.apache.spark.scheduler._ * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageStatusListener(conf: SparkConf) extends SparkListener { // This maintains only blocks that are cached (i.e. storage level is not StorageLevel.NONE) private[storage] val executorIdToStorageStatus = mutable.Map[String, StorageStatus]() diff --git a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala index 8f0d181fc8fe..e9694fdbca2d 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageUtils.scala @@ -35,6 +35,7 @@ import org.apache.spark.internal.Logging * class cannot mutate the source of the information. Accesses are not thread-safe. */ @DeveloperApi +@deprecated("This class may be removed or made private in a future release.", "2.2.0") class StorageStatus( val blockManagerId: BlockManagerId, val maxMemory: Long, diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index bdbdba578085..edf328b5ae53 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -29,8 +29,8 @@ import org.eclipse.jetty.client.api.Response import org.eclipse.jetty.proxy.ProxyServlet import org.eclipse.jetty.server._ import org.eclipse.jetty.server.handler._ +import org.eclipse.jetty.server.handler.gzip.GzipHandler import org.eclipse.jetty.servlet._ -import org.eclipse.jetty.servlets.gzip.GzipHandler import org.eclipse.jetty.util.component.LifeCycle import org.eclipse.jetty.util.thread.{QueuedThreadPool, ScheduledExecutorScheduler} import org.json4s.JValue diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 7d31ac54a717..bf4cf79e9faa 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -117,7 +117,7 @@ private[spark] class SparkUI private ( endTime = new Date(-1), duration = 0, lastUpdated = new Date(startTime), - sparkUser = "", + sparkUser = getSparkUser, completed = false )) )) diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index e53d6907bc40..79b0d81af52b 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -446,7 +446,7 @@ private[spark] object UIUtils extends Logging { val xml = XML.loadString(s"""$desc""") // Verify that this has only anchors and span (we are wrapping in span) - val allowedNodeLabels = Set("a", "span") + val allowedNodeLabels = Set("a", "span", "br") val illegalNodes = xml \\ "_" filterNot { case node: Node => allowedNodeLabels.contains(node.label) } diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala index 70b3ffd95e60..8c18464e6477 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentTab.scala @@ -32,6 +32,7 @@ private[ui] class EnvironmentTab(parent: SparkUI) extends SparkUITab(parent, "en * A SparkListener that prepares information to be displayed on the EnvironmentTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class EnvironmentListener extends SparkListener { var jvmInformation = Seq[(String, String)]() var sparkProperties = Seq[(String, String)]() diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index 0a3c63d14ca8..b7cbed468517 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.status.api.v1.ExecutorSummary +import org.apache.spark.status.api.v1.{ExecutorSummary, MemoryMetrics} import org.apache.spark.ui.{UIUtils, WebUIPage} // This isn't even used anymore -- but we need to keep it b/c of a MiMa false positive @@ -114,10 +114,16 @@ private[spark] object ExecutorsPage { val rddBlocks = status.numBlocks val memUsed = status.memUsed val maxMem = status.maxMem - val onHeapMemUsed = status.onHeapMemUsed - val offHeapMemUsed = status.offHeapMemUsed - val maxOnHeapMem = status.maxOnHeapMem - val maxOffHeapMem = status.maxOffHeapMem + val memoryMetrics = for { + onHeapUsed <- status.onHeapMemUsed + offHeapUsed <- status.offHeapMemUsed + maxOnHeap <- status.maxOnHeapMem + maxOffHeap <- status.maxOffHeapMem + } yield { + new MemoryMetrics(onHeapUsed, offHeapUsed, maxOnHeap, maxOffHeap) + } + + val diskUsed = status.diskUsed val taskSummary = listener.executorToTaskSummary.getOrElse(execId, ExecutorTaskSummary(execId)) @@ -142,10 +148,7 @@ private[spark] object ExecutorsPage { taskSummary.isBlacklisted, maxMem, taskSummary.executorLogs, - onHeapMemUsed, - offHeapMemUsed, - maxOnHeapMem, - maxOffHeapMem + memoryMetrics ) } } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 03851293eb2f..aabf6e0c63c0 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -62,6 +62,7 @@ private[ui] case class ExecutorTaskSummary( * A SparkListener that prepares information to be displayed on the ExecutorsTab */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: SparkConf) extends SparkListener { val executorToTaskSummary = LinkedHashMap[String, ExecutorTaskSummary]() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index f78db5ab80d1..8870187f2219 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -41,6 +41,7 @@ import org.apache.spark.ui.jobs.UIData._ * updating the internal data structures concurrently. */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { // Define a handful of type aliases so that data structures' types can serve as documentation. diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index c212362557be..148efb134e14 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -39,6 +39,7 @@ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storag * This class is thread-safe (unlike JobProgressListener) */ @DeveloperApi +@deprecated("This class will be removed in a future release.", "2.2.0") class StorageListener(storageStatusListener: StorageStatusListener) extends BlockStatusListener { private[ui] val _rddInfoMap = mutable.Map[Int, RDDInfo]() // exposed for testing diff --git a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala index 7479de55140e..a65ec75cc5db 100644 --- a/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala +++ b/core/src/main/scala/org/apache/spark/util/AccumulatorV2.scala @@ -84,8 +84,12 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { * Returns the name of this accumulator, can only be called after registration. */ final def name: Option[String] = { - assertMetadataNotNull() - metadata.name + if (atDriverSide) { + AccumulatorContext.get(id).flatMap(_.metadata.name) + } else { + assertMetadataNotNull() + metadata.name + } } /** @@ -161,7 +165,15 @@ abstract class AccumulatorV2[IN, OUT] extends Serializable { } val copyAcc = copyAndReset() assert(copyAcc.isZero, "copyAndReset must return a zero value copy") - copyAcc.metadata = metadata + val isInternalAcc = + (name.isDefined && name.get.startsWith(InternalAccumulator.METRICS_PREFIX)) || + getClass.getSimpleName == "SQLMetric" + if (isInternalAcc) { + // Do not serialize the name of internal accumulator and send it to executor. + copyAcc.metadata = metadata.copy(name = None) + } else { + copyAcc.metadata = metadata + } copyAcc } else { this @@ -263,16 +275,6 @@ private[spark] object AccumulatorContext { originals.clear() } - /** - * Looks for a registered accumulator by accumulator name. - */ - private[spark] def lookForAccumulatorByName(name: String): Option[AccumulatorV2[_, _]] = { - originals.values().asScala.find { ref => - val acc = ref.get - acc != null && acc.name.isDefined && acc.name.get == name - }.map(_.get) - } - // Identifier for distinguishing SQL metrics from other accumulators private[spark] val SQL_ACCUM_IDENTIFIER = "sql" } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala similarity index 95% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala rename to core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala index 4dd498cd91b4..ce06e18879a4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/core/src/main/scala/org/apache/spark/util/PeriodicCheckpointer.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.util import scala.collection.mutable @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel * @param sc SparkContext for the Datasets given to this checkpointer * @tparam T Dataset type, such as RDD[Double] */ -private[mllib] abstract class PeriodicCheckpointer[T]( +private[spark] abstract class PeriodicCheckpointer[T]( val checkpointInterval: Int, val sc: SparkContext) extends Logging { @@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T]( /** Get list of checkpoint files for this given Dataset */ protected def getCheckpointFiles(data: T): Iterable[String] + /** + * Call this to unpersist the Dataset. + */ + def unpersistDataSet(): Unit = { + while (persistedQueue.nonEmpty) { + val dataToUnpersist = persistedQueue.dequeue() + unpersist(dataToUnpersist) + } + } + /** * Call this at the end to delete any remaining checkpoint files. */ diff --git a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala index f0b68f0cb7e2..27922b31949b 100644 --- a/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala +++ b/core/src/main/scala/org/apache/spark/util/UninterruptibleThread.scala @@ -27,7 +27,13 @@ import javax.annotation.concurrent.GuardedBy * * Note: "runUninterruptibly" should be called only in `this` thread. */ -private[spark] class UninterruptibleThread(name: String) extends Thread(name) { +private[spark] class UninterruptibleThread( + target: Runnable, + name: String) extends Thread(target, name) { + + def this(name: String) { + this(null, name) + } /** A monitor to protect "uninterruptible" and "interrupted" */ private val uninterruptibleLock = new Object diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 943dde072327..4d37db96dfc3 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -740,7 +740,11 @@ private[spark] object Utils extends Logging { * always return a single directory. */ def getLocalDir(conf: SparkConf): String = { - getOrCreateLocalRootDirs(conf)(0) + getOrCreateLocalRootDirs(conf).headOption.getOrElse { + val configuredLocalDirs = getConfiguredLocalDirs(conf) + throw new IOException( + s"Failed to get a temp directory under [${configuredLocalDirs.mkString(",")}].") + } } private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { @@ -2606,10 +2610,24 @@ private[spark] object Utils extends Logging { } private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { - kvs.map { kv => - redactionPattern.findFirstIn(kv._1) - .map { _ => (kv._1, REDACTION_REPLACEMENT_TEXT) } - .getOrElse(kv) + // If the sensitive information regex matches with either the key or the value, redact the value + // While the original intent was to only redact the value if the key matched with the regex, + // we've found that especially in verbose mode, the value of the property may contain sensitive + // information like so: + // "sun.java.command":"org.apache.spark.deploy.SparkSubmit ... \ + // --conf spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password ... + // + // And, in such cases, simply searching for the sensitive information regex in the key name is + // not sufficient. The values themselves have to be searched as well and redacted if matched. + // This does mean we may be accounting more false positives - for example, if the value of an + // arbitrary property contained the term 'password', we may redact the value from the UI and + // logs. In order to work around it, user would have to make the spark.redaction.regex property + // more specific. + kvs.map { case (key, value) => + redactionPattern.findFirstIn(key) + .orElse(redactionPattern.findFirstIn(value)) + .map { _ => (key, REDACTION_REPLACEMENT_TEXT) } + .getOrElse((key, value)) } } diff --git a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json index e732af266350..0f94e3b255db 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_memory_usage_expectation.json @@ -22,10 +22,12 @@ "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "driver", "hostPort" : "172.22.0.167:51475", @@ -47,10 +49,12 @@ "isBlacklisted" : true, "maxMemory" : 908381388, "executorLogs" : { }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -75,11 +79,12 @@ "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" }, - - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -104,10 +109,12 @@ "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -132,8 +139,10 @@ "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json index e732af266350..0f94e3b255db 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_node_blacklisting_expectation.json @@ -22,10 +22,12 @@ "stdout" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stdout", "stderr" : "http://172.22.0.167:51469/logPage/?appId=app-20161116163331-0000&executorId=2&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "driver", "hostPort" : "172.22.0.167:51475", @@ -47,10 +49,12 @@ "isBlacklisted" : true, "maxMemory" : 908381388, "executorLogs" : { }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "1", "hostPort" : "172.22.0.167:51490", @@ -75,11 +79,12 @@ "stdout" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stdout", "stderr" : "http://172.22.0.167:51467/logPage/?appId=app-20161116163331-0000&executorId=1&logType=stderr" }, - - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "0", "hostPort" : "172.22.0.167:51491", @@ -104,10 +109,12 @@ "stdout" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stdout", "stderr" : "http://172.22.0.167:51465/logPage/?appId=app-20161116163331-0000&executorId=0&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } }, { "id" : "3", "hostPort" : "172.22.0.167:51485", @@ -132,8 +139,10 @@ "stdout" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stdout", "stderr" : "http://172.22.0.167:51466/logPage/?appId=app-20161116163331-0000&executorId=3&logType=stderr" }, - "onHeapMemoryUsed" : 0, - "offHeapMemoryUsed" : 0, - "maxOnHeapMemory" : 384093388, - "maxOffHeapMemory" : 524288000 + "memoryMetrics": { + "usedOnHeapStorageMemory": 0, + "usedOffHeapStorageMemory": 0, + "totalOnHeapStorageMemory": 384093388, + "totalOffHeapStorageMemory": 524288000 + } } ] diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index b117c7709b46..ee70a3399efe 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -21,8 +21,10 @@ import java.io.File import scala.reflect.ClassTag +import com.google.common.io.ByteStreams import org.apache.hadoop.fs.Path +import org.apache.spark.io.CompressionCodec import org.apache.spark.rdd._ import org.apache.spark.storage.{BlockId, StorageLevel, TestBlockId} import org.apache.spark.util.Utils @@ -580,3 +582,42 @@ object CheckpointSuite { ).asInstanceOf[RDD[(K, Array[Iterable[V]])]] } } + +class CheckpointCompressionSuite extends SparkFunSuite with LocalSparkContext { + + test("checkpoint compression") { + val checkpointDir = Utils.createTempDir() + try { + val conf = new SparkConf() + .set("spark.checkpoint.compress", "true") + .set("spark.ui.enabled", "false") + sc = new SparkContext("local", "test", conf) + sc.setCheckpointDir(checkpointDir.toString) + val rdd = sc.makeRDD(1 to 20, numSlices = 1) + rdd.checkpoint() + assert(rdd.collect().toSeq === (1 to 20)) + + // Verify that RDD is checkpointed + assert(rdd.firstParent.isInstanceOf[ReliableCheckpointRDD[_]]) + + val checkpointPath = new Path(rdd.getCheckpointFile.get) + val fs = checkpointPath.getFileSystem(sc.hadoopConfiguration) + val checkpointFile = + fs.listStatus(checkpointPath).map(_.getPath).find(_.getName.startsWith("part-")).get + + // Verify the checkpoint file is compressed, in other words, can be decompressed + val compressedInputStream = CompressionCodec.createCodec(conf) + .compressedInputStream(fs.open(checkpointFile)) + try { + ByteStreams.toByteArray(compressedInputStream) + } finally { + compressedInputStream.close() + } + + // Verify that the compressed content can be read back + assert(rdd.collect().toSeq === (1 to 20)) + } finally { + Utils.deleteRecursively(checkpointDir) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index 735f4454e299..7e26139a2bea 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -540,10 +540,24 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu } } - // Launches one task that will run forever. Once the SparkListener detects the task has + testCancellingTasks("that raise interrupted exception on cancel") { + Thread.sleep(9999999) + } + + // SPARK-20217 should not fail stage if task throws non-interrupted exception + testCancellingTasks("that raise runtime exception on cancel") { + try { + Thread.sleep(9999999) + } catch { + case t: Throwable => + throw new RuntimeException("killed") + } + } + + // Launches one task that will block forever. Once the SparkListener detects the task has // started, kill and re-schedule it. The second run of the task will complete immediately. // If this test times out, then the first version of the task wasn't killed successfully. - test("Killing tasks") { + def testCancellingTasks(desc: String)(blockFn: => Unit): Unit = test(s"Killing tasks $desc") { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) SparkContextSuite.isTaskStarted = false @@ -572,13 +586,7 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext with Eventu // first attempt will hang if (!SparkContextSuite.isTaskStarted) { SparkContextSuite.isTaskStarted = true - try { - Thread.sleep(9999999) - } catch { - case t: Throwable => - // SPARK-20217 should not fail stage if task throws non-interrupted exception - throw new RuntimeException("killed") - } + blockFn } // second attempt succeeds immediately } diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala new file mode 100644 index 000000000000..ab24a76e20a3 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/SparkHadoopUtilSuite.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.security.PrivilegedExceptionAction + +import scala.util.Random + +import org.apache.hadoop.fs.FileStatus +import org.apache.hadoop.fs.permission.{FsAction, FsPermission} +import org.apache.hadoop.security.UserGroupInformation +import org.scalatest.Matchers + +import org.apache.spark.SparkFunSuite + +class SparkHadoopUtilSuite extends SparkFunSuite with Matchers { + test("check file permission") { + import FsAction._ + val testUser = s"user-${Random.nextInt(100)}" + val testGroups = Array(s"group-${Random.nextInt(100)}") + val testUgi = UserGroupInformation.createUserForTesting(testUser, testGroups) + + testUgi.doAs(new PrivilegedExceptionAction[Void] { + override def run(): Void = { + val sparkHadoopUtil = new SparkHadoopUtil + + // If file is owned by user and user has access permission + var status = fileStatus(testUser, testGroups.head, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user but user has no access permission + status = fileStatus(testUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + val otherUser = s"test-${Random.nextInt(100)}" + val otherGroup = s"test-${Random.nextInt(100)}" + + // If file is owned by user's group and user's group has access permission + status = fileStatus(otherUser, testGroups.head, NONE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by user's group but user's group has no access permission + status = fileStatus(otherUser, testGroups.head, READ_WRITE, NONE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + // If file is owned by other user and this user has access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, READ_WRITE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(true) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(true) + + // If file is owned by other user but this user has no access permission + status = fileStatus(otherUser, otherGroup, READ_WRITE, READ_WRITE, NONE) + sparkHadoopUtil.checkAccessPermission(status, READ) should be(false) + sparkHadoopUtil.checkAccessPermission(status, WRITE) should be(false) + + null + } + }) + } + + private def fileStatus( + owner: String, + group: String, + userAction: FsAction, + groupAction: FsAction, + otherAction: FsAction): FileStatus = { + new FileStatus(0L, + false, + 0, + 0L, + 0L, + 0L, + new FsPermission(userAction, groupAction, otherAction), + owner, + group, + null) + } +} diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index 7c2ec01a03d0..a43839a8815f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -21,8 +21,10 @@ import java.io._ import java.nio.charset.StandardCharsets import scala.collection.mutable.ArrayBuffer +import scala.io.Source import com.google.common.io.ByteStreams +import org.apache.hadoop.fs.Path import org.scalatest.{BeforeAndAfterEach, Matchers} import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ @@ -34,6 +36,7 @@ import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate import org.apache.spark.internal.config._ import org.apache.spark.internal.Logging import org.apache.spark.TestUtils.JavaSourceFromString +import org.apache.spark.scheduler.EventLoggingListener import org.apache.spark.util.{CommandLineUtils, ResetSystemProperties, Utils} @@ -404,6 +407,37 @@ class SparkSubmitSuite runSparkSubmit(args) } + test("launch simple application with spark-submit with redaction") { + val testDir = Utils.createTempDir() + testDir.deleteOnExit() + val testDirPath = new Path(testDir.getAbsolutePath()) + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val fileSystem = Utils.getHadoopFileSystem("/", + SparkHadoopUtil.get.newConfiguration(new SparkConf())) + try { + val args = Seq( + "--class", SimpleApplicationTest.getClass.getName.stripSuffix("$"), + "--name", "testApp", + "--master", "local", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD=secret_password", + "--conf", "spark.eventLog.enabled=true", + "--conf", "spark.eventLog.testing=true", + "--conf", s"spark.eventLog.dir=${testDirPath.toUri.toString}", + "--conf", "spark.hadoop.fs.defaultFS=unsupported://example.com", + unusedJar.toString) + runSparkSubmit(args) + val listStatus = fileSystem.listStatus(testDirPath) + val logData = EventLoggingListener.openEventLog(listStatus.last.getPath, fileSystem) + Source.fromInputStream(logData).getLines().foreach { line => + assert(!line.contains("secret_password")) + } + } finally { + Utils.deleteRecursively(testDir) + } + } + test("includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 9839dcf8535d..bf7480d79f8a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -356,12 +356,13 @@ class StandaloneDynamicAllocationSuite test("kill the same executor twice (SPARK-9795)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about @@ -380,12 +381,13 @@ class StandaloneDynamicAllocationSuite test("the pending replacement executors should not be lost (SPARK-10515)") { sc = new SparkContext(appConf) val appId = sc.applicationId + sc.requestExecutors(2) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.size === 1) assert(apps.head.id === appId) assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === Int.MaxValue) + assert(apps.head.getExecutorLimit === 2) } // sync executors between the Master and the driver, needed because // the driver refuses to kill executors it does not know about diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index ec580a44b8e7..456158d41b93 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -27,6 +27,7 @@ import scala.concurrent.duration._ import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} +import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.hdfs.DistributedFileSystem import org.json4s.jackson.JsonMethods._ import org.mockito.Matchers.any @@ -130,9 +131,19 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc } } - test("SPARK-3697: ignore directories that cannot be read.") { + test("SPARK-3697: ignore files that cannot be read.") { // setReadable(...) does not work on Windows. Please refer JDK-6728842. assume(!Utils.isWindows) + + class TestFsHistoryProvider extends FsHistoryProvider(createTestConf()) { + var mergeApplicationListingCall = 0 + override protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { + super.mergeApplicationListing(fileStatus) + mergeApplicationListingCall += 1 + } + } + val provider = new TestFsHistoryProvider + val logFile1 = newLogFile("new1", None, inProgress = false) writeFile(logFile1, true, None, SparkListenerApplicationStart("app1-1", Some("app1-1"), 1L, "test", None), @@ -145,10 +156,11 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) logFile2.setReadable(false, false) - val provider = new FsHistoryProvider(createTestConf()) updateAndCheck(provider) { list => list.size should be (1) } + + provider.mergeApplicationListingCall should be (1) } test("history file is renamed from inprogress to completed") { diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 764156c3edc4..95acb9a54440 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -565,13 +565,12 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers assert(jobcount === getNumJobs("/jobs")) // no need to retain the test dir now the tests complete - logDir.deleteOnExit(); - + logDir.deleteOnExit() } test("ui and api authorization checks") { - val appId = "app-20161115172038-0000" - val owner = "jose" + val appId = "local-1430917381535" + val owner = "irashid" val admin = "root" val other = "alice" @@ -590,8 +589,11 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val port = server.boundPort val testUrls = Seq( - s"http://localhost:$port/api/v1/applications/$appId/jobs", - s"http://localhost:$port/history/$appId/jobs/") + s"http://localhost:$port/api/v1/applications/$appId/1/jobs", + s"http://localhost:$port/history/$appId/1/jobs/", + s"http://localhost:$port/api/v1/applications/$appId/logs", + s"http://localhost:$port/api/v1/applications/$appId/1/logs", + s"http://localhost:$port/api/v1/applications/$appId/2/logs") tests.foreach { case (user, expectedCode) => testUrls.foreach { url => diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index f47e574b4fc4..efcad140350b 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -44,6 +44,7 @@ import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.UninterruptibleThread class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { @@ -158,6 +159,18 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(failReason.isInstanceOf[FetchFailed]) } + test("Executor's worker threads should be UninterruptibleThread") { + val conf = new SparkConf() + .setMaster("local") + .setAppName("executor thread test") + .set("spark.ui.enabled", "false") + sc = new SparkContext(conf) + val executorThread = sc.parallelize(Seq(1), 1).map { _ => + Thread.currentThread.getClass.getName + }.collect().head + assert(executorThread === classOf[UninterruptibleThread].getName) + } + test("SPARK-19276: OOMs correctly handled with a FetchFailure") { // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it // may be a false positive. And we should call the uncaught exception handler. diff --git a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala index f9a7f151823a..7f20206202cb 100644 --- a/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/SortingSuite.scala @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w } test("get a range of elements in an array not partitioned by a range partitioner") { - val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) + val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x)) val pairs = sc.parallelize(pairArr, 10) val range = pairs.filterByRange(200, 800).collect() assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 8f576daa77d1..b22da565d86e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -198,7 +198,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark sc = new SparkContext("local", "test") // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { context = new TaskContextImpl(0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9ca6b8b0fe63..db14c9acfdce 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1070,11 +1070,12 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg sched.dagScheduler = mockDAGScheduler val taskSet = FakeTask.createTaskSet(numTasks = 1, stageId = 0, stageAttemptId = 0) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = new ManualClock(1)) - when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).then(new Answer[Unit] { - override def answer(invocationOnMock: InvocationOnMock): Unit = { - assert(manager.isZombie === true) - } - }) + when(mockDAGScheduler.taskEnded(any(), any(), any(), any(), any())).thenAnswer( + new Answer[Unit] { + override def answer(invocationOnMock: InvocationOnMock): Unit = { + assert(manager.isZombie) + } + }) val taskOption = manager.resourceOffer("exec1", "host1", NO_PREF) assert(taskOption.isDefined) // this would fail, inside our mock dag scheduler, if it calls dagScheduler.taskEnded() too soon diff --git a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala index dfecd04c1b96..4000218e71a8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockReplicationPolicySuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import scala.collection.mutable +import scala.language.implicitConversions import scala.util.Random import org.scalatest.{BeforeAndAfter, Matchers} diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index c7074078d8fd..f7b3a2754f0e 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.File +import java.io.{File, IOException} import org.scalatest.BeforeAndAfter @@ -33,9 +33,13 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { Utils.clearLocalRootDirs() } + after { + Utils.clearLocalRootDirs() + } + test("Utils.getLocalDir() returns a valid directory, even if some local dirs are missing") { // Regression test for SPARK-2974 - assert(!new File("/NONEXISTENT_DIR").exists()) + assert(!new File("/NONEXISTENT_PATH").exists()) val conf = new SparkConf(false) .set("spark.local.dir", s"/NONEXISTENT_PATH,${System.getProperty("java.io.tmpdir")}") assert(new File(Utils.getLocalDir(conf)).exists()) @@ -43,7 +47,7 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { test("SPARK_LOCAL_DIRS override also affects driver") { // Regression test for SPARK-2975 - assert(!new File("/NONEXISTENT_DIR").exists()) + assert(!new File("/NONEXISTENT_PATH").exists()) // spark.local.dir only contains invalid directories, but that's not a problem since // SPARK_LOCAL_DIRS will override it on both the driver and workers: val conf = new SparkConfWithEnv(Map("SPARK_LOCAL_DIRS" -> System.getProperty("java.io.tmpdir"))) @@ -51,4 +55,17 @@ class LocalDirsSuite extends SparkFunSuite with BeforeAndAfter { assert(new File(Utils.getLocalDir(conf)).exists()) } + test("Utils.getLocalDir() throws an exception if any temporary directory cannot be retrieved") { + val path1 = "/NONEXISTENT_PATH_ONE" + val path2 = "/NONEXISTENT_PATH_TWO" + assert(!new File(path1).exists()) + assert(!new File(path2).exists()) + val conf = new SparkConf(false).set("spark.local.dir", s"$path1,$path2") + val message = intercept[IOException] { + Utils.getLocalDir(conf) + }.getMessage + // If any temporary directory could not be retrieved under the given paths above, it should + // throw an exception with the message that includes the paths. + assert(message.contains(s"$path1,$path2")) + } } diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 93964a2d5674..48be3be81755 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -293,7 +293,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val execId = "exe-1" def makeTaskMetrics(base: Int): TaskMetrics = { - val taskMetrics = TaskMetrics.empty + val taskMetrics = TaskMetrics.registered val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics val inputMetrics = taskMetrics.inputMetrics diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a64dbeae4729..a77c8e3cab4e 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -830,7 +830,7 @@ private[spark] object JsonProtocolSuite extends Assertions { hasHadoopInput: Boolean, hasOutput: Boolean, hasRecords: Boolean = true) = { - val t = TaskMetrics.empty + val t = TaskMetrics.registered // Set CPU times same as wall times for testing purpose t.setExecutorDeserializeTime(a) t.setExecutorDeserializeCpuTime(a) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala similarity index 96% rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala rename to core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala index 14adf8c29fc6..f9e1b791c86e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicRDDCheckpointerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/PeriodicRDDCheckpointerSuite.scala @@ -15,18 +15,18 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.utils import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { +class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext { import PeriodicRDDCheckpointerSuite._ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 8ed09749ffd5..3339d5b35d3b 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -1010,15 +1010,19 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD", "spark.my.password", "spark.my.sECreT") - secretKeys.foreach { key => sparkConf.set(key, "secret_password") } + secretKeys.foreach { key => sparkConf.set(key, "sensitive_value") } // Set a non-secret key - sparkConf.set("spark.regular.property", "not_a_secret") + sparkConf.set("spark.regular.property", "regular_value") + // Set a property with a regular key but secret in the value + sparkConf.set("spark.sensitive.property", "has_secret_in_value") // Redact sensitive information val redactedConf = Utils.redact(sparkConf, sparkConf.getAll).toMap // Assert that secret information got redacted while the regular property remained the same secretKeys.foreach { key => assert(redactedConf(key) === Utils.REDACTION_REPLACEMENT_TEXT) } - assert(redactedConf("spark.regular.property") === "not_a_secret") + assert(redactedConf("spark.regular.property") === "regular_value") + assert(redactedConf("spark.sensitive.property") === Utils.REDACTION_REPLACEMENT_TEXT) + } } diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 31656ca0e5a6..bb7d31cad7be 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -44,4 +44,8 @@ files="src/main/java/org/apache/hive/service/server/ThreadWithGarbageCleanup.java"/> + + diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 73dc1f9a1398..9287bd47cf11 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -19,8 +19,8 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.12.jar -breeze_2.11-0.12.jar +breeze-macros_2.11-0.13.1.jar +breeze_2.11-0.13.1.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar @@ -129,6 +129,8 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar lz4-1.3.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.0.0-shaded-protobuf.jar metrics-core-3.1.2.jar @@ -162,13 +164,13 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -shapeless_2.11-2.0.0.jar +shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.6.jar -spire-macros_2.11-0.7.4.jar -spire_2.11-0.7.4.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 6bf0923a1d75..ab1de3d3dd8a 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -19,8 +19,8 @@ avro-mapred-1.7.7-hadoop2.jar base64-2.3.8.jar bcprov-jdk15on-1.51.jar bonecp-0.8.0.RELEASE.jar -breeze-macros_2.11-0.12.jar -breeze_2.11-0.12.jar +breeze-macros_2.11-0.13.1.jar +breeze_2.11-0.13.1.jar calcite-avatica-1.2.0-incubating.jar calcite-core-1.2.0-incubating.jar calcite-linq4j-1.2.0-incubating.jar @@ -130,6 +130,8 @@ libfb303-0.9.3.jar libthrift-0.9.3.jar log4j-1.2.17.jar lz4-1.3.0.jar +machinist_2.11-0.6.1.jar +macro-compat_2.11-1.1.1.jar mail-1.4.7.jar mesos-1.0.0-shaded-protobuf.jar metrics-core-3.1.2.jar @@ -163,13 +165,13 @@ scala-parser-combinators_2.11-1.0.4.jar scala-reflect-2.11.8.jar scala-xml_2.11-1.0.2.jar scalap-2.11.8.jar -shapeless_2.11-2.0.0.jar +shapeless_2.11-2.3.2.jar slf4j-api-1.7.16.jar slf4j-log4j12-1.7.16.jar snappy-0.2.jar snappy-java-1.1.2.6.jar -spire-macros_2.11-0.7.4.jar -spire_2.11-0.7.4.jar +spire-macros_2.11-0.13.0.jar +spire_2.11-0.13.0.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar diff --git a/dev/run-tests.py b/dev/run-tests.py index 450b68123e1f..818a0c9f4841 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -365,8 +365,16 @@ def build_spark_assembly_sbt(hadoop_version): print("[info] Building Spark assembly (w/Hive 1.2.1) using SBT with these arguments: ", " ".join(profiles_and_goals)) exec_sbt(profiles_and_goals) - # Make sure that Java and Scala API documentation can be generated - build_spark_unidoc_sbt(hadoop_version) + + # Note that we skip Unidoc build only if Hadoop 2.6 is explicitly set in this SBT build. + # Due to a different dependency resolution in SBT & Unidoc by an unknown reason, the + # documentation build fails on a specific machine & environment in Jenkins but it was unable + # to reproduce. Please see SPARK-20343. This is a band-aid fix that should be removed in + # the future. + is_hadoop_version_2_6 = os.environ.get("AMPLAB_JENKINS_BUILD_PROFILE") == "hadoop2.6" + if not is_hadoop_version_2_6: + # Make sure that Java and Scala API documentation can be generated + build_spark_unidoc_sbt(hadoop_version) def build_apache_spark(build_tool, hadoop_version): diff --git a/docs/_config.yml b/docs/_config.yml index 83bb30598d15..4b356053b086 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 2.2.0-SNAPSHOT -SPARK_VERSION_SHORT: 2.2.0 +SPARK_VERSION: 2.2.1-SNAPSHOT +SPARK_VERSION_SHORT: 2.2.1 SCALA_BINARY_VERSION: "2.11" SCALA_VERSION: "2.11.7" MESOS_VERSION: 1.0.0 diff --git a/docs/_data/menu-ml.yaml b/docs/_data/menu-ml.yaml index 0c6b9b20a6e4..047423f75aec 100644 --- a/docs/_data/menu-ml.yaml +++ b/docs/_data/menu-ml.yaml @@ -8,6 +8,8 @@ url: ml-clustering.html - text: Collaborative filtering url: ml-collaborative-filtering.html +- text: Frequent Pattern Mining + url: ml-frequent-pattern-mining.html - text: Model selection and tuning url: ml-tuning.html - text: Advanced topics diff --git a/docs/building-spark.md b/docs/building-spark.md index e99b70f7a8b4..0f551bc66b8c 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -232,7 +232,7 @@ Once installed, the `docker` service needs to be started, if not already running On Linux, this can be done by `sudo service docker start`. ./build/mvn install -DskipTests - ./build/mvn -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 + ./build/mvn test -Pdocker-integration-tests -pl :spark-docker-integration-tests_2.11 or diff --git a/docs/configuration.md b/docs/configuration.md index 2687f542b8bd..1d8d963016c7 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -213,6 +213,14 @@ of the most common options to set are: and typically can have up to 50 characters. + + spark.driver.supervise + false + + If true, restarts the driver automatically if it fails with a non-zero exit status. + Only has effect in Spark standalone mode or Mesos cluster deploy mode. + + Apart from these, the following properties are also available, and may be useful in some situations: @@ -364,8 +372,8 @@ Apart from these, the following properties are also available, and may be useful (?i)secret|password Regex to decide which Spark configuration properties and environment variables in driver and - executor environments contain sensitive information. When this regex matches a property, its - value is redacted from the environment UI and various logs like YARN and event logs. + executor environments contain sensitive information. When this regex matches a property key or + value, the value is redacted from the environment UI and various logs like YARN and event logs. @@ -2141,6 +2149,20 @@ showDF(properties, numRows = 200, truncate = FALSE) +### GraphX + + + + + + + + +
    Property NameDefaultMeaning
    spark.graphx.pregel.checkpointInterval-1 + Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains + after lots of iterations. The checkpoint is disabled by default. +
    + ### Deploy @@ -2248,8 +2270,8 @@ should be included on Spark's classpath: * `hdfs-site.xml`, which provides default behaviors for the HDFS client. * `core-site.xml`, which sets the default filesystem name. -The location of these configuration files varies across CDH and HDP versions, but -a common location is inside of `/etc/hadoop/conf`. Some tools, such as Cloudera Manager, create +The location of these configuration files varies across Hadoop versions, but +a common location is inside of `/etc/hadoop/conf`. Some tools create configurations on-the-fly, but offer a mechanisms to download copies of them. To make these files visible to Spark, set `HADOOP_CONF_DIR` in `$SPARK_HOME/spark-env.sh` diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md index e271b28fb4f2..76aa7b405e18 100644 --- a/docs/graphx-programming-guide.md +++ b/docs/graphx-programming-guide.md @@ -708,7 +708,9 @@ messages remaining. > messaging function. These constraints allow additional optimization within GraphX. The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch* -of its implementation (note calls to graph.cache have been removed): +of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally +checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number, +say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)): {% highlight scala %} class GraphOps[VD, ED] { @@ -722,6 +724,7 @@ class GraphOps[VD, ED] { : Graph[VD, ED] = { // Receive the initial message at each vertex var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache() + // compute the messages var messages = g.mapReduceTriplets(sendMsg, mergeMsg) var activeMessages = messages.count() @@ -734,8 +737,8 @@ class GraphOps[VD, ED] { // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. - messages = g.mapReduceTriplets( - sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + messages = GraphXUtils.mapReduceTriplets( + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() activeMessages = messages.count() i += 1 } diff --git a/docs/ml-frequent-pattern-mining.md b/docs/ml-frequent-pattern-mining.md new file mode 100644 index 000000000000..81634de8aade --- /dev/null +++ b/docs/ml-frequent-pattern-mining.md @@ -0,0 +1,87 @@ +--- +layout: global +title: Frequent Pattern Mining +displayTitle: Frequent Pattern Mining +--- + +Mining frequent items, itemsets, subsequences, or other substructures is usually among the +first steps to analyze a large-scale dataset, which has been an active research topic in +data mining for years. +We refer users to Wikipedia's [association rule learning](http://en.wikipedia.org/wiki/Association_rule_learning) +for more information. + +**Table of Contents** + +* This will become a table of contents (this text will be scraped). +{:toc} + +## FP-Growth + +The FP-growth algorithm is described in the paper +[Han et al., Mining frequent patterns without candidate generation](http://dx.doi.org/10.1145/335191.335372), +where "FP" stands for frequent pattern. +Given a dataset of transactions, the first step of FP-growth is to calculate item frequencies and identify frequent items. +Different from [Apriori-like](http://en.wikipedia.org/wiki/Apriori_algorithm) algorithms designed for the same purpose, +the second step of FP-growth uses a suffix tree (FP-tree) structure to encode transactions without generating candidate sets +explicitly, which are usually expensive to generate. +After the second step, the frequent itemsets can be extracted from the FP-tree. +In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, +as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). +PFP distributes the work of growing FP-trees based on the suffixes of transactions, +and hence is more scalable than a single-machine implementation. +We refer users to the papers for more details. + +`spark.ml`'s FP-growth implementation takes the following (hyper-)parameters: + +* `minSupport`: the minimum support for an itemset to be identified as frequent. + For example, if an item appears 3 out of 5 transactions, it has a support of 3/5=0.6. +* `minConfidence`: minimum confidence for generating Association Rule. Confidence is an indication of how often an + association rule has been found to be true. For example, if in the transactions itemset `X` appears 4 times, `X` + and `Y` co-occur only 2 times, the confidence for the rule `X => Y` is then 2/4 = 0.5. The parameter will not + affect the mining for frequent itemsets, but specify the minimum confidence for generating association rules + from frequent itemsets. +* `numPartitions`: the number of partitions used to distribute the work. By default the param is not set, and + number of partitions of the input dataset is used. + +The `FPGrowthModel` provides: + +* `freqItemsets`: frequent itemsets in the format of DataFrame("items"[Array], "freq"[Long]) +* `associationRules`: association rules generated with confidence above `minConfidence`, in the format of + DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]). +* `transform`: For each transaction in `itemsCol`, the `transform` method will compare its items against the antecedents + of each association rule. If the record contains all the antecedents of a specific association rule, the rule + will be considered as applicable and its consequents will be added to the prediction result. The transform + method will summarize the consequents from all the applicable rules as prediction. The prediction column has + the same data type as `itemsCol` and does not contain existing items in the `itemsCol`. + + +**Examples** + +
    + +
    +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.fpm.FPGrowth) for more details. + +{% include_example scala/org/apache/spark/examples/ml/FPGrowthExample.scala %} +
    + +
    +Refer to the [Java API docs](api/java/org/apache/spark/ml/fpm/FPGrowth.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaFPGrowthExample.java %} +
    + +
    +Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.fpm.FPGrowth) for more details. + +{% include_example python/ml/fpgrowth_example.py %} +
    + +
    + +Refer to the [R API docs](api/R/spark.fpGrowth.html) for more details. + +{% include_example r/ml/fpm.R %} +
    + +
    diff --git a/docs/mllib-dimensionality-reduction.md b/docs/mllib-dimensionality-reduction.md index 539cbc1b3163..a72680d52a26 100644 --- a/docs/mllib-dimensionality-reduction.md +++ b/docs/mllib-dimensionality-reduction.md @@ -76,13 +76,14 @@ Refer to the [`SingularValueDecomposition` Java docs](api/java/org/apache/spark/ The same code applies to `IndexedRowMatrix` if `U` is defined as an `IndexedRowMatrix`. + +
    +Refer to the [`SingularValueDecomposition` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.SingularValueDecomposition) for details on the API. -In order to run the above application, follow the instructions -provided in the [Self-Contained -Applications](quick-start.html#self-contained-applications) section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +{% include_example python/mllib/svd_example.py %} +The same code applies to `IndexedRowMatrix` if `U` is defined as an +`IndexedRowMatrix`.
    @@ -118,17 +119,21 @@ Refer to the [`PCA` Scala docs](api/scala/index.html#org.apache.spark.mllib.feat The following code demonstrates how to compute principal components on a `RowMatrix` and use them to project the vectors into a low-dimensional space. -The number of columns should be small, e.g, less than 1000. Refer to the [`RowMatrix` Java docs](api/java/org/apache/spark/mllib/linalg/distributed/RowMatrix.html) for details on the API. {% include_example java/org/apache/spark/examples/mllib/JavaPCAExample.java %} - -In order to run the above application, follow the instructions -provided in the [Self-Contained Applications](quick-start.html#self-contained-applications) -section of the Spark -quick-start guide. Be sure to also include *spark-mllib* to your build file as -a dependency. +
    + +The following code demonstrates how to compute principal components on a `RowMatrix` +and use them to project the vectors into a low-dimensional space. + +Refer to the [`RowMatrix` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.linalg.distributed.RowMatrix) for details on the API. + +{% include_example python/mllib/pca_rowmatrix_example.py %} + +
    + diff --git a/docs/mllib-frequent-pattern-mining.md b/docs/mllib-frequent-pattern-mining.md index 93e3f0b2d226..c9cd7cc85e75 100644 --- a/docs/mllib-frequent-pattern-mining.md +++ b/docs/mllib-frequent-pattern-mining.md @@ -24,7 +24,7 @@ explicitly, which are usually expensive to generate. After the second step, the frequent itemsets can be extracted from the FP-tree. In `spark.mllib`, we implemented a parallel version of FP-growth called PFP, as described in [Li et al., PFP: Parallel FP-growth for query recommendation](http://dx.doi.org/10.1145/1454008.1454027). -PFP distributes the work of growing FP-trees based on the suffices of transactions, +PFP distributes the work of growing FP-trees based on the suffixes of transactions, and hence more scalable than a single-machine implementation. We refer users to the papers for more details. diff --git a/docs/monitoring.md b/docs/monitoring.md index da954385dc45..3e577c5f3677 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -27,8 +27,8 @@ in the UI to persisted storage. ## Viewing After the Fact -If Spark is run on Mesos or YARN, it is still possible to construct the UI of an -application through Spark's history server, provided that the application's event logs exist. +It is still possible to construct the UI of an application through Spark's history server, +provided that the application's event logs exist. You can start the history server by executing: ./sbin/start-history-server.sh diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1c0b60f7b934..34ced9ed7b46 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -242,7 +242,7 @@ SPARK_WORKER_OPTS supports the following system properties: - +
    spark.worker.cleanup.appDataTtl7 * 24 * 3600 (7 days)604800 (7 days, 7 * 24 * 3600) The number of seconds to retain application work directories on each worker. This is a Time To Live and should depend on the amount of available disk space you have. Application logs and jars are diff --git a/docs/sparkr.md b/docs/sparkr.md index a1a35a7757e5..40395202fde3 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -452,6 +452,7 @@ SparkR supports the following machine learning algorithms currently: * [`spark.logit`](api/R/spark.logit.html): [`Logistic Regression`](ml-classification-regression.html#logistic-regression) * [`spark.mlp`](api/R/spark.mlp.html): [`Multilayer Perceptron (MLP)`](ml-classification-regression.html#multilayer-perceptron-classifier) * [`spark.naiveBayes`](api/R/spark.naiveBayes.html): [`Naive Bayes`](ml-classification-regression.html#naive-bayes) +* [`spark.svmLinear`](api/R/spark.svmLinear.html): [`Linear Support Vector Machine`](ml-classification-regression.html#linear-support-vector-machine) #### Regression @@ -466,6 +467,7 @@ SparkR supports the following machine learning algorithms currently: #### Clustering +* [`spark.bisectingKmeans`](api/R/spark.bisectingKmeans.html): [`Bisecting k-means`](ml-clustering.html#bisecting-k-means) * [`spark.gaussianMixture`](api/R/spark.gaussianMixture.html): [`Gaussian Mixture Model (GMM)`](ml-clustering.html#gaussian-mixture-model-gmm) * [`spark.kmeans`](api/R/spark.kmeans.html): [`K-Means`](ml-clustering.html#k-means) * [`spark.lda`](api/R/spark.lda.html): [`Latent Dirichlet Allocation (LDA)`](ml-clustering.html#latent-dirichlet-allocation-lda) @@ -557,6 +559,10 @@ The following example shows how to save/load a MLlib model by SparkR.
    +# Structured Streaming + +SparkR supports the Structured Streaming API (experimental). Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. For more information see the R API on the [Structured Streaming Programming Guide](structured-streaming-programming-guide.html) + # R Function Name Conflicts When loading and attaching a new package in R, it is possible to have a name [conflict](https://stat.ethz.ch/R-manual/R-devel/library/base/html/library.html), where a @@ -608,3 +614,11 @@ You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-ma ## Upgrading to SparkR 2.1.0 - `join` no longer performs Cartesian Product by default, use `crossJoin` instead. + +## Upgrading to SparkR 2.2.0 + + - A `numPartitions` parameter has been added to `createDataFrame` and `as.DataFrame`. When splitting the data, the partition position calculation has been made to match the one in Scala. + - The method `createExternalTable` has been deprecated to be replaced by `createTable`. Either methods can be called to create external or managed table. Additional catalog methods have also been added. + - By default, derby.log is now saved to `tempdir()`. This will be created when instantiating the SparkSession with `enableHiveSupport` set to `TRUE`. + - `spark.lda` was not setting the optimizer correctly. It has been corrected. + - Several model summary outputs are updated to have `coefficients` as `matrix`. This includes `spark.logit`, `spark.kmeans`, `spark.glm`. Model summary outputs for `spark.gaussianMixture` have added log-likelihood as `loglik`. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 28942b68fa20..490c1ce8a7cc 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -571,7 +571,7 @@ be created by calling the `table` method on a `SparkSession` with the name of th For file-based data source, e.g. text, parquet, json, etc. you can specify a custom table path via the `path` option, e.g. `df.write.option("path", "/some/path").saveAsTable("t")`. When the table is dropped, the custom table path will not be removed and the table data is still there. If no custom table path is -specifed, Spark will write data to a default table path under the warehouse directory. When the table is +specified, Spark will write data to a default table path under the warehouse directory. When the table is dropped, the default table path will be removed too. Starting from Spark 2.1, persistent datasource tables have per-partition metadata stored in the Hive metastore. This brings several benefits: diff --git a/docs/streaming-kafka-0-10-integration.md b/docs/streaming-kafka-0-10-integration.md index e3837013168d..92c296a9e6bd 100644 --- a/docs/streaming-kafka-0-10-integration.md +++ b/docs/streaming-kafka-0-10-integration.md @@ -12,6 +12,8 @@ For Scala/Java applications using SBT/Maven project definitions, link your strea artifactId = spark-streaming-kafka-0-10_{{site.SCALA_BINARY_VERSION}} version = {{site.SPARK_VERSION_SHORT}} +**Do not** manually add dependencies on `org.apache.kafka` artifacts (e.g. `kafka-clients`). The `spark-streaming-kafka-0-10` artifact has the appropriate transitive dependencies already, and different versions may be incompatible in hard to diagnose ways. + ### Creating a Direct Stream Note that the namespace for the import includes the version, org.apache.spark.streaming.kafka010 diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index 3cf7151819e2..53b3db21da76 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -8,13 +8,13 @@ title: Structured Streaming Programming Guide {:toc} # Overview -Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* +Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data. The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java, Python or R to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.* -**Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. +**Structured Streaming is still ALPHA in Spark 2.1** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count. # Quick Example -Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in -[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py). +Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in +[Scala]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/[Java]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/[Python]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/sql/streaming/structured_network_wordcount.py)/[R]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/r/streaming/structured_network_wordcount.R). And if you [download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
    @@ -63,6 +63,13 @@ spark = SparkSession \ .getOrCreate() {% endhighlight %} +
    +
    + +{% highlight r %} +sparkR.session(appName = "StructuredNetworkWordCount") +{% endhighlight %} +
    @@ -136,6 +143,22 @@ wordCounts = words.groupBy("word").count() This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as "word". Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream. + +
    + +{% highlight r %} +# Create DataFrame representing the stream of input lines from connection to localhost:9999 +lines <- read.stream("socket", host = "localhost", port = 9999) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(group_by(words, "word")) +{% endhighlight %} + +This `lines` SparkDataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named "value", and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have a SQL expression with two SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we name the new column as "word". Finally, we have defined the `wordCounts` SparkDataFrame by grouping by the unique values in the SparkDataFrame and counting them. Note that this is a streaming SparkDataFrame which represents the running word counts of the stream. +
    @@ -181,10 +204,20 @@ query = wordCounts \ query.awaitTermination() {% endhighlight %} + +
    + +{% highlight r %} +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") + +awaitTermination(query) +{% endhighlight %} +
    -After this code is executed, the streaming computation will have started in the background. The `query` object is a handle to that active streaming query, and we have decided to wait for the termination of the query using `query.awaitTermination()` to prevent the process from exiting while the query is active. +After this code is executed, the streaming computation will have started in the background. The `query` object is a handle to that active streaming query, and we have decided to wait for the termination of the query using `awaitTermination()` to prevent the process from exiting while the query is active. To actually execute this example code, you can either compile the code in your own [Spark application](quick-start.html#self-contained-applications), or simply @@ -211,6 +244,11 @@ $ ./bin/run-example org.apache.spark.examples.sql.streaming.JavaStructuredNetwor $ ./bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py localhost 9999 {% endhighlight %} +
    +{% highlight bash %} +$ ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 +{% endhighlight %} +
    Then, any lines typed in the terminal running the netcat server will be counted and printed on screen every second. It will look something like the following. @@ -325,6 +363,35 @@ Batch: 0 | spark| 1| +------+-----+ +------------------------------------------- +Batch: 1 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 2| +| spark| 1| +|hadoop| 1| ++------+-----+ +... +{% endhighlight %} + +
    +{% highlight bash %} +# TERMINAL 2: RUNNING structured_network_wordcount.R + +$ ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 + +------------------------------------------- +Batch: 0 +------------------------------------------- ++------+-----+ +| value|count| ++------+-----+ +|apache| 1| +| spark| 1| ++------+-----+ + ------------------------------------------- Batch: 1 ------------------------------------------- @@ -409,14 +476,14 @@ to track the read position in the stream. The engine uses checkpointing and writ # API using Datasets and DataFrames Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` -([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/[Java](api/java/org/apache/spark/sql/SparkSession.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession) docs) +([Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/[Java](api/java/org/apache/spark/sql/SparkSession.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession)/[R](api/R/sparkR.session.html) docs) to create streaming DataFrames/Datasets from streaming sources, and apply the same operations on them as static DataFrames/Datasets. If you are not familiar with Datasets/DataFrames, you are strongly advised to familiarize yourself with them using the [DataFrame/Dataset Programming Guide](sql-programming-guide.html). ## Creating streaming DataFrames and streaming Datasets -Streaming DataFrames can be created through the `DataStreamReader` interface +Streaming DataFrames can be created through the `DataStreamReader` interface ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) -returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. +returned by `SparkSession.readStream()`. In [R](api/R/read.stream.html), with the `read.stream()` method. Similar to the read interface for creating static DataFrame, you can specify the details of the source – data format, schema, options, etc. #### Input Sources In Spark 2.0, there are a few built-in sources. @@ -445,7 +512,8 @@ Here are the details of all the sources in Spark. path: path to the input directory, and common to all file formats.

    For file-format-specific options, see the related methods in DataStreamReader - (
    Scala/Java/Python). + (Scala/Java/Python/R). E.g. for "parquet" format options see DataStreamReader.parquet() Yes Supports glob paths, but does not support multiple comma-separated paths/globs. @@ -483,7 +551,7 @@ Here are some examples. {% highlight scala %} val spark: SparkSession = ... -// Read text from socket +// Read text from socket val socketDF = spark .readStream .format("socket") @@ -493,7 +561,7 @@ val socketDF = spark socketDF.isStreaming // Returns True for DataFrames that have streaming sources -socketDF.printSchema +socketDF.printSchema // Read all the csv files written atomically in a directory val userSchema = new StructType().add("name", "string").add("age", "integer") @@ -510,7 +578,7 @@ val csvDF = spark {% highlight java %} SparkSession spark = ... -// Read text from socket +// Read text from socket Dataset socketDF = spark .readStream() .format("socket") @@ -537,7 +605,7 @@ Dataset csvDF = spark {% highlight python %} spark = SparkSession. ... -# Read text from socket +# Read text from socket socketDF = spark \ .readStream \ .format("socket") \ @@ -547,7 +615,7 @@ socketDF = spark \ socketDF.isStreaming() # Returns True for DataFrames that have streaming sources -socketDF.printSchema() +socketDF.printSchema() # Read all the csv files written atomically in a directory userSchema = StructType().add("name", "string").add("age", "integer") @@ -558,6 +626,25 @@ csvDF = spark \ .csv("/path/to/directory") # Equivalent to format("csv").load("/path/to/directory") {% endhighlight %} +
    +
    + +{% highlight r %} +sparkR.session(...) + +# Read text from socket +socketDF <- read.stream("socket", host = hostname, port = port) + +isStreaming(socketDF) # Returns TRUE for SparkDataFrames that have streaming sources + +printSchema(socketDF) + +# Read all the csv files written atomically in a directory +schema <- structType(structField("name", "string"), + structField("age", "integer")) +csvDF <- read.stream("csv", path = "/path/to/directory", schema = schema, sep = ";") +{% endhighlight %} +
    @@ -638,12 +725,24 @@ ds.groupByKey((MapFunction) value -> value.getDeviceType(), df = ... # streaming DataFrame with IOT device data with schema { device: string, deviceType: string, signal: double, time: DateType } # Select the devices which have signal more than 10 -df.select("device").where("signal > 10") +df.select("device").where("signal > 10") # Running count of the number of updates for each device type df.groupBy("deviceType").count() {% endhighlight %} +
    + +{% highlight r %} +df <- ... # streaming DataFrame with IOT device data with schema { device: string, deviceType: string, signal: double, time: DateType } + +# Select the devices which have signal more than 10 +select(where(df, "signal > 10"), "device") + +# Running count of the number of updates for each device type +count(groupBy(df, "deviceType")) +{% endhighlight %} +
    ### Window Operations on Event Time @@ -778,7 +877,7 @@ windowedCounts = words \ In this example, we are defining the watermark of the query on the value of the column "timestamp", and also defining "10 minutes" as the threshold of how late is the data allowed to be. If this query is run in Update output mode (discussed later in [Output Modes](#output-modes) section), -the engine will keep updating counts of a window in the Resule Table until the window is older +the engine will keep updating counts of a window in the Result Table until the window is older than the watermark, which lags behind the current event time in column "timestamp" by 10 minutes. Here is an illustration. @@ -840,7 +939,7 @@ Streaming DataFrames can be joined with static DataFrames to create new streamin {% highlight scala %} val staticDf = spark.read. ... -val streamingDf = spark.readStream. ... +val streamingDf = spark.readStream. ... streamingDf.join(staticDf, "type") // inner equi-join with a static DF streamingDf.join(staticDf, "type", "right_join") // right outer join with a static DF @@ -972,7 +1071,7 @@ Once you have defined the final result DataFrame/Dataset, all that is left is fo ([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeStream()`. You will have to specify one or more of the following in this interface. -- *Details of the output sink:* Data format, location, etc. +- *Details of the output sink:* Data format, location, etc. - *Output mode:* Specify what gets written to the output sink. @@ -1077,7 +1176,7 @@ Here is the compatibility matrix. #### Output Sinks There are a few types of built-in output sinks. -- **File sink** - Stores the output to a directory. +- **File sink** - Stores the output to a directory. {% highlight scala %} writeStream @@ -1145,7 +1244,8 @@ Here are the details of all the sinks in Spark. · "s3a://a/b/c/dataset.txt"

    For file-format-specific options, see the related methods in DataFrameWriter - (Scala/Java/Python). + (Scala/Java/Python/R). E.g. for "parquet" format options see DataFrameWriter.parquet() Yes @@ -1208,7 +1308,7 @@ noAggDF .option("checkpointLocation", "path/to/checkpoint/dir") .option("path", "path/to/destination/dir") .start() - + // ========== DF with aggregation ========== val aggDF = df.groupBy("device").count() @@ -1219,7 +1319,7 @@ aggDF .format("console") .start() -// Have all the aggregates in an in-memory table +// Have all the aggregates in an in-memory table aggDF .writeStream .queryName("aggregates") // this query name will be the table name @@ -1250,7 +1350,7 @@ noAggDF .option("checkpointLocation", "path/to/checkpoint/dir") .option("path", "path/to/destination/dir") .start(); - + // ========== DF with aggregation ========== Dataset aggDF = df.groupBy("device").count(); @@ -1261,7 +1361,7 @@ aggDF .format("console") .start(); -// Have all the aggregates in an in-memory table +// Have all the aggregates in an in-memory table aggDF .writeStream() .queryName("aggregates") // this query name will be the table name @@ -1292,7 +1392,7 @@ noAggDF \ .option("checkpointLocation", "path/to/checkpoint/dir") \ .option("path", "path/to/destination/dir") \ .start() - + # ========== DF with aggregation ========== aggDF = df.groupBy("device").count() @@ -1314,6 +1414,35 @@ aggDF \ spark.sql("select * from aggregates").show() # interactively query in-memory table {% endhighlight %} + +
    + +{% highlight r %} +# ========== DF with no aggregations ========== +noAggDF <- select(where(deviceDataDf, "signal > 10"), "device") + +# Print new data to console +write.stream(noAggDF, "console") + +# Write new data to Parquet files +write.stream(noAggDF, + "parquet", + path = "path/to/destination/dir", + checkpointLocation = "path/to/checkpoint/dir") + +# ========== DF with aggregation ========== +aggDF <- count(groupBy(df, "device")) + +# Print updated aggregations to console +write.stream(aggDF, "console", outputMode = "complete") + +# Have all the aggregates in an in memory table. The query name will be the table name +write.stream(aggDF, "memory", queryName = "aggregates", outputMode = "complete") + +# Interactively query in-memory table +head(sql("select * from aggregates")) +{% endhighlight %} +
    @@ -1351,7 +1480,7 @@ query.name // get the name of the auto-generated or user-specified name query.explain() // print detailed explanations of the query -query.stop() // stop the query +query.stop() // stop the query query.awaitTermination() // block until query is terminated, with stop() or with error @@ -1403,7 +1532,7 @@ query.name() # get the name of the auto-generated or user-specified name query.explain() # print detailed explanations of the query -query.stop() # stop the query +query.stop() # stop the query query.awaitTermination() # block until query is terminated, with stop() or with error @@ -1415,6 +1544,24 @@ query.lastProgress() # the most recent progress update of this streaming quer {% endhighlight %} + +
    + +{% highlight r %} +query <- write.stream(df, "console") # get the query object + +queryName(query) # get the name of the auto-generated or user-specified name + +explain(query) # print detailed explanations of the query + +stopQuery(query) # stop the query + +awaitTermination(query) # block until query is terminated, with stop() or with error + +lastProgress(query) # the most recent progress update of this streaming query + +{% endhighlight %} +
    @@ -1461,6 +1608,12 @@ spark.streams().get(id) # get a query object by its unique id spark.streams().awaitAnyTermination() # block until any one of them terminates {% endhighlight %} + +
    +{% highlight bash %} +Not available in R. +{% endhighlight %} +
    @@ -1644,6 +1797,58 @@ Will print something like the following. ''' {% endhighlight %} + +
    + +{% highlight r %} +query <- ... # a StreamingQuery +lastProgress(query) + +''' +Will print something like the following. + +{ + "id" : "8c57e1ec-94b5-4c99-b100-f694162df0b9", + "runId" : "ae505c5a-a64e-4896-8c28-c7cbaf926f16", + "name" : null, + "timestamp" : "2017-04-26T08:27:28.835Z", + "numInputRows" : 0, + "inputRowsPerSecond" : 0.0, + "processedRowsPerSecond" : 0.0, + "durationMs" : { + "getOffset" : 0, + "triggerExecution" : 1 + }, + "stateOperators" : [ { + "numRowsTotal" : 4, + "numRowsUpdated" : 0 + } ], + "sources" : [ { + "description" : "TextSocketSource[host: localhost, port: 9999]", + "startOffset" : 1, + "endOffset" : 1, + "numInputRows" : 0, + "inputRowsPerSecond" : 0.0, + "processedRowsPerSecond" : 0.0 + } ], + "sink" : { + "description" : "org.apache.spark.sql.execution.streaming.ConsoleSink@76b37531" + } +} +''' + +status(query) +''' +Will print something like the following. + +{ + "message" : "Waiting for data to arrive", + "isDataAvailable" : false, + "isTriggerActive" : false +} +''' +{% endhighlight %} +
    @@ -1703,11 +1908,17 @@ spark.streams().addListener(new StreamingQueryListener() { Not available in Python. {% endhighlight %} + +
    +{% highlight bash %} +Not available in R. +{% endhighlight %} +
    ## Recovering from Failures with Checkpointing -In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries). +In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger) and the running aggregates (e.g. word counts in the [quick example](#quick-example)) to the checkpoint location. This checkpoint location has to be a path in an HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
    @@ -1745,20 +1956,18 @@ aggDF \ .start() {% endhighlight %} +
    +
    + +{% highlight r %} +write.stream(aggDF, "memory", outputMode = "complete", checkpointLocation = "path/to/HDFS/dir") +{% endhighlight %} +
    # Where to go from here -- Examples: See and run the -[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming) +- Examples: See and run the +[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming)/[R]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/r/streaming) examples. - Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/) - - - - - - - - - diff --git a/examples/pom.xml b/examples/pom.xml index 91c2e81ebed2..aa91e98b28ae 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java new file mode 100644 index 000000000000..717ec21c8b20 --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaFPGrowthExample.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml; + +// $example on$ +import java.util.Arrays; +import java.util.List; + +import org.apache.spark.ml.fpm.FPGrowth; +import org.apache.spark.ml.fpm.FPGrowthModel; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.types.*; +// $example off$ + +/** + * An example demonstrating FPGrowth. + * Run with + *
    + * bin/run-example ml.JavaFPGrowthExample
    + * 
    + */ +public class JavaFPGrowthExample { + public static void main(String[] args) { + SparkSession spark = SparkSession + .builder() + .appName("JavaFPGrowthExample") + .getOrCreate(); + + // $example on$ + List data = Arrays.asList( + RowFactory.create(Arrays.asList("1 2 5".split(" "))), + RowFactory.create(Arrays.asList("1 2 3 5".split(" "))), + RowFactory.create(Arrays.asList("1 2".split(" "))) + ); + StructType schema = new StructType(new StructField[]{ new StructField( + "items", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) + }); + Dataset itemsDF = spark.createDataFrame(data, schema); + + FPGrowthModel model = new FPGrowth() + .setItemsCol("items") + .setMinSupport(0.5) + .setMinConfidence(0.6) + .fit(itemsDF); + + // Display frequent itemsets. + model.freqItemsets().show(); + + // Display generated association rules. + model.associationRules().show(); + + // transform examines the input items against all the association rules and summarize the + // consequents as prediction + model.transform(itemsDF).show(); + // $example off$ + + spark.stop(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java index 3077f557ef88..0a7dc621e111 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPCAExample.java @@ -18,7 +18,8 @@ package org.apache.spark.examples.mllib; // $example on$ -import java.util.LinkedList; +import java.util.Arrays; +import java.util.List; // $example off$ import org.apache.spark.SparkConf; @@ -39,21 +40,25 @@ public class JavaPCAExample { public static void main(String[] args) { SparkConf conf = new SparkConf().setAppName("PCA Example"); SparkContext sc = new SparkContext(conf); + JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); // $example on$ - double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; - LinkedList rowsList = new LinkedList<>(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = JavaSparkContext.fromSparkContext(sc).parallelize(rowsList); + List data = Arrays.asList( + Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ); + + JavaRDD rows = jsc.parallelize(data); // Create a RowMatrix from JavaRDD. RowMatrix mat = new RowMatrix(rows.rdd()); - // Compute the top 3 principal components. - Matrix pc = mat.computePrincipalComponents(3); + // Compute the top 4 principal components. + // Principal components are stored in a local dense matrix. + Matrix pc = mat.computePrincipalComponents(4); + + // Project the rows to the linear space spanned by the top 4 principal components. RowMatrix projected = mat.multiply(pc); // $example off$ Vector[] collectPartitions = (Vector[])projected.rows().collect(); @@ -61,6 +66,6 @@ public static void main(String[] args) { for (Vector vector : collectPartitions) { System.out.println("\t" + vector); } - sc.stop(); + jsc.stop(); } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java index 3730e60f6880..802be3960a33 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaSVDExample.java @@ -18,7 +18,8 @@ package org.apache.spark.examples.mllib; // $example on$ -import java.util.LinkedList; +import java.util.Arrays; +import java.util.List; // $example off$ import org.apache.spark.SparkConf; @@ -43,22 +44,22 @@ public static void main(String[] args) { JavaSparkContext jsc = JavaSparkContext.fromSparkContext(sc); // $example on$ - double[][] array = {{1.12, 2.05, 3.12}, {5.56, 6.28, 8.94}, {10.2, 8.0, 20.5}}; - LinkedList rowsList = new LinkedList<>(); - for (int i = 0; i < array.length; i++) { - Vector currentRow = Vectors.dense(array[i]); - rowsList.add(currentRow); - } - JavaRDD rows = jsc.parallelize(rowsList); + List data = Arrays.asList( + Vectors.sparse(5, new int[] {1, 3}, new double[] {1.0, 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ); + + JavaRDD rows = jsc.parallelize(data); // Create a RowMatrix from JavaRDD. RowMatrix mat = new RowMatrix(rows.rdd()); - // Compute the top 3 singular values and corresponding singular vectors. - SingularValueDecomposition svd = mat.computeSVD(3, true, 1.0E-9d); - RowMatrix U = svd.U(); - Vector s = svd.s(); - Matrix V = svd.V(); + // Compute the top 5 singular values and corresponding singular vectors. + SingularValueDecomposition svd = mat.computeSVD(5, true, 1.0E-9d); + RowMatrix U = svd.U(); // The U factor is a RowMatrix. + Vector s = svd.s(); // The singular values are stored in a local dense vector. + Matrix V = svd.V(); // The V factor is a local dense matrix. // $example off$ Vector[] collectPartitions = (Vector[]) U.rows().collect(); System.out.println("U factor is:"); diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java index da3a5dfe8628..6b8e6554f1bb 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredSessionization.java @@ -28,8 +28,6 @@ import java.sql.Timestamp; import java.util.*; -import scala.Tuple2; - /** * Counts words in UTF8 encoded, '\n' delimited text received from the network. *

    @@ -76,8 +74,6 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio for (String word : lineWithTimestamp.getLine().split(" ")) { eventList.add(new Event(word, lineWithTimestamp.getTimestamp())); } - System.out.println( - "Number of events from " + lineWithTimestamp.getLine() + " = " + eventList.size()); return eventList.iterator(); } }; @@ -100,7 +96,7 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio // If timed out, then remove session and send final update if (state.hasTimedOut()) { SessionUpdate finalUpdate = new SessionUpdate( - sessionId, state.get().getDurationMs(), state.get().getNumEvents(), true); + sessionId, state.get().calculateDuration(), state.get().getNumEvents(), true); state.remove(); return finalUpdate; @@ -133,7 +129,7 @@ public Iterator call(LineWithTimestamp lineWithTimestamp) throws Exceptio // Set timeout such that the session will be expired if no data received for 10 seconds state.setTimeoutDuration("10 seconds"); return new SessionUpdate( - sessionId, state.get().getDurationMs(), state.get().getNumEvents(), false); + sessionId, state.get().calculateDuration(), state.get().getNumEvents(), false); } } }; @@ -215,7 +211,8 @@ public void setStartTimestampMs(long startTimestampMs) { public long getEndTimestampMs() { return endTimestampMs; } public void setEndTimestampMs(long endTimestampMs) { this.endTimestampMs = endTimestampMs; } - public long getDurationMs() { return endTimestampMs - startTimestampMs; } + public long calculateDuration() { return endTimestampMs - startTimestampMs; } + @Override public String toString() { return "SessionInfo(numEvents = " + numEvents + ", timestamps = " + startTimestampMs + " to " + endTimestampMs + ")"; diff --git a/examples/src/main/python/ml/fpgrowth_example.py b/examples/src/main/python/ml/fpgrowth_example.py new file mode 100644 index 000000000000..c92c3c27abb2 --- /dev/null +++ b/examples/src/main/python/ml/fpgrowth_example.py @@ -0,0 +1,56 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# $example on$ +from pyspark.ml.fpm import FPGrowth +# $example off$ +from pyspark.sql import SparkSession + +""" +An example demonstrating FPGrowth. +Run with: + bin/spark-submit examples/src/main/python/ml/fpgrowth_example.py +""" + +if __name__ == "__main__": + spark = SparkSession\ + .builder\ + .appName("FPGrowthExample")\ + .getOrCreate() + + # $example on$ + df = spark.createDataFrame([ + (0, [1, 2, 5]), + (1, [1, 2, 3, 5]), + (2, [1, 2]) + ], ["id", "items"]) + + fpGrowth = FPGrowth(itemsCol="items", minSupport=0.5, minConfidence=0.6) + model = fpGrowth.fit(df) + + # Display frequent itemsets. + model.freqItemsets.show() + + # Display generated association rules. + model.associationRules.show() + + # transform examines the input items against all the association rules and summarize the + # consequents as prediction + model.transform(df).show() + # $example off$ + + spark.stop() diff --git a/examples/src/main/python/mllib/pca_rowmatrix_example.py b/examples/src/main/python/mllib/pca_rowmatrix_example.py new file mode 100644 index 000000000000..49b9b1bbe08e --- /dev/null +++ b/examples/src/main/python/mllib/pca_rowmatrix_example.py @@ -0,0 +1,46 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonPCAOnRowMatrixExample") + + # $example on$ + rows = sc.parallelize([ + Vectors.sparse(5, {1: 1.0, 3: 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ]) + + mat = RowMatrix(rows) + # Compute the top 4 principal components. + # Principal components are stored in a local dense matrix. + pc = mat.computePrincipalComponents(4) + + # Project the rows to the linear space spanned by the top 4 principal components. + projected = mat.multiply(pc) + # $example off$ + collected = projected.rows.collect() + print("Projected Row Matrix of principal component:") + for vector in collected: + print(vector) + sc.stop() diff --git a/examples/src/main/python/mllib/svd_example.py b/examples/src/main/python/mllib/svd_example.py new file mode 100644 index 000000000000..5b220fdb3fd6 --- /dev/null +++ b/examples/src/main/python/mllib/svd_example.py @@ -0,0 +1,48 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +# $example on$ +from pyspark.mllib.linalg import Vectors +from pyspark.mllib.linalg.distributed import RowMatrix +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="PythonSVDExample") + + # $example on$ + rows = sc.parallelize([ + Vectors.sparse(5, {1: 1.0, 3: 7.0}), + Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), + Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0) + ]) + + mat = RowMatrix(rows) + + # Compute the top 5 singular values and corresponding singular vectors. + svd = mat.computeSVD(5, computeU=True) + U = svd.U # The U factor is a RowMatrix. + s = svd.s # The singular values are stored in a local dense vector. + V = svd.V # The V factor is a local dense matrix. + # $example off$ + collected = U.rows.collect() + print("U factor is:") + for vector in collected: + print(vector) + print("Singular values are: %s" % s) + print("V factor is:\n%s" % V) + sc.stop() diff --git a/examples/src/main/r/ml/fpm.R b/examples/src/main/r/ml/fpm.R new file mode 100644 index 000000000000..89c4564457d9 --- /dev/null +++ b/examples/src/main/r/ml/fpm.R @@ -0,0 +1,50 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# To run this example use +# ./bin/spark-submit examples/src/main/r/ml/fpm.R + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-ML-fpm-example") + +# $example on$ +# Load training data + +df <- selectExpr(createDataFrame(data.frame(rawItems = c( + "1,2,5", "1,2,3,5", "1,2" +))), "split(rawItems, ',') AS items") + +fpm <- spark.fpGrowth(df, itemsCol="items", minSupport=0.5, minConfidence=0.6) + +# Extracting frequent itemsets + +spark.freqItemsets(fpm) + +# Extracting association rules + +spark.associationRules(fpm) + +# Predict uses association rules to and combines possible consequents + +predict(fpm, df) + +# $example off$ + +sparkR.session.stop() diff --git a/examples/src/main/r/streaming/structured_network_wordcount.R b/examples/src/main/r/streaming/structured_network_wordcount.R new file mode 100644 index 000000000000..cda18ebc072e --- /dev/null +++ b/examples/src/main/r/streaming/structured_network_wordcount.R @@ -0,0 +1,57 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Counts words in UTF8 encoded, '\n' delimited text received from the network. + +# To run this on your local machine, you need to first run a Netcat server +# $ nc -lk 9999 +# and then run the example +# ./bin/spark-submit examples/src/main/r/streaming/structured_network_wordcount.R localhost 9999 + +# Load SparkR library into your R session +library(SparkR) + +# Initialize SparkSession +sparkR.session(appName = "SparkR-Streaming-structured-network-wordcount-example") + +args <- commandArgs(trailing = TRUE) + +if (length(args) != 2) { + print("Usage: structured_network_wordcount.R ") + print(" and describe the TCP server that Structured Streaming") + print("would connect to receive data.") + q("no") +} + +hostname <- args[[1]] +port <- as.integer(args[[2]]) + +# Create DataFrame representing the stream of input lines from connection to localhost:9999 +lines <- read.stream("socket", host = hostname, port = port) + +# Split the lines into words +words <- selectExpr(lines, "explode(split(value, ' ')) as word") + +# Generate running word count +wordCounts <- count(groupBy(words, "word")) + +# Start running the query that prints the running counts to the console +query <- write.stream(wordCounts, "console", outputMode = "complete") + +awaitTermination(query) + +sparkR.session.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala new file mode 100644 index 000000000000..59110d70de55 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/FPGrowthExample.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.ml + +// scalastyle:off println + +// $example on$ +import org.apache.spark.ml.fpm.FPGrowth +// $example off$ +import org.apache.spark.sql.SparkSession + +/** + * An example demonstrating FP-Growth. + * Run with + * {{{ + * bin/run-example ml.FPGrowthExample + * }}} + */ +object FPGrowthExample { + + def main(args: Array[String]): Unit = { + val spark = SparkSession + .builder + .appName(s"${this.getClass.getSimpleName}") + .getOrCreate() + import spark.implicits._ + + // $example on$ + val dataset = spark.createDataset(Seq( + "1 2 5", + "1 2 3 5", + "1 2") + ).map(t => t.split(" ")).toDF("items") + + val fpgrowth = new FPGrowth().setItemsCol("items").setMinSupport(0.5).setMinConfidence(0.6) + val model = fpgrowth.fit(dataset) + + // Display frequent itemsets. + model.freqItemsets.show() + + // Display generated association rules. + model.associationRules.show() + + // transform examines the input items against all the association rules and summarize the + // consequents as prediction + model.transform(dataset).show() + // $example off$ + + spark.stop() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala index a137ba2a2f9d..da43a8d9c7e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAOnRowMatrixExample.scala @@ -39,9 +39,9 @@ object PCAOnRowMatrixExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - val dataRDD = sc.parallelize(data, 2) + val rows = sc.parallelize(data) - val mat: RowMatrix = new RowMatrix(dataRDD) + val mat: RowMatrix = new RowMatrix(rows) // Compute the top 4 principal components. // Principal components are stored in a local dense matrix. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala index b286a3f7b909..769ae2a3a88b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVDExample.scala @@ -28,6 +28,9 @@ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.RowMatrix // $example off$ +/** + * Example for SingularValueDecomposition. + */ object SVDExample { def main(args: Array[String]): Unit = { @@ -41,15 +44,15 @@ object SVDExample { Vectors.dense(2.0, 0.0, 3.0, 4.0, 5.0), Vectors.dense(4.0, 0.0, 0.0, 6.0, 7.0)) - val dataRDD = sc.parallelize(data, 2) + val rows = sc.parallelize(data) - val mat: RowMatrix = new RowMatrix(dataRDD) + val mat: RowMatrix = new RowMatrix(rows) // Compute the top 5 singular values and corresponding singular vectors. val svd: SingularValueDecomposition[RowMatrix, Matrix] = mat.computeSVD(5, computeU = true) val U: RowMatrix = svd.U // The U factor is a RowMatrix. - val s: Vector = svd.s // The singular values are stored in a local dense vector. - val V: Matrix = svd.V // The V factor is a local dense matrix. + val s: Vector = svd.s // The singular values are stored in a local dense vector. + val V: Matrix = svd.V // The V factor is a local dense matrix. // $example off$ val collect = U.rows.collect() println("U factor is:") diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 8948df2da89e..04afe28fb788 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index f8ef8a991316..47e03419d3df 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 6d547c46d6a2..f961a8f54d9a 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 46901d64eda9..d8bc7dcf7524 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml index 295142cbfdff..6d46430d6e96 100644 --- a/external/kafka-0-10-assembly/pom.xml +++ b/external/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/pom.xml b/external/kafka-0-10-sql/pom.xml index 6cf448e65e8b..5d979ddf2f74 100644 --- a/external/kafka-0-10-sql/pom.xml +++ b/external/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala index 6d76904fb0e5..7c4f38e02fb2 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/CachedKafkaConsumer.scala @@ -28,6 +28,7 @@ import org.apache.kafka.common.TopicPartition import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.sql.kafka010.KafkaSource._ +import org.apache.spark.util.UninterruptibleThread /** @@ -62,11 +63,20 @@ private[kafka010] case class CachedKafkaConsumer private( case class AvailableOffsetRange(earliest: Long, latest: Long) + private def runUninterruptiblyIfPossible[T](body: => T): T = Thread.currentThread match { + case ut: UninterruptibleThread => + ut.runUninterruptibly(body) + case _ => + logWarning("CachedKafkaConsumer is not running in UninterruptibleThread. " + + "It may hang when CachedKafkaConsumer's methods are interrupted because of KAFKA-1894") + body + } + /** * Return the available offset range of the current partition. It's a pair of the earliest offset * and the latest offset. */ - def getAvailableOffsetRange(): AvailableOffsetRange = { + def getAvailableOffsetRange(): AvailableOffsetRange = runUninterruptiblyIfPossible { consumer.seekToBeginning(Set(topicPartition).asJava) val earliestOffset = consumer.position(topicPartition) consumer.seekToEnd(Set(topicPartition).asJava) @@ -92,7 +102,8 @@ private[kafka010] case class CachedKafkaConsumer private( offset: Long, untilOffset: Long, pollTimeoutMs: Long, - failOnDataLoss: Boolean): ConsumerRecord[Array[Byte], Array[Byte]] = { + failOnDataLoss: Boolean): + ConsumerRecord[Array[Byte], Array[Byte]] = runUninterruptiblyIfPossible { require(offset < untilOffset, s"offset must always be less than untilOffset [offset: $offset, untilOffset: $untilOffset]") logDebug(s"Get $groupId $topicPartition nextOffset $nextOffsetInFetchedData requested $offset") @@ -276,7 +287,7 @@ private[kafka010] case class CachedKafkaConsumer private( reportDataLoss0(failOnDataLoss, finalMessage, cause) } - private def close(): Unit = consumer.close() + def close(): Unit = consumer.close() private def seek(offset: Long): Unit = { logDebug(s"Seeking to $groupId $topicPartition $offset") @@ -371,7 +382,7 @@ private[kafka010] object CachedKafkaConsumer extends Logging { // If this is reattempt at running the task, then invalidate cache and start with // a new consumer - if (TaskContext.get != null && TaskContext.get.attemptNumber > 1) { + if (TaskContext.get != null && TaskContext.get.attemptNumber >= 1) { removeKafkaConsumer(topic, partition, kafkaParams) val consumer = new CachedKafkaConsumer(topicPartition, kafkaParams) consumer.inuse = true @@ -387,6 +398,14 @@ private[kafka010] object CachedKafkaConsumer extends Logging { } } + /** Create an [[CachedKafkaConsumer]] but don't put it into cache. */ + def createUncached( + topic: String, + partition: Int, + kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer = { + new CachedKafkaConsumer(new TopicPartition(topic, partition), kafkaParams) + } + private def reportDataLoss0( failOnDataLoss: Boolean, finalMessage: String, diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala index 2696d6f089d2..3e65949a6fd1 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetReader.scala @@ -95,8 +95,10 @@ private[kafka010] class KafkaOffsetReader( * Closes the connection to Kafka, and cleans up state. */ def close(): Unit = { - consumer.close() - kafkaReaderThread.shutdownNow() + runUninterruptibly { + consumer.close() + } + kafkaReaderThread.shutdown() } /** diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala index f180bbad6e36..97bd28316932 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaRelation.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} +import java.util.UUID import org.apache.kafka.common.TopicPartition @@ -33,9 +34,9 @@ import org.apache.spark.unsafe.types.UTF8String private[kafka010] class KafkaRelation( override val sqlContext: SQLContext, - kafkaReader: KafkaOffsetReader, - executorKafkaParams: ju.Map[String, Object], + strategy: ConsumerStrategy, sourceOptions: Map[String, String], + specifiedKafkaParams: Map[String, String], failOnDataLoss: Boolean, startingOffsets: KafkaOffsetRangeLimit, endingOffsets: KafkaOffsetRangeLimit) @@ -53,9 +54,27 @@ private[kafka010] class KafkaRelation( override def schema: StructType = KafkaOffsetReader.kafkaSchema override def buildScan(): RDD[Row] = { + // Each running query should use its own group id. Otherwise, the query may be only assigned + // partial data since Kafka will assign partitions to multiple consumers having the same group + // id. Hence, we should generate a unique id for each query. + val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" + + val kafkaOffsetReader = new KafkaOffsetReader( + strategy, + KafkaSourceProvider.kafkaParamsForDriver(specifiedKafkaParams), + sourceOptions, + driverGroupIdPrefix = s"$uniqueGroupId-driver") + // Leverage the KafkaReader to obtain the relevant partition offsets - val fromPartitionOffsets = getPartitionOffsets(startingOffsets) - val untilPartitionOffsets = getPartitionOffsets(endingOffsets) + val (fromPartitionOffsets, untilPartitionOffsets) = { + try { + (getPartitionOffsets(kafkaOffsetReader, startingOffsets), + getPartitionOffsets(kafkaOffsetReader, endingOffsets)) + } finally { + kafkaOffsetReader.close() + } + } + // Obtain topicPartitions in both from and until partition offset, ignoring // topic partitions that were added and/or deleted between the two above calls. if (fromPartitionOffsets.keySet != untilPartitionOffsets.keySet) { @@ -82,6 +101,8 @@ private[kafka010] class KafkaRelation( offsetRanges.sortBy(_.topicPartition.toString).mkString(", ")) // Create an RDD that reads from Kafka and get the (key, value) pair as byte arrays. + val executorKafkaParams = + KafkaSourceProvider.kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId) val rdd = new KafkaSourceRDD( sqlContext.sparkContext, executorKafkaParams, offsetRanges, pollTimeoutMs, failOnDataLoss, reuseKafkaConsumer = false).map { cr => @@ -98,6 +119,7 @@ private[kafka010] class KafkaRelation( } private def getPartitionOffsets( + kafkaReader: KafkaOffsetReader, kafkaOffsets: KafkaOffsetRangeLimit): Map[TopicPartition, Long] = { def validateTopicPartitions(partitions: Set[TopicPartition], partitionOffsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index ab1ce347cbe3..3cb4d8cad12c 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -111,10 +111,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { validateBatchOptions(parameters) - // Each running query should use its own group id. Otherwise, the query may be only assigned - // partial data since Kafka will assign partitions to multiple consumers having the same group - // id. Hence, we should generate a unique id for each query. - val uniqueGroupId = s"spark-kafka-relation-${UUID.randomUUID}" val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase(Locale.ROOT), v) } val specifiedKafkaParams = parameters @@ -131,20 +127,14 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister ENDING_OFFSETS_OPTION_KEY, LatestOffsetRangeLimit) assert(endingRelationOffsets != EarliestOffsetRangeLimit) - val kafkaOffsetReader = new KafkaOffsetReader( - strategy(caseInsensitiveParams), - kafkaParamsForDriver(specifiedKafkaParams), - parameters, - driverGroupIdPrefix = s"$uniqueGroupId-driver") - new KafkaRelation( sqlContext, - kafkaOffsetReader, - kafkaParamsForExecutors(specifiedKafkaParams, uniqueGroupId), - parameters, - failOnDataLoss(caseInsensitiveParams), - startingRelationOffsets, - endingRelationOffsets) + strategy(caseInsensitiveParams), + sourceOptions = parameters, + specifiedKafkaParams = specifiedKafkaParams, + failOnDataLoss = failOnDataLoss(caseInsensitiveParams), + startingOffsets = startingRelationOffsets, + endingOffsets = endingRelationOffsets) } override def createSink( @@ -213,46 +203,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) } - private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) = - ConfigUpdater("source", specifiedKafkaParams) - .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) - .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) - - // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial - // offsets by itself instead of counting on KafkaConsumer. - .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") - - // So that consumers in the driver does not commit offsets unnecessarily - .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") - - // So that the driver does not pull too much data - .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) - - // If buffer config is not set, set it to reasonable value to work around - // buffer issues (see KAFKA-3135) - .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .build() - - private def kafkaParamsForExecutors( - specifiedKafkaParams: Map[String, String], uniqueGroupId: String) = - ConfigUpdater("executor", specifiedKafkaParams) - .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) - .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) - - // Make sure executors do only what the driver tells them. - .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") - - // So that consumers in executors do not mess with any existing group id - .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") - - // So that consumers in executors does not commit offsets unnecessarily - .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") - - // If buffer config is not set, set it to reasonable value to work around - // buffer issues (see KAFKA-3135) - .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) - .build() - private def strategy(caseInsensitiveParams: Map[String, String]) = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { case ("assign", value) => @@ -414,30 +364,9 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister logWarning("maxOffsetsPerTrigger option ignored in batch queries") } } - - /** Class to conveniently update Kafka config params, while logging the changes */ - private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { - private val map = new ju.HashMap[String, Object](kafkaParams.asJava) - - def set(key: String, value: Object): this.type = { - map.put(key, value) - logInfo(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") - this - } - - def setIfUnset(key: String, value: Object): ConfigUpdater = { - if (!map.containsKey(key)) { - map.put(key, value) - logInfo(s"$module: Set $key to $value") - } - this - } - - def build(): ju.Map[String, Object] = map - } } -private[kafka010] object KafkaSourceProvider { +private[kafka010] object KafkaSourceProvider extends Logging { private val STRATEGY_OPTION_KEYS = Set("subscribe", "subscribepattern", "assign") private[kafka010] val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private[kafka010] val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" @@ -459,4 +388,66 @@ private[kafka010] object KafkaSourceProvider { case None => defaultOffsets } } + + def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]): ju.Map[String, Object] = + ConfigUpdater("source", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Set to "earliest" to avoid exceptions. However, KafkaSource will fetch the initial + // offsets by itself instead of counting on KafkaConsumer. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "earliest") + + // So that consumers in the driver does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // So that the driver does not pull too much data + .set(ConsumerConfig.MAX_POLL_RECORDS_CONFIG, new java.lang.Integer(1)) + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + def kafkaParamsForExecutors( + specifiedKafkaParams: Map[String, String], + uniqueGroupId: String): ju.Map[String, Object] = + ConfigUpdater("executor", specifiedKafkaParams) + .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) + .set(ConsumerConfig.VALUE_DESERIALIZER_CLASS_CONFIG, deserClassName) + + // Make sure executors do only what the driver tells them. + .set(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none") + + // So that consumers in executors do not mess with any existing group id + .set(ConsumerConfig.GROUP_ID_CONFIG, s"$uniqueGroupId-executor") + + // So that consumers in executors does not commit offsets unnecessarily + .set(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, "false") + + // If buffer config is not set, set it to reasonable value to work around + // buffer issues (see KAFKA-3135) + .setIfUnset(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer) + .build() + + /** Class to conveniently update Kafka config params, while logging the changes */ + private case class ConfigUpdater(module: String, kafkaParams: Map[String, String]) { + private val map = new ju.HashMap[String, Object](kafkaParams.asJava) + + def set(key: String, value: Object): this.type = { + map.put(key, value) + logDebug(s"$module: Set $key to $value, earlier value: ${kafkaParams.getOrElse(key, "")}") + this + } + + def setIfUnset(key: String, value: Object): ConfigUpdater = { + if (!map.containsKey(key)) { + map.put(key, value) + logDebug(s"$module: Set $key to $value") + } + this + } + + def build(): ju.Map[String, Object] = map + } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala index 6fb3473eb75f..9d9e2aaba807 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceRDD.scala @@ -125,16 +125,15 @@ private[kafka010] class KafkaSourceRDD( context: TaskContext): Iterator[ConsumerRecord[Array[Byte], Array[Byte]]] = { val sourcePartition = thePart.asInstanceOf[KafkaSourceRDDPartition] val topic = sourcePartition.offsetRange.topic - if (!reuseKafkaConsumer) { - // if we can't reuse CachedKafkaConsumers, let's reset the groupId to something unique - // to each task (i.e., append the task's unique partition id), because we will have - // multiple tasks (e.g., in the case of union) reading from the same topic partitions - val old = executorKafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String] - val id = TaskContext.getPartitionId() - executorKafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, old + "-" + id) - } val kafkaPartition = sourcePartition.offsetRange.partition - val consumer = CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) + val consumer = + if (!reuseKafkaConsumer) { + // If we can't reuse CachedKafkaConsumers, creating a new CachedKafkaConsumer. As here we + // uses `assign`, we don't need to worry about the "group.id" conflicts. + CachedKafkaConsumer.createUncached(topic, kafkaPartition, executorKafkaParams) + } else { + CachedKafkaConsumer.getOrCreate(topic, kafkaPartition, executorKafkaParams) + } val range = resolveRange(consumer, sourcePartition.offsetRange) assert( range.fromOffset <= range.untilOffset, @@ -170,7 +169,7 @@ private[kafka010] class KafkaSourceRDD( override protected def close(): Unit = { if (!reuseKafkaConsumer) { // Don't forget to close non-reuse KafkaConsumers. You may take down your cluster! - CachedKafkaConsumer.removeKafkaConsumer(topic, kafkaPartition, executorKafkaParams) + consumer.close() } else { // Indicate that we're no longer using this consumer CachedKafkaConsumer.releaseKafkaConsumer(topic, kafkaPartition, executorKafkaParams) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala index a637d52c933a..61936e32fd83 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -47,7 +47,7 @@ private[kafka010] object KafkaWriter extends Logging { queryExecution: QueryExecution, kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.logical.output + val schema = queryExecution.analyzed.output schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( if (topic == None) { throw new AnalysisException(s"topic option required when no " + @@ -84,7 +84,7 @@ private[kafka010] object KafkaWriter extends Logging { queryExecution: QueryExecution, kafkaParameters: ju.Map[String, Object], topic: Option[String] = None): Unit = { - val schema = queryExecution.logical.output + val schema = queryExecution.analyzed.output validateQuery(queryExecution, kafkaParameters, topic) SQLExecution.withNewExecutionId(sparkSession, queryExecution) { queryExecution.toRdd.foreachPartition { iter => diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala index 4bd052d249ec..2ab336c7ac47 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -28,6 +28,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.functions._ import org.apache.spark.sql.streaming._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, DataType} @@ -108,6 +109,21 @@ class KafkaSinkSuite extends StreamTest with SharedSQLContext { s"save mode overwrite not allowed for kafka")) } + test("SPARK-20496: batch - enforce analyzed plans") { + val inputEvents = + spark.range(1, 1000) + .select(to_json(struct("*")) as 'value) + + val topic = newTopic() + testUtils.createTopic(topic) + // used to throw UnresolvedException + inputEvents.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", topic) + .save() + } + test("streaming - write to kafka with topic field") { val input = MemoryStream[String] val topic = newTopic() diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml index 88499240cd56..e4336ecb07da 100644 --- a/external/kafka-0-10/pom.xml +++ b/external/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala index 4c6e2ce87e29..62cdf5b1134e 100644 --- a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala +++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala @@ -199,7 +199,7 @@ private[spark] class KafkaRDD[K, V]( val consumer = if (useConsumerCache) { CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor) - if (context.attemptNumber > 1) { + if (context.attemptNumber >= 1) { // just in case the prior attempt failures were cache related CachedKafkaConsumer.remove(groupId, part.topic, part.partition) } diff --git a/external/kafka-0-8-assembly/pom.xml b/external/kafka-0-8-assembly/pom.xml index 3fedd9eda195..2489d29ebe16 100644 --- a/external/kafka-0-8-assembly/pom.xml +++ b/external/kafka-0-8-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/kafka-0-8/pom.xml b/external/kafka-0-8/pom.xml index 8368a1f12218..98f81aee376a 100644 --- a/external/kafka-0-8/pom.xml +++ b/external/kafka-0-8/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl-assembly/pom.xml b/external/kinesis-asl-assembly/pom.xml index 90bb0e4987c8..88515f853edb 100644 --- a/external/kinesis-asl-assembly/pom.xml +++ b/external/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/kinesis-asl/pom.xml b/external/kinesis-asl/pom.xml index daa79e79163b..28797e3fe432 100644 --- a/external/kinesis-asl/pom.xml +++ b/external/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/external/spark-ganglia-lgpl/pom.xml b/external/spark-ganglia-lgpl/pom.xml index 7da27817ebaf..701455f22609 100644 --- a/external/spark-ganglia-lgpl/pom.xml +++ b/external/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 8df33660ea9d..1ed38a794f44 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 646462b4a835..755c6febc48e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -19,7 +19,10 @@ package org.apache.spark.graphx import scala.reflect.ClassTag +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer /** * Implements a Pregel-like bulk-synchronous message-passing API. @@ -122,27 +125,39 @@ object Pregel extends Logging { require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," + s" but got ${maxIterations}") - var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache() + val checkpointInterval = graph.vertices.sparkContext.getConf + .getInt("spark.graphx.pregel.checkpointInterval", -1) + var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)) + val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED]( + checkpointInterval, graph.vertices.sparkContext) + graphCheckpointer.update(g) + // compute the messages var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg) + val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)]( + checkpointInterval, graph.vertices.sparkContext) + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) var activeMessages = messages.count() + // Loop var prevG: Graph[VD, ED] = null var i = 0 while (activeMessages > 0 && i < maxIterations) { // Receive the messages and update the vertices. prevG = g - g = g.joinVertices(messages)(vprog).cache() + g = g.joinVertices(messages)(vprog) + graphCheckpointer.update(g) val oldMessages = messages // Send new messages, skipping edges where neither side received a message. We must cache // messages so it can be materialized on the next line, allowing us to uncache the previous // iteration. messages = GraphXUtils.mapReduceTriplets( - g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache() + g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))) // The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages // (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages // and the vertices of g). + messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]]) activeMessages = messages.count() logInfo("Pregel finished iteration " + i) @@ -154,7 +169,9 @@ object Pregel extends Logging { // count the iteration i += 1 } - messages.unpersist(blocking = false) + messageCheckpointer.unpersistDataSet() + graphCheckpointer.deleteAllCheckpoints() + messageCheckpointer.deleteAllCheckpoints() g } // end of apply diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 13b2b5771918..fd7b7f7c1c48 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -226,18 +226,18 @@ object PageRank extends Logging { // Propagates the message along outbound edges // and adding start nodes back in with activation resetProb val rankUpdates = rankGraph.aggregateMessages[BV[Double]]( - ctx => ctx.sendToDst(ctx.srcAttr :* ctx.attr), - (a : BV[Double], b : BV[Double]) => a :+ b, TripletFields.Src) + ctx => ctx.sendToDst(ctx.srcAttr *:* ctx.attr), + (a : BV[Double], b : BV[Double]) => a +:+ b, TripletFields.Src) rankGraph = rankGraph.outerJoinVertices(rankUpdates) { (vid, oldRank, msgSumOpt) => - val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) :* (1.0 - resetProb) + val popActivations: BV[Double] = msgSumOpt.getOrElse(zero) *:* (1.0 - resetProb) val resetActivations = if (sourcesInitMapBC.value contains vid) { - sourcesInitMapBC.value(vid) :* resetProb + sourcesInitMapBC.value(vid) *:* resetProb } else { zero } - popActivations :+ resetActivations + popActivations +:+ resetActivations }.cache() rankGraph.edges.foreachPartition(x => {}) // also materializes rankGraph.vertices @@ -250,9 +250,9 @@ object PageRank extends Logging { } // SPARK-18847 If the graph has sinks (vertices with no outgoing edges) correct the sum of ranks - val rankSums = rankGraph.vertices.values.fold(zero)(_ :+ _) + val rankSums = rankGraph.vertices.values.fold(zero)(_ +:+ _) rankGraph.mapVertices { (vid, attr) => - Vectors.fromBreeze(attr :/ rankSums) + Vectors.fromBreeze(attr /:/ rankSums) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala similarity index 91% rename from mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala rename to graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala index 80074897567e..fda501aa757d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointer.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointer.scala @@ -15,11 +15,12 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.graphx.util import org.apache.spark.SparkContext import org.apache.spark.graphx.Graph import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.PeriodicCheckpointer /** @@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel * @tparam VD Vertex descriptor type * @tparam ED Edge descriptor type * - * TODO: Move this out of MLlib? */ -private[mllib] class PeriodicGraphCheckpointer[VD, ED]( +private[spark] class PeriodicGraphCheckpointer[VD, ED]( checkpointInterval: Int, sc: SparkContext) extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) { @@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED]( override protected def persist(data: Graph[VD, ED]): Unit = { if (data.vertices.getStorageLevel == StorageLevel.NONE) { - data.vertices.persist() + /* We need to use cache because persist does not honor the default storage level requested + * when constructing the graph. Only cache does that. + */ + data.vertices.cache() } if (data.edges.getStorageLevel == StorageLevel.NONE) { - data.edges.persist() + data.edges.cache() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala similarity index 70% rename from mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala rename to graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala index a13e7f63a929..e0c65e6940f6 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/impl/PeriodicGraphCheckpointerSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/util/PeriodicGraphCheckpointerSuite.scala @@ -15,77 +15,81 @@ * limitations under the License. */ -package org.apache.spark.mllib.impl +package org.apache.spark.graphx.util import org.apache.hadoop.fs.Path import org.apache.spark.{SparkContext, SparkFunSuite} -import org.apache.spark.graphx.{Edge, Graph} -import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.graphx.{Edge, Graph, LocalSparkContext} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils -class PeriodicGraphCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext { +class PeriodicGraphCheckpointerSuite extends SparkFunSuite with LocalSparkContext { import PeriodicGraphCheckpointerSuite._ test("Persisting") { var graphsToCheck = Seq.empty[GraphToCheck] - val graph1 = createGraph(sc) - val checkpointer = - new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkPersistence(graphsToCheck, 1) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkPersistence(graphsToCheck, iteration) - iteration += 1 + withSpark { sc => + val graph1 = createGraph(sc) + val checkpointer = + new PeriodicGraphCheckpointer[Double, Double](10, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkPersistence(graphsToCheck, 1) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkPersistence(graphsToCheck, iteration) + iteration += 1 + } } } test("Checkpointing") { - val tempDir = Utils.createTempDir() - val path = tempDir.toURI.toString - val checkpointInterval = 2 - var graphsToCheck = Seq.empty[GraphToCheck] - sc.setCheckpointDir(path) - val graph1 = createGraph(sc) - val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( - checkpointInterval, graph1.vertices.sparkContext) - checkpointer.update(graph1) - graph1.edges.count() - graph1.vertices.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) - checkCheckpoint(graphsToCheck, 1, checkpointInterval) - - var iteration = 2 - while (iteration < 9) { - val graph = createGraph(sc) - checkpointer.update(graph) - graph.vertices.count() - graph.edges.count() - graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) - checkCheckpoint(graphsToCheck, iteration, checkpointInterval) - iteration += 1 - } + withSpark { sc => + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + val checkpointInterval = 2 + var graphsToCheck = Seq.empty[GraphToCheck] + sc.setCheckpointDir(path) + val graph1 = createGraph(sc) + val checkpointer = new PeriodicGraphCheckpointer[Double, Double]( + checkpointInterval, graph1.vertices.sparkContext) + checkpointer.update(graph1) + graph1.edges.count() + graph1.vertices.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph1, 1) + checkCheckpoint(graphsToCheck, 1, checkpointInterval) + + var iteration = 2 + while (iteration < 9) { + val graph = createGraph(sc) + checkpointer.update(graph) + graph.vertices.count() + graph.edges.count() + graphsToCheck = graphsToCheck :+ GraphToCheck(graph, iteration) + checkCheckpoint(graphsToCheck, iteration, checkpointInterval) + iteration += 1 + } - checkpointer.deleteAllCheckpoints() - graphsToCheck.foreach { graph => - confirmCheckpointRemoved(graph.graph) - } + checkpointer.deleteAllCheckpoints() + graphsToCheck.foreach { graph => + confirmCheckpointRemoved(graph.graph) + } - Utils.deleteRecursively(tempDir) + Utils.deleteRecursively(tempDir) + } } } private object PeriodicGraphCheckpointerSuite { + private val defaultStorageLevel = StorageLevel.MEMORY_ONLY_SER case class GraphToCheck(graph: Graph[Double, Double], gIndex: Int) @@ -96,7 +100,8 @@ private object PeriodicGraphCheckpointerSuite { Edge[Double](3, 4, 0)) def createGraph(sc: SparkContext): Graph[Double, Double] = { - Graph.fromEdges[Double, Double](sc.parallelize(edges), 0) + Graph.fromEdges[Double, Double]( + sc.parallelize(edges), 0, defaultStorageLevel, defaultStorageLevel) } def checkPersistence(graphs: Seq[GraphToCheck], iteration: Int): Unit = { @@ -116,8 +121,8 @@ private object PeriodicGraphCheckpointerSuite { assert(graph.vertices.getStorageLevel == StorageLevel.NONE) assert(graph.edges.getStorageLevel == StorageLevel.NONE) } else { - assert(graph.vertices.getStorageLevel != StorageLevel.NONE) - assert(graph.edges.getStorageLevel != StorageLevel.NONE) + assert(graph.vertices.getStorageLevel == defaultStorageLevel) + assert(graph.edges.getStorageLevel == defaultStorageLevel) } } catch { case _: AssertionError => diff --git a/launcher/pom.xml b/launcher/pom.xml index 025cd84f20f0..a4bb50ce7dda 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index 663f7fb0b010..16cce0a49653 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index 82f840b0fc26..fec1be909946 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala index 32d78e9b226e..3aea568cd652 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala @@ -56,7 +56,7 @@ private[ann] class SigmoidLayerModelWithSquaredError extends FunctionalLayerModel(new FunctionalLayer(new SigmoidFunction)) with LossFunction { override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) - val error = Bsum(delta :* delta) / 2 / output.cols + val error = Bsum(delta *:* delta) / 2 / output.cols ApplyInPlace(delta, output, delta, (x: Double, o: Double) => x * (o - o * o)) error } @@ -119,6 +119,6 @@ private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = { ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t) - -Bsum( target :* brzlog(output)) / output.cols + -Bsum( target *:* brzlog(output)) / output.cols } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index f76b14eeeb54..7507c7539d4e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -458,9 +458,7 @@ private class LinearSVCAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") - require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $numFeatures but got ${features.size}.") + if (weight == 0.0) return this val localFeaturesStd = bcFeaturesStd.value val localCoefficients = coefficientsArray @@ -512,6 +510,7 @@ private class LinearSVCAggregator( * @return This LinearSVCAggregator object. */ def merge(other: LinearSVCAggregator): this.type = { + if (other.weightSum != 0.0) { weightSum += other.weightSum lossSum += other.lossSum diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 965ce3d6f275..42dc7fbebe4c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -22,7 +22,7 @@ import java.util.Locale import scala.collection.mutable import breeze.linalg.{DenseVector => BDV} -import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN} +import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, LBFGSB => BreezeLBFGSB, OWLQN => BreezeOWLQN} import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -178,11 +178,90 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } } + /** + * The lower bounds on coefficients if fitting under bound constrained optimization. + * The bound matrix must be compatible with the shape (1, number of features) for binomial + * regression, or (number of classes, number of features) for multinomial regression. + * Otherwise, it throws exception. + * Default is none. + * + * @group expertParam + */ + @Since("2.2.0") + val lowerBoundsOnCoefficients: Param[Matrix] = new Param(this, "lowerBoundsOnCoefficients", + "The lower bounds on coefficients if fitting under bound constrained optimization.") + + /** @group expertGetParam */ + @Since("2.2.0") + def getLowerBoundsOnCoefficients: Matrix = $(lowerBoundsOnCoefficients) + + /** + * The upper bounds on coefficients if fitting under bound constrained optimization. + * The bound matrix must be compatible with the shape (1, number of features) for binomial + * regression, or (number of classes, number of features) for multinomial regression. + * Otherwise, it throws exception. + * Default is none. + * + * @group expertParam + */ + @Since("2.2.0") + val upperBoundsOnCoefficients: Param[Matrix] = new Param(this, "upperBoundsOnCoefficients", + "The upper bounds on coefficients if fitting under bound constrained optimization.") + + /** @group expertGetParam */ + @Since("2.2.0") + def getUpperBoundsOnCoefficients: Matrix = $(upperBoundsOnCoefficients) + + /** + * The lower bounds on intercepts if fitting under bound constrained optimization. + * The bounds vector size must be equal with 1 for binomial regression, or the number + * of classes for multinomial regression. Otherwise, it throws exception. + * Default is none. + * + * @group expertParam + */ + @Since("2.2.0") + val lowerBoundsOnIntercepts: Param[Vector] = new Param(this, "lowerBoundsOnIntercepts", + "The lower bounds on intercepts if fitting under bound constrained optimization.") + + /** @group expertGetParam */ + @Since("2.2.0") + def getLowerBoundsOnIntercepts: Vector = $(lowerBoundsOnIntercepts) + + /** + * The upper bounds on intercepts if fitting under bound constrained optimization. + * The bound vector size must be equal with 1 for binomial regression, or the number + * of classes for multinomial regression. Otherwise, it throws exception. + * Default is none. + * + * @group expertParam + */ + @Since("2.2.0") + val upperBoundsOnIntercepts: Param[Vector] = new Param(this, "upperBoundsOnIntercepts", + "The upper bounds on intercepts if fitting under bound constrained optimization.") + + /** @group expertGetParam */ + @Since("2.2.0") + def getUpperBoundsOnIntercepts: Vector = $(upperBoundsOnIntercepts) + + protected def usingBoundConstrainedOptimization: Boolean = { + isSet(lowerBoundsOnCoefficients) || isSet(upperBoundsOnCoefficients) || + isSet(lowerBoundsOnIntercepts) || isSet(upperBoundsOnIntercepts) + } + override protected def validateAndTransformSchema( schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { checkThresholdConsistency() + if (usingBoundConstrainedOptimization) { + require($(elasticNetParam) == 0.0, "Fitting under bound constrained optimization only " + + s"supports L2 regularization, but got elasticNetParam = $getElasticNetParam.") + } + if (!$(fitIntercept)) { + require(!isSet(lowerBoundsOnIntercepts) && !isSet(upperBoundsOnIntercepts), + "Please don't set bounds on intercepts if fitting without intercept.") + } super.validateAndTransformSchema(schema, fitting, featuresDataType) } } @@ -217,6 +296,9 @@ class LogisticRegression @Since("1.2.0") ( * For alpha in (0,1), the penalty is a combination of L1 and L2. * Default is 0.0 which is an L2 penalty. * + * Note: Fitting under bound constrained optimization only supports L2 regularization, + * so throws exception if this param is non-zero value. + * * @group setParam */ @Since("1.4.0") @@ -312,6 +394,83 @@ class LogisticRegression @Since("1.2.0") ( def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) setDefault(aggregationDepth -> 2) + /** + * Set the lower bounds on coefficients if fitting under bound constrained optimization. + * + * @group expertSetParam + */ + @Since("2.2.0") + def setLowerBoundsOnCoefficients(value: Matrix): this.type = set(lowerBoundsOnCoefficients, value) + + /** + * Set the upper bounds on coefficients if fitting under bound constrained optimization. + * + * @group expertSetParam + */ + @Since("2.2.0") + def setUpperBoundsOnCoefficients(value: Matrix): this.type = set(upperBoundsOnCoefficients, value) + + /** + * Set the lower bounds on intercepts if fitting under bound constrained optimization. + * + * @group expertSetParam + */ + @Since("2.2.0") + def setLowerBoundsOnIntercepts(value: Vector): this.type = set(lowerBoundsOnIntercepts, value) + + /** + * Set the upper bounds on intercepts if fitting under bound constrained optimization. + * + * @group expertSetParam + */ + @Since("2.2.0") + def setUpperBoundsOnIntercepts(value: Vector): this.type = set(upperBoundsOnIntercepts, value) + + private def assertBoundConstrainedOptimizationParamsValid( + numCoefficientSets: Int, + numFeatures: Int): Unit = { + if (isSet(lowerBoundsOnCoefficients)) { + require($(lowerBoundsOnCoefficients).numRows == numCoefficientSets && + $(lowerBoundsOnCoefficients).numCols == numFeatures, + "The shape of LowerBoundsOnCoefficients must be compatible with (1, number of features) " + + "for binomial regression, or (number of classes, number of features) for multinomial " + + "regression, but found: " + + s"(${getLowerBoundsOnCoefficients.numRows}, ${getLowerBoundsOnCoefficients.numCols}).") + } + if (isSet(upperBoundsOnCoefficients)) { + require($(upperBoundsOnCoefficients).numRows == numCoefficientSets && + $(upperBoundsOnCoefficients).numCols == numFeatures, + "The shape of upperBoundsOnCoefficients must be compatible with (1, number of features) " + + "for binomial regression, or (number of classes, number of features) for multinomial " + + "regression, but found: " + + s"(${getUpperBoundsOnCoefficients.numRows}, ${getUpperBoundsOnCoefficients.numCols}).") + } + if (isSet(lowerBoundsOnIntercepts)) { + require($(lowerBoundsOnIntercepts).size == numCoefficientSets, "The size of " + + "lowerBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + s"classes for multinomial regression, but found: ${getLowerBoundsOnIntercepts.size}.") + } + if (isSet(upperBoundsOnIntercepts)) { + require($(upperBoundsOnIntercepts).size == numCoefficientSets, "The size of " + + "upperBoundsOnIntercepts must be equal with 1 for binomial regression, or the number of " + + s"classes for multinomial regression, but found: ${getUpperBoundsOnIntercepts.size}.") + } + if (isSet(lowerBoundsOnCoefficients) && isSet(upperBoundsOnCoefficients)) { + require($(lowerBoundsOnCoefficients).toArray.zip($(upperBoundsOnCoefficients).toArray) + .forall(x => x._1 <= x._2), "LowerBoundsOnCoefficients should always be " + + "less than or equal to upperBoundsOnCoefficients, but found: " + + s"lowerBoundsOnCoefficients = $getLowerBoundsOnCoefficients, " + + s"upperBoundsOnCoefficients = $getUpperBoundsOnCoefficients.") + } + if (isSet(lowerBoundsOnIntercepts) && isSet(upperBoundsOnIntercepts)) { + require($(lowerBoundsOnIntercepts).toArray.zip($(upperBoundsOnIntercepts).toArray) + .forall(x => x._1 <= x._2), "LowerBoundsOnIntercepts should always be " + + "less than or equal to upperBoundsOnIntercepts, but found: " + + s"lowerBoundsOnIntercepts = $getLowerBoundsOnIntercepts, " + + s"upperBoundsOnIntercepts = $getUpperBoundsOnIntercepts.") + } + } + private var optInitialModel: Option[LogisticRegressionModel] = None private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = { @@ -378,6 +537,11 @@ class LogisticRegression @Since("1.2.0") ( } val numCoefficientSets = if (isMultinomial) numClasses else 1 + // Check params interaction is valid if fitting under bound constrained optimization. + if (usingBoundConstrainedOptimization) { + assertBoundConstrainedOptimizationParamsValid(numCoefficientSets, numFeatures) + } + if (isDefined(thresholds)) { require($(thresholds).length == numClasses, this.getClass.getSimpleName + ".train() called with non-matching numClasses and thresholds.length." + @@ -397,7 +561,7 @@ class LogisticRegression @Since("1.2.0") ( val isConstantLabel = histogram.count(_ != 0.0) == 1 - if ($(fitIntercept) && isConstantLabel) { + if ($(fitIntercept) && isConstantLabel && !usingBoundConstrainedOptimization) { logWarning(s"All labels are the same value and fitIntercept=true, so the coefficients " + s"will be zeros. Training is not needed.") val constantLabelIndex = Vectors.dense(histogram).argmax @@ -434,8 +598,53 @@ class LogisticRegression @Since("1.2.0") ( $(standardization), bcFeaturesStd, regParamL2, multinomial = isMultinomial, $(aggregationDepth)) + val numCoeffsPlusIntercepts = numFeaturesPlusIntercept * numCoefficientSets + + val (lowerBounds, upperBounds): (Array[Double], Array[Double]) = { + if (usingBoundConstrainedOptimization) { + val lowerBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.NegativeInfinity) + val upperBounds = Array.fill[Double](numCoeffsPlusIntercepts)(Double.PositiveInfinity) + val isSetLowerBoundsOnCoefficients = isSet(lowerBoundsOnCoefficients) + val isSetUpperBoundsOnCoefficients = isSet(upperBoundsOnCoefficients) + val isSetLowerBoundsOnIntercepts = isSet(lowerBoundsOnIntercepts) + val isSetUpperBoundsOnIntercepts = isSet(upperBoundsOnIntercepts) + + var i = 0 + while (i < numCoeffsPlusIntercepts) { + val coefficientSetIndex = i % numCoefficientSets + val featureIndex = i / numCoefficientSets + if (featureIndex < numFeatures) { + if (isSetLowerBoundsOnCoefficients) { + lowerBounds(i) = $(lowerBoundsOnCoefficients)( + coefficientSetIndex, featureIndex) * featuresStd(featureIndex) + } + if (isSetUpperBoundsOnCoefficients) { + upperBounds(i) = $(upperBoundsOnCoefficients)( + coefficientSetIndex, featureIndex) * featuresStd(featureIndex) + } + } else { + if (isSetLowerBoundsOnIntercepts) { + lowerBounds(i) = $(lowerBoundsOnIntercepts)(coefficientSetIndex) + } + if (isSetUpperBoundsOnIntercepts) { + upperBounds(i) = $(upperBoundsOnIntercepts)(coefficientSetIndex) + } + } + i += 1 + } + (lowerBounds, upperBounds) + } else { + (null, null) + } + } + val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { - new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + if (lowerBounds != null && upperBounds != null) { + new BreezeLBFGSB( + BDV[Double](lowerBounds), BDV[Double](upperBounds), $(maxIter), 10, $(tol)) + } else { + new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) + } } else { val standardizationParam = $(standardization) def regParamL1Fun = (index: Int) => { @@ -546,6 +755,26 @@ class LogisticRegression @Since("1.2.0") ( math.log(histogram(1) / histogram(0))) } + if (usingBoundConstrainedOptimization) { + // Make sure all initial values locate in the corresponding bound. + var i = 0 + while (i < numCoeffsPlusIntercepts) { + val coefficientSetIndex = i % numCoefficientSets + val featureIndex = i / numCoefficientSets + if (initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) < lowerBounds(i)) + { + initialCoefWithInterceptMatrix.update( + coefficientSetIndex, featureIndex, lowerBounds(i)) + } else if ( + initialCoefWithInterceptMatrix(coefficientSetIndex, featureIndex) > upperBounds(i)) + { + initialCoefWithInterceptMatrix.update( + coefficientSetIndex, featureIndex, upperBounds(i)) + } + i += 1 + } + } + val states = optimizer.iterations(new CachedDiffFunction(costFun), new BDV[Double](initialCoefWithInterceptMatrix.toArray)) @@ -599,7 +828,7 @@ class LogisticRegression @Since("1.2.0") ( if (isIntercept) interceptVec.toArray(classIndex) = value } - if ($(regParam) == 0.0 && isMultinomial) { + if ($(regParam) == 0.0 && isMultinomial && !usingBoundConstrainedOptimization) { /* When no regularization is applied, the multinomial coefficients lack identifiability because we do not use a pivot class. We can add any constant value to the coefficients @@ -609,13 +838,18 @@ class LogisticRegression @Since("1.2.0") ( Friedman, et al. "Regularization Paths for Generalized Linear Models via Coordinate Descent," https://core.ac.uk/download/files/153/6287975.pdf */ - val denseValues = denseCoefficientMatrix.values - val coefficientMean = denseValues.sum / denseValues.length - denseCoefficientMatrix.update(_ - coefficientMean) + val centers = Array.fill(numFeatures)(0.0) + denseCoefficientMatrix.foreachActive { case (i, j, v) => + centers(j) += v + } + centers.transform(_ / numCoefficientSets) + denseCoefficientMatrix.foreachActive { case (i, j, v) => + denseCoefficientMatrix.update(i, j, v - centers(j)) + } } // center the intercepts when using multinomial algorithm - if ($(fitIntercept) && isMultinomial) { + if ($(fitIntercept) && isMultinomial && !usingBoundConstrainedOptimization) { val interceptArray = interceptVec.toArray val interceptMean = interceptArray.sum / interceptArray.length (0 until interceptVec.size).foreach { i => interceptArray(i) -= interceptMean } @@ -1566,9 +1800,6 @@ private class LogisticAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + - s" Expecting $numFeatures but got ${features.size}.") - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -1591,8 +1822,6 @@ private class LogisticAggregator( * @return This LogisticAggregator object. */ def merge(other: LogisticAggregator): this.type = { - require(numFeatures == other.numFeatures, s"Dimensions mismatch when merging with another " + - s"LogisticAggregator. Expecting $numFeatures but got ${other.numFeatures}.") if (other.weightSum != 0.0) { weightSum += other.weightSum diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index a9c1a7ba0bc8..5259ee419445 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -472,7 +472,7 @@ class GaussianMixture @Since("2.0.0") ( */ val cov = { val ss = new DenseVector(new Array[Double](numFeatures)).asBreeze - slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) :^ 2.0) + slice.foreach(xi => ss += (xi.asBreeze - mean.asBreeze) ^:^ 2.0) val diagVec = Vectors.fromBreeze(ss) BLAS.scal(1.0 / numSamples, diagVec) val covVec = new DenseVector(Array.fill[Double]( diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 2f50dc7c85f3..e3026c8efa82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -36,7 +36,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} -import org.apache.spark.mllib.impl.PeriodicCheckpointer import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} import org.apache.spark.mllib.linalg.MatrixImplicits._ import org.apache.spark.mllib.linalg.VectorImplicits._ @@ -45,9 +44,9 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.PeriodicCheckpointer import org.apache.spark.util.VersionUtils - private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter with HasSeed with HasCheckpointInterval { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index d1f3b2af1e48..bb8f2a3aa5f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -116,7 +116,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) } - val newCol = bucketizer(filteredDataset($(inputCol))) + val newCol = bucketizer(filteredDataset($(inputCol)).cast(DoubleType)) val newField = prepOutputField(filteredDataset.schema) filteredDataset.withColumn($(outputCol), newCol, newField.metadata) } @@ -130,7 +130,7 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { - SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(inputCol)) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala index d604c1ac001a..8f00daa59f1a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.fpm -import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.hadoop.fs.Path @@ -54,7 +53,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { /** * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears - * more than (minSupport * size-of-the-dataset) times will be output + * more than (minSupport * size-of-the-dataset) times will be output in the frequent itemsets. * Default: 0.3 * @group param */ @@ -82,8 +81,8 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { def getNumPartitions: Int = $(numPartitions) /** - * Minimal confidence for generating Association Rule. - * Note that minConfidence has no effect during fitting. + * Minimal confidence for generating Association Rule. minConfidence will not affect the mining + * for frequent itemsets, but will affect the association rules generation. * Default: 0.8 * @group param */ @@ -118,7 +117,7 @@ private[fpm] trait FPGrowthParams extends Params with HasPredictionCol { * Recommendation. PFP distributes computation in such a way that each worker executes an * independent group of mining tasks. The FP-Growth algorithm is described in * Han et al., Mining frequent patterns without - * candidate generation. Note null values in the feature column are ignored during fit(). + * candidate generation. Note null values in the itemsCol column are ignored during fit(). * * @see * Association rule learning (Wikipedia) @@ -167,7 +166,6 @@ class FPGrowth @Since("2.2.0") ( } val parentModel = mllibFP.run(items) val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq)) - val schema = StructType(Seq( StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false), StructField("freq", LongType, nullable = false))) @@ -196,7 +194,7 @@ object FPGrowth extends DefaultParamsReadable[FPGrowth] { * :: Experimental :: * Model fitted by FPGrowth. * - * @param freqItemsets frequent items in the format of DataFrame("items"[Seq], "freq"[Long]) + * @param freqItemsets frequent itemsets in the format of DataFrame("items"[Array], "freq"[Long]) */ @Since("2.2.0") @Experimental @@ -244,10 +242,13 @@ class FPGrowthModel private[ml] ( /** * The transform method first generates the association rules according to the frequent itemsets. - * Then for each association rule, it will examine the input items against antecedents and - * summarize the consequents as prediction. The prediction column has the same data type as the - * input column(Array[T]) and will not contain existing items in the input column. The null - * values in the feature columns are treated as empty sets. + * Then for each transaction in itemsCol, the transform method will compare its items against the + * antecedents of each association rule. If the record contains all the antecedents of a + * specific association rule, the rule will be considered as applicable and its consequents + * will be added to the prediction result. The transform method will summarize the consequents + * from all the applicable rules as prediction. The prediction column has the same data type as + * the input column(Array[T]) and will not contain existing items in the input column. The null + * values in the itemsCol columns are treated as empty sets. * WARNING: internally it collects association rules to the driver and uses broadcast for * efficiency. This may bring pressure to driver memory for large set of association rules. */ @@ -335,13 +336,13 @@ private[fpm] object AssociationRules { /** * Computes the association rules with confidence above minConfidence. - * @param dataset DataFrame("items", "freq") containing frequent itemset obtained from - * algorithms like [[FPGrowth]]. + * @param dataset DataFrame("items"[Array], "freq"[Long]) containing frequent itemsets obtained + * from algorithms like [[FPGrowth]]. * @param itemsCol column name for frequent itemsets - * @param freqCol column name for frequent itemsets count - * @param minConfidence minimum confidence for the result association rules - * @return a DataFrame("antecedent", "consequent", "confidence") containing the association - * rules. + * @param freqCol column name for appearance count of the frequent itemsets + * @param minConfidence minimum confidence for generating the association rules + * @return a DataFrame("antecedent"[Array], "consequent"[Array], "confidence"[Double]) + * containing the association rules. */ def getAssociationRulesFromFP[T: ClassTag]( dataset: Dataset[_], diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index d6093a01c671..bff0d9bbb46f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -894,10 +894,10 @@ object GeneralizedLinearRegression extends DefaultParamsReadable[GeneralizedLine private[regression] object Probit extends Link("probit") { - override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).icdf(mu) + override def link(mu: Double): Double = dist.Gaussian(0.0, 1.0).inverseCdf(mu) override def deriv(mu: Double): Double = { - 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).icdf(mu)) + 1.0 / dist.Gaussian(0.0, 1.0).pdf(dist.Gaussian(0.0, 1.0).inverseCdf(mu)) } override def unlink(eta: Double): Double = dist.Gaussian(0.0, 1.0).cdf(eta) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index f7e3c8fa5b6e..eaad54985229 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -971,9 +971,6 @@ private class LeastSquaresAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => - require(dim == features.size, s"Dimensions mismatch when adding new sample." + - s" Expecting $dim but got ${features.size}.") - require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") if (weight == 0.0) return this @@ -1005,8 +1002,6 @@ private class LeastSquaresAggregator( * @return This LeastSquaresAggregator object. */ def merge(other: LeastSquaresAggregator): this.type = { - require(dim == other.dim, s"Dimensions mismatch when merging with another " + - s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.") if (other.weightSum != 0) { totalCnt += other.totalCnt diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index 4c525c0714ec..ce2bd7b430f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -21,12 +21,12 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.LabeledPoint import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} -import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.util.PeriodicRDDCheckpointer import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 051ec2404fb6..4d952ac88c9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -271,7 +271,7 @@ class GaussianMixture private ( private def initCovariance(x: IndexedSeq[BV[Double]]): BreezeMatrix[Double] = { val mu = vectorMean(x) val ss = BDV.zeros[Double](x(0).length) - x.foreach(xi => ss += (xi - mu) :^ 2.0) + x.foreach(xi => ss += (xi - mu) ^:^ 2.0) diag(ss / x.length.toDouble) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 7fd722a33292..663f63c25a94 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -314,7 +314,7 @@ class LocalLDAModel private[spark] ( docBound += count * LDAUtils.logSumExp(Elogthetad + localElogbeta(idx, ::).t) } // E[log p(theta | alpha) - log q(theta | gamma)] - docBound += sum((brzAlpha - gammad) :* Elogthetad) + docBound += sum((brzAlpha - gammad) *:* Elogthetad) docBound += sum(lgamma(gammad) - lgamma(brzAlpha)) docBound += lgamma(sum(brzAlpha)) - lgamma(sum(gammad)) @@ -324,7 +324,7 @@ class LocalLDAModel private[spark] ( // Bound component for prob(topic-term distributions): // E[log p(beta | eta) - log q(beta | lambda)] val sumEta = eta * vocabSize - val topicsPart = sum((eta - lambda) :* Elogbeta) + + val topicsPart = sum((eta - lambda) *:* Elogbeta) + sum(lgamma(lambda) - lgamma(eta)) + sum(lgamma(sumEta) - lgamma(sum(lambda(::, breeze.linalg.*)))) @@ -721,7 +721,7 @@ class DistributedLDAModel private[clustering] ( val N_wj = edgeContext.attr val smoothed_N_wk: TopicCounts = edgeContext.dstAttr + (eta - 1.0) val smoothed_N_kj: TopicCounts = edgeContext.srcAttr + (alpha - 1.0) - val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val phi_wk: TopicCounts = smoothed_N_wk /:/ smoothed_N_k val theta_kj: TopicCounts = normalize(smoothed_N_kj, 1.0) val tokenLogLikelihood = N_wj * math.log(phi_wk.dot(theta_kj)) edgeContext.sendToDst(tokenLogLikelihood) @@ -748,7 +748,7 @@ class DistributedLDAModel private[clustering] ( if (isTermVertex(vertex)) { val N_wk = vertex._2 val smoothed_N_wk: TopicCounts = N_wk + (eta - 1.0) - val phi_wk: TopicCounts = smoothed_N_wk :/ smoothed_N_k + val phi_wk: TopicCounts = smoothed_N_wk /:/ smoothed_N_k sumPrior + (eta - 1.0) * sum(phi_wk.map(math.log)) } else { val N_kj = vertex._2 @@ -788,20 +788,14 @@ class DistributedLDAModel private[clustering] ( @Since("1.5.0") def topTopicsPerDocument(k: Int): RDD[(Long, Array[Int], Array[Double])] = { graph.vertices.filter(LDA.isDocumentVertex).map { case (docID, topicCounts) => - // TODO: Remove work-around for the breeze bug. - // https://github.com/scalanlp/breeze/issues/561 - val topIndices = if (k == topicCounts.length) { - Seq.range(0, k) - } else { - argtopk(topicCounts, k) - } + val topIndices = argtopk(topicCounts, k) val sumCounts = sum(topicCounts) val weights = if (sumCounts != 0) { - topicCounts(topIndices) / sumCounts + topicCounts(topIndices).toArray.map(_ / sumCounts) } else { - topicCounts(topIndices) + topicCounts(topIndices).toArray } - (docID.toLong, topIndices.toArray, weights.toArray) + (docID.toLong, topIndices.toArray, weights) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 48bae4276c48..d633893e55f5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ -import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer +import org.apache.spark.graphx.util.PeriodicGraphCheckpointer import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -482,7 +482,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) stats.unpersist() expElogbetaBc.destroy(false) - val batchResult = statsSum :* expElogbeta.t + val batchResult = statsSum *:* expElogbeta.t // Note that this is an optimization to avoid batch.count updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) @@ -522,7 +522,7 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val dalpha = -(gradf - b) / q - if (all((weight * dalpha + alpha) :> 0D)) { + if (all((weight * dalpha + alpha) >:> 0D)) { alpha :+= weight * dalpha this.alpha = Vectors.dense(alpha.toArray) } @@ -584,7 +584,7 @@ private[clustering] object OnlineLDAOptimizer { val expElogthetad: BDV[Double] = exp(LDAUtils.dirichletExpectation(gammad)) // K val expElogbetad = expElogbeta(ids, ::).toDenseMatrix // ids * K - val phiNorm: BDV[Double] = expElogbetad * expElogthetad :+ 1e-100 // ids + val phiNorm: BDV[Double] = expElogbetad * expElogthetad +:+ 1e-100 // ids var meanGammaChange = 1D val ctsVector = new BDV[Double](cts) // ids @@ -592,14 +592,14 @@ private[clustering] object OnlineLDAOptimizer { while (meanGammaChange > 1e-3) { val lastgamma = gammad.copy // K K * ids ids - gammad := (expElogthetad :* (expElogbetad.t * (ctsVector :/ phiNorm))) :+ alpha + gammad := (expElogthetad *:* (expElogbetad.t * (ctsVector /:/ phiNorm))) +:+ alpha expElogthetad := exp(LDAUtils.dirichletExpectation(gammad)) // TODO: Keep more values in log space, and only exponentiate when needed. - phiNorm := expElogbetad * expElogthetad :+ 1e-100 + phiNorm := expElogbetad * expElogthetad +:+ 1e-100 meanGammaChange = sum(abs(gammad - lastgamma)) / k } - val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector :/ phiNorm).asDenseMatrix + val sstatsd = expElogthetad.asDenseMatrix.t * (ctsVector /:/ phiNorm).asDenseMatrix (gammad, sstatsd, ids) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index 1f6e1a077f92..c4bbe51a46c3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -29,7 +29,7 @@ private[clustering] object LDAUtils { */ private[clustering] def logSumExp(x: BDV[Double]): Double = { val a = max(x) - a + log(sum(exp(x :- a))) + a + log(sum(exp(x -:- a))) } /** diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index c858b9bbfc25..bf6bfe30bfe2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.classification.LogisticRegressionSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, SparseMatrix, Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix, Vector, Vectors} import org.apache.spark.ml.param.{ParamMap, ParamsSuite} import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ @@ -150,6 +150,54 @@ class LogisticRegressionSuite assert(!model.hasSummary) } + test("logistic regression: illegal params") { + val lowerBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnCoefficients1 = Matrices.dense(1, 4, Array(0.0, 1.0, 1.0, 0.0)) + val upperBoundsOnCoefficients2 = Matrices.dense(1, 3, Array(1.0, 0.0, 1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(1.0) + + // Work well when only set bound in one side. + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .fit(binaryDataset) + + withClue("bound constrained optimization only supports L2 regularization") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setElasticNetParam(1.0) + .fit(binaryDataset) + } + } + + withClue("lowerBoundsOnCoefficients should less than or equal to upperBoundsOnCoefficients") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients1) + .fit(binaryDataset) + } + } + + withClue("the coefficients bound matrix mismatched with shape (1, number of features)") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients2) + .fit(binaryDataset) + } + } + + withClue("bounds on intercepts should not be set if fitting without intercept") { + intercept[IllegalArgumentException] { + new LogisticRegression() + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(false) + .fit(binaryDataset) + } + } + } + test("empty probabilityCol") { val lr = new LogisticRegression().setProbabilityCol("") val model = lr.fit(smallBinaryDataset) @@ -610,6 +658,107 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-3) } + test("binary logistic regression with intercept without regularization with bound") { + // Bound constrained optimization with bound on one side. + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnIntercepts = Vectors.dense(1.0) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected1 = Vectors.dense(0.06079437, 0.0, -0.26351059, -0.59102199) + val interceptExpected1 = 1.0 + + assert(model1.intercept ~== interceptExpected1 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpected1 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.intercept ~== interceptExpected1 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected1 relTol 1E-3) + + // Bound constrained optimization with bound on both side. + val lowerBoundsOnCoefficients = Matrices.dense(1, 4, Array(0.0, -1.0, 0.0, -1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(0.0) + + val trainer3 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer4 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model3 = trainer3.fit(binaryDataset) + val model4 = trainer4.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected3 = Vectors.dense(0.0, 0.0, 0.0, -0.71708632) + val interceptExpected3 = 0.58776113 + + assert(model3.intercept ~== interceptExpected3 relTol 1E-3) + assert(model3.coefficients ~= coefficientsExpected3 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model4.intercept ~== interceptExpected3 relTol 1E-3) + assert(model4.coefficients ~= coefficientsExpected3 relTol 1E-3) + + // Bound constrained optimization with infinite bound on both side. + val trainer5 = new LogisticRegression() + .setUpperBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Double.PositiveInfinity)) + .setLowerBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Double.NegativeInfinity)) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer6 = new LogisticRegression() + .setUpperBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Double.PositiveInfinity)) + .setLowerBoundsOnCoefficients(Matrices.dense(1, 4, Array.fill(4)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Double.NegativeInfinity)) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model5 = trainer5.fit(binaryDataset) + val model6 = trainer6.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + // It should be same as unbound constrained optimization with LBFGS. + val coefficientsExpected5 = Vectors.dense(-0.5734389, 0.8911736, -0.3878645, -0.8060570) + val interceptExpected5 = 2.7355261 + + assert(model5.intercept ~== interceptExpected5 relTol 1E-3) + assert(model5.coefficients ~= coefficientsExpected5 relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model6.intercept ~== interceptExpected5 relTol 1E-3) + assert(model6.coefficients ~= coefficientsExpected5 relTol 1E-3) + } + test("binary logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false).setStandardization(true) .setWeightCol("weight") @@ -650,6 +799,34 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-2) } + test("binary logistic regression without intercept without regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)).toSparse + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected = Vectors.dense(0.20847553, 0.0, -0.24240289, -0.55568071) + + assert(model1.intercept ~== 0.0 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpected relTol 1E-3) + + // Without regularization, with or without standardization will converge to the same solution. + assert(model2.intercept ~== 0.0 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression with intercept with L1 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(1.0).setRegParam(0.12).setStandardization(true).setWeightCol("weight") @@ -815,6 +992,40 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-3) } + test("binary logistic regression with intercept with L2 regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + val upperBoundsOnIntercepts = Vectors.dense(1.0) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setRegParam(1.37) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setRegParam(1.37) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = Vectors.dense(-0.06985003, 0.0, -0.04794278, -0.10168595) + val interceptExpectedWithStd = 0.45750141 + val coefficientsExpected = Vectors.dense(-0.0494524, 0.0, -0.11360797, -0.06313577) + val interceptExpected = 0.53722967 + + assert(model1.intercept ~== interceptExpectedWithStd relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpectedWithStd relTol 1E-3) + assert(model2.intercept ~== interceptExpected relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) .setElasticNetParam(0.0).setRegParam(1.37).setStandardization(true).setWeightCol("weight") @@ -864,6 +1075,35 @@ class LogisticRegressionSuite assert(model2.coefficients ~= coefficientsR relTol 1E-2) } + test("binary logistic regression without intercept with L2 regularization with bound") { + val upperBoundsOnCoefficients = Matrices.dense(1, 4, Array(1.0, 0.0, 1.0, 0.0)) + + val trainer1 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setRegParam(1.37) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setRegParam(1.37) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(binaryDataset) + val model2 = trainer2.fit(binaryDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = Vectors.dense(-0.00796538, 0.0, -0.0394228, -0.0873314) + val coefficientsExpected = Vectors.dense(0.01105972, 0.0, -0.08574949, -0.05079558) + + assert(model1.intercept ~== 0.0 relTol 1E-3) + assert(model1.coefficients ~= coefficientsExpectedWithStd relTol 1E-3) + assert(model2.intercept ~== 0.0 relTol 1E-3) + assert(model2.coefficients ~= coefficientsExpected relTol 1E-3) + } + test("binary logistic regression with intercept with ElasticNet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setMaxIter(200) .setElasticNetParam(0.38).setRegParam(0.21).setStandardization(true).setWeightCol("weight") @@ -1084,7 +1324,6 @@ class LogisticRegressionSuite } test("multinomial logistic regression with intercept without regularization") { - val trainer1 = (new LogisticRegression).setFitIntercept(true) .setElasticNetParam(0.0).setRegParam(0.0).setStandardization(true).setWeightCol("weight") val trainer2 = (new LogisticRegression).setFitIntercept(true) @@ -1139,6 +1378,9 @@ class LogisticRegressionSuite 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true) val interceptsR = Vectors.dense(-2.10320093, 0.3394473, 1.76375361) + model1.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + model2.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) assert(model1.interceptVector ~== interceptsR relTol 0.05) @@ -1149,6 +1391,110 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression with intercept without regularization with bound") { + // Bound constrained optimization with bound on one side. + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(Array.fill(3)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected1 = new DenseMatrix(3, 4, Array( + 2.52076464, 2.73596057, 1.87984904, 2.73264492, + 1.93302281, 3.71363303, 1.50681746, 1.93398782, + 2.37839917, 1.93601818, 1.81924758, 2.45191255), isTransposed = true) + val interceptsExpected1 = Vectors.dense(1.00010477, 3.44237083, 4.86740286) + + checkCoefficientsEquivalent(model1.coefficientMatrix, coefficientsExpected1) + assert(model1.interceptVector ~== interceptsExpected1 relTol 0.01) + checkCoefficientsEquivalent(model2.coefficientMatrix, coefficientsExpected1) + assert(model2.interceptVector ~== interceptsExpected1 relTol 0.01) + + // Bound constrained optimization with bound on both side. + val upperBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(2.0)) + val upperBoundsOnIntercepts = Vectors.dense(Array.fill(3)(2.0)) + + val trainer3 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer4 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setUpperBoundsOnCoefficients(upperBoundsOnCoefficients) + .setUpperBoundsOnIntercepts(upperBoundsOnIntercepts) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model3 = trainer3.fit(multinomialDataset) + val model4 = trainer4.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected3 = new DenseMatrix(3, 4, Array( + 1.61967097, 1.16027835, 1.45131448, 1.97390431, + 1.30529317, 2.0, 1.12985473, 1.26652854, + 1.61647195, 1.0, 1.40642959, 1.72985589), isTransposed = true) + val interceptsExpected3 = Vectors.dense(1.0, 2.0, 2.0) + + checkCoefficientsEquivalent(model3.coefficientMatrix, coefficientsExpected3) + assert(model3.interceptVector ~== interceptsExpected3 relTol 0.01) + checkCoefficientsEquivalent(model4.coefficientMatrix, coefficientsExpected3) + assert(model4.interceptVector ~== interceptsExpected3 relTol 0.01) + + // Bound constrained optimization with infinite bound on both side. + val trainer5 = new LogisticRegression() + .setLowerBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.NegativeInfinity))) + .setUpperBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.PositiveInfinity))) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer6 = new LogisticRegression() + .setLowerBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.NegativeInfinity))) + .setLowerBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.NegativeInfinity))) + .setUpperBoundsOnCoefficients(Matrices.dense(3, 4, Array.fill(12)(Double.PositiveInfinity))) + .setUpperBoundsOnIntercepts(Vectors.dense(Array.fill(3)(Double.PositiveInfinity))) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model5 = trainer5.fit(multinomialDataset) + val model6 = trainer6.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + // It should be same as unbound constrained optimization with LBFGS. + val coefficientsExpected5 = new DenseMatrix(3, 4, Array( + 0.24337896, -0.05916156, 0.14446790, 0.35976165, + -0.3443375, 0.9181331, -0.2283959, -0.4388066, + 0.10095851, -0.85897154, 0.08392798, 0.07904499), isTransposed = true) + val interceptsExpected5 = Vectors.dense(-2.10320093, 0.3394473, 1.76375361) + + checkCoefficientsEquivalent(model5.coefficientMatrix, coefficientsExpected5) + assert(model5.interceptVector ~== interceptsExpected5 relTol 0.01) + checkCoefficientsEquivalent(model6.coefficientMatrix, coefficientsExpected5) + assert(model6.interceptVector ~== interceptsExpected5 relTol 0.01) + } + test("multinomial logistic regression without intercept without regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) @@ -1204,6 +1550,9 @@ class LogisticRegressionSuite -0.3180040, 0.9679074, -0.2252219, -0.4319914, 0.2452411, -0.6046524, 0.1050710, 0.1180180), isTransposed = true) + model1.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + model2.coefficientMatrix.colIter.foreach(v => assert(v.toArray.sum ~== 0.0 absTol eps)) + assert(model1.coefficientMatrix ~== coefficientsR relTol 0.05) assert(model1.coefficientMatrix.toArray.sum ~== 0.0 absTol eps) assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) @@ -1214,6 +1563,35 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression without intercept without regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.62410051, 1.38219391, 1.34486618, 1.74641729, + 1.23058989, 2.71787825, 1.0, 1.00007073, + 1.79478632, 1.14360459, 1.33011603, 1.55093897), isTransposed = true) + + checkCoefficientsEquivalent(model1.coefficientMatrix, coefficientsExpected) + assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) + checkCoefficientsEquivalent(model2.coefficientMatrix, coefficientsExpected) + assert(model2.interceptVector.toArray === Array.fill(3)(0.0)) + } + test("multinomial logistic regression with intercept with L1 regularization") { // use tighter constraints because OWL-QN solver takes longer to converge @@ -1512,6 +1890,46 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression with intercept with L2 regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + val lowerBoundsOnIntercepts = Vectors.dense(Array.fill(3)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setRegParam(0.1) + .setFitIntercept(true) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setLowerBoundsOnIntercepts(lowerBoundsOnIntercepts) + .setRegParam(0.1) + .setFitIntercept(true) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.0, 1.01647497, + 1.0, 1.44105616, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0), isTransposed = true) + val interceptsExpectedWithStd = Vectors.dense(2.52055893, 1.0, 2.560682) + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.03189386, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0), isTransposed = true) + val interceptsExpected = Vectors.dense(1.06418835, 1.0, 1.20494701) + + assert(model1.coefficientMatrix ~== coefficientsExpectedWithStd relTol 0.01) + assert(model1.interceptVector ~== interceptsExpectedWithStd relTol 0.01) + assert(model2.coefficientMatrix ~== coefficientsExpected relTol 0.01) + assert(model2.interceptVector ~== interceptsExpected relTol 0.01) + } + test("multinomial logistic regression without intercept with L2 regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(false) .setElasticNetParam(0.0).setRegParam(0.1).setStandardization(true).setWeightCol("weight") @@ -1609,6 +2027,41 @@ class LogisticRegressionSuite assert(model2.interceptVector.toArray.sum ~== 0.0 absTol eps) } + test("multinomial logistic regression without intercept with L2 regularization with bound") { + val lowerBoundsOnCoefficients = Matrices.dense(3, 4, Array.fill(12)(1.0)) + + val trainer1 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setRegParam(0.1) + .setFitIntercept(false) + .setStandardization(true) + .setWeightCol("weight") + val trainer2 = new LogisticRegression() + .setLowerBoundsOnCoefficients(lowerBoundsOnCoefficients) + .setRegParam(0.1) + .setFitIntercept(false) + .setStandardization(false) + .setWeightCol("weight") + + val model1 = trainer1.fit(multinomialDataset) + val model2 = trainer2.fit(multinomialDataset) + + // The solution is generated by https://github.com/yanboliang/bound-optimization. + val coefficientsExpectedWithStd = new DenseMatrix(3, 4, Array( + 1.01324653, 1.0, 1.0, 1.0415767, + 1.0, 1.0, 1.0, 1.0, + 1.02244888, 1.0, 1.0, 1.0), isTransposed = true) + val coefficientsExpected = new DenseMatrix(3, 4, Array( + 1.0, 1.0, 1.03932259, 1.0, + 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.03274649, 1.0), isTransposed = true) + + assert(model1.coefficientMatrix ~== coefficientsExpectedWithStd absTol 0.01) + assert(model1.interceptVector.toArray === Array.fill(3)(0.0)) + assert(model2.coefficientMatrix ~== coefficientsExpected absTol 0.01) + assert(model2.interceptVector.toArray === Array.fill(3)(0.0)) + } + test("multinomial logistic regression with intercept with elasticnet regularization") { val trainer1 = (new LogisticRegression).setFitIntercept(true).setWeightCol("weight") .setElasticNetParam(0.5).setRegParam(0.1).setStandardization(true) @@ -2267,4 +2720,19 @@ object LogisticRegressionSuite { val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i))) testData } + + /** + * When no regularization is applied, the multinomial coefficients lack identifiability + * because we do not use a pivot class. We can add any constant value to the coefficients + * and get the same likelihood. If fitting under bound constrained optimization, we don't + * choose the mean centered coefficients like what we do for unbound problems, since they + * may out of the bounds. We use this function to check whether two coefficients are equivalent. + */ + def checkCoefficientsEquivalent(coefficients1: Matrix, coefficients2: Matrix): Unit = { + coefficients1.colIter.zip(coefficients2.colIter).foreach { case (col1: Vector, col2: Vector) => + (col1.asBreeze - col2.asBreeze).toArray.toSeq.sliding(2).foreach { + case Seq(v1, v2) => assert(v1 ~= v2 absTol 1E-3) + } + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index b56f8e19ca53..3a2be236f125 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -168,7 +168,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(m1.pi ~== m2.pi relTol 0.01) assert(m1.theta ~== m2.theta relTol 0.01) } - val testParams = Seq( + val testParams = Seq[(String, Dataset[_])]( ("bernoulli", bernoulliDataset), ("multinomial", dataset) ) diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index aac29137d791..420fb17ddce8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -26,6 +26,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.types._ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -162,6 +164,29 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa .setSplits(Array(0.1, 0.8, 0.9)) testDefaultReadWrite(t) } + + test("Bucket numeric features") { + val splits = Array(-3.0, 0.0, 3.0) + val data = Array(-2.0, -1.0, 0.0, 1.0, 2.0) + val expectedBuckets = Array(0.0, 0.0, 1.0, 1.0, 1.0) + val dataFrame: DataFrame = data.zip(expectedBuckets).toSeq.toDF("feature", "expected") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits) + + val types = Seq(ShortType, IntegerType, LongType, FloatType, DoubleType, + ByteType, DecimalType(10, 0)) + for (mType <- types) { + val df = dataFrame.withColumn("feature", col("feature").cast(mType)) + bucketizer.transform(df).select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, "The result is not correct after bucketing in type " + + mType.toString + ". " + s"Expected $y but found $x.") + } + } + } } private object BucketizerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 6806cb03bc42..87f8b9034dde 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -122,6 +122,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setMinConfidence(0.5678) assert(fpGrowth.getMinSupport === 0.4567) assert(model.getMinConfidence === 0.5678) + // numPartitions should not have default value. + assert(fpGrowth.isDefined(fpGrowth.numPartitions) === false) MLTestingUtils.checkCopyAndUids(fpGrowth, model) ParamsSuite.checkParams(fpGrowth) ParamsSuite.checkParams(model) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala index 572959200f47..3d6a9f8d84ca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala @@ -191,8 +191,8 @@ class LBFGSSuite extends SparkFunSuite with MLlibTestSparkContext with Matchers // With smaller convergenceTol, it takes more steps. assert(lossLBFGS3.length > lossLBFGS2.length) - // Based on observation, lossLBFGS2 runs 5 iterations, no theoretically guaranteed. - assert(lossLBFGS3.length == 6) + // Based on observation, lossLBFGS3 runs 7 iterations, no theoretically guaranteed. + assert(lossLBFGS3.length == 7) assert((lossLBFGS3(4) - lossLBFGS3(5)) / lossLBFGS3(4) < convergenceTol) } diff --git a/pom.xml b/pom.xml index 14370d92a908..ccd8546a269c 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -58,10 +58,6 @@ https://issues.apache.org/jira/browse/SPARK - - ${maven.version} - - Dev Mailing List @@ -136,13 +132,12 @@ 10.12.1.1 1.8.2 1.6.0 - 9.2.16.v20160414 + 9.3.11.v20160721 3.1.0 0.8.0 2.4.0 2.0.8 3.1.2 - 1.7.7 hadoop2 0.9.3 @@ -659,7 +654,7 @@ org.scalanlp breeze_${scala.binary.version} - 0.12 + 0.13.1 diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 77dae289f775..e52baf51aed1 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -318,8 +318,8 @@ object SparkBuild extends PomBuild { enable(MimaBuild.mimaSettings(sparkHome, x))(x) } - /* Generate and pick the spark build info from extra-resources and override a dependency */ - enable(Core.settings ++ CoreDependencyOverrides.settings)(core) + /* Generate and pick the spark build info from extra-resources */ + enable(Core.settings)(core) /* Unsafe settings */ enable(Unsafe.settings)(unsafe) @@ -443,16 +443,6 @@ object DockerIntegrationTests { ) } -/** - * Overrides to work around sbt's dependency resolution being different from Maven's in Unidoc. - * - * Note that, this is a hack that should be removed in the future. See SPARK-20343 - */ -object CoreDependencyOverrides { - lazy val settings = Seq( - dependencyOverrides += "org.apache.avro" % "avro" % "1.7.7") -} - /** * Overrides to work around sbt's dependency resolution being different from Maven's. */ diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 2961cda553d6..3be07325f416 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -240,6 +240,32 @@ def signal_handler(signal, frame): if isinstance(threading.current_thread(), threading._MainThread): signal.signal(signal.SIGINT, signal_handler) + def __repr__(self): + return "".format( + master=self.master, + appName=self.appName, + ) + + def _repr_html_(self): + return """ +

    +

    SparkContext

    + +

    Spark UI

    + +
    +
    Version
    +
    v{sc.version}
    +
    Master
    +
    {sc.master}
    +
    AppName
    +
    {sc.appName}
    +
    +
    + """.format( + sc=self + ) + def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index b4fc357e42d7..a9756ea4af99 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -185,36 +185,33 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors >>> bdf = sc.parallelize([ - ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], []))]).toDF() - >>> blor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight") + ... Row(label=1.0, weight=1.0, features=Vectors.dense(0.0, 5.0)), + ... Row(label=0.0, weight=2.0, features=Vectors.dense(1.0, 2.0)), + ... Row(label=1.0, weight=3.0, features=Vectors.dense(2.0, 1.0)), + ... Row(label=0.0, weight=4.0, features=Vectors.dense(3.0, 3.0))]).toDF() + >>> blor = LogisticRegression(regParam=0.01, weightCol="weight") >>> blorModel = blor.fit(bdf) >>> blorModel.coefficients - DenseVector([5.5...]) + DenseVector([-1.080..., -0.646...]) >>> blorModel.intercept - -2.68... - >>> mdf = sc.parallelize([ - ... Row(label=1.0, weight=2.0, features=Vectors.dense(1.0)), - ... Row(label=0.0, weight=2.0, features=Vectors.sparse(1, [], [])), - ... Row(label=2.0, weight=2.0, features=Vectors.dense(3.0))]).toDF() - >>> mlor = LogisticRegression(maxIter=5, regParam=0.01, weightCol="weight", - ... family="multinomial") + 3.112... + >>> data_path = "data/mllib/sample_multiclass_classification_data.txt" + >>> mdf = spark.read.format("libsvm").load(data_path) + >>> mlor = LogisticRegression(regParam=0.1, elasticNetParam=1.0, family="multinomial") >>> mlorModel = mlor.fit(mdf) - >>> print(mlorModel.coefficientMatrix) - DenseMatrix([[-2.3...], - [ 0.2...], - [ 2.1... ]]) + >>> mlorModel.coefficientMatrix + SparseMatrix(3, 4, [0, 1, 2, 3], [3, 2, 1], [1.87..., -2.75..., -0.50...], 1) >>> mlorModel.interceptVector - DenseVector([2.0..., 0.8..., -2.8...]) - >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0))]).toDF() + DenseVector([0.04..., -0.42..., 0.37...]) + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 1.0))]).toDF() >>> result = blorModel.transform(test0).head() >>> result.prediction - 0.0 + 1.0 >>> result.probability - DenseVector([0.99..., 0.00...]) + DenseVector([0.02..., 0.97...]) >>> result.rawPrediction - DenseVector([8.22..., -8.22...]) - >>> test1 = sc.parallelize([Row(features=Vectors.sparse(1, [0], [1.0]))]).toDF() + DenseVector([-3.54..., 3.54...]) + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() >>> blorModel.transform(test1).head().prediction 1.0 >>> blor.setParams("vector") @@ -224,8 +221,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti >>> lr_path = temp_path + "/lr" >>> blor.save(lr_path) >>> lr2 = LogisticRegression.load(lr_path) - >>> lr2.getMaxIter() - 5 + >>> lr2.getRegParam() + 0.01 >>> model_path = temp_path + "/lr_model" >>> blorModel.save(model_path) >>> model2 = LogisticRegressionModel.load(model_path) @@ -1482,31 +1479,33 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): >>> from pyspark.sql import Row >>> from pyspark.ml.linalg import Vectors - >>> df = sc.parallelize([ - ... Row(label=0.0, features=Vectors.dense(1.0, 0.8)), - ... Row(label=1.0, features=Vectors.sparse(2, [], [])), - ... Row(label=2.0, features=Vectors.dense(0.5, 0.5))]).toDF() - >>> lr = LogisticRegression(maxIter=5, regParam=0.01) + >>> data_path = "data/mllib/sample_multiclass_classification_data.txt" + >>> df = spark.read.format("libsvm").load(data_path) + >>> lr = LogisticRegression(regParam=0.01) >>> ovr = OneVsRest(classifier=lr) >>> model = ovr.fit(df) - >>> [x.coefficients for x in model.models] - [DenseVector([3.3925, 1.8785]), DenseVector([-4.3016, -6.3163]), DenseVector([-4.5855, 6.1785])] + >>> model.models[0].coefficients + DenseVector([0.5..., -1.0..., 3.4..., 4.2...]) + >>> model.models[1].coefficients + DenseVector([-2.1..., 3.1..., -2.6..., -2.3...]) + >>> model.models[2].coefficients + DenseVector([0.3..., -3.4..., 1.0..., -1.1...]) >>> [x.intercept for x in model.models] - [-3.64747..., 2.55078..., -1.10165...] - >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0))]).toDF() + [-2.7..., -2.5..., -1.3...] + >>> test0 = sc.parallelize([Row(features=Vectors.dense(-1.0, 0.0, 1.0, 1.0))]).toDF() >>> model.transform(test0).head().prediction - 1.0 - >>> test1 = sc.parallelize([Row(features=Vectors.sparse(2, [0], [1.0]))]).toDF() - >>> model.transform(test1).head().prediction 0.0 - >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4))]).toDF() - >>> model.transform(test2).head().prediction + >>> test1 = sc.parallelize([Row(features=Vectors.sparse(4, [0], [1.0]))]).toDF() + >>> model.transform(test1).head().prediction 2.0 + >>> test2 = sc.parallelize([Row(features=Vectors.dense(0.5, 0.4, 0.3, 0.2))]).toDF() + >>> model.transform(test2).head().prediction + 0.0 >>> model_path = temp_path + "/ovr_model" >>> model.save(model_path) >>> model2 = OneVsRestModel.load(model_path) >>> model2.transform(test0).head().prediction - 1.0 + 0.0 .. versionadded:: 2.0.0 """ diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 8bc899a0788b..bcfb36880eb0 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -82,6 +82,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] Row(user=2, item=0, prediction=-1.5018409490585327) + >>> user_recs = model.recommendForAllUsers(3) + >>> user_recs.where(user_recs.user == 0)\ + .select("recommendations.item", "recommendations.rating").collect() + [Row(item=[0, 1, 2], rating=[3.910..., 1.992..., -0.138...])] + >>> item_recs = model.recommendForAllItems(3) + >>> item_recs.where(item_recs.item == 2)\ + .select("recommendations.user", "recommendations.rating").collect() + [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])] >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) @@ -384,6 +392,28 @@ def itemFactors(self): """ return self._call_java("itemFactors") + @since("2.2.0") + def recommendForAllUsers(self, numItems): + """ + Returns top `numItems` items recommended for each user, for all users. + + :param numItems: max number of recommendations for each user + :return: a DataFrame of (userCol, recommendations), where recommendations are + stored as an array of (itemCol, rating) Rows. + """ + return self._call_java("recommendForAllUsers", numItems) + + @since("2.2.0") + def recommendForAllItems(self, numUsers): + """ + Returns top `numUsers` users recommended for each item, for all items. + + :param numUsers: max number of recommendations for each item + :return: a DataFrame of (itemCol, recommendations), where recommendations are + stored as an array of (userCol, rating) Rows. + """ + return self._call_java("recommendForAllItems", numUsers) + if __name__ == "__main__": import doctest diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 600655c912ca..4cb802514be5 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -28,14 +28,13 @@ from pyspark import RDD, since from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper -from pyspark.mllib.linalg import _convert_to_vector, Matrix, QRDecomposition +from pyspark.mllib.linalg import _convert_to_vector, DenseMatrix, Matrix, QRDecomposition from pyspark.mllib.stat import MultivariateStatisticalSummary from pyspark.storagelevel import StorageLevel -__all__ = ['DistributedMatrix', 'RowMatrix', 'IndexedRow', - 'IndexedRowMatrix', 'MatrixEntry', 'CoordinateMatrix', - 'BlockMatrix'] +__all__ = ['BlockMatrix', 'CoordinateMatrix', 'DistributedMatrix', 'IndexedRow', + 'IndexedRowMatrix', 'MatrixEntry', 'RowMatrix', 'SingularValueDecomposition'] class DistributedMatrix(object): @@ -301,6 +300,136 @@ def tallSkinnyQR(self, computeQ=False): R = decomp.call("R") return QRDecomposition(Q, R) + @since('2.2.0') + def computeSVD(self, k, computeU=False, rCond=1e-9): + """ + Computes the singular value decomposition of the RowMatrix. + + The given row matrix A of dimension (m X n) is decomposed into + U * s * V'T where + + * U: (m X k) (left singular vectors) is a RowMatrix whose + columns are the eigenvectors of (A X A') + * s: DenseVector consisting of square root of the eigenvalues + (singular values) in descending order. + * v: (n X k) (right singular vectors) is a Matrix whose columns + are the eigenvectors of (A' X A) + + For more specific details on implementation, please refer + the Scala documentation. + + :param k: Number of leading singular values to keep (`0 < k <= n`). + It might return less than k if there are numerically zero singular values + or there are not enough Ritz values converged before the maximum number of + Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). + :param computeU: Whether or not to compute U. If set to be + True, then U is computed by A * V * s^-1 + :param rCond: Reciprocal condition number. All singular values + smaller than rCond * s[0] are treated as zero + where s[0] is the largest singular value. + :returns: :py:class:`SingularValueDecomposition` + + >>> rows = sc.parallelize([[3, 1, 1], [-1, 3, 1]]) + >>> rm = RowMatrix(rows) + + >>> svd_model = rm.computeSVD(2, True) + >>> svd_model.U.rows.collect() + [DenseVector([-0.7071, 0.7071]), DenseVector([-0.7071, -0.7071])] + >>> svd_model.s + DenseVector([3.4641, 3.1623]) + >>> svd_model.V + DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0) + """ + j_model = self._java_matrix_wrapper.call( + "computeSVD", int(k), bool(computeU), float(rCond)) + return SingularValueDecomposition(j_model) + + @since('2.2.0') + def computePrincipalComponents(self, k): + """ + Computes the k principal components of the given row matrix + + .. note:: This cannot be computed on matrices with more than 65535 columns. + + :param k: Number of principal components to keep. + :returns: :py:class:`pyspark.mllib.linalg.DenseMatrix` + + >>> rows = sc.parallelize([[1, 2, 3], [2, 4, 5], [3, 6, 1]]) + >>> rm = RowMatrix(rows) + + >>> # Returns the two principal components of rm + >>> pca = rm.computePrincipalComponents(2) + >>> pca + DenseMatrix(3, 2, [-0.349, -0.6981, 0.6252, -0.2796, -0.5592, -0.7805], 0) + + >>> # Transform into new dimensions with the greatest variance. + >>> rm.multiply(pca).rows.collect() # doctest: +NORMALIZE_WHITESPACE + [DenseVector([0.1305, -3.7394]), DenseVector([-0.3642, -6.6983]), \ + DenseVector([-4.6102, -4.9745])] + """ + return self._java_matrix_wrapper.call("computePrincipalComponents", k) + + @since('2.2.0') + def multiply(self, matrix): + """ + Multiply this matrix by a local dense matrix on the right. + + :param matrix: a local dense matrix whose number of rows must match the number of columns + of this matrix + :returns: :py:class:`RowMatrix` + + >>> rm = RowMatrix(sc.parallelize([[0, 1], [2, 3]])) + >>> rm.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect() + [DenseVector([2.0, 3.0]), DenseVector([6.0, 11.0])] + """ + if not isinstance(matrix, DenseMatrix): + raise ValueError("Only multiplication with DenseMatrix " + "is supported.") + j_model = self._java_matrix_wrapper.call("multiply", matrix) + return RowMatrix(j_model) + + +class SingularValueDecomposition(JavaModelWrapper): + """ + Represents singular value decomposition (SVD) factors. + + .. versionadded:: 2.2.0 + """ + + @property + @since('2.2.0') + def U(self): + """ + Returns a distributed matrix whose columns are the left + singular vectors of the SingularValueDecomposition if computeU was set to be True. + """ + u = self.call("U") + if u is not None: + mat_name = u.getClass().getSimpleName() + if mat_name == "RowMatrix": + return RowMatrix(u) + elif mat_name == "IndexedRowMatrix": + return IndexedRowMatrix(u) + else: + raise TypeError("Expected RowMatrix/IndexedRowMatrix got %s" % mat_name) + + @property + @since('2.2.0') + def s(self): + """ + Returns a DenseVector with singular values in descending order. + """ + return self.call("s") + + @property + @since('2.2.0') + def V(self): + """ + Returns a DenseMatrix whose columns are the right singular + vectors of the SingularValueDecomposition. + """ + return self.call("V") + class IndexedRow(object): """ @@ -528,6 +657,68 @@ def toBlockMatrix(self, rowsPerBlock=1024, colsPerBlock=1024): colsPerBlock) return BlockMatrix(java_block_matrix, rowsPerBlock, colsPerBlock) + @since('2.2.0') + def computeSVD(self, k, computeU=False, rCond=1e-9): + """ + Computes the singular value decomposition of the IndexedRowMatrix. + + The given row matrix A of dimension (m X n) is decomposed into + U * s * V'T where + + * U: (m X k) (left singular vectors) is a IndexedRowMatrix + whose columns are the eigenvectors of (A X A') + * s: DenseVector consisting of square root of the eigenvalues + (singular values) in descending order. + * v: (n X k) (right singular vectors) is a Matrix whose columns + are the eigenvectors of (A' X A) + + For more specific details on implementation, please refer + the scala documentation. + + :param k: Number of leading singular values to keep (`0 < k <= n`). + It might return less than k if there are numerically zero singular values + or there are not enough Ritz values converged before the maximum number of + Arnoldi update iterations is reached (in case that matrix A is ill-conditioned). + :param computeU: Whether or not to compute U. If set to be + True, then U is computed by A * V * s^-1 + :param rCond: Reciprocal condition number. All singular values + smaller than rCond * s[0] are treated as zero + where s[0] is the largest singular value. + :returns: SingularValueDecomposition object + + >>> rows = [(0, (3, 1, 1)), (1, (-1, 3, 1))] + >>> irm = IndexedRowMatrix(sc.parallelize(rows)) + >>> svd_model = irm.computeSVD(2, True) + >>> svd_model.U.rows.collect() # doctest: +NORMALIZE_WHITESPACE + [IndexedRow(0, [-0.707106781187,0.707106781187]),\ + IndexedRow(1, [-0.707106781187,-0.707106781187])] + >>> svd_model.s + DenseVector([3.4641, 3.1623]) + >>> svd_model.V + DenseMatrix(3, 2, [-0.4082, -0.8165, -0.4082, 0.8944, -0.4472, 0.0], 0) + """ + j_model = self._java_matrix_wrapper.call( + "computeSVD", int(k), bool(computeU), float(rCond)) + return SingularValueDecomposition(j_model) + + @since('2.2.0') + def multiply(self, matrix): + """ + Multiply this matrix by a local dense matrix on the right. + + :param matrix: a local dense matrix whose number of rows must match the number of columns + of this matrix + :returns: :py:class:`IndexedRowMatrix` + + >>> mat = IndexedRowMatrix(sc.parallelize([(0, (0, 1)), (1, (2, 3))])) + >>> mat.multiply(DenseMatrix(2, 2, [0, 2, 1, 3])).rows.collect() + [IndexedRow(0, [2.0,3.0]), IndexedRow(1, [6.0,11.0])] + """ + if not isinstance(matrix, DenseMatrix): + raise ValueError("Only multiplication with DenseMatrix " + "is supported.") + return IndexedRowMatrix(self._java_matrix_wrapper.call("multiply", matrix)) + class MatrixEntry(object): """ diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 523b3f111331..1037bab7f108 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -23,6 +23,7 @@ import sys import tempfile import array as pyarray +from math import sqrt from time import time, sleep from shutil import rmtree @@ -54,6 +55,7 @@ from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT +from pyspark.mllib.linalg.distributed import RowMatrix from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD @@ -1699,6 +1701,67 @@ def test_binary_term_freqs(self): ": expected " + str(expected[i]) + ", got " + str(output[i])) +class DimensionalityReductionTests(MLlibTestCase): + + denseData = [ + Vectors.dense([0.0, 1.0, 2.0]), + Vectors.dense([3.0, 4.0, 5.0]), + Vectors.dense([6.0, 7.0, 8.0]), + Vectors.dense([9.0, 0.0, 1.0]) + ] + sparseData = [ + Vectors.sparse(3, [(1, 1.0), (2, 2.0)]), + Vectors.sparse(3, [(0, 3.0), (1, 4.0), (2, 5.0)]), + Vectors.sparse(3, [(0, 6.0), (1, 7.0), (2, 8.0)]), + Vectors.sparse(3, [(0, 9.0), (2, 1.0)]) + ] + + def assertEqualUpToSign(self, vecA, vecB): + eq1 = vecA - vecB + eq2 = vecA + vecB + self.assertTrue(sum(abs(eq1)) < 1e-6 or sum(abs(eq2)) < 1e-6) + + def test_svd(self): + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + m = 4 + n = 3 + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + rm = mat.computeSVD(k, computeU=True) + self.assertEqual(rm.s.size, k) + self.assertEqual(rm.U.numRows(), m) + self.assertEqual(rm.U.numCols(), k) + self.assertEqual(rm.V.numRows, n) + self.assertEqual(rm.V.numCols, k) + + # Test that U returned is None if computeU is set to False. + self.assertEqual(mat.computeSVD(1).U, None) + + # Test that low rank matrices cannot have number of singular values + # greater than a limit. + rm = RowMatrix(self.sc.parallelize(tile([1, 2, 3], (3, 1)))) + self.assertEqual(rm.computeSVD(3, False, 1e-6).s.size, 1) + + def test_pca(self): + expected_pcs = array([ + [0.0, 1.0, 0.0], + [sqrt(2.0) / 2.0, 0.0, sqrt(2.0) / 2.0], + [sqrt(2.0) / 2.0, 0.0, -sqrt(2.0) / 2.0] + ]) + n = 3 + denseMat = RowMatrix(self.sc.parallelize(self.denseData)) + sparseMat = RowMatrix(self.sc.parallelize(self.sparseData)) + for mat in [denseMat, sparseMat]: + for k in range(1, 4): + pcs = mat.computePrincipalComponents(k) + self.assertEqual(pcs.numRows, n) + self.assertEqual(pcs.numCols, k) + + # We can just test the updated principal component for equality. + self.assertEqualUpToSign(pcs.toArray()[:, k - 1], expected_pcs[:, k - 1]) + + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 774caf53f3a4..d62ba9623b44 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -371,6 +371,35 @@ def withWatermark(self, eventTime, delayThreshold): jdf = self._jdf.withWatermark(eventTime, delayThreshold) return DataFrame(jdf, self.sql_ctx) + @since(2.2) + def hint(self, name, *parameters): + """Specifies some hint on the current DataFrame. + + :param name: A name of the hint. + :param parameters: Optional parameters. + :return: :class:`DataFrame` + + >>> df.join(df2.hint("broadcast"), "name").show() + +----+---+------+ + |name|age|height| + +----+---+------+ + | Bob| 5| 85| + +----+---+------+ + """ + if len(parameters) == 1 and isinstance(parameters[0], list): + parameters = parameters[0] + + if not isinstance(name, str): + raise TypeError("name should be provided as str, got {0}".format(type(name))) + + for p in parameters: + if not isinstance(p, str): + raise TypeError( + "all parameters should be str, got {0} of type {1}".format(p, type(p))) + + jdf = self._jdf.hint(name, self._jseq(parameters)) + return DataFrame(jdf, self.sql_ctx) + @since(1.3) def count(self): """Returns the number of rows in this :class:`DataFrame`. @@ -1238,7 +1267,7 @@ def fillna(self, value, subset=None): Value to replace null values with. If the value is a dict, then `subset` is ignored and `value` must be a mapping from column name (string) to replacement value. The replacement value must be - an int, long, float, or string. + an int, long, float, boolean, or string. :param subset: optional list of column names to consider. Columns specified in subset that do not have matching data type are ignored. For example, if `value` is a string, and subset contains a non-string column, diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 9f4772eec9f2..c1bf2bd76fb7 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -221,6 +221,17 @@ def __init__(self, sparkContext, jsparkSession=None): or SparkSession._instantiatedSession._sc._jsc is None: SparkSession._instantiatedSession = self + def _repr_html_(self): + return """ +
    +

    SparkSession - {catalogImplementation}

    + {sc_HTML} +
    + """.format( + catalogImplementation=self.conf.get("spark.sql.catalogImplementation"), + sc_HTML=self.sparkContext._repr_html_() + ) + @since(2.0) def newSession(self): """ diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 2b2444304e04..2aa2d23c6f0d 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -1711,6 +1711,10 @@ def test_fillna(self): self.assertEqual(row.age, None) self.assertEqual(row.height, None) + # fillna with dictionary for boolean types + row = self.spark.createDataFrame([Row(a=None), Row(a=True)]).fillna({"a": True}).first() + self.assertEqual(row.a, True) + def test_bitwise_operations(self): from pyspark.sql import functions row = Row(a=170, b=75) @@ -1902,6 +1906,22 @@ def test_functions_broadcast(self): # planner should not crash without a join broadcast(df1)._jdf.queryExecution().executedPlan() + def test_generic_hints(self): + from pyspark.sql import DataFrame + + df1 = self.spark.range(10e10).toDF("id") + df2 = self.spark.range(10e10).toDF("id") + + self.assertIsInstance(df1.hint("broadcast"), DataFrame) + self.assertIsInstance(df1.hint("broadcast", []), DataFrame) + + # Dummy rules + self.assertIsInstance(df1.hint("broadcast", "foo", "bar"), DataFrame) + self.assertIsInstance(df1.hint("broadcast", ["foo", "bar"]), DataFrame) + + plan = df1.join(df2.hint("broadcast"), "id")._jdf.queryExecution().executedPlan() + self.assertEqual(1, plan.toString().count("BroadcastHashJoin")) + def test_toDF_with_schema_string(self): data = [Row(key=i, value=str(i)) for i in range(100)] rdd = self.sc.parallelize(data, 5) diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 41bf8c269b79..c0bb1968b4b9 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "2.2.0.dev0" +__version__ = "2.2.1.dev0" diff --git a/repl/pom.xml b/repl/pom.xml index a256ae3b8418..f3c49dfb0060 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/resource-managers/mesos/pom.xml b/resource-managers/mesos/pom.xml index 03846d9f5a3b..547836050a61 100644 --- a/resource-managers/mesos/pom.xml +++ b/resource-managers/mesos/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index cd98110ddcc0..127fadabcce5 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -101,7 +101,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") Launch Time - {state.startDate} + {UIUtils.formatDate(state.startDate)} Finish Time @@ -154,7 +154,7 @@ private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") Memory{driver.mem} - Submitted{driver.submissionDate} + Submitted{UIUtils.formatDate(driver.submissionDate)} Supervise{driver.supervise} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 13ba7d311e57..c9107c3e73d3 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -68,7 +68,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( val id = submission.submissionId {id} - {submission.submissionDate} + {UIUtils.formatDate(submission.submissionDate)} {submission.command.mainClass} cpus: {submission.cores}, mem: {submission.mem} @@ -88,10 +88,10 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( {id} {historyCol} - {state.driverDescription.submissionDate} + {UIUtils.formatDate(state.driverDescription.submissionDate)} {state.driverDescription.command.mainClass} cpus: {state.driverDescription.cores}, mem: {state.driverDescription.mem} - {state.startDate} + {UIUtils.formatDate(state.startDate)} {state.slaveId.getValue} {stateString(state.mesosTaskStatus)} @@ -101,7 +101,7 @@ private[mesos] class MesosClusterPage(parent: MesosClusterUI) extends WebUIPage( val id = submission.submissionId {id} - {submission.submissionDate} + {UIUtils.formatDate(submission.submissionDate)} {submission.command.mainClass} {submission.retryState.get.lastFailureStatus} {submission.retryState.get.nextRetry} diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 2a36ec4fa811..8f5b97ccb1f8 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -60,8 +60,16 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private val maxCoresOption = conf.getOption("spark.cores.max").map(_.toInt) + private val executorCoresOption = conf.getOption("spark.executor.cores").map(_.toInt) + + private val minCoresPerExecutor = executorCoresOption.getOrElse(1) + // Maximum number of cores to acquire - private val maxCores = maxCoresOption.getOrElse(Int.MaxValue) + private val maxCores = { + val cores = maxCoresOption.getOrElse(Int.MaxValue) + // Set maxCores to a multiple of smallest executor we can launch + cores - (cores % minCoresPerExecutor) + } private val useFetcherCache = conf.getBoolean("spark.mesos.fetcherCache.enable", false) @@ -489,8 +497,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def executorCores(offerCPUs: Int): Int = { - sc.conf.getInt("spark.executor.cores", - math.min(offerCPUs, maxCores - totalCoresAcquired)) + executorCoresOption.getOrElse( + math.min(offerCPUs, maxCores - totalCoresAcquired) + ) } override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus) { diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index c040f05d93b3..0418bfbaa5ed 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -199,6 +199,40 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite verifyDeclinedOffer(driver, createOfferId("o2"), true) } + test("mesos declines offers with a filter when maxCores not a multiple of executor.cores") { + val maxCores = 4 + val executorCores = 3 + setBackend(Map( + "spark.cores.max" -> maxCores.toString, + "spark.executor.cores" -> executorCores.toString + )) + val executorMemory = backend.executorMemory(sc) + offerResources(List( + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1) + )) + verifyTaskLaunched(driver, "o1") + verifyDeclinedOffer(driver, createOfferId("o2"), true) + } + + test("mesos declines offers with a filter when reached spark.cores.max with executor.cores") { + val maxCores = 4 + val executorCores = 2 + setBackend(Map( + "spark.cores.max" -> maxCores.toString, + "spark.executor.cores" -> executorCores.toString + )) + val executorMemory = backend.executorMemory(sc) + offerResources(List( + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1), + Resources(executorMemory, maxCores + 1) + )) + verifyTaskLaunched(driver, "o1") + verifyTaskLaunched(driver, "o2") + verifyDeclinedOffer(driver, createOfferId("o3"), true) + } + test("mesos assigns tasks round-robin on offers") { val executorCores = 4 val maxCores = executorCores * 2 diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index a1b641c8eeb8..e00ed33d2ba1 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala index 99fb58a28934..59adb7e22d18 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala @@ -24,6 +24,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.concurrent.duration._ +import scala.io.Source import scala.language.postfixOps import com.google.common.io.{ByteStreams, Files} @@ -87,24 +88,30 @@ class YarnClusterSuite extends BaseYarnClusterSuite { testBasicYarnApp(false) } - test("run Spark in yarn-client mode with different configurations") { + test("run Spark in yarn-client mode with different configurations, ensuring redaction") { testBasicYarnApp(true, Map( "spark.driver.memory" -> "512m", "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", - "spark.executor.instances" -> "2" + "spark.executor.instances" -> "2", + // Sending some senstive information, which we'll make sure gets redacted + "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, + "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) } - test("run Spark in yarn-cluster mode with different configurations") { + test("run Spark in yarn-cluster mode with different configurations, ensuring redaction") { testBasicYarnApp(false, Map( "spark.driver.memory" -> "512m", "spark.driver.cores" -> "1", "spark.executor.cores" -> "1", "spark.executor.memory" -> "512m", - "spark.executor.instances" -> "2" + "spark.executor.instances" -> "2", + // Sending some senstive information, which we'll make sure gets redacted + "spark.executorEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD, + "spark.yarn.appMasterEnv.HADOOP_CREDSTORE_PASSWORD" -> YarnClusterDriver.SECRET_PASSWORD )) } @@ -349,6 +356,7 @@ private object YarnClusterDriverUseSparkHadoopUtilConf extends Logging with Matc private object YarnClusterDriver extends Logging with Matchers { val WAIT_TIMEOUT_MILLIS = 10000 + val SECRET_PASSWORD = "secret_password" def main(args: Array[String]): Unit = { if (args.length != 1) { @@ -395,6 +403,13 @@ private object YarnClusterDriver extends Logging with Matchers { assert(executorInfos.nonEmpty) executorInfos.foreach { info => assert(info.logUrlMap.nonEmpty) + info.logUrlMap.values.foreach { url => + val log = Source.fromURL(url).mkString + assert( + !log.contains(SECRET_PASSWORD), + s"Executor logs contain sensitive info (${SECRET_PASSWORD}): \n${log} " + ) + } } // If we are running in yarn-cluster mode, verify that driver logs links and present and are @@ -406,8 +421,13 @@ private object YarnClusterDriver extends Logging with Matchers { assert(driverLogs.contains("stderr")) assert(driverLogs.contains("stdout")) val urlStr = driverLogs("stderr") - // Ensure that this is a valid URL, else this will throw an exception - new URL(urlStr) + driverLogs.foreach { kv => + val log = Source.fromURL(kv._2).mkString + assert( + !log.contains(SECRET_PASSWORD), + s"Driver logs contain sensitive info (${SECRET_PASSWORD}): \n${log} " + ) + } val containerId = YarnSparkHadoopUtil.get.getContainerId val user = Utils.getCurrentUserName() assert(urlStr.endsWith(s"/node/containerlogs/$containerId/$user/stderr?start=-4096")) diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala index 4079d9e40fc4..0a413b2c23de 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackendSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.scheduler.cluster +import scala.language.reflectiveCalls + import org.mockito.Mockito.when import org.scalatest.mock.MockitoSugar diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 765c92b8d3b9..5ecee28a1f0b 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java index bd5e2d7ecca9..5f1032d1229d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/streaming/GroupStateTimeout.java @@ -37,7 +37,9 @@ public class GroupStateTimeout { * `map/flatMapGroupsWithState` by calling `GroupState.setTimeoutDuration()`. See documentation * on `GroupState` for more details. */ - public static GroupStateTimeout ProcessingTimeTimeout() { return ProcessingTimeTimeout$.MODULE$; } + public static GroupStateTimeout ProcessingTimeTimeout() { + return ProcessingTimeTimeout$.MODULE$; + } /** * Timeout based on event-time. The event-time timestamp for timeout can be set for each @@ -51,4 +53,5 @@ public class GroupStateTimeout { /** No timeout. */ public static GroupStateTimeout NoTimeout() { return NoTimeout$.MODULE$; } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 9816b33ae8df..25783bdc39f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -136,6 +136,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: + ResolveAggAliasInGroupBy :: ResolveMissingReferences :: ExtractGenerator :: ResolveGenerate :: @@ -150,6 +151,7 @@ class Analyzer( ResolveAggregateFunctions :: TimeWindowing :: ResolveInlineTables(conf) :: + ResolveTimeZone(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -161,8 +163,6 @@ class Analyzer( HandleNullInputsForUDF), Batch("FixNullability", Once, FixNullability), - Batch("ResolveTimeZone", Once, - ResolveTimeZone), Batch("Subquery", Once, UpdateOuterReferences), Batch("Cleanup", fixedPoint, @@ -173,7 +173,7 @@ class Analyzer( * Analyze cte definitions and substitute child plan with analyzed cte definitions. */ object CTESubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => @@ -201,7 +201,7 @@ class Analyzer( * Substitute child plan with WindowSpecDefinitions. */ object WindowsSubstitution extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { @@ -243,7 +243,7 @@ class Analyzer( private def hasUnresolvedAlias(exprs: Seq[NamedExpression]) = exprs.exists(_.find(_.isInstanceOf[UnresolvedAlias]).isDefined) - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Aggregate(groups, aggs, child) if child.resolved && hasUnresolvedAlias(aggs) => Aggregate(groups, assignAliases(aggs), child) @@ -615,7 +615,7 @@ class Analyzer( case _ => plan } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => @@ -787,7 +787,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if !p.childrenResolved => p // If the projection list contains Stars, expand it. @@ -845,11 +845,10 @@ class Analyzer( case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") - q transformExpressionsUp { + q.transformExpressionsUp { case u @ UnresolvedAttribute(nameParts) => - // Leave unchanged if resolution fails. Hopefully will be resolved next round. - val result = - withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } + // Leave unchanged if resolution fails. Hopefully will be resolved next round. + val result = withPosition(u) { q.resolveChildren(nameParts, resolver).getOrElse(u) } logDebug(s"Resolving $u to $result") result case UnresolvedExtractValue(child, fieldExpr) if child.resolved => @@ -962,11 +961,11 @@ class Analyzer( * have no effect on the results. */ object ResolveOrdinalInOrderByAndGroupBy extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p // Replace the index with the related attribute for ORDER BY, // which is a 1-base position of the projection list. - case s @ Sort(orders, global, child) + case Sort(orders, global, child) if orders.exists(_.child.isInstanceOf[UnresolvedOrdinal]) => val newOrders = orders map { case s @ SortOrder(UnresolvedOrdinal(index), direction, nullOrdering, _) => @@ -983,17 +982,11 @@ class Analyzer( // Replace the index with the corresponding expression in aggregateExpressions. The index is // a 1-base position of aggregateExpressions, which is output columns (select expression) - case a @ Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && + case Aggregate(groups, aggs, child) if aggs.forall(_.resolved) && groups.exists(_.isInstanceOf[UnresolvedOrdinal]) => val newGroups = groups.map { - case ordinal @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => - aggs(index - 1) match { - case e if ResolveAggregateFunctions.containsAggregate(e) => - ordinal.failAnalysis( - s"GROUP BY position $index is an aggregate function, and " + - "aggregate functions are not allowed in GROUP BY") - case o => o - } + case u @ UnresolvedOrdinal(index) if index > 0 && index <= aggs.size => + aggs(index - 1) case ordinal @ UnresolvedOrdinal(index) => ordinal.failAnalysis( s"GROUP BY position $index is not in select list " + @@ -1004,6 +997,27 @@ class Analyzer( } } + /** + * Replace unresolved expressions in grouping keys with resolved ones in SELECT clauses. + * This rule is expected to run after [[ResolveReferences]] applied. + */ + object ResolveAggAliasInGroupBy extends Rule[LogicalPlan] { + + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { + case agg @ Aggregate(groups, aggs, child) + if conf.groupByAliases && child.resolved && aggs.forall(_.resolved) && + groups.exists(_.isInstanceOf[UnresolvedAttribute]) => + // This is a strict check though, we put this to apply the rule only in alias expressions + def notResolvableByChild(attrName: String): Boolean = + !child.output.exists(a => resolver(a.name, attrName)) + agg.copy(groupingExpressions = groups.map { + case u: UnresolvedAttribute if notResolvableByChild(u.name) => + aggs.find(ne => resolver(ne.name, u.name)).getOrElse(u) + case e => e + }) + } + } + /** * In many dialects of SQL it is valid to sort by attributes that are not present in the SELECT * clause. This rule detects such queries and adds the required attributes to the original @@ -1013,7 +1027,7 @@ class Analyzer( * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ object ResolveMissingReferences extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa @@ -1137,7 +1151,7 @@ class Analyzer( * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. @@ -1449,7 +1463,7 @@ class Analyzer( /** * Resolve and rewrite all subqueries in an operator tree.. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { // In case of HAVING (a filter after an aggregate) we use both the aggregate and // its child for resolution. case f @ Filter(_, a: Aggregate) if f.childrenResolved => @@ -1464,7 +1478,7 @@ class Analyzer( * Turns projections that contain aggregate expressions into aggregations. */ object GlobalAggregates extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) if containsAggregates(projectList) => Aggregate(Nil, projectList, child) } @@ -1490,7 +1504,7 @@ class Analyzer( * underlying aggregate operator and then projected away after the original operator. */ object ResolveAggregateFunctions extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case filter @ Filter(havingCondition, aggregate @ Aggregate(grouping, originalAggExprs, child)) if aggregate.resolved => @@ -1662,7 +1676,7 @@ class Analyzer( } } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, _) if projectList.exists(hasNestedGenerator) => val nestedGenerator = projectList.find(hasNestedGenerator).get throw new AnalysisException("Generators are not supported when it's nested in " + @@ -1720,7 +1734,7 @@ class Analyzer( * that wrap the [[Generator]]. */ object ResolveGenerate extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case g: Generate if !g.child.resolved || !g.generator.resolved => g case g: Generate if !g.resolved => g.copy(generatorOutput = makeGeneratorOutput(g.generator, g.generatorOutput.map(_.name))) @@ -2037,7 +2051,7 @@ class Analyzer( * put them into an inner Project and finally project them away at the outer Project. */ object PullOutNondeterministic extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p: Project => p case f: Filter => f @@ -2082,7 +2096,7 @@ class Analyzer( * and we should return null if the input is null. */ object HandleNullInputsForUDF extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.resolved => p // Skip unresolved nodes. case p => p transformExpressionsUp { @@ -2147,7 +2161,7 @@ class Analyzer( * Then apply a Project on a normal Join to eliminate natural or using join. */ object ResolveNaturalAndUsingJoin extends Rule[LogicalPlan] { - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case j @ Join(left, right, UsingJoin(joinType, usingCols), condition) if left.resolved && right.resolved && j.duplicateResolved => commonNaturalJoinProcessing(left, right, joinType, usingCols, None) @@ -2212,7 +2226,7 @@ class Analyzer( * to the given input attributes. */ object ResolveDeserializer extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2230,8 +2244,8 @@ class Analyzer( val result = resolved transformDown { case UnresolvedMapObjects(func, inputData, cls) if inputData.resolved => inputData.dataType match { - case ArrayType(et, _) => - val expr = MapObjects(func, inputData, et, cls) transformUp { + case ArrayType(et, cn) => + val expr = MapObjects(func, inputData, et, cn, cls) transformUp { case UnresolvedExtractValue(child, fieldName) if child.resolved => ExtractValue(child, fieldName, resolver) } @@ -2298,7 +2312,7 @@ class Analyzer( * constructed is an inner class. */ object ResolveNewInstance extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2332,7 +2346,7 @@ class Analyzer( "type of the field in the target object") } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p if !p.childrenResolved => p case p if p.resolved => p @@ -2347,23 +2361,6 @@ class Analyzer( } } } - - /** - * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local - * time zone. - */ - object ResolveTimeZone extends Rule[LogicalPlan] { - - override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveExpressions { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - // Casts could be added in the subquery plan through the rule TypeCoercion while coercing - // the types between the value expression and list query expression of IN expression. - // We need to subject the subquery plan through ResolveTimeZone again to setup timezone - // information for time zone aware expressions. - case e: ListQuery => e.withNewPlan(apply(e.plan)) - } - } } /** @@ -2403,7 +2400,7 @@ object CleanupAliases extends Rule[LogicalPlan] { case other => trimAliases(other) } - override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case Project(projectList, child) => val cleanedProjectList = projectList.map(trimNonTopLevelAliases(_).asInstanceOf[NamedExpression]) @@ -2471,7 +2468,7 @@ object TimeWindowing extends Rule[LogicalPlan] { * @return the logical plan that will generate the time windows using the Expand operator, with * the Filter operator for correctness and Project for usability. */ - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case p: LogicalPlan if p.children.size == 1 => val child = p.children.head val windowExpressions = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index da0c6b098f5c..61797bc34dc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -254,6 +254,11 @@ trait CheckAnalysis extends PredicateHelper { } def checkValidGroupingExprs(expr: Expression): Unit = { + if (expr.find(_.isInstanceOf[AggregateExpression]).isDefined) { + failAnalysis( + "aggregate functions are not allowed in GROUP BY, but found " + expr.sql) + } + // Check if the data type of expr is orderable. if (!RowOrdering.isOrderable(expr.dataType)) { failAnalysis( @@ -271,8 +276,8 @@ trait CheckAnalysis extends PredicateHelper { } } - aggregateExprs.foreach(checkValidAggregateExpression) groupingExprs.foreach(checkValidGroupingExprs) + aggregateExprs.foreach(checkValidAggregateExpression) case Sort(orders, _, _) => orders.foreach { order => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala index c4827b81e8b6..df688fa0e58a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveHints.scala @@ -86,7 +86,13 @@ object ResolveHints { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case h: Hint if BROADCAST_HINT_NAMES.contains(h.name.toUpperCase(Locale.ROOT)) => - applyBroadcastHint(h.child, h.parameters.toSet) + if (h.parameters.isEmpty) { + // If there is no table alias specified, turn the entire subtree into a BroadcastHint. + BroadcastHint(h.child) + } else { + // Otherwise, find within the subtree query plans that should be broadcasted. + applyBroadcastHint(h.child, h.parameters.toSet) + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index a991dd96e282..f2df3e132629 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -29,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ -case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] { +case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) @@ -99,12 +98,9 @@ case class ResolveInlineTables(conf: SQLConf) extends Rule[LogicalPlan] { val castedExpr = if (e.dataType.sameType(targetType)) { e } else { - Cast(e, targetType) + cast(e, targetType) } - castedExpr.transform { - case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => - e.withTimeZone(conf.sessionLocalTimeZone) - }.eval() + castedExpr.eval() } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala index 8841309939c2..de6de24350f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveTableValuedFunctions.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.analysis +import java.util.Locale + import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules._ @@ -103,7 +105,7 @@ object ResolveTableValuedFunctions extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case u: UnresolvedTableValuedFunction if u.functionArgs.forall(_.resolved) => - builtinFunctions.get(u.functionName.toLowerCase()) match { + builtinFunctions.get(u.functionName.toLowerCase(Locale.ROOT)) match { case Some(tvf) => val resolved = tvf.flatMap { case (argList, resolver) => argList.implicitCast(u.functionArgs) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index 3f76f26dbe4e..6ab4153bac70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -267,7 +267,7 @@ object UnsupportedOperationChecker { throwError("Limits are not supported on streaming DataFrames/Datasets") case Sort(_, _, _) if !containsCompleteData(subPlan) => - throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on" + + throwError("Sorting is not supported on streaming DataFrames/Datasets, unless it is on " + "aggregated DataFrame/Dataset in Complete output mode") case Sample(_, _, _, _, child) if child.isStreaming => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala new file mode 100644 index 000000000000..a27aa845bf0a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, TimeZoneAwareExpression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.DataType + +/** + * Replace [[TimeZoneAwareExpression]] without timezone id by its copy with session local + * time zone. + */ +case class ResolveTimeZone(conf: SQLConf) extends Rule[LogicalPlan] { + private val transformTimeZoneExprs: PartialFunction[Expression, Expression] = { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(conf.sessionLocalTimeZone) + // Casts could be added in the subquery plan through the rule TypeCoercion while coercing + // the types between the value expression and list query expression of IN expression. + // We need to subject the subquery plan through ResolveTimeZone again to setup timezone + // information for time zone aware expressions. + case e: ListQuery => e.withNewPlan(apply(e.plan)) + } + + override def apply(plan: LogicalPlan): LogicalPlan = + plan.resolveExpressions(transformTimeZoneExprs) + + def resolveTimeZones(e: Expression): Expression = e.transform(transformTimeZoneExprs) +} + +/** + * Mix-in trait for constructing valid [[Cast]] expressions. + */ +trait CastSupport { + /** + * Configuration used to create a valid cast expression. + */ + def conf: SQLConf + + /** + * Create a Cast expression with the session local time zone. + */ + def cast(child: Expression, dataType: DataType): Cast = { + Cast(child, dataType, Option(conf.sessionLocalTimeZone)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala index 3bd54c257d98..ea46dd728240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/view.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf * This should be only done after the batch of Resolution, because the view attributes are not * completely resolved during the batch of Resolution. */ -case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] { +case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case v @ View(desc, output, child) if child.resolved && output != child.output => val resolver = conf.resolver @@ -78,7 +78,7 @@ case class AliasViewChild(conf: SQLConf) extends Rule[LogicalPlan] { throw new AnalysisException(s"Cannot up cast ${originAttr.sql} from " + s"${originAttr.dataType.simpleString} to ${attr.simpleString} as it may truncate\n") } else { - Alias(Cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, + Alias(cast(originAttr, attr.dataType), attr.name)(exprId = attr.exprId, qualifier = attr.qualifier, explicitMetadata = Some(attr.metadata)) } case (_, originAttr) => originAttr diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index 08a01e860189..974ef900e2ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.catalyst.analysis.{FunctionAlreadyExistsException, NoSuchDatabaseException, NoSuchFunctionException, NoSuchTableException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ListenerBus /** * Interface for the system catalog (of functions, partitions, tables, and databases). @@ -30,7 +31,8 @@ import org.apache.spark.sql.types.StructType * * Implementations should throw [[NoSuchDatabaseException]] when databases don't exist. */ -abstract class ExternalCatalog { +abstract class ExternalCatalog + extends ListenerBus[ExternalCatalogEventListener, ExternalCatalogEvent] { import CatalogTypes.TablePartitionSpec protected def requireDbExists(db: String): Unit = { @@ -61,9 +63,22 @@ abstract class ExternalCatalog { // Databases // -------------------------------------------------------------------------- - def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit + final def createDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = { + val db = dbDefinition.name + postToAll(CreateDatabasePreEvent(db)) + doCreateDatabase(dbDefinition, ignoreIfExists) + postToAll(CreateDatabaseEvent(db)) + } + + protected def doCreateDatabase(dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit + + final def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = { + postToAll(DropDatabasePreEvent(db)) + doDropDatabase(db, ignoreIfNotExists, cascade) + postToAll(DropDatabaseEvent(db)) + } - def dropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit + protected def doDropDatabase(db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit /** * Alter a database whose name matches the one specified in `dbDefinition`, @@ -88,11 +103,39 @@ abstract class ExternalCatalog { // Tables // -------------------------------------------------------------------------- - def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit + final def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { + val db = tableDefinition.database + val name = tableDefinition.identifier.table + postToAll(CreateTablePreEvent(db, name)) + doCreateTable(tableDefinition, ignoreIfExists) + postToAll(CreateTableEvent(db, name)) + } - def dropTable(db: String, table: String, ignoreIfNotExists: Boolean, purge: Boolean): Unit + protected def doCreateTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit - def renameTable(db: String, oldName: String, newName: String): Unit + final def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit = { + postToAll(DropTablePreEvent(db, table)) + doDropTable(db, table, ignoreIfNotExists, purge) + postToAll(DropTableEvent(db, table)) + } + + protected def doDropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean, + purge: Boolean): Unit + + final def renameTable(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameTablePreEvent(db, oldName, newName)) + doRenameTable(db, oldName, newName) + postToAll(RenameTableEvent(db, oldName, newName)) + } + + protected def doRenameTable(db: String, oldName: String, newName: String): Unit /** * Alter a table whose database and name match the ones specified in `tableDefinition`, assuming @@ -269,11 +312,30 @@ abstract class ExternalCatalog { // Functions // -------------------------------------------------------------------------- - def createFunction(db: String, funcDefinition: CatalogFunction): Unit + final def createFunction(db: String, funcDefinition: CatalogFunction): Unit = { + val name = funcDefinition.identifier.funcName + postToAll(CreateFunctionPreEvent(db, name)) + doCreateFunction(db, funcDefinition) + postToAll(CreateFunctionEvent(db, name)) + } - def dropFunction(db: String, funcName: String): Unit + protected def doCreateFunction(db: String, funcDefinition: CatalogFunction): Unit - def renameFunction(db: String, oldName: String, newName: String): Unit + final def dropFunction(db: String, funcName: String): Unit = { + postToAll(DropFunctionPreEvent(db, funcName)) + doDropFunction(db, funcName) + postToAll(DropFunctionEvent(db, funcName)) + } + + protected def doDropFunction(db: String, funcName: String): Unit + + final def renameFunction(db: String, oldName: String, newName: String): Unit = { + postToAll(RenameFunctionPreEvent(db, oldName, newName)) + doRenameFunction(db, oldName, newName) + postToAll(RenameFunctionEvent(db, oldName, newName)) + } + + protected def doRenameFunction(db: String, oldName: String, newName: String): Unit def getFunction(db: String, funcName: String): CatalogFunction @@ -281,4 +343,9 @@ abstract class ExternalCatalog { def listFunctions(db: String, pattern: String): Seq[String] + override protected def doPostEvent( + listener: ExternalCatalogEventListener, + event: ExternalCatalogEvent): Unit = { + listener.onEvent(event) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 9ca1c71d1dcb..81dd8efc0015 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -98,7 +98,7 @@ class InMemoryCatalog( // Databases // -------------------------------------------------------------------------- - override def createDatabase( + override protected def doCreateDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = synchronized { if (catalog.contains(dbDefinition.name)) { @@ -119,7 +119,7 @@ class InMemoryCatalog( } } - override def dropDatabase( + override protected def doDropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = synchronized { @@ -180,7 +180,7 @@ class InMemoryCatalog( // Tables // -------------------------------------------------------------------------- - override def createTable( + override protected def doCreateTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { assert(tableDefinition.identifier.database.isDefined) @@ -221,7 +221,7 @@ class InMemoryCatalog( } } - override def dropTable( + override protected def doDropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -264,7 +264,10 @@ class InMemoryCatalog( } } - override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { + override protected def doRenameTable( + db: String, + oldName: String, + newName: String): Unit = synchronized { requireTableExists(db, oldName) requireTableNotExists(db, newName) val oldDesc = catalog(db).tables(oldName) @@ -565,18 +568,21 @@ class InMemoryCatalog( // Functions // -------------------------------------------------------------------------- - override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { + override protected def doCreateFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) requireFunctionNotExists(db, func.identifier.funcName) catalog(db).functions.put(func.identifier.funcName, func) } - override def dropFunction(db: String, funcName: String): Unit = synchronized { + override protected def doDropFunction(db: String, funcName: String): Unit = synchronized { requireFunctionExists(db, funcName) catalog(db).functions.remove(funcName) } - override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized { + override protected def doRenameFunction( + db: String, + oldName: String, + newName: String): Unit = synchronized { requireFunctionExists(db, oldName) requireFunctionNotExists(db, newName) val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db))) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 3fbf83f3a38a..6c6d600190b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -115,14 +115,14 @@ class SessionCatalog( * Format table name, taking into account case sensitivity. */ protected[this] def formatTableName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } /** * Format database name, taking into account case sensitivity. */ protected[this] def formatDatabaseName(name: String): String = { - if (conf.caseSensitiveAnalysis) name else name.toLowerCase + if (conf.caseSensitiveAnalysis) name else name.toLowerCase(Locale.ROOT) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala new file mode 100644 index 000000000000..459973a13bb1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/events.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.scheduler.SparkListenerEvent + +/** + * Event emitted by the external catalog when it is modified. Events are either fired before or + * after the modification (the event should document this). + */ +trait ExternalCatalogEvent extends SparkListenerEvent + +/** + * Listener interface for external catalog modification events. + */ +trait ExternalCatalogEventListener { + def onEvent(event: ExternalCatalogEvent): Unit +} + +/** + * Event fired when a database is create or dropped. + */ +trait DatabaseEvent extends ExternalCatalogEvent { + /** + * Database of the object that was touched. + */ + val database: String +} + +/** + * Event fired before a database is created. + */ +case class CreateDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database has been created. + */ +case class CreateDatabaseEvent(database: String) extends DatabaseEvent + +/** + * Event fired before a database is dropped. + */ +case class DropDatabasePreEvent(database: String) extends DatabaseEvent + +/** + * Event fired after a database has been dropped. + */ +case class DropDatabaseEvent(database: String) extends DatabaseEvent + +/** + * Event fired when a table is created, dropped or renamed. + */ +trait TableEvent extends DatabaseEvent { + /** + * Name of the table that was touched. + */ + val name: String +} + +/** + * Event fired before a table is created. + */ +case class CreateTablePreEvent(database: String, name: String) extends TableEvent + +/** + * Event fired after a table has been created. + */ +case class CreateTableEvent(database: String, name: String) extends TableEvent + +/** + * Event fired before a table is dropped. + */ +case class DropTablePreEvent(database: String, name: String) extends TableEvent + +/** + * Event fired after a table has been dropped. + */ +case class DropTableEvent(database: String, name: String) extends TableEvent + +/** + * Event fired before a table is renamed. + */ +case class RenameTablePreEvent( + database: String, + name: String, + newName: String) + extends TableEvent + +/** + * Event fired after a table has been renamed. + */ +case class RenameTableEvent( + database: String, + name: String, + newName: String) + extends TableEvent + +/** + * Event fired when a function is created, dropped or renamed. + */ +trait FunctionEvent extends DatabaseEvent { + /** + * Name of the function that was touched. + */ + val name: String +} + +/** + * Event fired before a function is created. + */ +case class CreateFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been created. + */ +case class CreateFunctionEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired before a function is dropped. + */ +case class DropFunctionPreEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired after a function has been dropped. + */ +case class DropFunctionEvent(database: String, name: String) extends FunctionEvent + +/** + * Event fired before a function is renamed. + */ +case class RenameFunctionPreEvent( + database: String, + name: String, + newName: String) + extends FunctionEvent + +/** + * Event fired after a function has been renamed. + */ +case class RenameFunctionEvent( + database: String, + name: String, + newName: String) + extends FunctionEvent diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index bb1273f5c3d8..a53ef426f79b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -89,6 +89,31 @@ object Cast { case _ => false } + /** + * Return true if we need to use the `timeZone` information casting `from` type to `to` type. + * The patterns matched reflect the current implementation in the Cast node. + * c.f. usage of `timeZone` in: + * * Cast.castToString + * * Cast.castToDate + * * Cast.castToTimestamp + */ + def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match { + case (StringType, TimestampType) => true + case (DateType, TimestampType) => true + case (TimestampType, StringType) => true + case (TimestampType, DateType) => true + case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType) + case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) => + needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue) + case (StructType(fromFields), StructType(toFields)) => + fromFields.length == toFields.length && + fromFields.zip(toFields).exists { + case (fromField, toField) => + needsTimeZone(fromField.dataType, toField.dataType) + } + case _ => false + } + /** * Return true iff we may truncate during casting `from` type to `to` type. e.g. long -> int, * timestamp -> date. @@ -165,6 +190,13 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression = copy(timeZoneId = Option(timeZoneId)) + // When this cast involves TimeZone, it's only resolved if the timeZoneId is set; + // Otherwise behave like Expression.resolved. + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined) + + private[this] def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType) + // [[func]] assumes the input is no longer null because eval already does the null check. @inline private[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index f8fe774823e5..bb8fd5032d63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -24,7 +24,6 @@ import java.util.{Calendar, TimeZone} import scala.util.control.NonFatal import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -34,6 +33,9 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} * Common base class for time zone aware expressions. */ trait TimeZoneAwareExpression extends Expression { + /** The expression is only resolved when the time zone has been set. */ + override lazy val resolved: Boolean = + childrenResolved && checkInputDataTypes().isSuccess && timeZoneId.isDefined /** the timezone ID to be used to evaluate value. */ def timeZoneId: Option[String] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index df4d406b84d6..6b90354367f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{ByteArrayOutputStream, CharArrayWriter, StringWriter} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, CharArrayWriter, InputStreamReader, StringWriter} import scala.util.parsing.combinator.RegexParsers @@ -149,7 +149,9 @@ case class GetJsonObject(json: Expression, path: Expression) if (parsed.isDefined) { try { - Utils.tryWithResource(jsonFactory.createParser(jsonStr.getBytes)) { parser => + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, jsonStr)) { parser => val output = new ByteArrayOutputStream() val matched = Utils.tryWithResource( jsonFactory.createGenerator(output, JsonEncoding.UTF8)) { generator => @@ -393,8 +395,10 @@ case class JsonTuple(children: Seq[Expression]) } try { - Utils.tryWithResource(jsonFactory.createParser(json.getBytes)) { - parser => parseRow(parser, input) + /* We know the bytes are UTF-8 encoded. Pass a Reader to avoid having Jackson + detect character encoding which could fail for some malformed strings */ + Utils.tryWithResource(CreateJacksonParser.utf8String(jsonFactory, json)) { parser => + parseRow(parser, input) } } catch { case _: JsonProcessingException => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index f446c3e4a75f..1a202ecf745c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -451,6 +451,8 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param elementNullable When false, indicating elements in the collection are always + * non-null value. * @param customCollectionCls Class of the resulting collection (returning ObjectType) * or None (returning ArrayType) */ @@ -458,11 +460,12 @@ object MapObjects { function: Expression => Expression, inputData: Expression, elementType: DataType, + elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { val id = curId.getAndIncrement() val loopValue = s"MapObjects_loopValue$id" val loopIsNull = s"MapObjects_loopIsNull$id" - val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) + val loopVar = LambdaVariable(loopValue, loopIsNull, elementType, elementNullable) MapObjects( loopValue, loopIsNull, elementType, function(loopVar), inputData, customCollectionCls) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala index e0ed03a68981..025a388aacaa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/CreateJacksonParser.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.json -import java.io.InputStream +import java.io.{ByteArrayInputStream, InputStream, InputStreamReader} import com.fasterxml.jackson.core.{JsonFactory, JsonParser} import org.apache.hadoop.io.Text @@ -33,7 +33,10 @@ private[sql] object CreateJacksonParser extends Serializable { val bb = record.getByteBuffer assert(bb.hasArray) - jsonFactory.createParser(bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + val bain = new ByteArrayInputStream( + bb.array(), bb.arrayOffset() + bb.position(), bb.remaining()) + + jsonFactory.createParser(new InputStreamReader(bain, "UTF-8")) } def text(jsonFactory: JsonFactory, record: Text): JsonParser = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index d221b0611a89..f2b9764b0f08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -119,7 +119,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: SQLConf) CostBasedJoinReorder(conf)) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates(conf)) :: - Batch("Typed Filter Optimization", fixedPoint, + Batch("Object Expressions Optimization", fixedPoint, + EliminateMapObjects, CombineTypedFilters) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, @@ -440,8 +441,7 @@ object ColumnPruning extends Rule[LogicalPlan] { g.copy(child = prunedChild(g.child, g.references)) // Turn off `join` for Generate if no column from it's child is used - case p @ Project(_, g: Generate) - if g.join && !g.outer && p.references.subsetOf(g.generatedSet) => + case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) // Eliminate unneeded attributes from right side of a Left Existence Join. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 8445ee06bd89..34382bd27240 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.expressions.objects.AssertNotNull import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -153,6 +154,11 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { case TrueLiteral Or _ => TrueLiteral case _ Or TrueLiteral => TrueLiteral + case a And b if Not(a).semanticEquals(b) => FalseLiteral + case a Or b if Not(a).semanticEquals(b) => TrueLiteral + case a And b if a.semanticEquals(Not(b)) => FalseLiteral + case a Or b if a.semanticEquals(Not(b)) => TrueLiteral + case a And b if a.semanticEquals(b) => a case a Or b if a.semanticEquals(b) => a @@ -368,6 +374,8 @@ case class NullPropagation(conf: SQLConf) extends Rule[LogicalPlan] { case EqualNullSafe(Literal(null, _), r) => IsNull(r) case EqualNullSafe(l, Literal(null, _)) => IsNull(l) + case AssertNotNull(c, _) if !c.nullable => c + // For Coalesce, remove null literals. case e @ Coalesce(children) => val newChildren = children.filterNot(isNullLiteral) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala index c3ab58744953..2fe303977442 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/joins.scala @@ -134,8 +134,8 @@ case class EliminateOuterJoin(conf: SQLConf) extends Rule[LogicalPlan] with Pred val leftConditions = conditions.filter(_.references.subsetOf(join.left.outputSet)) val rightConditions = conditions.filter(_.references.subsetOf(join.right.outputSet)) - val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) - val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) + lazy val leftHasNonNullPredicate = leftConditions.exists(canFilterOutNull) + lazy val rightHasNonNullPredicate = rightConditions.exists(canFilterOutNull) join.joinType match { case RightOuter if leftHasNonNullPredicate => Inner diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala index 257dbfac8c3e..8cdc6425bcad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -96,3 +97,15 @@ object CombineTypedFilters extends Rule[LogicalPlan] { } } } + +/** + * Removes MapObjects when the following conditions are satisfied + * 1. Mapobject(... lambdavariable(..., false) ...), which means types for input and output + * are primitive types with non-nullable + * 2. no custom collection class specified representation of data item. + */ +object EliminateMapObjects extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case MapObjects(_, _, _, LambdaVariable(_, _, _, false), inputData, None) => inputData + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index e1db1ef5b869..a48a693a95c9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -215,7 +215,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { */ protected def visitNonOptionalPartitionSpec( ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { - visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + visitPartitionSpec(ctx).map { + case (key, None) => throw new ParseException(s"Found an empty partition key '$key'.", ctx) + case (key, Some(value)) => key -> value + } } /** @@ -1488,8 +1491,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case ("decimal", precision :: scale :: Nil) => DecimalType(precision.getText.toInt, scale.getText.toInt) case (dt, params) => - throw new ParseException( - s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + val dtStr = if (params.nonEmpty) s"$dt(${params.mkString(",")})" else dt + throw new ParseException(s"DataType $dtStr is not supported.", ctx) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 80ab75cc17fa..dcccbd0ed8d6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -34,8 +34,7 @@ import org.apache.spark.sql.types.{DataType, StructType} abstract class AbstractSqlParser extends ParserInterface with Logging { /** Creates/Resolves DataType for a given SQL string. */ - def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => - // TODO add this to the parser interface. + override def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => astBuilder.visitSingleDataType(parser.singleDataType()) } @@ -50,8 +49,10 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { } /** Creates FunctionIdentifier for a given SQL string. */ - def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = parse(sqlText) { parser => - astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = { + parse(sqlText) { parser => + astBuilder.visitSingleFunctionIdentifier(parser.singleFunctionIdentifier()) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index db3598bde04d..75240d219622 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -17,30 +17,51 @@ package org.apache.spark.sql.catalyst.parser +import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} /** * Interface for a parser. */ +@DeveloperApi trait ParserInterface { - /** Creates LogicalPlan for a given SQL string. */ + /** + * Parse a string to a [[LogicalPlan]]. + */ + @throws[ParseException]("Text cannot be parsed to a LogicalPlan") def parsePlan(sqlText: String): LogicalPlan - /** Creates Expression for a given SQL string. */ + /** + * Parse a string to an [[Expression]]. + */ + @throws[ParseException]("Text cannot be parsed to an Expression") def parseExpression(sqlText: String): Expression - /** Creates TableIdentifier for a given SQL string. */ + /** + * Parse a string to a [[TableIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a TableIdentifier") def parseTableIdentifier(sqlText: String): TableIdentifier - /** Creates FunctionIdentifier for a given SQL string. */ + /** + * Parse a string to a [[FunctionIdentifier]]. + */ + @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier") def parseFunctionIdentifier(sqlText: String): FunctionIdentifier /** - * Creates StructType for a given SQL string, which is a comma separated list of field - * definitions which will preserve the correct Hive metadata. + * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list + * of field definitions which will preserve the correct Hive metadata. */ + @throws[ParseException]("Text cannot be parsed to a schema") def parseTableSchema(sqlText: String): StructType + + /** + * Parse a string to a [[DataType]]. + */ + @throws[ParseException]("Text cannot be parsed to a DataType") + def parseDataType(sqlText: String): DataType } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3ad757ebba85..f663d7b8a8f7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -83,7 +83,7 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the - * given `generator` is empty. `outer` has no effect when `join` is false. + * given `generator` is empty. * @param qualifier Qualifier for the attributes of generator(UDTF) * @param generatorOutput The output schema of the Generator. * @param child Children logical plan node diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index cc4c0835954b..2109c1c23b70 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -444,6 +444,11 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { case None => Nil case Some(null) => Nil case Some(any) => any :: Nil + case table: CatalogTable => + table.storage.serde match { + case Some(serde) => table.identifier :: serde :: Nil + case _ => table.identifier :: Nil + } case other => other :: Nil }.mkString(", ") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2e1798e22b9f..b24419a41edb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -421,6 +421,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val GROUP_BY_ALIASES = buildConf("spark.sql.groupByAliases") + .doc("When true, aliases in a select list can be used in group by clauses. When false, " + + "an analysis exception is thrown in the case.") + .booleanConf + .createWithDefault(true) + // The output committer class used by data sources. The specified class needs to be a // subclass of org.apache.hadoop.mapreduce.OutputCommitter. val OUTPUT_COMMITTER_CLASS = @@ -1003,6 +1009,8 @@ class SQLConf extends Serializable with Logging { def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) + def groupByAliases: Boolean = getConf(GROUP_BY_ALIASES) + def crossJoinEnabled: Boolean = getConf(SQLConf.CROSS_JOINS_ENABLED) def sessionLocalTimeZone: String = getConf(SQLConf.SESSION_LOCAL_TIMEZONE) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index af1a9cee2962..c6c0a605d89f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -81,4 +81,10 @@ object StaticSQLConf { "SQL configuration and the current database.") .booleanConf .createWithDefault(false) + + val SPARK_SESSION_EXTENSIONS = buildStaticConf("spark.sql.extensions") + .doc("Name of the class used to configure Spark Session extensions. The class should " + + "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") + .stringConf + .createOptional } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index e8f6884c025c..80916ee9c537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -132,14 +132,22 @@ final class Decimal extends Ordered[Decimal] with Serializable { } /** - * Set this Decimal to the given BigInteger value. Will have precision 38 and scale 0. + * If the value is not in the range of long, convert it to BigDecimal and + * the precision and scale are based on the converted value. + * + * This code avoids BigDecimal object allocation as possible to improve runtime efficiency */ def set(bigintval: BigInteger): Decimal = { - this.decimalVal = null - this.longVal = bigintval.longValueExact() - this._precision = DecimalType.MAX_PRECISION - this._scale = 0 - this + try { + this.decimalVal = null + this.longVal = bigintval.longValueExact() + this._precision = DecimalType.MAX_PRECISION + this._scale = 0 + this + } catch { + case _: ArithmeticException => + set(BigDecimal(bigintval)) + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index f45a82686984..d0fe81505225 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types.{LongType, NullType, TimestampType} /** @@ -91,12 +92,13 @@ class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { test("convert TimeZoneAwareExpression") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) - val converted = ResolveInlineTables(conf).convert(table) + val withTimeZone = ResolveTimeZone(conf).apply(table) + val LocalRelation(output, data) = ResolveInlineTables(conf).apply(withTimeZone) val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] - assert(converted.output.map(_.dataType) == Seq(TimestampType)) - assert(converted.data.size == 1) - assert(converted.data(0).getLong(0) == correct) + assert(output.map(_.dataType) == Seq(TimestampType)) + assert(data.size == 1) + assert(data.head.getLong(0) == correct) } test("nullability inference in convert") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 011d09ff6064..2624f5586fd5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -787,6 +788,12 @@ class TypeCoercionSuite extends PlanTest { } } + private val timeZoneResolver = ResolveTimeZone(new SQLConf) + + private def widenSetOperationTypes(plan: LogicalPlan): LogicalPlan = { + timeZoneResolver(TypeCoercion.WidenSetOperationTypes(plan)) + } + test("WidenSetOperationTypes for except and intersect") { val firstTable = LocalRelation( AttributeReference("i", IntegerType)(), @@ -799,11 +806,10 @@ class TypeCoercionSuite extends PlanTest { AttributeReference("f", FloatType)(), AttributeReference("l", LongType)()) - val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val r1 = wt(Except(firstTable, secondTable)).asInstanceOf[Except] - val r2 = wt(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Except(firstTable, secondTable)).asInstanceOf[Except] + val r2 = widenSetOperationTypes(Intersect(firstTable, secondTable)).asInstanceOf[Intersect] checkOutput(r1.left, expectedTypes) checkOutput(r1.right, expectedTypes) checkOutput(r2.left, expectedTypes) @@ -838,10 +844,9 @@ class TypeCoercionSuite extends PlanTest { AttributeReference("p", ByteType)(), AttributeReference("q", DoubleType)()) - val wt = TypeCoercion.WidenSetOperationTypes val expectedTypes = Seq(StringType, DecimalType.SYSTEM_DEFAULT, FloatType, DoubleType) - val unionRelation = wt( + val unionRelation = widenSetOperationTypes( Union(firstTable :: secondTable :: thirdTable :: forthTable :: Nil)).asInstanceOf[Union] assert(unionRelation.children.length == 4) checkOutput(unionRelation.children.head, expectedTypes) @@ -862,17 +867,15 @@ class TypeCoercionSuite extends PlanTest { } } - val dp = TypeCoercion.WidenSetOperationTypes - val left1 = LocalRelation( AttributeReference("l", DecimalType(10, 8))()) val right1 = LocalRelation( AttributeReference("r", DecimalType(5, 5))()) val expectedType1 = Seq(DecimalType(10, 8)) - val r1 = dp(Union(left1, right1)).asInstanceOf[Union] - val r2 = dp(Except(left1, right1)).asInstanceOf[Except] - val r3 = dp(Intersect(left1, right1)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Union(left1, right1)).asInstanceOf[Union] + val r2 = widenSetOperationTypes(Except(left1, right1)).asInstanceOf[Except] + val r3 = widenSetOperationTypes(Intersect(left1, right1)).asInstanceOf[Intersect] checkOutput(r1.children.head, expectedType1) checkOutput(r1.children.last, expectedType1) @@ -891,17 +894,17 @@ class TypeCoercionSuite extends PlanTest { val plan2 = LocalRelation( AttributeReference("r", rType)()) - val r1 = dp(Union(plan1, plan2)).asInstanceOf[Union] - val r2 = dp(Except(plan1, plan2)).asInstanceOf[Except] - val r3 = dp(Intersect(plan1, plan2)).asInstanceOf[Intersect] + val r1 = widenSetOperationTypes(Union(plan1, plan2)).asInstanceOf[Union] + val r2 = widenSetOperationTypes(Except(plan1, plan2)).asInstanceOf[Except] + val r3 = widenSetOperationTypes(Intersect(plan1, plan2)).asInstanceOf[Intersect] checkOutput(r1.children.last, Seq(expectedType)) checkOutput(r2.right, Seq(expectedType)) checkOutput(r3.right, Seq(expectedType)) - val r4 = dp(Union(plan2, plan1)).asInstanceOf[Union] - val r5 = dp(Except(plan2, plan1)).asInstanceOf[Except] - val r6 = dp(Intersect(plan2, plan1)).asInstanceOf[Intersect] + val r4 = widenSetOperationTypes(Union(plan2, plan1)).asInstanceOf[Union] + val r5 = widenSetOperationTypes(Except(plan2, plan1)).asInstanceOf[Except] + val r6 = widenSetOperationTypes(Intersect(plan2, plan1)).asInstanceOf[Intersect] checkOutput(r4.children.last, Seq(expectedType)) checkOutput(r5.left, Seq(expectedType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala new file mode 100644 index 000000000000..2539ea615ff9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogEventSuite.scala @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog + +import java.net.URI +import java.nio.file.{Files, Path} + +import scala.collection.mutable + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.types.StructType + +/** + * Test Suite for external catalog events + */ +class ExternalCatalogEventSuite extends SparkFunSuite { + + protected def newCatalog: ExternalCatalog = new InMemoryCatalog() + + private def testWithCatalog( + name: String)( + f: (ExternalCatalog, Seq[ExternalCatalogEvent] => Unit) => Unit): Unit = test(name) { + val catalog = newCatalog + val recorder = mutable.Buffer.empty[ExternalCatalogEvent] + catalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + recorder += event + } + }) + f(catalog, (expected: Seq[ExternalCatalogEvent]) => { + val actual = recorder.clone() + recorder.clear() + assert(expected === actual) + }) + } + + private def createDbDefinition(uri: URI): CatalogDatabase = { + CatalogDatabase(name = "db5", description = "", locationUri = uri, Map.empty) + } + + private def createDbDefinition(): CatalogDatabase = { + createDbDefinition(preparePath(Files.createTempDirectory("db_"))) + } + + private def preparePath(path: Path): URI = path.normalize().toUri + + testWithCatalog("database") { (catalog, checkEvents) => + // CREATE + val dbDefinition = createDbDefinition() + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createDatabase(dbDefinition, ignoreIfExists = true) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + intercept[AnalysisException] { + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + } + checkEvents(CreateDatabasePreEvent("db5") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropDatabase("db4", ignoreIfNotExists = false, cascade = false) + } + checkEvents(DropDatabasePreEvent("db4") :: Nil) + + catalog.dropDatabase("db5", ignoreIfNotExists = false, cascade = false) + checkEvents(DropDatabasePreEvent("db5") :: DropDatabaseEvent("db5") :: Nil) + + catalog.dropDatabase("db4", ignoreIfNotExists = true, cascade = false) + checkEvents(DropDatabasePreEvent("db4") :: DropDatabaseEvent("db4") :: Nil) + } + + testWithCatalog("table") { (catalog, checkEvents) => + val path1 = Files.createTempDirectory("db_") + val path2 = Files.createTempDirectory(path1, "tbl_") + val uri1 = preparePath(path1) + val uri2 = preparePath(path2) + + // CREATE + val dbDefinition = createDbDefinition(uri1) + + val storage = CatalogStorageFormat.empty.copy( + locationUri = Option(uri2)) + val tableDefinition = CatalogTable( + identifier = TableIdentifier("tbl1", Some("db5")), + tableType = CatalogTableType.MANAGED, + storage = storage, + schema = new StructType().add("id", "long")) + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createTable(tableDefinition, ignoreIfExists = false) + checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil) + + catalog.createTable(tableDefinition, ignoreIfExists = true) + checkEvents(CreateTablePreEvent("db5", "tbl1") :: CreateTableEvent("db5", "tbl1") :: Nil) + + intercept[AnalysisException] { + catalog.createTable(tableDefinition, ignoreIfExists = false) + } + checkEvents(CreateTablePreEvent("db5", "tbl1") :: Nil) + + // RENAME + catalog.renameTable("db5", "tbl1", "tbl2") + checkEvents( + RenameTablePreEvent("db5", "tbl1", "tbl2") :: + RenameTableEvent("db5", "tbl1", "tbl2") :: Nil) + + intercept[AnalysisException] { + catalog.renameTable("db5", "tbl1", "tbl2") + } + checkEvents(RenameTablePreEvent("db5", "tbl1", "tbl2") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropTable("db5", "tbl1", ignoreIfNotExists = false, purge = true) + } + checkEvents(DropTablePreEvent("db5", "tbl1") :: Nil) + + catalog.dropTable("db5", "tbl2", ignoreIfNotExists = false, purge = true) + checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil) + + catalog.dropTable("db5", "tbl2", ignoreIfNotExists = true, purge = true) + checkEvents(DropTablePreEvent("db5", "tbl2") :: DropTableEvent("db5", "tbl2") :: Nil) + } + + testWithCatalog("function") { (catalog, checkEvents) => + // CREATE + val dbDefinition = createDbDefinition() + + val functionDefinition = CatalogFunction( + identifier = FunctionIdentifier("fn7", Some("db5")), + className = "", + resources = Seq.empty) + + val newIdentifier = functionDefinition.identifier.copy(funcName = "fn4") + val renamedFunctionDefinition = functionDefinition.copy(identifier = newIdentifier) + + catalog.createDatabase(dbDefinition, ignoreIfExists = false) + checkEvents(CreateDatabasePreEvent("db5") :: CreateDatabaseEvent("db5") :: Nil) + + catalog.createFunction("db5", functionDefinition) + checkEvents(CreateFunctionPreEvent("db5", "fn7") :: CreateFunctionEvent("db5", "fn7") :: Nil) + + intercept[AnalysisException] { + catalog.createFunction("db5", functionDefinition) + } + checkEvents(CreateFunctionPreEvent("db5", "fn7") :: Nil) + + // RENAME + catalog.renameFunction("db5", "fn7", "fn4") + checkEvents( + RenameFunctionPreEvent("db5", "fn7", "fn4") :: + RenameFunctionEvent("db5", "fn7", "fn4") :: Nil) + intercept[AnalysisException] { + catalog.renameFunction("db5", "fn7", "fn4") + } + checkEvents(RenameFunctionPreEvent("db5", "fn7", "fn4") :: Nil) + + // DROP + intercept[AnalysisException] { + catalog.dropFunction("db5", "fn7") + } + checkEvents(DropFunctionPreEvent("db5", "fn7") :: Nil) + + catalog.dropFunction("db5", "fn4") + checkEvents(DropFunctionPreEvent("db5", "fn4") :: DropFunctionEvent("db5", "fn4") :: Nil) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 9978f35a0381..ca89bf7db0b4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -160,7 +160,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Seconds") { assert(Second(Literal.create(null, DateType), gmtId).resolved === false) - assert(Second(Cast(Literal(d), TimestampType), None).resolved === true) + assert(Second(Cast(Literal(d), TimestampType, gmtId), gmtId).resolved === true) checkEvaluation(Second(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation(Second(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 15) checkEvaluation(Second(Literal(ts), gmtId), 15) @@ -220,7 +220,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Hour") { assert(Hour(Literal.create(null, DateType), gmtId).resolved === false) - assert(Hour(Literal(ts), None).resolved === true) + assert(Hour(Literal(ts), gmtId).resolved === true) checkEvaluation(Hour(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation(Hour(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 13) checkEvaluation(Hour(Literal(ts), gmtId), 13) @@ -246,7 +246,7 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { test("Minute") { assert(Minute(Literal.create(null, DateType), gmtId).resolved === false) - assert(Minute(Literal(ts), None).resolved === true) + assert(Minute(Literal(ts), gmtId).resolved === true) checkEvaluation(Minute(Cast(Literal(d), TimestampType, gmtId), gmtId), 0) checkEvaluation( Minute(Cast(Literal(sdf.format(d)), TimestampType, gmtId), gmtId), 10) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 1ba6dd1c5e8c..b6399edb68dd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -25,10 +25,12 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.analysis.ResolveTimeZone import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} -import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -45,7 +47,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { protected def checkEvaluation( expression: => Expression, expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val serializer = new JavaSerializer(new SparkConf()).newInstance - val expr: Expression = serializer.deserialize(serializer.serialize(expression)) + val resolver = ResolveTimeZone(new SQLConf) + val expr = resolver.resolveTimeZones(serializer.deserialize(serializer.serialize(expression))) val catalystValue = CatalystTypeConverters.convertToCatalyst(expected) checkEvaluationWithoutCodegen(expr, catalystValue, inputRow) checkEvaluationWithGeneratedMutableProjection(expr, catalystValue, inputRow) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index c5b72235e5db..f892e8020460 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -39,6 +39,10 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { |"fb:testid":"1234"} |""".stripMargin + /* invalid json with leading nulls would trigger java.io.CharConversionException + in Jackson's JsonFactory.createParser(byte[]) due to RFC-4627 encoding detection */ + val badJson = "\u0000\u0000\u0000A\u0001AAA" + test("$.store.bicycle") { checkEvaluation( GetJsonObject(Literal(json), Literal("$.store.bicycle")), @@ -224,6 +228,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { null) } + test("SPARK-16548: character conversion") { + checkEvaluation( + GetJsonObject(Literal(badJson), Literal("$.a")), + null + ) + } + test("non foldable literal") { checkEvaluation( GetJsonObject(NonFoldableLiteral(json), NonFoldableLiteral("$.fb:testid")), @@ -340,6 +351,12 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { InternalRow(null, null, null, null, null)) } + test("SPARK-16548: json_tuple - invalid json with leading nulls") { + checkJsonTuple( + JsonTuple(Literal(badJson) :: jsonTupleQuery), + InternalRow(null, null, null, null, null)) + } + test("json_tuple - preserve newlines") { checkJsonTuple( JsonTuple(Literal("{\"a\":\"b\nc\"}") :: Literal("a") :: Nil), @@ -436,6 +453,13 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("SPARK-20549: from_json bad UTF-8") { + val schema = StructType(StructField("a", IntegerType) :: Nil) + checkEvaluation( + JsonToStructs(schema, Map.empty, Literal(badJson), gmtId), + null) + } + test("from_json with timestamp") { val schema = StructType(StructField("t", TimestampType) :: Nil) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index 935bff7cef2e..c275f997ba6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.Row class BooleanSimplificationSuite extends PlanTest with PredicateHelper { @@ -42,6 +43,16 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { val testRelation = LocalRelation('a.int, 'b.int, 'c.int, 'd.string) + val testRelationWithData = LocalRelation.fromExternalRows( + testRelation.output, Seq(Row(1, 2, 3, "abc")) + ) + + private def checkCondition(input: Expression, expected: LogicalPlan): Unit = { + val plan = testRelationWithData.where(input).analyze + val actual = Optimize.execute(plan) + comparePlans(actual, expected) + } + private def checkCondition(input: Expression, expected: Expression): Unit = { val plan = testRelation.where(input).analyze val actual = Optimize.execute(plan) @@ -160,4 +171,12 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { testRelation.where('a > 2 || ('b > 3 && 'b < 5))) comparePlans(actual, expected) } + + test("Complementation Laws") { + checkCondition('a && !'a, testRelation) + checkCondition(!'a && 'a, testRelation) + + checkCondition('a || !'a, testRelationWithData) + checkCondition(!'a || 'a, testRelationWithData) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala new file mode 100644 index 000000000000..d4f37e2a5e87 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateMapObjectsSuite.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.objects.Invoke +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{DeserializeToObject, LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types._ + +class EliminateMapObjectsSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = { + Batch("EliminateMapObjects", FixedPoint(50), + NullPropagation(conf), + SimplifyCasts, + EliminateMapObjects) :: Nil + } + } + + implicit private def intArrayEncoder = ExpressionEncoder[Array[Int]]() + implicit private def doubleArrayEncoder = ExpressionEncoder[Array[Double]]() + + test("SPARK-20254: Remove unnecessary data conversion for primitive array") { + val intObjType = ObjectType(classOf[Array[Int]]) + val intInput = LocalRelation('a.array(ArrayType(IntegerType, false))) + val intQuery = intInput.deserialize[Array[Int]].analyze + val intOptimized = Optimize.execute(intQuery) + val intExpected = DeserializeToObject( + Invoke(intInput.output(0), "toIntArray", intObjType, Nil, true, false), + AttributeReference("obj", intObjType, true)(), intInput) + comparePlans(intOptimized, intExpected) + + val doubleObjType = ObjectType(classOf[Array[Double]]) + val doubleInput = LocalRelation('a.array(ArrayType(DoubleType, false))) + val doubleQuery = doubleInput.deserialize[Array[Double]].analyze + val doubleOptimized = Optimize.execute(doubleQuery) + val doubleExpected = DeserializeToObject( + Invoke(doubleInput.output(0), "toDoubleArray", doubleObjType, Nil, true, false), + AttributeReference("obj", doubleObjType, true)(), doubleInput) + comparePlans(doubleOptimized, doubleExpected) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 3964fa3924b2..449052336900 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -30,7 +30,7 @@ class DataTypeParserSuite extends SparkFunSuite { } } - def intercept(sql: String): Unit = + def intercept(sql: String): ParseException = intercept[ParseException](CatalystSqlParser.parseDataType(sql)) def unsupported(dataTypeString: String): Unit = { @@ -118,6 +118,11 @@ class DataTypeParserSuite extends SparkFunSuite { unsupported("struct") + test("Do not print empty parentheses for no params") { + assert(intercept("unkwon").getMessage.contains("unkwon is not supported")) + assert(intercept("unkwon(1,2,3)").getMessage.contains("unkwon(1,2,3) is not supported")) + } + // DataType parser accepts certain reserved keywords. checkDataType( "Struct", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 714883a4099c..93c231e30b49 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -212,4 +212,10 @@ class DecimalSuite extends SparkFunSuite with PrivateMethodTester { } } } + + test("SPARK-20341: support BigInt's value does not fit in long value range") { + val bigInt = scala.math.BigInt("9223372036854775808") + val decimal = Decimal.apply(bigInt) + assert(decimal.toJavaBigDecimal.unscaledValue.toString === "9223372036854775808") + } } diff --git a/sql/core/pom.xml b/sql/core/pom.xml index b203f31a76f0..c9ac366ed6e6 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index eb97118872ea..5a810cae1e18 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -66,7 +66,6 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.StructType$; import org.apache.spark.util.AccumulatorV2; -import org.apache.spark.util.LongAccumulator; /** * Base class for custom RecordReaders for Parquet that directly materialize to `T`. @@ -153,14 +152,16 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont } // For test purpose. - // If the predefined accumulator exists, the row group number to read will be updated - // to the accumulator. So we can check if the row groups are filtered or not in test case. + // If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read + // will be updated to the accumulator. So we can check if the row groups are filtered or not + // in test case. TaskContext taskContext = TaskContext$.MODULE$.get(); if (taskContext != null) { - Option> accu = taskContext.taskMetrics() - .lookForAccumulatorByName("numRowGroups"); - if (accu.isDefined()) { - ((LongAccumulator)accu.get()).add((long)blocks.size()); + Option> accu = taskContext.taskMetrics().externalAccums().lastOption(); + if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { + @SuppressWarnings("unchecked") + AccumulatorV2 intAccum = (AccumulatorV2) accu.get(); + intAccum.add(blocks.size()); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java index 354c878aca00..b105e60a2d34 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java @@ -180,7 +180,7 @@ public Object[] array() { @Override public boolean getBoolean(int ordinal) { - throw new UnsupportedOperationException(); + return data.getBoolean(offset + ordinal); } @Override @@ -188,7 +188,7 @@ public boolean getBoolean(int ordinal) { @Override public short getShort(int ordinal) { - throw new UnsupportedOperationException(); + return data.getShort(offset + ordinal); } @Override @@ -199,7 +199,7 @@ public short getShort(int ordinal) { @Override public float getFloat(int ordinal) { - throw new UnsupportedOperationException(); + return data.getFloat(offset + ordinal); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e988c0722bd7..a7d3744d00e9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -436,28 +436,29 @@ public void loadBytes(ColumnVector.Array array) { // Split out the slow path. @Override protected void reserveInternal(int newCapacity) { + int oldCapacity = (this.data == 0L) ? 0 : capacity; if (this.resultArray != null) { this.lengthData = - Platform.reallocateMemory(lengthData, elementsAppended * 4, newCapacity * 4); + Platform.reallocateMemory(lengthData, oldCapacity * 4, newCapacity * 4); this.offsetData = - Platform.reallocateMemory(offsetData, elementsAppended * 4, newCapacity * 4); + Platform.reallocateMemory(offsetData, oldCapacity * 4, newCapacity * 4); } else if (type instanceof ByteType || type instanceof BooleanType) { - this.data = Platform.reallocateMemory(data, elementsAppended, newCapacity); + this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 2, newCapacity * 2); + this.data = Platform.reallocateMemory(data, oldCapacity * 2, newCapacity * 2); } else if (type instanceof IntegerType || type instanceof FloatType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { - this.data = Platform.reallocateMemory(data, elementsAppended * 4, newCapacity * 4); + this.data = Platform.reallocateMemory(data, oldCapacity * 4, newCapacity * 4); } else if (type instanceof LongType || type instanceof DoubleType || DecimalType.is64BitDecimalType(type) || type instanceof TimestampType) { - this.data = Platform.reallocateMemory(data, elementsAppended * 8, newCapacity * 8); + this.data = Platform.reallocateMemory(data, oldCapacity * 8, newCapacity * 8); } else if (resultStruct != null) { // Nothing to store. } else { throw new RuntimeException("Unhandled " + type); } - this.nulls = Platform.reallocateMemory(nulls, elementsAppended, newCapacity); - Platform.setMemory(nulls + elementsAppended, (byte)0, newCapacity - elementsAppended); + this.nulls = Platform.reallocateMemory(nulls, oldCapacity, newCapacity); + Platform.setMemory(nulls + oldCapacity, (byte)0, newCapacity - oldCapacity); capacity = newCapacity; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 9b410bacff5d..94ed32294cfa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -410,53 +410,53 @@ protected void reserveInternal(int newCapacity) { int[] newLengths = new int[newCapacity]; int[] newOffsets = new int[newCapacity]; if (this.arrayLengths != null) { - System.arraycopy(this.arrayLengths, 0, newLengths, 0, elementsAppended); - System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, elementsAppended); + System.arraycopy(this.arrayLengths, 0, newLengths, 0, capacity); + System.arraycopy(this.arrayOffsets, 0, newOffsets, 0, capacity); } arrayLengths = newLengths; arrayOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity); byteData = newData; } } else if (type instanceof ByteType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; - if (byteData != null) System.arraycopy(byteData, 0, newData, 0, elementsAppended); + if (byteData != null) System.arraycopy(byteData, 0, newData, 0, capacity); byteData = newData; } } else if (type instanceof ShortType) { if (shortData == null || shortData.length < newCapacity) { short[] newData = new short[newCapacity]; - if (shortData != null) System.arraycopy(shortData, 0, newData, 0, elementsAppended); + if (shortData != null) System.arraycopy(shortData, 0, newData, 0, capacity); shortData = newData; } } else if (type instanceof IntegerType || type instanceof DateType || DecimalType.is32BitDecimalType(type)) { if (intData == null || intData.length < newCapacity) { int[] newData = new int[newCapacity]; - if (intData != null) System.arraycopy(intData, 0, newData, 0, elementsAppended); + if (intData != null) System.arraycopy(intData, 0, newData, 0, capacity); intData = newData; } } else if (type instanceof LongType || type instanceof TimestampType || DecimalType.is64BitDecimalType(type)) { if (longData == null || longData.length < newCapacity) { long[] newData = new long[newCapacity]; - if (longData != null) System.arraycopy(longData, 0, newData, 0, elementsAppended); + if (longData != null) System.arraycopy(longData, 0, newData, 0, capacity); longData = newData; } } else if (type instanceof FloatType) { if (floatData == null || floatData.length < newCapacity) { float[] newData = new float[newCapacity]; - if (floatData != null) System.arraycopy(floatData, 0, newData, 0, elementsAppended); + if (floatData != null) System.arraycopy(floatData, 0, newData, 0, capacity); floatData = newData; } } else if (type instanceof DoubleType) { if (doubleData == null || doubleData.length < newCapacity) { double[] newData = new double[newCapacity]; - if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, elementsAppended); + if (doubleData != null) System.arraycopy(doubleData, 0, newData, 0, capacity); doubleData = newData; } } else if (resultStruct != null) { @@ -466,7 +466,7 @@ protected void reserveInternal(int newCapacity) { } byte[] newNulls = new byte[newCapacity]; - if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, elementsAppended); + if (nulls != null) System.arraycopy(nulls, 0, newNulls, 0, capacity); nulls = newNulls; capacity = newCapacity; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 520663f62440..5f602dc25fb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1073,6 +1073,22 @@ class Dataset[T] private[sql]( */ def apply(colName: String): Column = col(colName) + /** + * Specifies some hint on the current Dataset. As an example, the following code specifies + * that one of the plan can be broadcasted: + * + * {{{ + * df1.join(df2.hint("broadcast")) + * }}} + * + * @group basic + * @since 2.2.0 + */ + @scala.annotation.varargs + def hint(name: String, parameters: String*): Dataset[T] = withTypedPlan { + Hint(name, parameters, logicalPlan) + } + /** * Selects column based on the column name and return it as a [[Column]]. * @@ -1726,15 +1742,23 @@ class Dataset[T] private[sql]( // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the - // ordering deterministic. - // MapType cannot be sorted. - val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType]) - .map(SortOrder(_, Ascending)), global = false, logicalPlan) + // ordering deterministic. Note that MapTypes cannot be sorted and are explicitly pruned out + // from the sort order. + val sortOrder = logicalPlan.output + .filter(attr => RowOrdering.isOrderable(attr.dataType)) + .map(SortOrder(_, Ascending)) + val plan = if (sortOrder.nonEmpty) { + Sort(sortOrder, global = false, logicalPlan) + } else { + // SPARK-12662: If sort order is empty, we materialize the dataset to guarantee determinism + cache() + logicalPlan + } val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => new Dataset[T]( - sparkSession, Sample(x(0), x(1), withReplacement = false, seed, sorted)(), encoder) + sparkSession, Sample(x(0), x(1), withReplacement = false, seed, plan)(), encoder) }.toArray } @@ -2778,7 +2802,7 @@ class Dataset[T] private[sql]( * Wrap a Dataset action to track all Spark jobs in the body so that we can connect them with * an execution. */ - private[sql] def withNewExecutionId[U](body: => U): U = { + private def withNewExecutionId[U](body: => U): U = { SQLExecution.withNewExecutionId(sparkSession, queryExecution)(body) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 95f3463dfe62..a519492ed8f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Range} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.internal.{BaseSessionStateBuilder, CatalogImpl, SessionState, SessionStateBuilder, SharedState} +import org.apache.spark.sql.internal._ import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.streaming._ @@ -77,11 +77,12 @@ import org.apache.spark.util.Utils class SparkSession private( @transient val sparkContext: SparkContext, @transient private val existingSharedState: Option[SharedState], - @transient private val parentSessionState: Option[SessionState]) + @transient private val parentSessionState: Option[SessionState], + @transient private[sql] val extensions: SparkSessionExtensions) extends Serializable with Closeable with Logging { self => private[sql] def this(sc: SparkContext) { - this(sc, None, None) + this(sc, None, None, new SparkSessionExtensions) } sparkContext.assertNotStopped() @@ -219,7 +220,7 @@ class SparkSession private( * @since 2.0.0 */ def newSession(): SparkSession = { - new SparkSession(sparkContext, Some(sharedState), parentSessionState = None) + new SparkSession(sparkContext, Some(sharedState), parentSessionState = None, extensions) } /** @@ -235,7 +236,7 @@ class SparkSession private( * implementation is Hive, this will initialize the metastore, which may take some time. */ private[sql] def cloneSession(): SparkSession = { - val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState)) + val result = new SparkSession(sparkContext, Some(sharedState), Some(sessionState), extensions) result.sessionState // force copy of SessionState result } @@ -754,6 +755,8 @@ object SparkSession { private[this] val options = new scala.collection.mutable.HashMap[String, String] + private[this] val extensions = new SparkSessionExtensions + private[this] var userSuppliedContext: Option[SparkContext] = None private[spark] def sparkContext(sparkContext: SparkContext): Builder = synchronized { @@ -847,6 +850,17 @@ object SparkSession { } } + /** + * Inject extensions into the [[SparkSession]]. This allows a user to add Analyzer rules, + * Optimizer rules, Planning Strategies or a customized parser. + * + * @since 2.2.0 + */ + def withExtensions(f: SparkSessionExtensions => Unit): Builder = { + f(extensions) + this + } + /** * Gets an existing [[SparkSession]] or, if there is no existing one, creates a new * one based on the options set in this builder. @@ -903,7 +917,26 @@ object SparkSession { } sc } - session = new SparkSession(sparkContext) + + // Initialize extensions if the user has defined a configurator class. + val extensionConfOption = sparkContext.conf.get(StaticSQLConf.SPARK_SESSION_EXTENSIONS) + if (extensionConfOption.isDefined) { + val extensionConfClassName = extensionConfOption.get + try { + val extensionConfClass = Utils.classForName(extensionConfClassName) + val extensionConf = extensionConfClass.newInstance() + .asInstanceOf[SparkSessionExtensions => Unit] + extensionConf(extensions) + } catch { + // Ignore the error if we cannot find the class or when the class has the wrong type. + case e @ (_: ClassCastException | + _: ClassNotFoundException | + _: NoClassDefFoundError) => + logWarning(s"Cannot use $extensionConfClassName to configure session extensions.", e) + } + } + + session = new SparkSession(sparkContext, None, None, extensions) options.foreach { case (k, v) => session.sessionState.conf.setConfString(k, v) } defaultSession.set(session) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala new file mode 100644 index 000000000000..f99c108161f9 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSessionExtensions.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.collection.mutable + +import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * :: Experimental :: + * Holder for injection points to the [[SparkSession]]. We make NO guarantee about the stability + * regarding binary compatibility and source compatibility of methods here. + * + * This current provides the following extension points: + * - Analyzer Rules. + * - Check Analysis Rules + * - Optimizer Rules. + * - Planning Strategies. + * - Customized Parser. + * - (External) Catalog listeners. + * + * The extensions can be used by calling withExtension on the [[SparkSession.Builder]], for + * example: + * {{{ + * SparkSession.builder() + * .master("...") + * .conf("...", true) + * .withExtensions { extensions => + * extensions.injectResolutionRule { session => + * ... + * } + * extensions.injectParser { (session, parser) => + * ... + * } + * } + * .getOrCreate() + * }}} + * + * Note that none of the injected builders should assume that the [[SparkSession]] is fully + * initialized and should not touch the session's internals (e.g. the SessionState). + */ +@DeveloperApi +@Experimental +@InterfaceStability.Unstable +class SparkSessionExtensions { + type RuleBuilder = SparkSession => Rule[LogicalPlan] + type CheckRuleBuilder = SparkSession => LogicalPlan => Unit + type StrategyBuilder = SparkSession => Strategy + type ParserBuilder = (SparkSession, ParserInterface) => ParserInterface + + private[this] val resolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + resolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer resolution `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed as part of the resolution phase of analysis. + */ + def injectResolutionRule(builder: RuleBuilder): Unit = { + resolutionRuleBuilders += builder + } + + private[this] val postHocResolutionRuleBuilders = mutable.Buffer.empty[RuleBuilder] + + /** + * Build the analyzer post-hoc resolution `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildPostHocResolutionRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + postHocResolutionRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an analyzer `Rule` builder into the [[SparkSession]]. These analyzer + * rules will be executed after resolution. + */ + def injectPostHocResolutionRule(builder: RuleBuilder): Unit = { + postHocResolutionRuleBuilders += builder + } + + private[this] val checkRuleBuilders = mutable.Buffer.empty[CheckRuleBuilder] + + /** + * Build the check analysis `Rule`s using the given [[SparkSession]]. + */ + private[sql] def buildCheckRules(session: SparkSession): Seq[LogicalPlan => Unit] = { + checkRuleBuilders.map(_.apply(session)) + } + + /** + * Inject an check analysis `Rule` builder into the [[SparkSession]]. The injected rules will + * be executed after the analysis phase. A check analysis rule is used to detect problems with a + * LogicalPlan and should throw an exception when a problem is found. + */ + def injectCheckRule(builder: CheckRuleBuilder): Unit = { + checkRuleBuilders += builder + } + + private[this] val optimizerRules = mutable.Buffer.empty[RuleBuilder] + + private[sql] def buildOptimizerRules(session: SparkSession): Seq[Rule[LogicalPlan]] = { + optimizerRules.map(_.apply(session)) + } + + /** + * Inject an optimizer `Rule` builder into the [[SparkSession]]. The injected rules will be + * executed during the operator optimization batch. An optimizer rule is used to improve the + * quality of an analyzed logical plan; these rules should never modify the result of the + * LogicalPlan. + */ + def injectOptimizerRule(builder: RuleBuilder): Unit = { + optimizerRules += builder + } + + private[this] val plannerStrategyBuilders = mutable.Buffer.empty[StrategyBuilder] + + private[sql] def buildPlannerStrategies(session: SparkSession): Seq[Strategy] = { + plannerStrategyBuilders.map(_.apply(session)) + } + + /** + * Inject a planner `Strategy` builder into the [[SparkSession]]. The injected strategy will + * be used to convert a `LogicalPlan` into a executable + * [[org.apache.spark.sql.execution.SparkPlan]]. + */ + def injectPlannerStrategy(builder: StrategyBuilder): Unit = { + plannerStrategyBuilders += builder + } + + private[this] val parserBuilders = mutable.Buffer.empty[ParserBuilder] + + private[sql] def buildParser( + session: SparkSession, + initial: ParserInterface): ParserInterface = { + parserBuilders.foldLeft(initial) { (parser, builder) => + builder(session, parser) + } + } + + /** + * Inject a custom parser into the [[SparkSession]]. Note that the builder is passed a session + * and an initial parser. The latter allows for a user to create a partial parser and to delegate + * to the underlying parser for completeness. If a user injects more parsers, then the parsers + * are stacked on top of each other. + */ + def injectParser(builder: ParserBuilder): Unit = { + parserBuilders += builder + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala index f87d05884b27..c35e5638e927 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GenerateExec.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} private[execution] sealed case class LazyIterator(func: () => TraversableOnce[InternalRow]) extends Iterator[InternalRow] { - lazy val results = func().toIterator + lazy val results: Iterator[InternalRow] = func().toIterator override def hasNext: Boolean = results.hasNext override def next(): InternalRow = results.next() } @@ -50,7 +50,7 @@ private[execution] sealed case class LazyIterator(func: () => TraversableOnce[In * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the - * given `generator` is empty. `outer` has no effect when `join` is false. + * given `generator` is empty. * @param generatorOutput the qualified output attributes of the generator of this node, which * constructed in analysis phase, and we can not change it, as the * parent node bound with it already. @@ -78,15 +78,15 @@ case class GenerateExec( override def outputPartitioning: Partitioning = child.outputPartitioning - val boundGenerator = BindReferences.bindReference(generator, child.output) + lazy val boundGenerator: Generator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { // boundGenerator.terminate() should be triggered after all of the rows in the partition - val rows = if (join) { - child.execute().mapPartitionsInternal { iter => - val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) + val numOutputRows = longMetric("numOutputRows") + child.execute().mapPartitionsWithIndexInternal { (index, iter) => + val generatorNullRow = new GenericInternalRow(generator.elementSchema.length) + val rows = if (join) { val joinedRow = new JoinedRow - iter.flatMap { row => // we should always set the left (child output) joinedRow.withLeft(row) @@ -101,18 +101,21 @@ case class GenerateExec( // keep it the same as Hive does joinedRow.withRight(row) } + } else { + iter.flatMap { row => + val outputRows = boundGenerator.eval(row) + if (outer && outputRows.isEmpty) { + Seq(generatorNullRow) + } else { + outputRows + } + } ++ LazyIterator(boundGenerator.terminate) } - } else { - child.execute().mapPartitionsInternal { iter => - iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate) - } - } - val numOutputRows = longMetric("numOutputRows") - rows.mapPartitionsWithIndexInternal { (index, iter) => + // Convert the rows to unsafe rows. val proj = UnsafeProjection.create(output, output) proj.initialize(index) - iter.map { r => + rows.map { r => numOutputRows += 1 proj(r) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 6566502bd8a8..4e718d609c92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -36,7 +36,7 @@ class SparkPlanner( experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( FileSourceStrategy :: - DataSourceStrategy :: + DataSourceStrategy(conf) :: SpecialLimits :: Aggregation :: JoinSelection :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 44278e37c527..64698d552757 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -331,10 +331,11 @@ case class SampleExec( case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) extends LeafExecNode with CodegenSupport { - def start: Long = range.start - def step: Long = range.step - def numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) - def numElements: BigInt = range.numElements + val start: Long = range.start + val end: Long = range.end + val step: Long = range.step + val numSlices: Int = range.numSlices.getOrElse(sparkContext.defaultParallelism) + val numElements: BigInt = range.numElements override val output: Seq[Attribute] = range.output @@ -463,9 +464,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) | $number = $batchEnd; | } | - | if ($taskContext.isInterrupted()) { - | throw new TaskKilledException(); - | } + | $taskContext.killTaskIfInterrupted(); | | long $nextBatchTodo; | if ($numElementsTodo > ${batchSize}L) { @@ -540,7 +539,7 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) } } - override def simpleString: String = range.simpleString + override def simpleString: String = s"Range ($start, $end, step=$step, splits=$numSlices)" } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 214e8d309de1..7063b08f7c64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -42,7 +42,9 @@ case class InMemoryTableScanExec( override def output: Seq[Attribute] = attributes private def updateAttribute(expr: Expression): Expression = { - val attrMap = AttributeMap(relation.child.output.zip(output)) + // attributes can be pruned so using relation's output. + // E.g., relation.output is [id, item] but this scan's output can be [item] only. + val attrMap = AttributeMap(relation.child.output.zip(relation.output)) expr.transform { case attr: Attribute => attrMap.getOrElse(attr, attr) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 2d83d512e702..d307122b5c70 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, QualifiedTableName} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils} @@ -48,7 +48,7 @@ import org.apache.spark.unsafe.types.UTF8String * Note that, this rule must be run after `PreprocessTableCreation` and * `PreprocessTableInsertion`. */ -case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { +case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { def resolver: Resolver = conf.resolver @@ -98,11 +98,11 @@ case class DataSourceAnalysis(conf: SQLConf) extends Rule[LogicalPlan] { val potentialSpecs = staticPartitions.filter { case (partKey, partValue) => resolver(field.name, partKey) } - if (potentialSpecs.size == 0) { + if (potentialSpecs.isEmpty) { None } else if (potentialSpecs.size == 1) { val partValue = potentialSpecs.head._2 - Some(Alias(Cast(Literal(partValue), field.dataType), field.name)()) + Some(Alias(cast(Literal(partValue), field.dataType), field.name)()) } else { throw new AnalysisException( s"Partition column ${field.name} have multiple values specified, " + @@ -258,7 +258,9 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] /** * A Strategy for planning scans over data sources defined using the sources API. */ -object DataSourceStrategy extends Strategy with Logging { +case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with CastSupport { + import DataSourceStrategy._ + def apply(plan: LogicalPlan): Seq[execution.SparkPlan] = plan match { case PhysicalOperation(projects, filters, l @ LogicalRelation(t: CatalystScan, _, _)) => pruneFilterProjectRaw( @@ -298,7 +300,7 @@ object DataSourceStrategy extends Strategy with Logging { // Restriction: Bucket pruning works iff the bucketing column has one and only one column. def getBucketId(bucketColumn: Attribute, numBuckets: Int, value: Any): Int = { val mutableRow = new SpecificInternalRow(Seq(bucketColumn.dataType)) - mutableRow(0) = Cast(Literal(value), bucketColumn.dataType).eval(null) + mutableRow(0) = cast(Literal(value), bucketColumn.dataType).eval(null) val bucketIdGeneration = UnsafeProjection.create( HashPartitioning(bucketColumn :: Nil, numBuckets).partitionIdExpression :: Nil, bucketColumn :: Nil) @@ -436,7 +438,9 @@ object DataSourceStrategy extends Strategy with Logging { private[this] def toCatalystRDD(relation: LogicalRelation, rdd: RDD[Row]): RDD[InternalRow] = { toCatalystRDD(relation, relation.output, rdd) } +} +object DataSourceStrategy { /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * @@ -527,8 +531,8 @@ object DataSourceStrategy extends Strategy with Logging { * all [[Filter]]s that are completely filtered at the DataSource. */ protected[sql] def selectFilters( - relation: BaseRelation, - predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { + relation: BaseRelation, + predicates: Seq[Expression]): (Seq[Expression], Seq[Filter], Set[Filter]) = { // For conciseness, all Catalyst filter expressions of type `expressions.Expression` below are // called `predicate`s, while all data source filters of type `sources.Filter` are simply called diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 9897ab73b0da..91e31650617e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics +import org.apache.spark.sql.execution.streaming.FileStreamSink import org.apache.spark.sql.SparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.util.SerializableConfiguration @@ -36,20 +37,28 @@ import org.apache.spark.util.SerializableConfiguration * A [[FileIndex]] that generates the list of files to process by recursively listing all the * files present in `paths`. * - * @param rootPaths the list of root table paths to scan + * @param rootPathsSpecified the list of root table paths to scan (some of which might be + * filtered out later) * @param parameters as set of options to control discovery * @param partitionSchema an optional partition schema that will be use to provide types for the * discovered partitions */ class InMemoryFileIndex( sparkSession: SparkSession, - override val rootPaths: Seq[Path], + rootPathsSpecified: Seq[Path], parameters: Map[String, String], partitionSchema: Option[StructType], fileStatusCache: FileStatusCache = NoopCache) extends PartitioningAwareFileIndex( sparkSession, parameters, partitionSchema, fileStatusCache) { + // Filter out streaming metadata dirs or files such as "/.../_spark_metadata" (the metadata dir) + // or "/.../_spark_metadata/0" (a file in the metadata dir). `rootPathsSpecified` might contain + // such streaming metadata dir or files, e.g. when after globbing "basePath/*" where "basePath" + // is the output of a streaming query. + override val rootPaths = + rootPathsSpecified.filterNot(FileStreamSink.ancestorIsMetadataDirectory(_, hadoopConf)) + @volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _ @volatile private var cachedLeafDirToChildrenFiles: Map[Path, Array[FileStatus]] = _ @volatile private var cachedPartitionSpec: PartitionSpec = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index c3583209efc5..2d70172487e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -243,7 +243,7 @@ object PartitioningUtils { if (equalSignIndex == -1) { None } else { - val columnName = columnSpec.take(equalSignIndex) + val columnName = unescapePathName(columnSpec.take(equalSignIndex)) assert(columnName.nonEmpty, s"Empty partition column name in '$columnSpec'") val rawColumnValue = columnSpec.drop(equalSignIndex + 1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 5fc3c2753b6c..0183805d5625 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -652,8 +652,17 @@ object JdbcUtils extends Logging { case e: SQLException => val cause = e.getNextException if (cause != null && e.getCause != cause) { + // If there is no cause already, set 'next exception' as cause. If cause is null, + // it *may* be because no cause was set yet if (e.getCause == null) { - e.initCause(cause) + try { + e.initCause(cause) + } catch { + // Or it may be null because the cause *was* explicitly initialized, to *null*, + // in which case this fails. There is no other way to detect it. + // addSuppressed in this case as well. + case _: IllegalStateException => e.addSuppressed(cause) + } } else { e.addSuppressed(cause) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 7abf2ae5166b..3f4a78580f1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -22,7 +22,7 @@ import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode, SparkSession} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, RowOrdering} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, RowOrdering} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.command.DDLUtils @@ -315,7 +315,7 @@ case class PreprocessTableCreation(sparkSession: SparkSession) extends Rule[Logi * table. It also does data type casting and field renaming, to make sure that the columns to be * inserted have the correct data type and fields have the correct names. */ -case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { +case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { private def preprocess( insert: InsertIntoTable, tblName: String, @@ -367,7 +367,7 @@ case class PreprocessTableInsertion(conf: SQLConf) extends Rule[LogicalPlan] { // Renaming is needed for handling the following cases like // 1) Column names/types do not match, e.g., INSERT INTO TABLE tab1 SELECT 1, 2 // 2) Target tables have column metadata - Alias(Cast(actual, expected.dataType), expected.name)( + Alias(cast(actual, expected.dataType), expected.name)( explicitMetadata = Option(expected.metadata)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 07ec4e9429e4..6885d0bf67cc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -53,6 +53,26 @@ object FileStreamSink extends Logging { case _ => false } } + + /** + * Returns true if the path is the metadata dir or its ancestor is the metadata dir. + * E.g.: + * - ancestorIsMetadataDirectory(/.../_spark_metadata) => true + * - ancestorIsMetadataDirectory(/.../_spark_metadata/0) => true + * - ancestorIsMetadataDirectory(/a/b/c) => false + */ + def ancestorIsMetadataDirectory(path: Path, hadoopConf: Configuration): Boolean = { + val fs = path.getFileSystem(hadoopConf) + var currentPath = path.makeQualified(fs.getUri, fs.getWorkingDirectory) + while (currentPath != null) { + if (currentPath.getName == FileStreamSink.metadataDir) { + return true + } else { + currentPath = currentPath.getParent + } + } + return false + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index bcf0d970f7ec..b6ddf7437ea1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -23,6 +23,7 @@ import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.locks.ReentrantLock +import scala.collection.mutable.{Map => MutableMap} import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -148,15 +149,18 @@ class StreamExecution( "logicalPlan must be initialized in StreamExecutionThread " + s"but the current thread was ${Thread.currentThread}") var nextSourceId = 0L + val toExecutionRelationMap = MutableMap[StreamingRelation, StreamingExecutionRelation]() val _logicalPlan = analyzedPlan.transform { - case StreamingRelation(dataSource, _, output) => - // Materialize source to avoid creating it in every batch - val metadataPath = s"$checkpointRoot/sources/$nextSourceId" - val source = dataSource.createSource(metadataPath) - nextSourceId += 1 - // We still need to use the previous `output` instead of `source.schema` as attributes in - // "df.logicalPlan" has already used attributes of the previous `output`. - StreamingExecutionRelation(source, output) + case streamingRelation@StreamingRelation(dataSource, _, output) => + toExecutionRelationMap.getOrElseUpdate(streamingRelation, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$checkpointRoot/sources/$nextSourceId" + val source = dataSource.createSource(metadataPath) + nextSourceId += 1 + // We still need to use the previous `output` instead of `source.schema` as attributes in + // "df.logicalPlan" has already used attributes of the previous `output`. + StreamingExecutionRelation(source, output) + }) } sources = _logicalPlan.collect { case s: StreamingExecutionRelation => s.source } uniqueSources = sources.distinct @@ -252,6 +256,8 @@ class StreamExecution( */ private def runBatches(): Unit = { try { + sparkSession.sparkContext.setJobGroup(runId.toString, getBatchDescriptionString, + interruptOnCancel = true) if (sparkSession.sessionState.conf.streamingMetricsEnabled) { sparkSession.sparkContext.env.metricsSystem.registerSource(streamMetrics) } @@ -289,6 +295,7 @@ class StreamExecution( if (currentBatchId < 0) { // We'll do this initialization only once populateStartOffsets(sparkSessionToRunBatches) + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) logDebug(s"Stream running from $committedOffsets to $availableOffsets") } else { constructNextBatch() @@ -308,6 +315,7 @@ class StreamExecution( logDebug(s"batch ${currentBatchId} committed") // We'll increase currentBatchId after we complete processing current batch's data currentBatchId += 1 + sparkSession.sparkContext.setJobDescription(getBatchDescriptionString) } else { currentStatus = currentStatus.copy(isDataAvailable = false) updateStatusMessage("Waiting for data to arrive") @@ -684,8 +692,11 @@ class StreamExecution( // intentionally state.set(TERMINATED) if (microBatchThread.isAlive) { + sparkSession.sparkContext.cancelJobGroup(runId.toString) microBatchThread.interrupt() microBatchThread.join() + // microBatchThread may spawn new jobs, so we need to cancel again to prevent a leak + sparkSession.sparkContext.cancelJobGroup(runId.toString) } logInfo(s"Query $prettyIdString was stopped") } @@ -825,6 +836,11 @@ class StreamExecution( } } + private def getBatchDescriptionString: String = { + val batchDescription = if (currentBatchId < 0) "init" else currentBatchId.toString + Option(name).map(_ + "
    ").getOrElse("") + + s"id = $id
    runId = $runId
    batch = $batchDescription" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 2b14eca919fa..2a801d87b12e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.internal import org.apache.spark.SparkConf import org.apache.spark.annotation.{Experimental, InterfaceStability} -import org.apache.spark.sql.{ExperimentalMethods, SparkSession, Strategy, UDFRegistration} +import org.apache.spark.sql.{ExperimentalMethods, SparkSession, UDFRegistration, _} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer @@ -63,6 +63,11 @@ abstract class BaseSessionStateBuilder( */ protected def newBuilder: NewBuilder + /** + * Session extensions defined in the [[SparkSession]]. + */ + protected def extensions: SparkSessionExtensions = session.extensions + /** * Extract entries from `SparkConf` and put them in the `SQLConf` */ @@ -108,7 +113,9 @@ abstract class BaseSessionStateBuilder( * * Note: this depends on the `conf` field. */ - protected lazy val sqlParser: ParserInterface = new SparkSqlParser(conf) + protected lazy val sqlParser: ParserInterface = { + extensions.buildParser(session, new SparkSqlParser(conf)) + } /** * ResourceLoader that is used to load function resources and jars. @@ -171,7 +178,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customResolutionRules: Seq[Rule[LogicalPlan]] = Nil + protected def customResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildResolutionRules(session) + } /** * Custom post resolution rules to add to the Analyzer. Prefer overriding this instead of @@ -179,7 +188,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = Nil + protected def customPostHocResolutionRules: Seq[Rule[LogicalPlan]] = { + extensions.buildPostHocResolutionRules(session) + } /** * Custom check rules to add to the Analyzer. Prefer overriding this instead of creating @@ -187,7 +198,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `analyzer` function. */ - protected def customCheckRules: Seq[LogicalPlan => Unit] = Nil + protected def customCheckRules: Seq[LogicalPlan => Unit] = { + extensions.buildCheckRules(session) + } /** * Logical query plan optimizer. @@ -207,7 +220,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `optimizer` function. */ - protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = Nil + protected def customOperatorOptimizationRules: Seq[Rule[LogicalPlan]] = { + extensions.buildOptimizerRules(session) + } /** * Planner that converts optimized logical plans to physical plans. @@ -227,7 +242,9 @@ abstract class BaseSessionStateBuilder( * * Note that this may NOT depend on the `planner` function. */ - protected def customPlanningStrategies: Seq[Strategy] = Nil + protected def customPlanningStrategies: Seq[Strategy] = { + extensions.buildPlannerStrategies(session) + } /** * Create a query execution object. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index aebb663df5c9..0b8e53868c99 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal import scala.reflect.runtime.universe.TypeTag +import scala.util.control.NonFatal import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ @@ -98,14 +99,27 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { CatalogImpl.makeDataset(tables, sparkSession) } + /** + * Returns a Table for the given table/view or temporary view. + * + * Note that this function requires the table already exists in the Catalog. + * + * If the table metadata retrieval failed due to any reason (e.g., table serde class + * is not accessible or the table type is not accepted by Spark SQL), this function + * still returns the corresponding Table without the description and tableType) + */ private def makeTable(tableIdent: TableIdentifier): Table = { - val metadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val metadata = try { + Some(sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent)) + } catch { + case NonFatal(_) => None + } val isTemp = sessionCatalog.isTemporaryTable(tableIdent) new Table( name = tableIdent.table, - database = metadata.identifier.database.orNull, - description = metadata.comment.orNull, - tableType = if (isTemp) "TEMPORARY" else metadata.tableType.name, + database = metadata.map(_.identifier.database).getOrElse(tableIdent.database).orNull, + description = metadata.map(_.comment.orNull).orNull, + tableType = if (isTemp) "TEMPORARY" else metadata.map(_.tableType.name).orNull, isTemporary = isTemp) } @@ -197,7 +211,11 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * `AnalysisException` when no `Table` can be found. */ override def getTable(dbName: String, tableName: String): Table = { - makeTable(TableIdentifier(tableName, Option(dbName))) + if (tableExists(dbName, tableName)) { + makeTable(TableIdentifier(tableName, Option(dbName))) + } else { + throw new AnalysisException(s"Table or view '$tableName' not found in database '$dbName'") + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index 0289471bf841..a93b70114607 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -17,10 +17,14 @@ package org.apache.spark.sql.internal +import java.net.URL +import java.util.Locale + import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FsUrlStreamHandlerFactory import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging @@ -107,6 +111,13 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { } } + // Make sure we propagate external catalog events to the spark listener bus + externalCatalog.addListener(new ExternalCatalogEventListener { + override def onEvent(event: ExternalCatalogEvent): Unit = { + sparkContext.listenerBus.post(event) + } + }) + /** * A manager for global temporary views. */ @@ -114,7 +125,7 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // System preserved database should not exists in metastore. However it's hard to guarantee it // for every session, because case-sensitivity differs. Here we always lowercase it to make our // life easier. - val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase + val globalTempDB = sparkContext.conf.get(GLOBAL_TEMP_DATABASE).toLowerCase(Locale.ROOT) if (externalCatalog.databaseExists(globalTempDB)) { throw new SparkException( s"$globalTempDB is a system preserved database, please rename your existing database " + @@ -145,7 +156,13 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { } } -object SharedState { +object SharedState extends Logging { + try { + URL.setURLStreamHandlerFactory(new FsUrlStreamHandlerFactory()) + } catch { + case e: Error => + logWarning("URL.setURLStreamHandlerFactory failed to set FsUrlStreamHandlerFactory") + } private val HIVE_EXTERNAL_CATALOG_CLASS_NAME = "org.apache.spark.sql.hive.HiveExternalCatalog" diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql index 9c8d851e36e9..6566338f3d4a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by-ordinal.sql @@ -49,6 +49,9 @@ select a, count(a) from (select 1 as a) tmp group by 1 order by 1; -- group by ordinal followed by having select count(a), a from (select 1 as a) tmp group by 2 having a > 0; +-- mixed cases: group-by ordinals and aliases +select a, a AS k, count(b) from data group by k, 1; + -- turn of group by ordinal set spark.sql.groupByOrdinal=false; diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql index 4d0ed4315300..a7994f3beaff 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-by.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-by.sql @@ -35,3 +35,21 @@ FROM testData; -- Aggregate with foldable input and multiple distinct groups. SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS c) GROUP BY a; + +-- Aliases in SELECT could be used in GROUP BY +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1; + +-- Aggregate functions cannot be used in GROUP BY +SELECT COUNT(b) AS k FROM testData GROUP BY k; + +-- Test data. +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v); +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a; + +-- turn off group by aliases +set spark.sql.groupByAliases=false; + +-- Check analysis exceptions +SELECT a AS k, COUNT(b) FROM testData GROUP BY k; diff --git a/sql/core/src/test/resources/sql-tests/inputs/having.sql b/sql/core/src/test/resources/sql-tests/inputs/having.sql index 364c022d959d..868a911e787f 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/having.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/having.sql @@ -13,3 +13,6 @@ SELECT count(k) FROM hav GROUP BY v + 1 HAVING v + 1 = 2; -- SPARK-11032: resolve having correctly SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0); + +-- SPARK-20329: make sure we handle timezones correctly +SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1; diff --git a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out index c0930bbde69a..9ecbe19078dd 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by-ordinal.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 19 +-- Number of queries: 20 -- !query 0 @@ -122,7 +122,7 @@ select a, b, sum(b) from data group by 3 struct<> -- !query 11 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 39 +aggregate functions are not allowed in GROUP BY, but found sum(CAST(data.`b` AS BIGINT)); -- !query 12 @@ -131,7 +131,7 @@ select a, b, sum(b) + 2 from data group by 3 struct<> -- !query 12 output org.apache.spark.sql.AnalysisException -GROUP BY position 3 is an aggregate function, and aggregate functions are not allowed in GROUP BY; line 1 pos 43 +aggregate functions are not allowed in GROUP BY, but found (sum(CAST(data.`b` AS BIGINT)) + CAST(2 AS BIGINT)); -- !query 13 @@ -173,16 +173,26 @@ struct -- !query 17 -set spark.sql.groupByOrdinal=false +select a, a AS k, count(b) from data group by k, 1 -- !query 17 schema -struct +struct -- !query 17 output -spark.sql.groupByOrdinal false +1 1 2 +2 2 2 +3 3 2 -- !query 18 -select sum(b) from data group by -1 +set spark.sql.groupByOrdinal=false -- !query 18 schema -struct +struct -- !query 18 output +spark.sql.groupByOrdinal false + + +-- !query 19 +select sum(b) from data group by -1 +-- !query 19 schema +struct +-- !query 19 output 9 diff --git a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out index 4b87d5161fc0..6bf9dff883c1 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-by.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-by.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 15 +-- Number of queries: 22 -- !query 0 @@ -139,3 +139,67 @@ SELECT COUNT(DISTINCT b), COUNT(DISTINCT b, c) FROM (SELECT 1 AS a, 2 AS b, 3 AS struct -- !query 14 output 1 1 + + +-- !query 15 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 15 schema +struct +-- !query 15 output +1 2 +2 2 +3 2 +NULL 1 + + +-- !query 16 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k HAVING k > 1 +-- !query 16 schema +struct +-- !query 16 output +2 2 +3 2 + + +-- !query 17 +SELECT COUNT(b) AS k FROM testData GROUP BY k +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +aggregate functions are not allowed in GROUP BY, but found count(testdata.`b`); + + +-- !query 18 +CREATE OR REPLACE TEMPORARY VIEW testDataHasSameNameWithAlias AS SELECT * FROM VALUES +(1, 1, 3), (1, 2, 1) AS testDataHasSameNameWithAlias(k, a, v) +-- !query 18 schema +struct<> +-- !query 18 output + + + +-- !query 19 +SELECT k AS a, COUNT(v) FROM testDataHasSameNameWithAlias GROUP BY a +-- !query 19 schema +struct<> +-- !query 19 output +org.apache.spark.sql.AnalysisException +expression 'testdatahassamenamewithalias.`k`' is neither present in the group by, nor is it an aggregate function. Add to group by or wrap in first() (or first_value) if you don't care which value you get.; + + +-- !query 20 +set spark.sql.groupByAliases=false +-- !query 20 schema +struct +-- !query 20 output +spark.sql.groupByAliases false + + +-- !query 21 +SELECT a AS k, COUNT(b) FROM testData GROUP BY k +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`k`' given input columns: [a, b]; line 1 pos 47 diff --git a/sql/core/src/test/resources/sql-tests/results/having.sql.out b/sql/core/src/test/resources/sql-tests/results/having.sql.out index e0923832673c..d87ee5221647 100644 --- a/sql/core/src/test/resources/sql-tests/results/having.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/having.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 4 +-- Number of queries: 5 -- !query 0 @@ -38,3 +38,12 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0) struct -- !query 3 output 1 + + +-- !query 4 +SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1 +-- !query 4 schema +struct<(a + CAST(b AS BIGINT)):bigint> +-- !query 4 output +3 +7 diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out index 315e1730ce7d..fedabaee2237 100644 --- a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -141,7 +141,7 @@ struct<> -- !query 13 output org.apache.spark.sql.AnalysisException -DataType invalidtype() is not supported.(line 1, pos 2) +DataType invalidtype is not supported.(line 1, pos 2) == SQL == a InvalidType diff --git a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out index 9f0b95994be5..732b11050f46 100644 --- a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out @@ -88,7 +88,7 @@ Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nul == Physical Plan == *Project [coalesce(cast(id#xL as string), x) AS ifnull(`id`, 'x')#x, id#xL AS nullif(`id`, 'x')#xL, coalesce(cast(id#xL as string), x) AS nvl(`id`, 'x')#x, x AS nvl2(`id`, 'x', 'y')#x] -+- *Range (0, 2, step=1, splits=None) ++- *Range (0, 2, step=1, splits=2) -- !query 9 diff --git a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out index acd4ecf14617..e2ee970d35f6 100644 --- a/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/table-valued-functions.sql.out @@ -102,4 +102,4 @@ EXPLAIN select * from RaNgE(2) struct -- !query 8 output == Physical Plan == -*Range (0, 2, step=1, splits=None) +*Range (0, 2, step=1, splits=2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala index 3e85d9552312..7e61a6802515 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -19,13 +19,12 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { +import org.apache.spark.SparkConf - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") - super.beforeAll() - } +class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") // adding some checking after each test is run, assuring that the configs are not changed // in test code @@ -38,12 +37,9 @@ class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with Befo } class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { - - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code @@ -55,15 +51,14 @@ class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeA } } -class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with -BeforeAndAfter { +class TwoLevelAggregateHashMapWithVectorizedMapSuite + extends DataFrameAggregateSuite + with BeforeAndAfter { - protected override def beforeAll(): Unit = { - sparkConf.set("spark.sql.codegen.fallback", "false") - sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") - sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.sql.codegen.fallback", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index b0f398dab745..bc708ca88d7e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -39,6 +39,9 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType)))) } + private lazy val nullData = Seq( + (Some(1), Some(1)), (Some(1), Some(2)), (Some(1), None), (None, None)).toDF("a", "b") + test("column names with space") { val df = Seq((1, "a")).toDF("name with space", "name.with.dot") @@ -283,23 +286,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { } test("<=>") { - checkAnswer( - testData2.filter($"a" === 1), - testData2.collect().toSeq.filter(r => r.getInt(0) == 1)) - - checkAnswer( - testData2.filter($"a" === $"b"), - testData2.collect().toSeq.filter(r => r.getInt(0) == r.getInt(1))) - } - - test("=!=") { - val nullData = spark.createDataFrame(sparkContext.parallelize( - Row(1, 1) :: - Row(1, 2) :: - Row(1, null) :: - Row(null, null) :: Nil), - StructType(Seq(StructField("a", IntegerType), StructField("b", IntegerType)))) - checkAnswer( nullData.filter($"b" <=> 1), Row(1, 1) :: Nil) @@ -321,7 +307,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { checkAnswer( nullData2.filter($"a" <=> null), Row(null) :: Nil) + } + test("=!=") { + checkAnswer( + nullData.filter($"b" =!= 1), + Row(1, 2) :: Nil) + + checkAnswer(nullData.filter($"b" =!= null), Nil) + + checkAnswer( + nullData.filter($"a" =!= $"b"), + Row(1, 2) :: Nil) } test(">") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7079120bb7d..8569c2d76b69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -538,4 +538,11 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(3, 0, 0.0, 1, 5.0), Row(2, 1, 4.0, 0, 0.0)) ) } + + test("aggregate function in GROUP BY") { + val e = intercept[AnalysisException] { + testData.groupBy(sum($"key")).count() + } + assert(e.message.contains("aggregate functions are not allowed in GROUP BY")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 541ffb58e727..4a52af6c32c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -151,7 +151,7 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } - test("broadcast join hint") { + test("broadcast join hint using broadcast function") { val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") @@ -174,6 +174,22 @@ class DataFrameJoinSuite extends QueryTest with SharedSQLContext { } } + test("broadcast join hint using Dataset.hint") { + // make sure a giant join is not broadcastable + val plan1 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong), "id") + .queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoinExec => p }.size == 0) + + // now with a hint it should be broadcasted + val plan2 = + spark.range(10e10.toLong) + .join(spark.range(10e10.toLong).hint("broadcast"), "id") + .queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoinExec => p }.size == 1) + } + test("join - outer join conversion") { val df = Seq((1, 2, "1"), (3, 4, "3")).toDF("int", "int2", "str").as("a") val df2 = Seq((1, 3, "1"), (5, 6, "5")).toDF("int", "int2", "str").as("b") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala index 5e323c02b253..7b495656b93d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameRangeSuite.scala @@ -185,6 +185,12 @@ class DataFrameRangeSuite extends QueryTest with SharedSQLContext with Eventuall } } } + + test("SPARK-20430 Initialize Range parameters in a driver side") { + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + checkAnswer(sql("SELECT * FROM range(3)"), Row(0) :: Row(1) :: Row(2) :: Nil) + } + } } object DataFrameRangeSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 97890a035a62..dd118f88e3bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -68,25 +68,38 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } test("randomSplit on reordered partitions") { - // This test ensures that randomSplit does not create overlapping splits even when the - // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of - // rows in each partition. - val data = - sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") - val splits = data.randomSplit(Array[Double](2, 3), seed = 1) - assert(splits.length == 2, "wrong number of splits") + def testNonOverlappingSplits(data: DataFrame): Unit = { + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) - // Verify that the splits span the entire dataset - assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + // Verify that the splits don't overlap + assert(splits(0).collect().toSeq.intersect(splits(1).collect().toSeq).isEmpty) - // Verify that the splits don't overlap - assert(splits(0).intersect(splits(1)).collect().isEmpty) + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } - // Verify that the results are deterministic across multiple runs - val firstRun = splits.toSeq.map(_.collect().toSeq) - val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) - assert(firstRun == secondRun) + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val dataWithInts = sparkContext.parallelize(1 to 600, 2) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int") + val dataWithMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Map(i -> i.toString))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "map") + val dataWithArrayOfMaps = sparkContext.parallelize(1 to 600, 2) + .map(i => (i, Array(Map(i -> i.toString)))) + .mapPartitions(scala.util.Random.shuffle(_)).toDF("int", "arrayOfMaps") + + testNonOverlappingSplits(dataWithInts) + testNonOverlappingSplits(dataWithMaps) + testNonOverlappingSplits(dataWithArrayOfMaps) } test("pearson correlation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 52bd4e19f895..b4893b56a8a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1722,4 +1722,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { "Cannot have map type columns in DataFrame which calls set operations")) } } + + test("SPARK-20359: catalyst outer join optimization should not throw npe") { + val df1 = Seq("a", "b", "c").toDF("x") + .withColumn("y", udf{ (x: String) => x.substring(0, 1) + "!" }.apply($"x")) + val df2 = Seq("a", "b").toDF("x1") + df1 + .join(df2, df1("x") === df2("x1"), "left_outer") + .filter($"x1".isNotNull || !$"y".isin("a!")) + .count + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala index 92c5656f65bb..68f7de047b39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSerializerRegistratorSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql import com.esotericsoftware.kryo.{Kryo, Serializer} import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.spark.SparkConf import org.apache.spark.serializer.KryoRegistrator import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.test.TestSparkSession /** * Test suite to test Kryo custom registrators. @@ -30,12 +30,10 @@ import org.apache.spark.sql.test.TestSparkSession class DatasetSerializerRegistratorSuite extends QueryTest with SharedSQLContext { import testImplicits._ - /** - * Initialize the [[TestSparkSession]] with a [[KryoRegistrator]]. - */ - protected override def beforeAll(): Unit = { - sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) - super.beforeAll() + + override protected def sparkConf: SparkConf = { + // Make sure we use the KryoRegistrator + super.sparkConf.set("spark.kryo.registrator", TestRegistrator().getClass.getCanonicalName) } test("Kryo registrator") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala index cef5bbf0e85a..b9871afd59e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala @@ -91,7 +91,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(explode_outer('intList)), - Row(1) :: Row(2) :: Row(3) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) } test("single posexplode") { @@ -105,7 +105,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { val df = Seq((1, Seq(1, 2, 3)), (2, Seq())).toDF("a", "intList") checkAnswer( df.select(posexplode_outer('intList)), - Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil) + Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Row(null, null) :: Nil) } test("explode and other columns") { @@ -161,7 +161,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('intList).as('int)).select('int), - Row(1) :: Row(2) :: Row(3) :: Nil) + Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil) checkAnswer( df.select(explode('intList).as('int)).select(sum('int)), @@ -182,7 +182,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map)), - Row("a", "b") :: Row("c", "d") :: Nil) + Row("a", "b") :: Row(null, null) :: Row("c", "d") :: Nil) } test("explode on map with aliases") { @@ -198,7 +198,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select(explode_outer('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"), - Row("a", "b") :: Nil) + Row("a", "b") :: Row(null, null) :: Nil) } test("self join explode") { @@ -279,7 +279,7 @@ class GeneratorFunctionSuite extends QueryTest with SharedSQLContext { ) checkAnswer( df2.selectExpr("inline_outer(col1)"), - Row(3, "4") :: Row(5, "6") :: Nil + Row(null, null) :: Row(3, "4") :: Row(5, "6") :: Nil ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 8465e8d036a6..69a500c845a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -274,7 +274,7 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { val errMsg2 = intercept[AnalysisException] { df3.selectExpr("""from_json(value, 'time InvalidType')""") } - assert(errMsg2.getMessage.contains("DataType invalidtype() is not supported")) + assert(errMsg2.getMessage.contains("DataType invalidtype is not supported")) val errMsg3 = intercept[AnalysisException] { df3.selectExpr("from_json(value, 'time Timestamp', named_struct('a', 1))") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 0dd9296a3f0f..3ecbf96b4196 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import java.io.File import java.math.MathContext +import java.net.{MalformedURLException, URL} import java.sql.Timestamp import java.util.concurrent.atomic.AtomicBoolean @@ -2606,4 +2607,16 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { case ae: AnalysisException => assert(ae.plan == null && ae.getMessage == ae.getSimpleMessage) } } + + test("SPARK-12868: Allow adding jars from hdfs ") { + val jarFromHdfs = "hdfs://doesnotmatter/test.jar" + val jarFromInvalidFs = "fffs://doesnotmatter/test.jar" + + // if 'hdfs' is not supported, MalformedURLException will be thrown + new URL(jarFromHdfs) + + intercept[MalformedURLException] { + new URL(jarFromInvalidFs) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala new file mode 100644 index 000000000000..43db79663322 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{SparkPlan, SparkStrategy} +import org.apache.spark.sql.types.{DataType, StructType} + +/** + * Test cases for the [[SparkSessionExtensions]]. + */ +class SparkSessionExtensionSuite extends SparkFunSuite { + type ExtensionsBuilder = SparkSessionExtensions => Unit + private def create(builder: ExtensionsBuilder): ExtensionsBuilder = builder + + private def stop(spark: SparkSession): Unit = { + spark.stop() + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + + private def withSession(builder: ExtensionsBuilder)(f: SparkSession => Unit): Unit = { + val spark = SparkSession.builder().master("local[1]").withExtensions(builder).getOrCreate() + try f(spark) finally { + stop(spark) + } + } + + test("inject analyzer rule") { + withSession(_.injectResolutionRule(MyRule)) { session => + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } + } + + test("inject check analysis rule") { + withSession(_.injectCheckRule(MyCheckRule)) { session => + assert(session.sessionState.analyzer.extendedCheckRules.contains(MyCheckRule(session))) + } + } + + test("inject optimizer rule") { + withSession(_.injectOptimizerRule(MyRule)) { session => + assert(session.sessionState.optimizer.batches.flatMap(_.rules).contains(MyRule(session))) + } + } + + test("inject spark planner strategy") { + withSession(_.injectPlannerStrategy(MySparkStrategy)) { session => + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + } + } + + test("inject parser") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + } + withSession(extension) { session => + assert(session.sessionState.sqlParser == CatalystSqlParser) + } + } + + test("inject stacked parsers") { + val extension = create { extensions => + extensions.injectParser((_, _) => CatalystSqlParser) + extensions.injectParser(MyParser) + extensions.injectParser(MyParser) + } + withSession(extension) { session => + val parser = MyParser(session, MyParser(session, CatalystSqlParser)) + assert(session.sessionState.sqlParser == parser) + } + } + + test("use custom class for extensions") { + val session = SparkSession.builder() + .master("local[1]") + .config("spark.sql.extensions", classOf[MyExtensions].getCanonicalName) + .getOrCreate() + try { + assert(session.sessionState.planner.strategies.contains(MySparkStrategy(session))) + assert(session.sessionState.analyzer.extendedResolutionRules.contains(MyRule(session))) + } finally { + stop(session) + } + } +} + +case class MyRule(spark: SparkSession) extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan +} + +case class MyCheckRule(spark: SparkSession) extends (LogicalPlan => Unit) { + override def apply(plan: LogicalPlan): Unit = { } +} + +case class MySparkStrategy(spark: SparkSession) extends SparkStrategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = Seq.empty +} + +case class MyParser(spark: SparkSession, delegate: ParserInterface) extends ParserInterface { + override def parsePlan(sqlText: String): LogicalPlan = + delegate.parsePlan(sqlText) + + override def parseExpression(sqlText: String): Expression = + delegate.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + delegate.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + delegate.parseFunctionIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + delegate.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = + delegate.parseDataType(sqlText) +} + +class MyExtensions extends (SparkSessionExtensions => Unit) { + def apply(e: SparkSessionExtensions): Unit = { + e.injectPlannerStrategy(MySparkStrategy) + e.injectResolutionRule(MyRule) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala index 05a2b2c862c7..f7f1ccea281c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/DataSourceScanExecRedactionSuite.scala @@ -18,22 +18,17 @@ package org.apache.spark.sql.execution import org.apache.hadoop.fs.Path +import org.apache.spark.SparkConf import org.apache.spark.sql.QueryTest import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.util.Utils /** * Suite that tests the redaction of DataSourceScanExec */ class DataSourceScanExecRedactionSuite extends QueryTest with SharedSQLContext { - import Utils._ - - override def beforeAll(): Unit = { - sparkConf.set("spark.redaction.string.regex", - "file:/[\\w_]+") - super.beforeAll() - } + override protected def sparkConf: SparkConf = super.sparkConf + .set("spark.redaction.string.regex", "file:/[\\w_]+") test("treeString is redacted") { withTempDir { dir => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 1c1931b6a6da..05637821f71f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.execution import java.util.Locale +import scala.language.reflectiveCalls + import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.test.SharedSQLContext diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 8a2993bdf4b2..8a798fb44469 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -107,6 +107,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -148,6 +149,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -187,6 +189,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -225,6 +228,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } @@ -273,6 +277,7 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 1e6a6a8ba336..109b1d9db60d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -414,4 +414,19 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { assert(partitionedAttrs.subsetOf(inMemoryScan.outputSet)) } } + + test("SPARK-20356: pruned InMemoryTableScanExec should have correct ordering and partitioning") { + withSQLConf("spark.sql.shuffle.partitions" -> "200") { + val df1 = Seq(("a", 1), ("b", 1), ("c", 2)).toDF("item", "group") + val df2 = Seq(("a", 1), ("b", 2), ("c", 3)).toDF("item", "id") + val df3 = df1.join(df2, Seq("item")).select($"id", $"group".as("item")).distinct() + + df3.unpersist() + val agg_without_cache = df3.groupBy($"item").count() + + df3.cache() + val agg_with_cache = df3.groupBy($"item").count() + checkAnswer(agg_without_cache, agg_with_cache) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 97c61dc8694b..8a6bc62fec96 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -530,13 +530,13 @@ class DDLCommandSuite extends PlanTest { """.stripMargin val sql4 = """ - |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08', |country='us') SET SERDE 'org.apache.class' WITH SERDEPROPERTIES ('columns'='foo,bar', |'field.delim' = ',') """.stripMargin val sql5 = """ - |ALTER TABLE table_name PARTITION (test, dt='2008-08-08', + |ALTER TABLE table_name PARTITION (test=1, dt='2008-08-08', |country='us') SET SERDEPROPERTIES ('columns'='foo,bar', 'field.delim' = ',') """.stripMargin val parsed1 = parser.parsePlan(sql1) @@ -558,12 +558,12 @@ class DDLCommandSuite extends PlanTest { tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us"))) val expected5 = AlterTableSerDePropertiesCommand( tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) + Some(Map("test" -> "1", "dt" -> "2008-08-08", "country" -> "us"))) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) @@ -832,6 +832,14 @@ class DDLCommandSuite extends PlanTest { assert(e.contains("Found duplicate keys 'a'")) } + test("empty values in non-optional partition specs") { + val e = intercept[ParseException] { + parser.parsePlan( + "SHOW PARTITIONS dbx.tab1 PARTITION (a='1', b)") + }.getMessage + assert(e.contains("Found an empty partition key 'b'")) + } + test("drop table") { val tableName1 = "db.tab" val tableName2 = "tab" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fe74ab49f91b..0abcff76060f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -49,7 +49,8 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo protected override def generateTable( catalog: SessionCatalog, - name: TableIdentifier): CatalogTable = { + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable = { val storage = CatalogStorageFormat.empty.copy(locationUri = Some(catalog.defaultTablePath(name))) val metadata = new MetadataBuilder() @@ -70,46 +71,6 @@ class InMemoryCatalogedDDLSuite extends DDLSuite with SharedSQLContext with Befo tracksPartitionsInCatalog = true) } - test("alter table: set location (datasource table)") { - testSetLocation(isDatasourceTable = true) - } - - test("alter table: set properties (datasource table)") { - testSetProperties(isDatasourceTable = true) - } - - test("alter table: unset properties (datasource table)") { - testUnsetProperties(isDatasourceTable = true) - } - - test("alter table: set serde (datasource table)") { - testSetSerde(isDatasourceTable = true) - } - - test("alter table: set serde partition (datasource table)") { - testSetSerdePartition(isDatasourceTable = true) - } - - test("alter table: change column (datasource table)") { - testChangeColumn(isDatasourceTable = true) - } - - test("alter table: add partition (datasource table)") { - testAddPartitions(isDatasourceTable = true) - } - - test("alter table: drop partition (datasource table)") { - testDropPartitions(isDatasourceTable = true) - } - - test("alter table: rename partition (datasource table)") { - testRenamePartitions(isDatasourceTable = true) - } - - test("drop table - data source table") { - testDropTable(isDatasourceTable = true) - } - test("create a managed Hive source table") { assume(spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "in-memory") val tabName = "tbl" @@ -163,7 +124,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { spark.sparkContext.conf.get(CATALOG_IMPLEMENTATION) == "hive" } - protected def generateTable(catalog: SessionCatalog, name: TableIdentifier): CatalogTable + protected def generateTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): CatalogTable private val escapedIdentifier = "`(.+)`".r @@ -205,8 +169,11 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { ignoreIfExists = false) } - private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { - catalog.createTable(generateTable(catalog, name), ignoreIfExists = false) + private def createTable( + catalog: SessionCatalog, + name: TableIdentifier, + isDataSource: Boolean = true): Unit = { + catalog.createTable(generateTable(catalog, name, isDataSource), ignoreIfExists = false) } private def createTablePartition( @@ -223,6 +190,46 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { new Path(CatalogUtils.URIToString(warehousePath), s"$dbName.db").toUri } + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties (datasource table)") { + testSetProperties(isDatasourceTable = true) + } + + test("alter table: unset properties (datasource table)") { + testUnsetProperties(isDatasourceTable = true) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: set serde partition (datasource table)") { + testSetSerdePartition(isDatasourceTable = true) + } + + test("alter table: change column (datasource table)") { + testChangeColumn(isDatasourceTable = true) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: rename partition (datasource table)") { + testRenamePartitions(isDatasourceTable = true) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + test("the qualified path of a database is stored in the catalog") { val catalog = spark.sessionState.catalog @@ -835,32 +842,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("alter table: set location") { - testSetLocation(isDatasourceTable = false) - } - - test("alter table: set properties") { - testSetProperties(isDatasourceTable = false) - } - - test("alter table: unset properties") { - testUnsetProperties(isDatasourceTable = false) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde") { - testSetSerde(isDatasourceTable = false) - } - - // TODO: move this test to HiveDDLSuite.scala - ignore("alter table: set serde partition") { - testSetSerdePartition(isDatasourceTable = false) - } - - test("alter table: change column") { - testChangeColumn(isDatasourceTable = false) - } - test("alter table: bucketing is not supported") { val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) @@ -885,10 +866,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") } - test("alter table: add partition") { - testAddPartitions(isDatasourceTable = false) - } - test("alter table: recover partitions (sequential)") { withSQLConf("spark.rdd.parallelListingThreshold" -> "10") { testRecoverPartitions() @@ -957,17 +934,10 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") } - test("alter table: drop partition") { - testDropPartitions(isDatasourceTable = false) - } - test("alter table: drop partition is not supported for views") { assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") } - test("alter table: rename partition") { - testRenamePartitions(isDatasourceTable = false) - } test("show databases") { sql("CREATE DATABASE showdb2B") @@ -1011,18 +981,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(catalog.listTables("default") == Nil) } - test("drop table") { - testDropTable(isDatasourceTable = false) - } - protected def testDropTable(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) assert(catalog.listTables("dbx") == Seq(tableIdent)) sql("DROP TABLE dbx.tab1") assert(catalog.listTables("dbx") == Nil) @@ -1046,22 +1012,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) } - private def convertToDatasourceTable( - catalog: SessionCatalog, - tableIdent: TableIdentifier): Unit = { - catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( - provider = Some("csv"))) - assert(catalog.getTableMetadata(tableIdent).provider == Some("csv")) - } - protected def testSetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { if (isUsingHiveMetastore) { normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties @@ -1084,13 +1042,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testUnsetProperties(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getProps: Map[String, String] = { if (isUsingHiveMetastore) { normalizeCatalogTable(catalog.getTableMetadata(tableIdent)).properties @@ -1121,15 +1079,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetLocation(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val partSpec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, partSpec, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isDefined) assert(normalizeSerdeProp(catalog.getTableMetadata(tableIdent).storage.properties).isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) @@ -1171,13 +1129,13 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetSerde(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def checkSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { val serdeProp = catalog.getTableMetadata(tableIdent).storage.properties if (isUsingHiveMetastore) { @@ -1187,8 +1145,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } if (isUsingHiveMetastore) { - assert(catalog.getTableMetadata(tableIdent).storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some(expectedSerde)) } else { assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) } @@ -1229,18 +1191,18 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testSetSerdePartition(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val spec = Map("a" -> "1", "b" -> "2") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, spec, tableIdent) createTablePartition(catalog, Map("a" -> "1", "b" -> "3"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "2"), tableIdent) createTablePartition(catalog, Map("a" -> "2", "b" -> "3"), tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } def checkPartitionSerdeProps(expectedSerdeProps: Map[String, String]): Unit = { val serdeProp = catalog.getPartition(tableIdent, spec).storage.properties if (isUsingHiveMetastore) { @@ -1250,8 +1212,12 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } if (isUsingHiveMetastore) { - assert(catalog.getPartition(tableIdent, spec).storage.serde == - Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) + val expectedSerde = if (isDatasourceTable) { + "org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe" + } else { + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe" + } + assert(catalog.getPartition(tableIdent, spec).storage.serde == Some(expectedSerde)) } else { assert(catalog.getPartition(tableIdent, spec).storage.serde.isEmpty) } @@ -1295,6 +1261,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testAddPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1303,11 +1272,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) // basic add partition @@ -1354,6 +1320,9 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testDropPartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "5") @@ -1362,7 +1331,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { val part4 = Map("a" -> "4", "b" -> "8") val part5 = Map("a" -> "9", "b" -> "9") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) @@ -1370,9 +1339,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { createTablePartition(catalog, part5, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3, part4, part5)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic drop partition sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (a='4', b='8'), PARTITION (a='3', b='7')") @@ -1407,20 +1373,20 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testRenamePartitions(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val tableIdent = TableIdentifier("tab1", Some("dbx")) val part1 = Map("a" -> "1", "b" -> "q") val part2 = Map("a" -> "2", "b" -> "c") val part3 = Map("a" -> "3", "b" -> "p") createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) + createTable(catalog, tableIdent, isDatasourceTable) createTablePartition(catalog, part1, tableIdent) createTablePartition(catalog, part2, tableIdent) createTablePartition(catalog, part3, tableIdent) assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } // basic rename partition sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='q') RENAME TO PARTITION (a='100', b='p')") @@ -1451,14 +1417,14 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } protected def testChangeColumn(isDatasourceTable: Boolean): Unit = { + if (!isUsingHiveMetastore) { + assert(isDatasourceTable, "InMemoryCatalog only supports data source tables") + } val catalog = spark.sessionState.catalog val resolver = spark.sessionState.conf.resolver val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") - createTable(catalog, tableIdent) - if (isDatasourceTable) { - convertToDatasourceTable(catalog, tableIdent) - } + createTable(catalog, tableIdent, isDatasourceTable) def getMetadata(colName: String): Metadata = { val column = catalog.getTableMetadata(tableIdent).schema.fields.find { field => resolver(field.name, colName) @@ -1601,13 +1567,15 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } test("drop current database") { - sql("CREATE DATABASE temp") - sql("USE temp") - sql("DROP DATABASE temp") - val e = intercept[AnalysisException] { + withDatabase("temp") { + sql("CREATE DATABASE temp") + sql("USE temp") + sql("DROP DATABASE temp") + val e = intercept[AnalysisException] { sql("CREATE TABLE t (a INT, b INT) USING parquet") }.getMessage - assert(e.contains("Database 'temp' not found")) + assert(e.contains("Database 'temp' not found")) + } } test("drop default database") { @@ -1837,22 +1805,25 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("tbl"), Row(1)) val defaultTablePath = spark.sessionState.catalog .getTableMetadata(TableIdentifier("tbl")).storage.locationUri.get - - sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") - spark.catalog.refreshTable("tbl") - // SET LOCATION won't move data from previous table path to new table path. - assert(spark.table("tbl").count() == 0) - // the previous table path should be still there. - assert(new File(defaultTablePath).exists()) - - sql("INSERT INTO tbl SELECT 2") - checkAnswer(spark.table("tbl"), Row(2)) - // newly inserted data will go to the new table path. - assert(dir.listFiles().nonEmpty) - - sql("DROP TABLE tbl") - // the new table path will be removed after DROP TABLE. - assert(!dir.exists()) + try { + sql(s"ALTER TABLE tbl SET LOCATION '${dir.toURI}'") + spark.catalog.refreshTable("tbl") + // SET LOCATION won't move data from previous table path to new table path. + assert(spark.table("tbl").count() == 0) + // the previous table path should be still there. + assert(new File(defaultTablePath).exists()) + + sql("INSERT INTO tbl SELECT 2") + checkAnswer(spark.table("tbl"), Row(2)) + // newly inserted data will go to the new table path. + assert(dir.listFiles().nonEmpty) + + sql("DROP TABLE tbl") + // the new table path will be removed after DROP TABLE. + assert(!dir.exists()) + } finally { + Utils.deleteRecursively(new File(defaultTablePath)) + } } } } @@ -2125,7 +2096,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"location uri contains $specialChars for database") { - try { + withDatabase ("tmpdb") { withTable("t") { withTempDir { dir => val loc = new File(dir, specialChars) @@ -2140,8 +2111,6 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { assert(tblloc.listFiles().nonEmpty) } } - } finally { - spark.sql("DROP DATABASE IF EXISTS tmpdb") } } } @@ -2295,5 +2264,24 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } } + + test(s"basic DDL using locale tr - caseSensitive $caseSensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> s"$caseSensitive") { + withLocale("tr") { + val dbName = "DaTaBaSe_I" + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + sql(s"USE $dbName") + + val tabName = "tAb_I" + withTable(tabName) { + sql(s"CREATE TABLE $tabName(col_I int) USING PARQUET") + sql(s"INSERT OVERWRITE TABLE $tabName SELECT 1") + checkAnswer(sql(s"SELECT col_I FROM $tabName"), Row(1) :: Nil) + } + } + } + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index a9511cbd9e4c..b4616826e40b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.functions.col import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator} @@ -236,6 +237,17 @@ class FileIndexSuite extends SharedSQLContext { val fileStatusCache = FileStatusCache.getOrCreate(spark) fileStatusCache.putLeafFiles(new Path("/tmp", "abc"), files.toArray) } + + test("SPARK-20367 - properly unescape column names in inferPartitioning") { + withTempPath { path => + val colToUnescape = "Column/#%'?" + spark + .range(1) + .select(col("id").as(colToUnescape), col("id")) + .write.partitionBy(colToUnescape).parquet(path.getAbsolutePath) + assert(spark.read.parquet(path.getAbsolutePath).schema.exists(_.name == colToUnescape)) + } + } } class FakeParentPathFileSystem extends RawLocalFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index f36162858bf7..fa3c69612704 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.util.Utils class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { import testImplicits._ - protected override val sparkConf = new SparkConf().set("spark.default.parallelism", "1") + protected override def sparkConf = super.sparkConf.set("spark.default.parallelism", "1") test("unpartitioned table, single partition") { val table = @@ -395,7 +395,7 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi val fileCatalog = new InMemoryFileIndex( sparkSession = spark, - rootPaths = Seq(new Path(tempDir)), + rootPathsSpecified = Seq(new Path(tempDir)), parameters = Map.empty[String, String], partitionSchema = None) // This should not fail. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 9a3328fcecee..dd53b561326f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.util.{AccumulatorContext, LongAccumulator} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2} /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -499,18 +499,20 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex val path = s"${dir.getCanonicalPath}/table" (1 to 1024).map(i => (101, i)).toDF("a", "b").write.parquet(path) - Seq(("true", (x: Long) => x == 0), ("false", (x: Long) => x > 0)).map { case (push, func) => - withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> push) { - val accu = new LongAccumulator - accu.register(sparkContext, Some("numRowGroups")) + Seq(true, false).foreach { enablePushDown => + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> enablePushDown.toString) { + val accu = new NumRowGroupsAcc + sparkContext.register(accu) val df = spark.read.parquet(path).filter("a < 100") df.foreachPartition(_.foreach(v => accu.add(0))) df.collect - val numRowGroups = AccumulatorContext.lookForAccumulatorByName("numRowGroups") - assert(numRowGroups.isDefined) - assert(func(numRowGroups.get.asInstanceOf[LongAccumulator].value)) + if (enablePushDown) { + assert(accu.value == 0) + } else { + assert(accu.value > 0) + } AccumulatorContext.remove(accu.id) } } @@ -537,3 +539,27 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + +class NumRowGroupsAcc extends AccumulatorV2[Integer, Integer] { + private var _sum = 0 + + override def isZero: Boolean = _sum == 0 + + override def copy(): AccumulatorV2[Integer, Integer] = { + val acc = new NumRowGroupsAcc() + acc._sum = _sum + acc + } + + override def reset(): Unit = _sum = 0 + + override def add(v: Integer): Unit = _sum += v + + override def merge(other: AccumulatorV2[Integer, Integer]): Unit = other match { + case a: NumRowGroupsAcc => _sum += a._sum + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } + + override def value: Integer = _sum +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index c36609586c80..2efff3f57d7d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -23,7 +23,7 @@ import java.sql.Timestamp import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.SparkException +import org.apache.spark.{DebugFilesystem, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow @@ -316,6 +316,39 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext } } + /** + * this is part of test 'Enabling/disabling ignoreCorruptFiles' but run in a loop + * to increase the chance of failure + */ + ignore("SPARK-20407 ParquetQuerySuite 'Enabling/disabling ignoreCorruptFiles' flaky test") { + def testIgnoreCorruptFiles(): Unit = { + withTempDir { dir => + val basePath = dir.getCanonicalPath + spark.range(1).toDF("a").write.parquet(new Path(basePath, "first").toString) + spark.range(1, 2).toDF("a").write.parquet(new Path(basePath, "second").toString) + spark.range(2, 3).toDF("a").write.json(new Path(basePath, "third").toString) + val df = spark.read.parquet( + new Path(basePath, "first").toString, + new Path(basePath, "second").toString, + new Path(basePath, "third").toString) + checkAnswer( + df, + Seq(Row(0), Row(1))) + } + } + + for (i <- 1 to 100) { + DebugFilesystem.clearOpenStreams() + withSQLConf(SQLConf.IGNORE_CORRUPT_FILES.key -> "false") { + val exception = intercept[SparkException] { + testIgnoreCorruptFiles() + } + assert(exception.getMessage().contains("is not a Parquet file")) + } + DebugFilesystem.assertNoOpenStreams() + } + } + test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") { withTempPath { dir => val basePath = dir.getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala index 20ac06f048c6..3d480b148db5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/CompactibleFileStreamLogSuite.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.test.SharedSQLContext class CompactibleFileStreamLogSuite extends SparkFunSuite with SharedSQLContext { /** To avoid caching of FS objects */ - override protected val sparkConf = - new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + override protected def sparkConf = + super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") import CompactibleFileStreamLog._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala index 662c4466b21b..7689bc03a4cc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLogSuite.scala @@ -38,8 +38,8 @@ import org.apache.spark.util.UninterruptibleThread class HDFSMetadataLogSuite extends SparkFunSuite with SharedSQLContext { /** To avoid caching of FS objects */ - override protected val sparkConf = - new SparkConf().set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") + override protected def sparkConf = + super.sparkConf.set(s"spark.hadoop.fs.$scheme.impl.disable.cache", "true") private implicit def toOption[A](a: A): Option[A] = Option(a) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index 9b65419dba23..ba0ca666b5c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -90,6 +90,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { originalDataFrame: DataFrame): Unit = { // This test verifies parts of the plan. Disable whole stage codegen. withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "false") { + val strategy = DataSourceStrategy(spark.sessionState.conf) val bucketedDataFrame = spark.table("bucketed_table").select("i", "j", "k") val BucketSpec(numBuckets, bucketColumnNames, _) = bucketSpec // Limit: bucket pruning only works when the bucket column has one and only one column @@ -98,7 +99,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { val bucketColumn = bucketedDataFrame.schema.toAttributes(bucketColumnIndex) val matchedBuckets = new BitSet(numBuckets) bucketValues.foreach { value => - matchedBuckets.set(DataSourceStrategy.getBucketId(bucketColumn, numBuckets, value)) + matchedBuckets.set(strategy.getBucketId(bucketColumn, numBuckets, value)) } // Filter could hide the bug in bucket pruning. Thus, skipping all the filters diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala index b16c9f8fc96b..735e07c21373 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceAnalysisSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast, Expression, Literal} import org.apache.spark.sql.execution.datasources.DataSourceAnalysis import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.types.{DataType, IntegerType, StructType} class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -49,7 +49,11 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { } Seq(true, false).foreach { caseSensitive => - val rule = DataSourceAnalysis(new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive)) + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) + def cast(e: Expression, dt: DataType): Expression = { + Cast(e, dt, Option(conf.sessionLocalTimeZone)) + } + val rule = DataSourceAnalysis(conf) test( s"convertStaticPartitions only handle INSERT having at least static partitions " + s"(caseSensitive: $caseSensitive)") { @@ -150,7 +154,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { if (!caseSensitive) { val nonPartitionedAttributes = Seq('e.int, 'f.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1"), "C" -> Some("3")), @@ -162,7 +166,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { { val nonPartitionedAttributes = Seq('e.int, 'f.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType), Cast(Literal("3"), IntegerType)) + Seq(cast(Literal("1"), IntegerType), cast(Literal("3"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1"), "c" -> Some("3")), @@ -174,7 +178,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { // Test the case having a single static partition column. { val nonPartitionedAttributes = Seq('e.int, 'f.int) - val expected = nonPartitionedAttributes ++ Seq(Cast(Literal("1"), IntegerType)) + val expected = nonPartitionedAttributes ++ Seq(cast(Literal("1"), IntegerType)) val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes, providedPartitions = Map("b" -> Some("1")), @@ -189,7 +193,7 @@ class DataSourceAnalysisSuite extends SparkFunSuite with BeforeAndAfterAll { val dynamicPartitionAttributes = Seq('g.int) val expected = nonPartitionedAttributes ++ - Seq(Cast(Literal("1"), IntegerType)) ++ + Seq(cast(Literal("1"), IntegerType)) ++ dynamicPartitionAttributes val actual = rule.convertStaticPartitions( sourceAttributes = nonPartitionedAttributes ++ dynamicPartitionAttributes, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index 1211242b9fbb..1a2d3a13f3a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.streaming import java.util.Locale +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{AnalysisException, DataFrame} import org.apache.spark.sql.execution.DataSourceScanExec import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.streaming.{MemoryStream, MetadataLogFileIndex} +import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -145,6 +147,43 @@ class FileStreamSinkSuite extends StreamTest { } } + test("partitioned writing and batch reading with 'basePath'") { + withTempDir { outputDir => + withTempDir { checkpointDir => + val outputPath = outputDir.getAbsolutePath + val inputData = MemoryStream[Int] + val ds = inputData.toDS() + + var query: StreamingQuery = null + + try { + query = + ds.map(i => (i, -i, i * 1000)) + .toDF("id1", "id2", "value") + .writeStream + .partitionBy("id1", "id2") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .format("parquet") + .start(outputPath) + + inputData.addData(1, 2, 3) + failAfter(streamingTimeout) { + query.processAllAvailable() + } + + val readIn = spark.read.option("basePath", outputPath).parquet(s"$outputDir/*/*") + checkDatasetUnorderly( + readIn.as[(Int, Int, Int)], + (1000, 1, -1), (2000, 2, -2), (3000, 3, -3)) + } finally { + if (query != null) { + query.stop() + } + } + } + } + } + // This tests whether FileStreamSink works with aggregations. Specifically, it tests // whether the correct streaming QueryExecution (i.e. IncrementalExecution) is used to // to execute the trigger for writing data to file sink. See SPARK-18440 for more details. @@ -266,4 +305,22 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("FileStreamSink.ancestorIsMetadataDirectory()") { + val hadoopConf = spark.sparkContext.hadoopConfiguration + def assertAncestorIsMetadataDirectory(path: String): Unit = + assert(FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + def assertAncestorIsNotMetadataDirectory(path: String): Unit = + assert(!FileStreamSink.ancestorIsMetadataDirectory(new Path(path), hadoopConf)) + + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}") + assertAncestorIsMetadataDirectory(s"/a/${FileStreamSink.metadataDir}/") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c") + assertAncestorIsMetadataDirectory(s"/a/b/${FileStreamSink.metadataDir}/c/") + + assertAncestorIsNotMetadataDirectory(s"/a/b/c") + assertAncestorIsNotMetadataDirectory(s"/a/b/c/${FileStreamSink.metadataDir}extra") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 13fe51a55773..1fc062974e18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -25,6 +25,8 @@ import scala.util.control.ControlThrowable import org.apache.commons.io.FileUtils +import org.apache.spark.SparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.command.ExplainCommand @@ -69,6 +71,27 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + test("SPARK-20432: union one stream with itself") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load().select("a") + val unioned = df.union(df) + withTempDir { outputDir => + withTempDir { checkpointDir => + val query = + unioned + .writeStream.format("parquet") + .option("checkpointLocation", checkpointDir.getAbsolutePath) + .start(outputDir.getAbsolutePath) + try { + query.processAllAvailable() + val outputDf = spark.read.parquet(outputDir.getAbsolutePath).as[Long] + checkDatasetUnorderly[Long](outputDf, (0L to 10L).union((0L to 10L)).toArray: _*) + } finally { + query.stop() + } + } + } + } + test("union two streams") { val inputData1 = MemoryStream[Int] val inputData2 = MemoryStream[Int] @@ -120,6 +143,33 @@ class StreamSuite extends StreamTest { assertDF(df) } + test("Within the same streaming query, one StreamingRelation should only be transformed to one " + + "StreamingExecutionRelation") { + val df = spark.readStream.format(classOf[FakeDefaultSource].getName).load() + var query: StreamExecution = null + try { + query = + df.union(df) + .writeStream + .format("memory") + .queryName("memory") + .start() + .asInstanceOf[StreamingQueryWrapper] + .streamingQuery + query.awaitInitialization(streamingTimeout.toMillis) + val executionRelations = + query + .logicalPlan + .collect { case ser: StreamingExecutionRelation => ser } + assert(executionRelations.size === 2) + assert(executionRelations.distinct.size === 1) + } finally { + if (query != null) { + query.stop() + } + } + } + test("unsupported queries") { val streamInput = MemoryStream[Int] val batchInput = Seq(1, 2, 3).toDS() @@ -500,6 +550,70 @@ class StreamSuite extends StreamTest { } } } + + test("calling stop() on a query cancels related jobs") { + val input = MemoryStream[Int] + val query = input + .toDS() + .map { i => + while (!org.apache.spark.TaskContext.get().isInterrupted()) { + // keep looping till interrupted by query.stop() + Thread.sleep(100) + } + i + } + .writeStream + .format("console") + .start() + + input.addData(1) + // wait for jobs to start + eventually(timeout(streamingTimeout)) { + assert(sparkContext.statusTracker.getActiveJobIds().nonEmpty) + } + + query.stop() + // make sure jobs are stopped + eventually(timeout(streamingTimeout)) { + assert(sparkContext.statusTracker.getActiveJobIds().isEmpty) + } + } + + test("batch id is updated correctly in the job description") { + val queryName = "memStream" + @volatile var jobDescription: String = null + def assertDescContainsQueryNameAnd(batch: Integer): Unit = { + // wait for listener event to be processed + spark.sparkContext.listenerBus.waitUntilEmpty(streamingTimeout.toMillis) + assert(jobDescription.contains(queryName) && jobDescription.contains(s"batch = $batch")) + } + + spark.sparkContext.addSparkListener(new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobDescription = jobStart.properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION) + } + }) + + val input = MemoryStream[Int] + val query = input + .toDS() + .map(_ + 1) + .writeStream + .format("memory") + .queryName(queryName) + .start() + + input.addData(1) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 0) + input.addData(2, 3) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 1) + input.addData(4) + query.processAllAvailable() + assertDescContainsQueryNameAnd(batch = 2) + query.stop() + } } abstract class FakeSource extends StreamSourceProvider { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index f796a4cb4a39..4345a70601c3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -69,6 +69,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest with BeforeAndAfte ) } + test("count distinct") { + val inputData = MemoryStream[(Int, Seq[Int])] + + val aggregated = + inputData.toDF() + .select($"*", explode($"_2") as 'value) + .groupBy($"_1") + .agg(size(collect_set($"value"))) + .as[(Int, Int)] + + testStream(aggregated, Update)( + AddData(inputData, (1, Seq(1, 2))), + CheckLastBatch((1, 2)) + ) + } + test("simple count, complete mode") { val inputData = MemoryStream[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index b8a694c17731..59c6a6fade17 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -21,6 +21,7 @@ import java.util.UUID import scala.collection.mutable import scala.concurrent.duration._ +import scala.language.reflectiveCalls import org.scalactic.TolerantNumerics import org.scalatest.concurrent.AsyncAssertions.Waiter diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 6a4cc95d36be..f6d47734d7e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.test import java.io.File import java.net.URI import java.nio.file.Files -import java.util.UUID +import java.util.{Locale, UUID} +import scala.concurrent.duration._ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterAll +import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ @@ -49,7 +51,7 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ private[sql] trait SQLTestUtils - extends SparkFunSuite + extends SparkFunSuite with Eventually with BeforeAndAfterAll with SQLTestData { self => @@ -138,6 +140,15 @@ private[sql] trait SQLTestUtils } } + /** + * Waits for all tasks on all executors to be finished. + */ + protected def waitForTasksToFinish(): Unit = { + eventually(timeout(10.seconds)) { + assert(spark.sparkContext.statusTracker + .getExecutorInfos.map(_.numRunningTasks()).sum == 0) + } + } /** * Creates a temporary directory, which is then passed to `f` and will be deleted after `f` * returns. @@ -146,7 +157,11 @@ private[sql] trait SQLTestUtils */ protected def withTempDir(f: File => Unit): Unit = { val dir = Utils.createTempDir().getCanonicalFile - try f(dir) finally Utils.deleteRecursively(dir) + try f(dir) finally { + // wait for all tasks to finish before deleting files + waitForTasksToFinish() + Utils.deleteRecursively(dir) + } } /** @@ -222,12 +237,39 @@ private[sql] trait SQLTestUtils try f(dbName) finally { if (spark.catalog.currentDatabase == dbName) { - spark.sql(s"USE ${DEFAULT_DATABASE}") + spark.sql(s"USE $DEFAULT_DATABASE") } spark.sql(s"DROP DATABASE $dbName CASCADE") } } + /** + * Drops database `dbName` after calling `f`. + */ + protected def withDatabase(dbNames: String*)(f: => Unit): Unit = { + try f finally { + dbNames.foreach { name => + spark.sql(s"DROP DATABASE IF EXISTS $name") + } + spark.sql(s"USE $DEFAULT_DATABASE") + } + } + + /** + * Enables Locale `language` before executing `f`, then switches back to the default locale of JVM + * after `f` returns. + */ + protected def withLocale(language: String)(f: => Unit): Unit = { + val originalLocale = Locale.getDefault + try { + // Add Locale setting + Locale.setDefault(new Locale(language)) + f + } finally { + Locale.setDefault(originalLocale) + } + } + /** * Activates database `db` before executing `f`, then switches back to `default` database after * `f` returns. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index e122b39f6fc4..81c69a338abc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,19 +17,22 @@ package org.apache.spark.sql.test +import scala.concurrent.duration._ + import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { +trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - protected val sparkConf = new SparkConf() + protected def sparkConf = { + new SparkConf().set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + } /** * The [[TestSparkSession]] to use for all tests in this suite. @@ -50,8 +53,7 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected implicit def sqlContext: SQLContext = _spark.sqlContext protected def createSparkSession: TestSparkSession = { - new TestSparkSession( - sparkConf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + new TestSparkSession(sparkConf) } /** @@ -84,6 +86,10 @@ trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach { protected override def afterEach(): Unit = { super.afterEach() - DebugFilesystem.assertNoOpenStreams() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } } } diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 9c879218ddc0..0c344fa4975e 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 0f249d7d5935..3dca86630723 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../../pom.xml diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 8b0fdf49cefa..ba48facff293 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -137,17 +137,33 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } + /** + * Checks the validity of column names. Hive metastore disallows the table to use comma in + * data column names. Partition columns do not have such a restriction. Views do not have such + * a restriction. + */ + private def verifyColumnNames(table: CatalogTable): Unit = { + if (table.tableType != VIEW) { + table.dataSchema.map(_.name).foreach { colName => + if (colName.contains(",")) { + throw new AnalysisException("Cannot create a table having a column whose name contains " + + s"commas in Hive metastore. Table: ${table.identifier}; Column: $colName") + } + } + } + } + // -------------------------------------------------------------------------- // Databases // -------------------------------------------------------------------------- - override def createDatabase( + override protected def doCreateDatabase( dbDefinition: CatalogDatabase, ignoreIfExists: Boolean): Unit = withClient { client.createDatabase(dbDefinition, ignoreIfExists) } - override def dropDatabase( + override protected def doDropDatabase( db: String, ignoreIfNotExists: Boolean, cascade: Boolean): Unit = withClient { @@ -194,7 +210,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Tables // -------------------------------------------------------------------------- - override def createTable( + override protected def doCreateTable( tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = withClient { assert(tableDefinition.identifier.database.isDefined) @@ -202,6 +218,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val table = tableDefinition.identifier.table requireDbExists(db) verifyTableProperties(tableDefinition) + verifyColumnNames(tableDefinition) if (tableExists(db, table) && !ignoreIfExists) { throw new TableAlreadyExistsException(db = db, table = table) @@ -456,7 +473,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat } } - override def dropTable( + override protected def doDropTable( db: String, table: String, ignoreIfNotExists: Boolean, @@ -465,7 +482,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.dropTable(db, table, ignoreIfNotExists, purge) } - override def renameTable(db: String, oldName: String, newName: String): Unit = withClient { + override protected def doRenameTable( + db: String, + oldName: String, + newName: String): Unit = withClient { val rawTable = getRawTable(db, oldName) // Note that Hive serde tables don't use path option in storage properties to store the value @@ -611,6 +631,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat requireTableExists(db, table) val rawTable = getRawTable(db, table) val withNewSchema = rawTable.copy(schema = schema) + verifyColumnNames(withNewSchema) // Add table metadata such as table schema, partition columns, etc. to table properties. val updatedTable = withNewSchema.copy( properties = withNewSchema.properties ++ tableMetaToTableProps(withNewSchema)) @@ -1056,7 +1077,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // Functions // -------------------------------------------------------------------------- - override def createFunction( + override protected def doCreateFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { requireDbExists(db) @@ -1069,12 +1090,15 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } - override def dropFunction(db: String, name: String): Unit = withClient { + override protected def doDropFunction(db: String, name: String): Unit = withClient { requireFunctionExists(db, name) client.dropFunction(db, name) } - override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient { + override protected def doRenameFunction( + db: String, + oldName: String, + newName: String): Unit = withClient { requireFunctionExists(db, oldName) requireFunctionNotExists(db, newName) client.renameFunction(db, oldName, newName) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 9d3b31f39c0f..e16c9e46b772 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -101,7 +101,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ Seq( FileSourceStrategy, - DataSourceStrategy, + DataSourceStrategy(conf), SpecialLimits, InMemoryScans, HiveTableScans, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 3906968aaff1..c3d734e5a036 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ @@ -50,15 +50,28 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA protected override def generateTable( catalog: SessionCatalog, - name: TableIdentifier): CatalogTable = { + name: TableIdentifier, + isDataSource: Boolean): CatalogTable = { val storage = - CatalogStorageFormat( - locationUri = Some(catalog.defaultTablePath(name)), - inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), - serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), - compressed = false, - properties = Map("serialization.format" -> "1")) + if (isDataSource) { + val serde = HiveSerDe.sourceToSerDe("parquet") + assert(serde.isDefined, "The default format is not Hive compatible") + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = serde.get.inputFormat, + outputFormat = serde.get.outputFormat, + serde = serde.get.serde, + compressed = false, + properties = Map("serialization.format" -> "1")) + } else { + CatalogStorageFormat( + locationUri = Some(catalog.defaultTablePath(name)), + inputFormat = Some("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Some("org.apache.hadoop.hive.ql.io.HiveSequenceFileOutputFormat"), + serde = Some("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe"), + compressed = false, + properties = Map("serialization.format" -> "1")) + } val metadata = new MetadataBuilder() .putString("key", "value") .build() @@ -71,7 +84,7 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA .add("col2", "string") .add("a", "int") .add("b", "int"), - provider = Some("hive"), + provider = if (isDataSource) Some("parquet") else Some("hive"), partitionColumnNames = Seq("a", "b"), createTime = 0L, tracksPartitionsInCatalog = true) @@ -107,6 +120,46 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA ) } + test("alter table: set location") { + testSetLocation(isDatasourceTable = false) + } + + test("alter table: set properties") { + testSetProperties(isDatasourceTable = false) + } + + test("alter table: unset properties") { + testUnsetProperties(isDatasourceTable = false) + } + + test("alter table: set serde") { + testSetSerde(isDatasourceTable = false) + } + + test("alter table: set serde partition") { + testSetSerdePartition(isDatasourceTable = false) + } + + test("alter table: change column") { + testChangeColumn(isDatasourceTable = false) + } + + test("alter table: rename partition") { + testRenamePartitions(isDatasourceTable = false) + } + + test("alter table: drop partition") { + testDropPartitions(isDatasourceTable = false) + } + + test("alter table: add partition") { + testAddPartitions(isDatasourceTable = false) + } + + test("drop table") { + testDropTable(isDatasourceTable = false) + } + } class HiveDDLSuite @@ -130,7 +183,7 @@ class HiveDDLSuite if (dbPath.isEmpty) { hiveContext.sessionState.catalog.defaultTablePath(tableIdentifier) } else { - new Path(new Path(dbPath.get), tableIdentifier.table) + new Path(new Path(dbPath.get), tableIdentifier.table).toUri } val filesystemPath = new Path(expectedTablePath.toString) val fs = filesystemPath.getFileSystem(spark.sessionState.newHadoopConf()) @@ -1197,6 +1250,14 @@ class HiveDDLSuite s"CREATE INDEX $indexName ON TABLE $tabName (a) AS 'COMPACT' WITH DEFERRED REBUILD") val indexTabName = spark.sessionState.catalog.listTables("default", s"*$indexName*").head.table + + // Even if index tables exist, listTables and getTable APIs should still work + checkAnswer( + spark.catalog.listTables().toDF(), + Row(indexTabName, "default", null, null, false) :: + Row(tabName, "default", null, "MANAGED", false) :: Nil) + assert(spark.catalog.getTable("default", indexTabName).name === indexTabName) + intercept[TableAlreadyExistsException] { sql(s"CREATE TABLE $indexTabName(b int)") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index 8a37bc3665d3..aa1ca2909074 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -43,11 +43,29 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto test("explain extended command") { checkKeywordsExist(sql(" explain select * from src where key=123 "), - "== Physical Plan ==") + "== Physical Plan ==", + "org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe") + checkKeywordsNotExist(sql(" explain select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", - "== Optimized Logical Plan ==") + "== Optimized Logical Plan ==", + "Owner", + "Database", + "Created", + "Last Access", + "Type", + "Provider", + "Properties", + "Statistics", + "Location", + "Serde Library", + "InputFormat", + "OutputFormat", + "Partition Provider", + "Schema" + ) + checkKeywordsExist(sql(" explain extended select * from src where key=123 "), "== Parsed Logical Plan ==", "== Analyzed Logical Plan ==", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 75f3744ff35b..c944f28d10ef 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1976,6 +1976,30 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("Auto alias construction of get_json_object") { + val df = Seq(("1", """{"f1": "value1", "f5": 5.23}""")).toDF("key", "jstring") + val expectedMsg = "Cannot create a table having a column whose name contains commas " + + "in Hive metastore. Table: `default`.`t`; Column: get_json_object(jstring, $.f1)" + + withTable("t") { + val e = intercept[AnalysisException] { + df.select($"key", functions.get_json_object($"jstring", "$.f1")) + .write.format("hive").saveAsTable("t") + }.getMessage + assert(e.contains(expectedMsg)) + } + + withTempView("tempView") { + withTable("t") { + df.createTempView("tempView") + val e = intercept[AnalysisException] { + sql("CREATE TABLE t AS SELECT key, get_json_object(jstring, '$.f1') FROM tempView") + }.getMessage + assert(e.contains(expectedMsg)) + } + } + } + test("SPARK-19912 String literals should be escaped for Hive metastore partition pruning") { withTable("spark_19912") { Seq( diff --git a/streaming/pom.xml b/streaming/pom.xml index de1be9c13e05..604007c6feaa 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 938ba2f6ac20..b2e8e469d197 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.11 - 2.2.0-SNAPSHOT + 2.2.1-SNAPSHOT ../pom.xml