diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index e0ffde922dac..bc3aceba2256 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -69,6 +69,7 @@ exportMethods("arrange",
"first",
"freqItems",
"gapply",
+ "gapplyCollect",
"group_by",
"groupBy",
"head",
@@ -234,6 +235,7 @@ exportMethods("%in%",
"over",
"percent_rank",
"pmod",
+ "posexplode",
"quarter",
"rand",
"randn",
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 567758d2e2f2..ec09aab6f969 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1339,7 +1339,7 @@ setMethod("dapplyCollect",
#' gapply
#'
-#' Group the SparkDataFrame using the specified columns and apply the R function to each
+#' Groups the SparkDataFrame using the specified columns and applies the R function to each
#' group.
#'
#' @param x A SparkDataFrame
@@ -1351,9 +1351,11 @@ setMethod("dapplyCollect",
#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
#' The schema must match to output of `func`. It has to be defined for each
#' output column with preferred output column name and corresponding data type.
+#' @return a SparkDataFrame
#' @family SparkDataFrame functions
#' @rdname gapply
#' @name gapply
+#' @seealso \link{gapplyCollect}
#' @export
#' @examples
#'
@@ -1369,14 +1371,22 @@ setMethod("dapplyCollect",
#' columns with data types integer and string and the mean which is a double.
#' schema <- structType(structField("a", "integer"), structField("c", "string"),
#' structField("avg", "double"))
-#' df1 <- gapply(
+#' result <- gapply(
#' df,
-#' list("a", "c"),
+#' c("a", "c"),
#' function(key, x) {
#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
-#' },
-#' schema)
-#' collect(df1)
+#' }, schema)
+#'
+#' We can also group the data and afterwards call gapply on GroupedData.
+#' For Example:
+#' gdf <- group_by(df, "a", "c")
+#' result <- gapply(
+#' gdf,
+#' function(key, x) {
+#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+#' }, schema)
+#' collect(result)
#'
#' Result
#' ------
@@ -1394,7 +1404,7 @@ setMethod("dapplyCollect",
#' structField("Petal_Width", "double"))
#' df1 <- gapply(
#' df,
-#' list(df$"Species"),
+#' df$"Species",
#' function(key, x) {
#' m <- suppressWarnings(lm(Sepal_Length ~
#' Sepal_Width + Petal_Length + Petal_Width, x))
@@ -1402,8 +1412,8 @@ setMethod("dapplyCollect",
#' }, schema)
#' collect(df1)
#'
-#'Result
-#'---------
+#' Result
+#' ---------
#' Model (Intercept) Sepal_Width Petal_Length Petal_Width
#' 1 0.699883 0.3303370 0.9455356 -0.1697527
#' 2 1.895540 0.3868576 0.9083370 -0.6792238
@@ -1418,6 +1428,89 @@ setMethod("gapply",
gapply(grouped, func, schema)
})
+#' gapplyCollect
+#'
+#' Groups the SparkDataFrame using the specified columns, applies the R function to each
+#' group and collects the result back to R as data.frame.
+#'
+#' @param x A SparkDataFrame
+#' @param cols Grouping columns
+#' @param func A function to be applied to each group partition specified by grouping
+#' column of the SparkDataFrame. The function `func` takes as argument
+#' a key - grouping columns and a data frame - a local R data.frame.
+#' The output of `func` is a local R data.frame.
+#' @return a data.frame
+#' @family SparkDataFrame functions
+#' @rdname gapplyCollect
+#' @name gapplyCollect
+#' @seealso \link{gapply}
+#' @export
+#' @examples
+#'
+#' \dontrun{
+#' Computes the arithmetic mean of the second column by grouping
+#' on the first and third columns. Output the grouping values and the average.
+#'
+#' df <- createDataFrame (
+#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
+#' c("a", "b", "c", "d"))
+#'
+#' result <- gapplyCollect(
+#' df,
+#' c("a", "c"),
+#' function(key, x) {
+#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+#' colnames(y) <- c("key_a", "key_c", "mean_b")
+#' y
+#' })
+#'
+#' We can also group the data and afterwards call gapply on GroupedData.
+#' For Example:
+#' gdf <- group_by(df, "a", "c")
+#' result <- gapplyCollect(
+#' gdf,
+#' function(key, x) {
+#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+#' colnames(y) <- c("key_a", "key_c", "mean_b")
+#' y
+#' })
+#'
+#' Result
+#' ------
+#' key_a key_c mean_b
+#' 3 3 3.0
+#' 1 1 1.5
+#'
+#' Fits linear models on iris dataset by grouping on the 'Species' column and
+#' using 'Sepal_Length' as a target variable, 'Sepal_Width', 'Petal_Length'
+#' and 'Petal_Width' as training features.
+#'
+#' df <- createDataFrame (iris)
+#' result <- gapplyCollect(
+#' df,
+#' df$"Species",
+#' function(key, x) {
+#' m <- suppressWarnings(lm(Sepal_Length ~
+#' Sepal_Width + Petal_Length + Petal_Width, x))
+#' data.frame(t(coef(m)))
+#' })
+#'
+#' Result
+#'---------
+#' Model X.Intercept. Sepal_Width Petal_Length Petal_Width
+#' 1 0.699883 0.3303370 0.9455356 -0.1697527
+#' 2 1.895540 0.3868576 0.9083370 -0.6792238
+#' 3 2.351890 0.6548350 0.2375602 0.2521257
+#'
+#'}
+#' @note gapplyCollect(SparkDataFrame) since 2.0.0
+setMethod("gapplyCollect",
+ signature(x = "SparkDataFrame"),
+ function(x, cols, func) {
+ grouped <- do.call("groupBy", c(x, cols))
+ gapplyCollect(grouped, func)
+ })
+
############################## RDD Map Functions ##################################
# All of the following functions mirror the existing RDD map functions, #
# but allow for use with DataFrames by first converting to an RRDD before calling #
@@ -2524,8 +2617,7 @@ setMethod("describe",
setMethod("describe",
signature(x = "SparkDataFrame"),
function(x) {
- colList <- as.list(c(columns(x)))
- sdf <- callJMethod(x@sdf, "describe", colList)
+ sdf <- callJMethod(x@sdf, "describe", list())
dataFrame(sdf)
})
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 8df73db36e95..bc0daa25c9f6 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -714,11 +714,14 @@ dropTempView <- function(viewName) {
#'
#' The data source is specified by the `source` and a set of options(...).
#' If `source` is not specified, the default data source configured by
-#' "spark.sql.sources.default" will be used.
+#' "spark.sql.sources.default" will be used. \cr
+#' Similar to R read.csv, when `source` is "csv", by default, a value of "NA" will be interpreted
+#' as NA.
#'
#' @param path The path of files to load
#' @param source The name of external data source
#' @param schema The data schema defined in structType
+#' @param na.strings Default string value for NA when source is "csv"
#' @return SparkDataFrame
#' @rdname read.df
#' @name read.df
@@ -735,7 +738,7 @@ dropTempView <- function(viewName) {
#' @name read.df
#' @method read.df default
#' @note read.df since 1.4.0
-read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) {
+read.df.default <- function(path = NULL, source = NULL, schema = NULL, na.strings = "NA", ...) {
sparkSession <- getSparkSession()
options <- varargsToEnv(...)
if (!is.null(path)) {
@@ -744,6 +747,9 @@ read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) {
if (is.null(source)) {
source <- getDefaultSqlSource()
}
+ if (source == "csv" && is.null(options[["nullValue"]])) {
+ options[["nullValue"]] <- na.strings
+ }
if (!is.null(schema)) {
stopifnot(class(schema) == "structType")
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source,
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 09e5afa97060..52d46f9d7612 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -2934,3 +2934,20 @@ setMethod("sort_array",
jc <- callJStatic("org.apache.spark.sql.functions", "sort_array", x@jc, asc)
column(jc)
})
+
+#' posexplode
+#'
+#' Creates a new row for each element with position in the given array or map column.
+#'
+#' @rdname posexplode
+#' @name posexplode
+#' @family collection_funcs
+#' @export
+#' @examples \dontrun{posexplode(df$c)}
+#' @note posexplode since 2.1.0
+setMethod("posexplode",
+ signature(x = "Column"),
+ function(x) {
+ jc <- callJStatic("org.apache.spark.sql.functions", "posexplode", x@jc)
+ column(jc)
+ })
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 27dfd67ffc93..e4ec508795a1 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -469,6 +469,10 @@ setGeneric("dapplyCollect", function(x, func) { standardGeneric("dapplyCollect")
#' @export
setGeneric("gapply", function(x, ...) { standardGeneric("gapply") })
+#' @rdname gapplyCollect
+#' @export
+setGeneric("gapplyCollect", function(x, ...) { standardGeneric("gapplyCollect") })
+
#' @rdname summary
#' @export
setGeneric("describe", function(x, col, ...) { standardGeneric("describe") })
@@ -1050,6 +1054,10 @@ setGeneric("percent_rank", function(x) { standardGeneric("percent_rank") })
#' @export
setGeneric("pmod", function(y, x) { standardGeneric("pmod") })
+#' @rdname posexplode
+#' @export
+setGeneric("posexplode", function(x) { standardGeneric("posexplode") })
+
#' @rdname quarter
#' @export
setGeneric("quarter", function(x) { standardGeneric("quarter") })
@@ -1247,6 +1255,7 @@ setGeneric("spark.glm", function(data, formula, ...) { standardGeneric("spark.gl
#' @export
setGeneric("glm")
+#' predict
#' @rdname predict
#' @export
setGeneric("predict", function(object, ...) { standardGeneric("predict") })
@@ -1271,6 +1280,7 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s
#' @export
setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
+#' write.ml
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R
index 0687f14adf7b..5ed7e8abb43d 100644
--- a/R/pkg/R/group.R
+++ b/R/pkg/R/group.R
@@ -196,64 +196,51 @@ createMethods()
#' gapply
#'
-#' Applies a R function to each group in the input GroupedData
-#'
-#' @param x a GroupedData
-#' @param func A function to be applied to each group partition specified by GroupedData.
-#' The function `func` takes as argument a key - grouping columns and
-#' a data frame - a local R data.frame.
-#' The output of `func` is a local R data.frame.
-#' @param schema The schema of the resulting SparkDataFrame after the function is applied.
-#' The schema must match to output of `func`. It has to be defined for each
-#' output column with preferred output column name and corresponding data type.
-#' @return a SparkDataFrame
+#' @param x A GroupedData
#' @rdname gapply
#' @name gapply
#' @export
-#' @examples
-#' \dontrun{
-#' Computes the arithmetic mean of the second column by grouping
-#' on the first and third columns. Output the grouping values and the average.
-#'
-#' df <- createDataFrame (
-#' list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
-#' c("a", "b", "c", "d"))
-#'
-#' Here our output contains three columns, the key which is a combination of two
-#' columns with data types integer and string and the mean which is a double.
-#' schema <- structType(structField("a", "integer"), structField("c", "string"),
-#' structField("avg", "double"))
-#' df1 <- gapply(
-#' df,
-#' list("a", "c"),
-#' function(key, x) {
-#' y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
-#' },
-#' schema)
-#' collect(df1)
-#'
-#' Result
-#' ------
-#' a c avg
-#' 3 3 3.0
-#' 1 1 1.5
-#' }
#' @note gapply(GroupedData) since 2.0.0
setMethod("gapply",
signature(x = "GroupedData"),
function(x, func, schema) {
- try(if (is.null(schema)) stop("schema cannot be NULL"))
- packageNamesArr <- serialize(.sparkREnv[[".packages"]],
- connection = NULL)
- broadcastArr <- lapply(ls(.broadcastNames),
- function(name) { get(name, .broadcastNames) })
- sdf <- callJStatic(
- "org.apache.spark.sql.api.r.SQLUtils",
- "gapply",
- x@sgd,
- serialize(cleanClosure(func), connection = NULL),
- packageNamesArr,
- broadcastArr,
- schema$jobj)
- dataFrame(sdf)
+ if (is.null(schema)) stop("schema cannot be NULL")
+ gapplyInternal(x, func, schema)
+ })
+
+#' gapplyCollect
+#'
+#' @param x A GroupedData
+#' @rdname gapplyCollect
+#' @name gapplyCollect
+#' @export
+#' @note gapplyCollect(GroupedData) since 2.0.0
+setMethod("gapplyCollect",
+ signature(x = "GroupedData"),
+ function(x, func) {
+ gdf <- gapplyInternal(x, func, NULL)
+ content <- callJMethod(gdf@sdf, "collect")
+ # content is a list of items of struct type. Each item has a single field
+ # which is a serialized data.frame corresponds to one group of the
+ # SparkDataFrame.
+ ldfs <- lapply(content, function(x) { unserialize(x[[1]]) })
+ ldf <- do.call(rbind, ldfs)
+ row.names(ldf) <- NULL
+ ldf
})
+
+gapplyInternal <- function(x, func, schema) {
+ packageNamesArr <- serialize(.sparkREnv[[".packages"]],
+ connection = NULL)
+ broadcastArr <- lapply(ls(.broadcastNames),
+ function(name) { get(name, .broadcastNames) })
+ sdf <- callJStatic(
+ "org.apache.spark.sql.api.r.SQLUtils",
+ "gapply",
+ x@sgd,
+ serialize(cleanClosure(func), connection = NULL),
+ packageNamesArr,
+ broadcastArr,
+ if (class(schema) == "structType") { schema$jobj } else { NULL })
+ dataFrame(sdf)
+}
\ No newline at end of file
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 8e6c2ddf93cf..4fe73671f80d 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -267,9 +267,10 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
return(list(apriori = apriori, tables = tables))
})
-#' Fit a k-means model
+#' K-Means Clustering Model
#'
-#' Fit a k-means model, similarly to R's kmeans().
+#' Fits a k-means clustering model against a Spark DataFrame, similarly to R's kmeans().
+#' Users can print, make predictions on the produced model and save the model to the input path.
#'
#' @param data SparkDataFrame for training
#' @param formula A symbolic description of the model to be fitted. Currently only a few formula
@@ -278,14 +279,32 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
#' @param k Number of centers
#' @param maxIter Maximum iteration number
#' @param initMode The initialization algorithm choosen to fit the model
-#' @return A fitted k-means model
+#' @return \code{spark.kmeans} returns a fitted k-means model
#' @rdname spark.kmeans
+#' @name spark.kmeans
#' @export
#' @examples
#' \dontrun{
-#' model <- spark.kmeans(data, ~ ., k = 4, initMode = "random")
+#' sparkR.session()
+#' data(iris)
+#' df <- createDataFrame(iris)
+#' model <- spark.kmeans(df, Sepal_Length ~ Sepal_Width, k = 4, initMode = "random")
+#' summary(model)
+#'
+#' # fitted values on training data
+#' fitted <- predict(model, df)
+#' head(select(fitted, "Sepal_Length", "prediction"))
+#'
+#' # save fitted model to input path
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#'
+#' # can also read back the saved model and print
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
#' }
#' @note spark.kmeans since 2.0.0
+#' @seealso \link{predict}, \link{read.ml}, \link{write.ml}
setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"),
function(data, formula, k = 2, maxIter = 20, initMode = c("k-means||", "random")) {
formula <- paste(deparse(formula), collapse = "")
@@ -301,7 +320,7 @@ setMethod("spark.kmeans", signature(data = "SparkDataFrame", formula = "formula"
#' Note: A saved-loaded model does not support this method.
#'
#' @param object A fitted k-means model
-#' @return SparkDataFrame containing fitted values
+#' @return \code{fitted} returns a SparkDataFrame containing fitted values
#' @rdname fitted
#' @export
#' @examples
@@ -323,20 +342,12 @@ setMethod("fitted", signature(object = "KMeansModel"),
}
})
-#' Get the summary of a k-means model
-#'
-#' Returns the summary of a k-means model produced by spark.kmeans(),
-#' similarly to R's summary().
+# Get the summary of a k-means model
#'
-#' @param object a fitted k-means model
-#' @return the model's coefficients, size and cluster
-#' @rdname summary
+#' @param object A fitted k-means model
+#' @return \code{summary} returns the model's coefficients, size and cluster
+#' @rdname spark.kmeans
#' @export
-#' @examples
-#' \dontrun{
-#' model <- spark.kmeans(trainingData, ~ ., 2)
-#' summary(model)
-#' }
#' @note summary(KMeansModel) since 2.0.0
setMethod("summary", signature(object = "KMeansModel"),
function(object, ...) {
@@ -358,19 +369,11 @@ setMethod("summary", signature(object = "KMeansModel"),
cluster = cluster, is.loaded = is.loaded))
})
-#' Predicted values based on model
-#'
-#' Makes predictions from a k-means model or a model produced by spark.kmeans().
+# Predicted values based on a k-means model
#'
-#' @param object A fitted k-means model
-#' @rdname predict
+#' @return \code{predict} returns the predicted values based on a k-means model
+#' @rdname spark.kmeans
#' @export
-#' @examples
-#' \dontrun{
-#' model <- spark.kmeans(trainingData, ~ ., 2)
-#' predicted <- predict(model, testData)
-#' showDF(predicted)
-#' }
#' @note predict(KMeansModel) since 2.0.0
setMethod("predict", signature(object = "KMeansModel"),
function(object, newData) {
@@ -442,11 +445,11 @@ setMethod("write.ml", signature(object = "NaiveBayesModel", path = "character"),
# Saves the AFT survival regression model to the input path.
-#' @param path The directory where the model is savedist containing the model's coefficien
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
#' @rdname spark.survreg
-#' @name write.ml
#' @export
#' @note write.ml(AFTSurvivalRegressionModel, character) since 2.0.0
#' @seealso \link{read.ml}
@@ -477,24 +480,15 @@ setMethod("write.ml", signature(object = "GeneralizedLinearRegressionModel", pat
invisible(callJMethod(writer, "save", path))
})
-#' Save fitted MLlib model to the input path
-#'
-#' Save the k-means model to the input path.
+# Save fitted MLlib model to the input path
#'
-#' @param object A fitted k-means model
#' @param path The directory where the model is saved
#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
#' which means throw exception if the output path exists.
#'
-#' @rdname write.ml
+#' @rdname spark.kmeans
#' @name write.ml
#' @export
-#' @examples
-#' \dontrun{
-#' model <- spark.kmeans(trainingData, ~ ., k = 2)
-#' path <- "path/to/model"
-#' write.ml(model, path)
-#' }
#' @note write.ml(KMeansModel, character) since 2.0.0
setMethod("write.ml", signature(object = "KMeansModel", path = "character"),
function(object, path, overwrite = FALSE) {
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 74def5ce4245..bd7b5f062e6d 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -208,6 +208,44 @@ test_that("create DataFrame from RDD", {
unsetHiveContext()
})
+test_that("read csv as DataFrame", {
+ csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv")
+ mockLinesCsv <- c("year,make,model,comment,blank",
+ "\"2012\",\"Tesla\",\"S\",\"No comment\",",
+ "1997,Ford,E350,\"Go get one now they are going fast\",",
+ "2015,Chevy,Volt",
+ "NA,Dummy,Placeholder")
+ writeLines(mockLinesCsv, csvPath)
+
+ # default "header" is false, inferSchema to handle "year" as "int"
+ df <- read.df(csvPath, "csv", header = "true", inferSchema = "true")
+ expect_equal(count(df), 4)
+ expect_equal(columns(df), c("year", "make", "model", "comment", "blank"))
+ expect_equal(sort(unlist(collect(where(df, df$year == 2015)))),
+ sort(unlist(list(year = 2015, make = "Chevy", model = "Volt"))))
+
+ # since "year" is "int", let's skip the NA values
+ withoutna <- na.omit(df, how = "any", cols = "year")
+ expect_equal(count(withoutna), 3)
+
+ unlink(csvPath)
+ csvPath <- tempfile(pattern = "sparkr-test", fileext = ".csv")
+ mockLinesCsv <- c("year,make,model,comment,blank",
+ "\"2012\",\"Tesla\",\"S\",\"No comment\",",
+ "1997,Ford,E350,\"Go get one now they are going fast\",",
+ "2015,Chevy,Volt",
+ "Empty,Dummy,Placeholder")
+ writeLines(mockLinesCsv, csvPath)
+
+ df2 <- read.df(csvPath, "csv", header = "true", inferSchema = "true", na.string = "Empty")
+ expect_equal(count(df2), 4)
+ withoutna2 <- na.omit(df2, how = "any", cols = "year")
+ expect_equal(count(withoutna2), 3)
+ expect_equal(count(where(withoutna2, withoutna2$make == "Dummy")), 0)
+
+ unlink(csvPath)
+})
+
test_that("convert NAs to null type in DataFrames", {
rdd <- parallelize(sc, list(list(1L, 2L), list(NA, 4L)))
df <- createDataFrame(rdd, list("a", "b"))
@@ -1047,7 +1085,7 @@ test_that("column functions", {
c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c)
c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c)
c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c)
- c7 <- mean(c) + min(c) + month(c) + negate(c) + quarter(c)
+ c7 <- mean(c) + min(c) + month(c) + negate(c) + posexplode(c) + quarter(c)
c8 <- reverse(c) + rint(c) + round(c) + rtrim(c) + sha1(c) + monotonically_increasing_id()
c9 <- signum(c) + sin(c) + sinh(c) + size(c) + stddev(c) + soundex(c) + sqrt(c) + sum(c)
c10 <- sumDistinct(c) + tan(c) + tanh(c) + toDegrees(c) + toRadians(c)
@@ -1699,6 +1737,7 @@ test_that("mutate(), transform(), rename() and names()", {
})
test_that("read/write ORC files", {
+ setHiveContext(sc)
df <- read.df(jsonPath, "json")
# Test write.df and read.df
@@ -1715,6 +1754,7 @@ test_that("read/write ORC files", {
expect_equal(count(orcDF), count(df))
unlink(orcPath2)
+ unsetHiveContext()
})
test_that("read/write Parquet files", {
@@ -1776,13 +1816,17 @@ test_that("describe() and summarize() on a DataFrame", {
expect_equal(collect(stats)[2, "age"], "24.5")
expect_equal(collect(stats)[3, "age"], "7.7781745930520225")
stats <- describe(df)
- expect_equal(collect(stats)[4, "name"], "Andy")
+ expect_equal(collect(stats)[4, "name"], NULL)
expect_equal(collect(stats)[5, "age"], "30")
stats2 <- summary(df)
- expect_equal(collect(stats2)[4, "name"], "Andy")
+ expect_equal(collect(stats2)[4, "name"], NULL)
expect_equal(collect(stats2)[5, "age"], "30")
+ # SPARK-16425: SparkR summary() fails on column of type logical
+ df <- withColumn(df, "boolean", df$age == 30)
+ summary(df)
+
# Test base::summary is working
expect_equal(length(summary(attenu, digits = 4)), 35)
})
@@ -2231,21 +2275,24 @@ test_that("repartition by columns on DataFrame", {
expect_equal(nrow(df1), 2)
})
-test_that("gapply() on a DataFrame", {
+test_that("gapply() and gapplyCollect() on a DataFrame", {
df <- createDataFrame (
list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)),
c("a", "b", "c", "d"))
expected <- collect(df)
- df1 <- gapply(df, list("a"), function(key, x) { x }, schema(df))
+ df1 <- gapply(df, "a", function(key, x) { x }, schema(df))
actual <- collect(df1)
expect_identical(actual, expected)
+ df1Collect <- gapplyCollect(df, list("a"), function(key, x) { x })
+ expect_identical(df1Collect, expected)
+
# Computes the sum of second column by grouping on the first and third columns
# and checks if the sum is larger than 2
schema <- structType(structField("a", "integer"), structField("e", "boolean"))
df2 <- gapply(
df,
- list(df$"a", df$"c"),
+ c(df$"a", df$"c"),
function(key, x) {
y <- data.frame(key[1], sum(x$b) > 2)
},
@@ -2254,13 +2301,24 @@ test_that("gapply() on a DataFrame", {
expected <- c(TRUE, TRUE)
expect_identical(actual, expected)
+ df2Collect <- gapplyCollect(
+ df,
+ c(df$"a", df$"c"),
+ function(key, x) {
+ y <- data.frame(key[1], sum(x$b) > 2)
+ colnames(y) <- c("a", "e")
+ y
+ })
+ actual <- df2Collect$e
+ expect_identical(actual, expected)
+
# Computes the arithmetic mean of the second column by grouping
# on the first and third columns. Output the groupping value and the average.
schema <- structType(structField("a", "integer"), structField("c", "string"),
structField("avg", "double"))
df3 <- gapply(
df,
- list("a", "c"),
+ c("a", "c"),
function(key, x) {
y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
},
@@ -2275,11 +2333,22 @@ test_that("gapply() on a DataFrame", {
rownames(expected) <- NULL
expect_identical(actual, expected)
+ df3Collect <- gapplyCollect(
+ df,
+ c("a", "c"),
+ function(key, x) {
+ y <- data.frame(key, mean(x$b), stringsAsFactors = FALSE)
+ colnames(y) <- c("a", "c", "avg")
+ y
+ })
+ actual <- df3Collect[order(df3Collect$a), ]
+ expect_identical(actual$avg, expected$avg)
+
irisDF <- suppressWarnings(createDataFrame (iris))
schema <- structType(structField("Sepal_Length", "double"), structField("Avg", "double"))
# Groups by `Sepal_Length` and computes the average for `Sepal_Width`
df4 <- gapply(
- cols = list("Sepal_Length"),
+ cols = "Sepal_Length",
irisDF,
function(key, x) {
y <- data.frame(key, mean(x$Sepal_Width), stringsAsFactors = FALSE)
diff --git a/R/pkg/inst/worker/daemon.R b/R/pkg/inst/worker/daemon.R
index f55beac6c8c0..b92e6be995ca 100644
--- a/R/pkg/inst/worker/daemon.R
+++ b/R/pkg/inst/worker/daemon.R
@@ -44,7 +44,7 @@ while (TRUE) {
if (inherits(p, "masterProcess")) {
close(inputCon)
Sys.setenv(SPARKR_WORKER_PORT = port)
- source(script)
+ try(source(script))
# Set SIGUSR1 so that child can exit
tools::pskill(Sys.getpid(), tools::SIGUSR1)
parallel:::mcexit(0L)
diff --git a/bin/pyspark b/bin/pyspark
index 396a07c9f413..ac8aa04dba8a 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -50,9 +50,11 @@ if [[ -z "$PYSPARK_DRIVER_PYTHON" ]]; then
PYSPARK_DRIVER_PYTHON="${PYSPARK_PYTHON:-"$DEFAULT_PYTHON"}"
fi
+WORKS_WITH_IPYTHON=$($DEFAULT_PYTHON -c 'import sys; print(sys.version_info >= (2, 7, 0))')
+
# Determine the Python executable to use for the executors:
if [[ -z "$PYSPARK_PYTHON" ]]; then
- if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && $DEFAULT_PYTHON != "python2.7" ]]; then
+ if [[ $PYSPARK_DRIVER_PYTHON == *ipython* && ! WORKS_WITH_IPYTHON ]]; then
echo "IPython requires Python 2.7+; please install python2.7 or set PYSPARK_PYTHON" 1>&2
exit 1
else
diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
index 922c37a10efd..e79eef032589 100644
--- a/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
+++ b/common/network-common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java
@@ -48,11 +48,27 @@
* use this functionality in both a Guava 11 environment and a Guava >14 environment.
*/
public final class LimitedInputStream extends FilterInputStream {
+ private final boolean closeWrappedStream;
private long left;
private long mark = -1;
public LimitedInputStream(InputStream in, long limit) {
+ this(in, limit, true);
+ }
+
+ /**
+ * Create a LimitedInputStream that will read {@code limit} bytes from {@code in}.
+ *
+ * If {@code closeWrappedStream} is true, this will close {@code in} when it is closed.
+ * Otherwise, the stream is left open for reading its remaining content.
+ *
+ * @param in a {@link InputStream} to read from
+ * @param limit the number of bytes to read
+ * @param closeWrappedStream whether to close {@code in} when {@link #close} is called
+ */
+ public LimitedInputStream(InputStream in, long limit, boolean closeWrappedStream) {
super(in);
+ this.closeWrappedStream = closeWrappedStream;
Preconditions.checkNotNull(in);
Preconditions.checkArgument(limit >= 0, "limit must be non-negative");
left = limit;
@@ -102,4 +118,11 @@ public LimitedInputStream(InputStream in, long limit) {
left -= skipped;
return skipped;
}
+
+ @Override
+ public void close() throws IOException {
+ if (closeWrappedStream) {
+ super.close();
+ }
+ }
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 014aef86b5cc..cf38a04ed7cf 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -72,7 +72,10 @@ final class ShuffleExternalSorter extends MemoryConsumer {
private final TaskContext taskContext;
private final ShuffleWriteMetrics writeMetrics;
- /** Force this sorter to spill when there are this many elements in memory. For testing only */
+ /**
+ * Force this sorter to spill when there are this many elements in memory. The default value is
+ * 1024 * 1024 * 1024, which allows the maximum size of the pointer array to be 8G.
+ */
private final long numElementsForSpillThreshold;
/** The buffer size to use when writing spills using DiskBlockObjectWriter */
@@ -114,7 +117,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024;
this.numElementsForSpillThreshold =
- conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", Long.MAX_VALUE);
+ conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold", 1024 * 1024 * 1024);
this.writeMetrics = writeMetrics;
this.inMemSorter = new ShuffleInMemorySorter(
this, initialSize, conf.getBoolean("spark.shuffle.sort.useRadixSort", true));
@@ -372,7 +375,9 @@ public void insertRecord(Object recordBase, long recordOffset, int length, int p
// for tests
assert(inMemSorter != null);
- if (inMemSorter.numRecords() > numElementsForSpillThreshold) {
+ if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
+ logger.info("Spilling data because number of spilledRecords crossed the threshold " +
+ numElementsForSpillThreshold);
spill();
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index daa63d47e6ae..44e6aa73d975 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -346,12 +346,19 @@ private long[] mergeSpillsWithFileStream(
for (int i = 0; i < spills.length; i++) {
final long partitionLengthInSpill = spills[i].partitionLengths[partition];
if (partitionLengthInSpill > 0) {
- InputStream partitionInputStream =
- new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill);
- if (compressionCodec != null) {
- partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+ InputStream partitionInputStream = null;
+ boolean innerThrewException = true;
+ try {
+ partitionInputStream =
+ new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false);
+ if (compressionCodec != null) {
+ partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream);
+ }
+ ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
+ innerThrewException = false;
+ } finally {
+ Closeables.close(partitionInputStream, innerThrewException);
}
- ByteStreams.copy(partitionInputStream, mergedFileOutputStream);
}
}
mergedFileOutputStream.flush();
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index e14a23f4a6a8..50f5b068b276 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -59,6 +59,13 @@ public final class UnsafeExternalSorter extends MemoryConsumer {
/** The buffer size to use when writing spills using DiskBlockObjectWriter */
private final int fileBufferSizeBytes;
+ /**
+ * Force this sorter to spill when there are this many elements in memory. The default value is
+ * 1024 * 1024 * 1024 / 2 which allows the maximum size of the pointer array to be 8G.
+ */
+ public static final long DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD = 1024 * 1024 * 1024 / 2;
+
+ private final long numElementsForSpillThreshold;
/**
* Memory pages that hold the records being sorted. The pages in this list are freed when
* spilling, although in principle we could recycle these pages across spills (on the other hand,
@@ -88,10 +95,11 @@ public static UnsafeExternalSorter createWithExistingInMemorySorter(
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
+ long numElementsForSpillThreshold,
UnsafeInMemorySorter inMemorySorter) throws IOException {
UnsafeExternalSorter sorter = new UnsafeExternalSorter(taskMemoryManager, blockManager,
serializerManager, taskContext, recordComparator, prefixComparator, initialSize,
- pageSizeBytes, inMemorySorter, false /* ignored */);
+ numElementsForSpillThreshold, pageSizeBytes, inMemorySorter, false /* ignored */);
sorter.spill(Long.MAX_VALUE, sorter);
// The external sorter will be used to insert records, in-memory sorter is not needed.
sorter.inMemSorter = null;
@@ -107,10 +115,11 @@ public static UnsafeExternalSorter create(
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
+ long numElementsForSpillThreshold,
boolean canUseRadixSort) {
return new UnsafeExternalSorter(taskMemoryManager, blockManager, serializerManager,
- taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes, null,
- canUseRadixSort);
+ taskContext, recordComparator, prefixComparator, initialSize, pageSizeBytes,
+ numElementsForSpillThreshold, null, canUseRadixSort);
}
private UnsafeExternalSorter(
@@ -122,6 +131,7 @@ private UnsafeExternalSorter(
PrefixComparator prefixComparator,
int initialSize,
long pageSizeBytes,
+ long numElementsForSpillThreshold,
@Nullable UnsafeInMemorySorter existingInMemorySorter,
boolean canUseRadixSort) {
super(taskMemoryManager, pageSizeBytes, taskMemoryManager.getTungstenMemoryMode());
@@ -143,6 +153,7 @@ private UnsafeExternalSorter(
this.inMemSorter = existingInMemorySorter;
}
this.peakMemoryUsedBytes = getMemoryUsage();
+ this.numElementsForSpillThreshold = numElementsForSpillThreshold;
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
@@ -372,6 +383,13 @@ private void acquireNewPageIfNecessary(int required) {
public void insertRecord(Object recordBase, long recordOffset, int length, long prefix)
throws IOException {
+ assert(inMemSorter != null);
+ if (inMemSorter.numRecords() >= numElementsForSpillThreshold) {
+ logger.info("Spilling data because number of spilledRecords crossed the threshold " +
+ numElementsForSpillThreshold);
+ spill();
+ }
+
growPointerArrayIfNecessary();
// Need 4 bytes to store the record length.
final int required = length + 4;
@@ -383,7 +401,6 @@ public void insertRecord(Object recordBase, long recordOffset, int length, long
pageCursor += 4;
Platform.copyMemory(recordBase, recordOffset, base, pageCursor, length);
pageCursor += length;
- assert(inMemSorter != null);
inMemSorter.insertRecord(recordAddress, prefix);
}
diff --git a/core/src/main/resources/org/apache/spark/ui/static/spark-logo-77x50px-hd.png b/core/src/main/resources/org/apache/spark/ui/static/spark-logo-77x50px-hd.png
index ffe255063035..cee28916e8db 100644
Binary files a/core/src/main/resources/org/apache/spark/ui/static/spark-logo-77x50px-hd.png and b/core/src/main/resources/org/apache/spark/ui/static/spark-logo-77x50px-hd.png differ
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 632b0ae9c2c3..e8d6d587b482 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -232,7 +232,11 @@ private object TorrentBroadcast extends Logging {
val out = compressionCodec.map(c => c.compressedOutputStream(cbbos)).getOrElse(cbbos)
val ser = serializer.newInstance()
val serOut = ser.serializeStream(out)
- serOut.writeObject[T](obj).close()
+ Utils.tryWithSafeFinally {
+ serOut.writeObject[T](obj)
+ } {
+ serOut.close()
+ }
cbbos.toChunkedByteBuffer.getChunks()
}
@@ -246,8 +250,11 @@ private object TorrentBroadcast extends Logging {
val in: InputStream = compressionCodec.map(c => c.compressedInputStream(is)).getOrElse(is)
val ser = serializer.newInstance()
val serIn = ser.deserializeStream(in)
- val obj = serIn.readObject[T]()
- serIn.close()
+ val obj = Utils.tryWithSafeFinally {
+ serIn.readObject[T]()
+ } {
+ serIn.close()
+ }
obj
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index bb1793d451df..90c71cc6cfab 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -232,6 +232,10 @@ class SparkHadoopUtil extends Logging {
recurse(baseStatus)
}
+ def isGlobPath(pattern: Path): Boolean = {
+ pattern.toString.exists("{}[]*?\\".toSet.contains)
+ }
+
def globPath(pattern: Path): Seq[Path] = {
val fs = pattern.getFileSystem(conf)
Option(fs.globStatus(pattern)).map { statuses =>
@@ -240,11 +244,7 @@ class SparkHadoopUtil extends Logging {
}
def globPathIfNecessary(pattern: Path): Seq[Path] = {
- if (pattern.toString.exists("{}[]*?\\".toSet.contains)) {
- globPath(pattern)
- } else {
- Seq(pattern)
- }
+ if (isGlobPath(pattern)) globPath(pattern) else Seq(pattern)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/internal/Logging.scala b/core/src/main/scala/org/apache/spark/internal/Logging.scala
index c51050c13d3a..66a0cfec6296 100644
--- a/core/src/main/scala/org/apache/spark/internal/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/internal/Logging.scala
@@ -32,10 +32,7 @@ private[spark] trait Logging {
// Make the log field transient so that objects with Logging can
// be serialized and used on another machine
- @transient lazy val log: Logger = {
- initializeLogIfNecessary(false)
- LoggerFactory.getLogger(logName)
- }
+ @transient private var log_ : Logger = null
// Method to get the logger name for this object
protected def logName = {
@@ -43,6 +40,15 @@ private[spark] trait Logging {
this.getClass.getName.stripSuffix("$")
}
+ // Method to get or create the logger for this object
+ protected def log: Logger = {
+ if (log_ == null) {
+ initializeLogIfNecessary(false)
+ log_ = LoggerFactory.getLogger(logName)
+ }
+ log_
+ }
+
// Log methods that take only a String
protected def logInfo(msg: => String) {
if (log.isInfoEnabled) log.info(msg)
diff --git a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala
index 6819222e15a1..6bba259acc39 100644
--- a/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/source/StaticSources.scala
@@ -47,4 +47,16 @@ object CodegenMetrics extends Source {
* Histogram of the time it took to compile source code text (in milliseconds).
*/
val METRIC_COMPILATION_TIME = metricRegistry.histogram(MetricRegistry.name("compilationTime"))
+
+ /**
+ * Histogram of the bytecode size of each class generated by CodeGenerator.
+ */
+ val METRIC_GENERATED_CLASS_BYTECODE_SIZE =
+ metricRegistry.histogram(MetricRegistry.name("generatedClassSize"))
+
+ /**
+ * Histogram of the bytecode size of each method in classes generated by CodeGenerator.
+ */
+ val METRIC_GENERATED_METHOD_BYTECODE_SIZE =
+ metricRegistry.histogram(MetricRegistry.name("generatedMethodSize"))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
index e88e4ad4750d..99e6d3958374 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala
@@ -244,7 +244,6 @@ private[spark] class MesosCoarseGrainedSchedulerBackend(
d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
appId = frameworkId.getValue
mesosExternalShuffleClient.foreach(_.init(appId))
- logInfo("Registered as framework ID " + appId)
markRegistered()
}
diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
index d17a7894fd8a..f0ed41f6903f 100644
--- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala
@@ -32,6 +32,7 @@ import org.apache.commons.io.IOUtils
import org.apache.spark.{SparkEnv, SparkException}
import org.apache.spark.io.CompressionCodec
+import org.apache.spark.util.Utils
/**
* Custom serializer used for generic Avro records. If the user registers the schemas
@@ -72,8 +73,11 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
def compress(schema: Schema): Array[Byte] = compressCache.getOrElseUpdate(schema, {
val bos = new ByteArrayOutputStream()
val out = codec.compressedOutputStream(bos)
- out.write(schema.toString.getBytes(StandardCharsets.UTF_8))
- out.close()
+ Utils.tryWithSafeFinally {
+ out.write(schema.toString.getBytes(StandardCharsets.UTF_8))
+ } {
+ out.close()
+ }
bos.toByteArray
})
@@ -86,7 +90,12 @@ private[serializer] class GenericAvroSerializer(schemas: Map[Long, String])
schemaBytes.array(),
schemaBytes.arrayOffset() + schemaBytes.position(),
schemaBytes.remaining())
- val bytes = IOUtils.toByteArray(codec.compressedInputStream(bis))
+ val in = codec.compressedInputStream(bis)
+ val bytes = Utils.tryWithSafeFinally {
+ IOUtils.toByteArray(in)
+ } {
+ in.close()
+ }
new Schema.Parser().parse(new String(bytes, StandardCharsets.UTF_8))
})
diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala
index 5783df5d8220..b21d36d4a8d8 100644
--- a/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala
+++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllJobsResource.scala
@@ -68,7 +68,12 @@ private[v1] object AllJobsResource {
listener: JobProgressListener,
includeStageDetails: Boolean): JobData = {
listener.synchronized {
- val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max)
+ val lastStageInfo =
+ if (job.stageIds.isEmpty) {
+ None
+ } else {
+ listener.stageIdToInfo.get(job.stageIds.max)
+ }
val lastStageData = lastStageInfo.flatMap { s =>
listener.stageIdToData.get((s.stageId, s.attemptId))
}
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index f77cc2f9b7aa..156cf1748b2a 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1772,50 +1772,66 @@ private[spark] object Utils extends Logging {
}
/**
- * Terminates a process waiting for at most the specified duration. Returns whether
- * the process terminated.
+ * Terminates a process waiting for at most the specified duration.
+ *
+ * @return the process exit value if it was successfully terminated, else None
*/
def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = {
- try {
- // Java8 added a new API which will more forcibly kill the process. Use that if available.
- val destroyMethod = process.getClass().getMethod("destroyForcibly");
- destroyMethod.setAccessible(true)
- destroyMethod.invoke(process)
- } catch {
- case NonFatal(e) =>
- if (!e.isInstanceOf[NoSuchMethodException]) {
- logWarning("Exception when attempting to kill process", e)
- }
- process.destroy()
- }
+ // Politely destroy first
+ process.destroy()
+
if (waitForProcess(process, timeoutMs)) {
+ // Successful exit
Option(process.exitValue())
} else {
- None
+ // Java 8 added a new API which will more forcibly kill the process. Use that if available.
+ try {
+ classOf[Process].getMethod("destroyForcibly").invoke(process)
+ } catch {
+ case _: NoSuchMethodException => return None // Not available; give up
+ case NonFatal(e) => logWarning("Exception when attempting to kill process", e)
+ }
+ // Wait, again, although this really should return almost immediately
+ if (waitForProcess(process, timeoutMs)) {
+ Option(process.exitValue())
+ } else {
+ logWarning("Timed out waiting to forcibly kill process")
+ None
+ }
}
}
/**
* Wait for a process to terminate for at most the specified duration.
- * Return whether the process actually terminated after the given timeout.
+ *
+ * @return whether the process actually terminated before the given timeout.
*/
def waitForProcess(process: Process, timeoutMs: Long): Boolean = {
- var terminated = false
- val startTime = System.currentTimeMillis
- while (!terminated) {
- try {
- process.exitValue()
- terminated = true
- } catch {
- case e: IllegalThreadStateException =>
- // Process not terminated yet
- if (System.currentTimeMillis - startTime > timeoutMs) {
- return false
+ try {
+ // Use Java 8 method if available
+ classOf[Process].getMethod("waitFor", java.lang.Long.TYPE, classOf[TimeUnit])
+ .invoke(process, timeoutMs.asInstanceOf[java.lang.Long], TimeUnit.MILLISECONDS)
+ .asInstanceOf[Boolean]
+ } catch {
+ case _: NoSuchMethodException =>
+ // Otherwise implement it manually
+ var terminated = false
+ val startTime = System.currentTimeMillis
+ while (!terminated) {
+ try {
+ process.exitValue()
+ terminated = true
+ } catch {
+ case e: IllegalThreadStateException =>
+ // Process not terminated yet
+ if (System.currentTimeMillis - startTime > timeoutMs) {
+ return false
+ }
+ Thread.sleep(100)
}
- Thread.sleep(100)
- }
+ }
+ true
}
- true
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
index fb4706e78d38..89b0874e3865 100644
--- a/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
+++ b/core/src/main/scala/org/apache/spark/util/io/ChunkedByteBuffer.scala
@@ -31,14 +31,13 @@ import org.apache.spark.storage.StorageUtils
* Read-only byte buffer which is physically stored as multiple chunks rather than a single
* contiguous array.
*
- * @param chunks an array of [[ByteBuffer]]s. Each buffer in this array must be non-empty and have
- * position == 0. Ownership of these buffers is transferred to the ChunkedByteBuffer,
- * so if these buffers may also be used elsewhere then the caller is responsible for
- * copying them as needed.
+ * @param chunks an array of [[ByteBuffer]]s. Each buffer in this array must have position == 0.
+ * Ownership of these buffers is transferred to the ChunkedByteBuffer, so if these
+ * buffers may also be used elsewhere then the caller is responsible for copying
+ * them as needed.
*/
private[spark] class ChunkedByteBuffer(var chunks: Array[ByteBuffer]) {
require(chunks != null, "chunks must not be null")
- require(chunks.forall(_.limit() > 0), "chunks must be non-empty")
require(chunks.forall(_.position() == 0), "chunks' positions must be 0")
private[this] var disposed: Boolean = false
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 2cae4beb4c77..960698f4ebac 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -176,6 +176,7 @@ private UnsafeExternalSorter newSorter() throws IOException {
prefixComparator,
/* initialSize */ 1024,
pageSizeBytes,
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD,
shouldUseRadixSort());
}
@@ -399,6 +400,7 @@ public void forcedSpillingWithoutComparator() throws Exception {
null,
/* initialSize */ 1024,
pageSizeBytes,
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD,
shouldUseRadixSort());
long[] record = new long[100];
int recordSize = record.length * 8;
@@ -435,6 +437,7 @@ public void testPeakMemoryUsed() throws Exception {
prefixComparator,
1024,
pageSizeBytes,
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD,
shouldUseRadixSort());
// Peak memory should be monotonically increasing. More specifically, every time
diff --git a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
index f205d4f0d60b..38b48a4c9e65 100644
--- a/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/ChunkedByteBufferSuite.scala
@@ -38,12 +38,6 @@ class ChunkedByteBufferSuite extends SparkFunSuite {
emptyChunkedByteBuffer.toInputStream(dispose = true).close()
}
- test("chunks must be non-empty") {
- intercept[IllegalArgumentException] {
- new ChunkedByteBuffer(Array(ByteBuffer.allocate(0)))
- }
- }
-
test("getChunks() duplicates chunks") {
val chunkedByteBuffer = new ChunkedByteBuffer(Array(ByteBuffer.allocate(8)))
chunkedByteBuffer.getChunks().head.position(4)
@@ -63,8 +57,9 @@ class ChunkedByteBufferSuite extends SparkFunSuite {
}
test("toArray()") {
+ val empty = ByteBuffer.wrap(Array[Byte]())
val bytes = ByteBuffer.wrap(Array.tabulate(8)(_.toByte))
- val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes, bytes))
+ val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes, bytes, empty))
assert(chunkedByteBuffer.toArray === bytes.array() ++ bytes.array())
}
@@ -79,9 +74,10 @@ class ChunkedByteBufferSuite extends SparkFunSuite {
}
test("toInputStream()") {
+ val empty = ByteBuffer.wrap(Array[Byte]())
val bytes1 = ByteBuffer.wrap(Array.tabulate(256)(_.toByte))
val bytes2 = ByteBuffer.wrap(Array.tabulate(128)(_.toByte))
- val chunkedByteBuffer = new ChunkedByteBuffer(Array(bytes1, bytes2))
+ val chunkedByteBuffer = new ChunkedByteBuffer(Array(empty, bytes1, bytes2))
assert(chunkedByteBuffer.size === bytes1.limit() + bytes2.limit())
val inputStream = chunkedByteBuffer.toInputStream(dispose = false)
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index df279b5a37c7..f5d0fb00b732 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -863,7 +863,7 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging {
assert(terminated.isDefined)
Utils.waitForProcess(process, 5000)
val duration = System.currentTimeMillis() - start
- assert(duration < 5000)
+ assert(duration < 6000) // add a little extra time to allow a force kill to finish
assert(!pidExists(pid))
} finally {
signal(pid, "SIGKILL")
diff --git a/graphx/data/followers.txt b/data/graphx/followers.txt
similarity index 100%
rename from graphx/data/followers.txt
rename to data/graphx/followers.txt
diff --git a/graphx/data/users.txt b/data/graphx/users.txt
similarity index 100%
rename from graphx/data/users.txt
rename to data/graphx/users.txt
diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh
index 65e80fc76056..2833dc765111 100755
--- a/dev/create-release/release-build.sh
+++ b/dev/create-release/release-build.sh
@@ -80,7 +80,7 @@ NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads
BASE_DIR=$(pwd)
MVN="build/mvn --force"
-PUBLISH_PROFILES="-Pyarn -Phive -Phadoop-2.2"
+PUBLISH_PROFILES="-Pyarn -Phive -Phive-thriftserver -Phadoop-2.2"
PUBLISH_PROFILES="$PUBLISH_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl"
rm -rf spark
@@ -254,8 +254,7 @@ if [[ "$1" == "publish-snapshot" ]]; then
# Generate random point for Zinc
export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)")
- $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES \
- -Phive-thriftserver deploy
+ $MVN -DzincPort=$ZINC_PORT --settings $tmp_settings -DskipTests $PUBLISH_PROFILES deploy
./dev/change-scala-version.sh 2.10
$MVN -DzincPort=$ZINC_PORT -Dscala-2.10 --settings $tmp_settings \
-DskipTests $PUBLISH_PROFILES clean deploy
@@ -291,8 +290,7 @@ if [[ "$1" == "publish-release" ]]; then
# Generate random point for Zinc
export ZINC_PORT=$(python -S -c "import random; print random.randrange(3030,4030)")
- $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES \
- -Phive-thriftserver clean install
+ $MVN -DzincPort=$ZINC_PORT -Dmaven.repo.local=$tmp_repo -DskipTests $PUBLISH_PROFILES clean install
./dev/change-scala-version.sh 2.10
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index ce5725764be6..927e010383de 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -408,6 +408,7 @@ def __hash__(self):
"pyspark.ml.tuning",
"pyspark.ml.tests",
"pyspark.ml.evaluation",
+ "pyspark.ml.stat.distribution",
],
blacklisted_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index d493f62f0e57..2d0c3fd71293 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -73,6 +73,7 @@
Spark Streaming
DataFrames, Datasets and SQL
+ Structured Streaming
MLlib (Machine Learning)
GraphX (Graph Processing)
SparkR (R on Spark)
diff --git a/docs/_plugins/include_example.rb b/docs/_plugins/include_example.rb
index f7485826a762..306888801df2 100644
--- a/docs/_plugins/include_example.rb
+++ b/docs/_plugins/include_example.rb
@@ -32,8 +32,18 @@ def render(context)
@code_dir = File.join(site.source, config_dir)
clean_markup = @markup.strip
- @file = File.join(@code_dir, clean_markup)
- @lang = clean_markup.split('.').last
+
+ parts = clean_markup.strip.split(' ')
+ if parts.length > 1 then
+ @snippet_label = ':' + parts[0]
+ snippet_file = parts[1]
+ else
+ @snippet_label = ''
+ snippet_file = parts[0]
+ end
+
+ @file = File.join(@code_dir, snippet_file)
+ @lang = snippet_file.split('.').last
code = File.open(@file).read.encode("UTF-8")
code = select_lines(code)
@@ -41,7 +51,7 @@ def render(context)
rendered_code = Pygments.highlight(code, :lexer => @lang)
hint = "Find full example code at " \
- "\"examples/src/main/#{clean_markup}\" in the Spark repo.
"
+ "\"examples/src/main/#{snippet_file}\" in the Spark repo."
rendered_code + hint
end
@@ -66,13 +76,13 @@ def select_lines(code)
# Select the array of start labels from code.
startIndices = lines
.each_with_index
- .select { |l, i| l.include? "$example on$" }
+ .select { |l, i| l.include? "$example on#{@snippet_label}$" }
.map { |l, i| i }
# Select the array of end labels from code.
endIndices = lines
.each_with_index
- .select { |l, i| l.include? "$example off$" }
+ .select { |l, i| l.include? "$example off#{@snippet_label}$" }
.map { |l, i| i }
raise "Start indices amount is not equal to end indices amount, see #{@file}." \
@@ -92,7 +102,10 @@ def select_lines(code)
if start == endline
lastIndex = endline
range = Range.new(start + 1, endline - 1)
- result += trim_codeblock(lines[range]).join
+ trimmed = trim_codeblock(lines[range])
+ # Filter out possible example tags of overlapped labels.
+ taggs_filtered = trimmed.select { |l| !l.include? '$example ' }
+ result += taggs_filtered.join
result += "\n"
end
result
diff --git a/docs/configuration.md b/docs/configuration.md
index cee59cf2aa05..1e95b862441f 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -1564,8 +1564,8 @@ spark.sql("SET -v").show(n=200, truncate=False)
{% highlight r %}
-# sqlContext is an existing sqlContext.
-properties <- sql(sqlContext, "SET -v")
+sparkR.session()
+properties <- sql("SET -v")
showDF(properties, numRows = 200, truncate = FALSE)
{% endhighlight %}
diff --git a/docs/graphx-programming-guide.md b/docs/graphx-programming-guide.md
index 81cf17475fb6..2e9966c0a2b6 100644
--- a/docs/graphx-programming-guide.md
+++ b/docs/graphx-programming-guide.md
@@ -603,29 +603,7 @@ slightly unreliable and instead opted for more explicit user control.
In the following example we use the [`aggregateMessages`][Graph.aggregateMessages] operator to
compute the average age of the more senior followers of each user.
-{% highlight scala %}
-// Import random graph generation library
-import org.apache.spark.graphx.util.GraphGenerators
-// Create a graph with "age" as the vertex property. Here we use a random graph for simplicity.
-val graph: Graph[Double, Int] =
- GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble )
-// Compute the number of older followers and their total age
-val olderFollowers: VertexRDD[(Int, Double)] = graph.aggregateMessages[(Int, Double)](
- triplet => { // Map Function
- if (triplet.srcAttr > triplet.dstAttr) {
- // Send message to destination vertex containing counter and age
- triplet.sendToDst(1, triplet.srcAttr)
- }
- },
- // Add counter and age
- (a, b) => (a._1 + b._1, a._2 + b._2) // Reduce Function
-)
-// Divide total age by number of older followers to get average age of older followers
-val avgAgeOfOlderFollowers: VertexRDD[Double] =
- olderFollowers.mapValues( (id, value) => value match { case (count, totalAge) => totalAge / count } )
-// Display the results
-avgAgeOfOlderFollowers.collect.foreach(println(_))
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/graphx/AggregateMessagesExample.scala %}
> The `aggregateMessages` operation performs optimally when the messages (and the sums of
> messages) are constant sized (e.g., floats and addition instead of lists and concatenation).
@@ -793,29 +771,7 @@ second argument list contains the user defined functions for receiving messages
We can use the Pregel operator to express computation such as single source
shortest path in the following example.
-{% highlight scala %}
-import org.apache.spark.graphx._
-// Import random graph generation library
-import org.apache.spark.graphx.util.GraphGenerators
-// A graph with edge attributes containing distances
-val graph: Graph[Long, Double] =
- GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble)
-val sourceId: VertexId = 42 // The ultimate source
-// Initialize the graph such that all vertices except the root have distance infinity.
-val initialGraph = graph.mapVertices((id, _) => if (id == sourceId) 0.0 else Double.PositiveInfinity)
-val sssp = initialGraph.pregel(Double.PositiveInfinity)(
- (id, dist, newDist) => math.min(dist, newDist), // Vertex Program
- triplet => { // Send Message
- if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
- Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
- } else {
- Iterator.empty
- }
- },
- (a,b) => math.min(a,b) // Merge Message
- )
-println(sssp.vertices.collect.mkString("\n"))
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/graphx/SSSPExample.scala %}
@@ -1007,66 +963,21 @@ PageRank measures the importance of each vertex in a graph, assuming an edge fro
GraphX comes with static and dynamic implementations of PageRank as methods on the [`PageRank` object][PageRank]. Static PageRank runs for a fixed number of iterations, while dynamic PageRank runs until the ranks converge (i.e., stop changing by more than a specified tolerance). [`GraphOps`][GraphOps] allows calling these algorithms directly as methods on `Graph`.
-GraphX also includes an example social network dataset that we can run PageRank on. A set of users is given in `graphx/data/users.txt`, and a set of relationships between users is given in `graphx/data/followers.txt`. We compute the PageRank of each user as follows:
+GraphX also includes an example social network dataset that we can run PageRank on. A set of users is given in `data/graphx/users.txt`, and a set of relationships between users is given in `data/graphx/followers.txt`. We compute the PageRank of each user as follows:
-{% highlight scala %}
-// Load the edges as a graph
-val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt")
-// Run PageRank
-val ranks = graph.pageRank(0.0001).vertices
-// Join the ranks with the usernames
-val users = sc.textFile("graphx/data/users.txt").map { line =>
- val fields = line.split(",")
- (fields(0).toLong, fields(1))
-}
-val ranksByUsername = users.join(ranks).map {
- case (id, (username, rank)) => (username, rank)
-}
-// Print the result
-println(ranksByUsername.collect().mkString("\n"))
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/graphx/PageRankExample.scala %}
## Connected Components
The connected components algorithm labels each connected component of the graph with the ID of its lowest-numbered vertex. For example, in a social network, connected components can approximate clusters. GraphX contains an implementation of the algorithm in the [`ConnectedComponents` object][ConnectedComponents], and we compute the connected components of the example social network dataset from the [PageRank section](#pagerank) as follows:
-{% highlight scala %}
-// Load the graph as in the PageRank example
-val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt")
-// Find the connected components
-val cc = graph.connectedComponents().vertices
-// Join the connected components with the usernames
-val users = sc.textFile("graphx/data/users.txt").map { line =>
- val fields = line.split(",")
- (fields(0).toLong, fields(1))
-}
-val ccByUsername = users.join(cc).map {
- case (id, (username, cc)) => (username, cc)
-}
-// Print the result
-println(ccByUsername.collect().mkString("\n"))
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/graphx/ConnectedComponentsExample.scala %}
## Triangle Counting
A vertex is part of a triangle when it has two adjacent vertices with an edge between them. GraphX implements a triangle counting algorithm in the [`TriangleCount` object][TriangleCount] that determines the number of triangles passing through each vertex, providing a measure of clustering. We compute the triangle count of the social network dataset from the [PageRank section](#pagerank). *Note that `TriangleCount` requires the edges to be in canonical orientation (`srcId < dstId`) and the graph to be partitioned using [`Graph.partitionBy`][Graph.partitionBy].*
-{% highlight scala %}
-// Load the edges in canonical order and partition the graph for triangle count
-val graph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt", true).partitionBy(PartitionStrategy.RandomVertexCut)
-// Find the triangle count for each vertex
-val triCounts = graph.triangleCount().vertices
-// Join the triangle counts with the usernames
-val users = sc.textFile("graphx/data/users.txt").map { line =>
- val fields = line.split(",")
- (fields(0).toLong, fields(1))
-}
-val triCountByUsername = users.join(triCounts).map { case (id, (username, tc)) =>
- (username, tc)
-}
-// Print the result
-println(triCountByUsername.collect().mkString("\n"))
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/graphx/TriangleCountingExample.scala %}
# Examples
@@ -1076,36 +987,4 @@ to important relationships and users, run page-rank on the sub-graph, and
then finally return attributes associated with the top users. I can do
all of this in just a few lines with GraphX:
-{% highlight scala %}
-// Connect to the Spark cluster
-val sc = new SparkContext("spark://master.amplab.org", "research")
-
-// Load my user data and parse into tuples of user id and attribute list
-val users = (sc.textFile("graphx/data/users.txt")
- .map(line => line.split(",")).map( parts => (parts.head.toLong, parts.tail) ))
-
-// Parse the edge data which is already in userId -> userId format
-val followerGraph = GraphLoader.edgeListFile(sc, "graphx/data/followers.txt")
-
-// Attach the user attributes
-val graph = followerGraph.outerJoinVertices(users) {
- case (uid, deg, Some(attrList)) => attrList
- // Some users may not have attributes so we set them as empty
- case (uid, deg, None) => Array.empty[String]
-}
-
-// Restrict the graph to users with usernames and names
-val subgraph = graph.subgraph(vpred = (vid, attr) => attr.size == 2)
-
-// Compute the PageRank
-val pagerankGraph = subgraph.pageRank(0.001)
-
-// Get the attributes of the top pagerank users
-val userInfoWithPageRank = subgraph.outerJoinVertices(pagerankGraph.vertices) {
- case (uid, attrList, Some(pr)) => (pr, attrList.toList)
- case (uid, attrList, None) => (0.0, attrList.toList)
-}
-
-println(userInfoWithPageRank.vertices.top(5)(Ordering.by(_._2._1)).mkString("\n"))
-
-{% endhighlight %}
+{% include_example scala/org/apache/spark/examples/graphx/ComprehensiveExample.scala %}
diff --git a/docs/img/cluster-overview.png b/docs/img/cluster-overview.png
index 317554c5f2a5..b1b7c1ab5a5f 100644
Binary files a/docs/img/cluster-overview.png and b/docs/img/cluster-overview.png differ
diff --git a/docs/img/edge-cut.png b/docs/img/edge-cut.png
deleted file mode 100644
index 698f4ff181e4..000000000000
Binary files a/docs/img/edge-cut.png and /dev/null differ
diff --git a/docs/img/edge_cut_vs_vertex_cut.png b/docs/img/edge_cut_vs_vertex_cut.png
index ae30396d3fe1..5b1ed78a42c0 100644
Binary files a/docs/img/edge_cut_vs_vertex_cut.png and b/docs/img/edge_cut_vs_vertex_cut.png differ
diff --git a/docs/img/graph_parallel.png b/docs/img/graph_parallel.png
deleted file mode 100644
index 330be5567cf9..000000000000
Binary files a/docs/img/graph_parallel.png and /dev/null differ
diff --git a/docs/img/graphx_logo.png b/docs/img/graphx_logo.png
index 9869ac148cad..04eb4d9217cd 100644
Binary files a/docs/img/graphx_logo.png and b/docs/img/graphx_logo.png differ
diff --git a/docs/img/graphx_performance_comparison.png b/docs/img/graphx_performance_comparison.png
deleted file mode 100644
index 62dcf098c904..000000000000
Binary files a/docs/img/graphx_performance_comparison.png and /dev/null differ
diff --git a/docs/img/ml-Pipeline.png b/docs/img/ml-Pipeline.png
index 607928906bed..33e2150b2415 100644
Binary files a/docs/img/ml-Pipeline.png and b/docs/img/ml-Pipeline.png differ
diff --git a/docs/img/ml-PipelineModel.png b/docs/img/ml-PipelineModel.png
index 9ebc16719d36..4d6f5c27af1c 100644
Binary files a/docs/img/ml-PipelineModel.png and b/docs/img/ml-PipelineModel.png differ
diff --git a/docs/img/property_graph.png b/docs/img/property_graph.png
index 6f3f89a010c5..72e13b91aaa8 100644
Binary files a/docs/img/property_graph.png and b/docs/img/property_graph.png differ
diff --git a/docs/img/spark-logo-hd.png b/docs/img/spark-logo-hd.png
index e4508e7553d4..464eda547ba9 100644
Binary files a/docs/img/spark-logo-hd.png and b/docs/img/spark-logo-hd.png differ
diff --git a/docs/img/spark-webui-accumulators.png b/docs/img/spark-webui-accumulators.png
index 237052d7b5db..55c50d608b11 100644
Binary files a/docs/img/spark-webui-accumulators.png and b/docs/img/spark-webui-accumulators.png differ
diff --git a/docs/img/streaming-arch.png b/docs/img/streaming-arch.png
index ac35f1d34cf3..f86fb6bcbbdd 100644
Binary files a/docs/img/streaming-arch.png and b/docs/img/streaming-arch.png differ
diff --git a/docs/img/streaming-dstream-ops.png b/docs/img/streaming-dstream-ops.png
index a1c5634aa3c3..73084ff1a1f0 100644
Binary files a/docs/img/streaming-dstream-ops.png and b/docs/img/streaming-dstream-ops.png differ
diff --git a/docs/img/streaming-dstream-window.png b/docs/img/streaming-dstream-window.png
index 276d2fee5e30..4db3c97c291f 100644
Binary files a/docs/img/streaming-dstream-window.png and b/docs/img/streaming-dstream-window.png differ
diff --git a/docs/img/streaming-dstream.png b/docs/img/streaming-dstream.png
index 90f43b8c7138..326a3aef0fa6 100644
Binary files a/docs/img/streaming-dstream.png and b/docs/img/streaming-dstream.png differ
diff --git a/docs/img/streaming-flow.png b/docs/img/streaming-flow.png
index a870cb9b1839..bf5dd40d3559 100644
Binary files a/docs/img/streaming-flow.png and b/docs/img/streaming-flow.png differ
diff --git a/docs/img/streaming-kinesis-arch.png b/docs/img/streaming-kinesis-arch.png
index bea5fa88df98..4fb026064dad 100644
Binary files a/docs/img/streaming-kinesis-arch.png and b/docs/img/streaming-kinesis-arch.png differ
diff --git a/docs/img/structured-streaming-example-model.png b/docs/img/structured-streaming-example-model.png
new file mode 100644
index 000000000000..d63498fdec4d
Binary files /dev/null and b/docs/img/structured-streaming-example-model.png differ
diff --git a/docs/img/structured-streaming-late-data.png b/docs/img/structured-streaming-late-data.png
new file mode 100644
index 000000000000..f9389c60552c
Binary files /dev/null and b/docs/img/structured-streaming-late-data.png differ
diff --git a/docs/img/structured-streaming-model.png b/docs/img/structured-streaming-model.png
new file mode 100644
index 000000000000..02becb91ef6a
Binary files /dev/null and b/docs/img/structured-streaming-model.png differ
diff --git a/docs/img/structured-streaming-stream-as-a-table.png b/docs/img/structured-streaming-stream-as-a-table.png
new file mode 100644
index 000000000000..bc6352446409
Binary files /dev/null and b/docs/img/structured-streaming-stream-as-a-table.png differ
diff --git a/docs/img/structured-streaming-window.png b/docs/img/structured-streaming-window.png
new file mode 100644
index 000000000000..875716c70868
Binary files /dev/null and b/docs/img/structured-streaming-window.png differ
diff --git a/docs/img/structured-streaming.pptx b/docs/img/structured-streaming.pptx
new file mode 100644
index 000000000000..6aad2ed33e92
Binary files /dev/null and b/docs/img/structured-streaming.pptx differ
diff --git a/docs/img/triplet.png b/docs/img/triplet.png
index 8b82a09bed29..5d38ccebd3f2 100644
Binary files a/docs/img/triplet.png and b/docs/img/triplet.png differ
diff --git a/docs/img/vertex-cut.png b/docs/img/vertex-cut.png
deleted file mode 100644
index 0a508dcee99e..000000000000
Binary files a/docs/img/vertex-cut.png and /dev/null differ
diff --git a/docs/img/vertex_routing_edge_tables.png b/docs/img/vertex_routing_edge_tables.png
index 4379becc87ee..2d1f3808e721 100644
Binary files a/docs/img/vertex_routing_edge_tables.png and b/docs/img/vertex_routing_edge_tables.png differ
diff --git a/docs/mllib-guide.md b/docs/mllib-guide.md
index c28d13732eed..17fd3e1edf4b 100644
--- a/docs/mllib-guide.md
+++ b/docs/mllib-guide.md
@@ -104,9 +104,105 @@ and the migration guide below will explain all changes between releases.
## From 1.6 to 2.0
-The deprecations and changes of behavior in the `spark.mllib` or `spark.ml` packages include:
+### Breaking changes
-Deprecations:
+There were several breaking changes in Spark 2.0, which are outlined below.
+
+**Linear algebra classes for DataFrame-based APIs**
+
+Spark's linear algebra dependencies were moved to a new project, `mllib-local`
+(see [SPARK-13944](https://issues.apache.org/jira/browse/SPARK-13944)).
+As part of this change, the linear algebra classes were copied to a new package, `spark.ml.linalg`.
+The DataFrame-based APIs in `spark.ml` now depend on the `spark.ml.linalg` classes,
+leading to a few breaking changes, predominantly in various model classes
+(see [SPARK-14810](https://issues.apache.org/jira/browse/SPARK-14810) for a full list).
+
+**Note:** the RDD-based APIs in `spark.mllib` continue to depend on the previous package `spark.mllib.linalg`.
+
+_Converting vectors and matrices_
+
+While most pipeline components support backward compatibility for loading,
+some existing `DataFrames` and pipelines in Spark versions prior to 2.0, that contain vector or matrix
+columns, may need to be migrated to the new `spark.ml` vector and matrix types.
+Utilities for converting `DataFrame` columns from `spark.mllib.linalg` to `spark.ml.linalg` types
+(and vice versa) can be found in `spark.mllib.util.MLUtils`.
+
+There are also utility methods available for converting single instances of
+vectors and matrices. Use the `asML` method on a `mllib.linalg.Vector` / `mllib.linalg.Matrix`
+for converting to `ml.linalg` types, and
+`mllib.linalg.Vectors.fromML` / `mllib.linalg.Matrices.fromML`
+for converting to `mllib.linalg` types.
+
+
+
+
+{% highlight scala %}
+import org.apache.spark.mllib.util.MLUtils
+
+// convert DataFrame columns
+val convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF)
+val convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF)
+// convert a single vector or matrix
+val mlVec: org.apache.spark.ml.linalg.Vector = mllibVec.asML
+val mlMat: org.apache.spark.ml.linalg.Matrix = mllibMat.asML
+{% endhighlight %}
+
+Refer to the [`MLUtils` Scala docs](api/scala/index.html#org.apache.spark.mllib.util.MLUtils$) for further detail.
+
+
+
+
+{% highlight java %}
+import org.apache.spark.mllib.util.MLUtils;
+import org.apache.spark.sql.Dataset;
+
+// convert DataFrame columns
+Dataset convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF);
+Dataset convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF);
+// convert a single vector or matrix
+org.apache.spark.ml.linalg.Vector mlVec = mllibVec.asML();
+org.apache.spark.ml.linalg.Matrix mlMat = mllibMat.asML();
+{% endhighlight %}
+
+Refer to the [`MLUtils` Java docs](api/java/org/apache/spark/mllib/util/MLUtils.html) for further detail.
+
+
+
+
+{% highlight python %}
+from pyspark.mllib.util import MLUtils
+
+# convert DataFrame columns
+convertedVecDF = MLUtils.convertVectorColumnsToML(vecDF)
+convertedMatrixDF = MLUtils.convertMatrixColumnsToML(matrixDF)
+# convert a single vector or matrix
+mlVec = mllibVec.asML()
+mlMat = mllibMat.asML()
+{% endhighlight %}
+
+Refer to the [`MLUtils` Python docs](api/python/pyspark.mllib.html#pyspark.mllib.util.MLUtils) for further detail.
+
+
+
+**Deprecated methods removed**
+
+Several deprecated methods were removed in the `spark.mllib` and `spark.ml` packages:
+
+* `setScoreCol` in `ml.evaluation.BinaryClassificationEvaluator`
+* `weights` in `LinearRegression` and `LogisticRegression` in `spark.ml`
+* `setMaxNumIterations` in `mllib.optimization.LBFGS` (marked as `DeveloperApi`)
+* `treeReduce` and `treeAggregate` in `mllib.rdd.RDDFunctions` (these functions are available on `RDD`s directly, and were marked as `DeveloperApi`)
+* `defaultStategy` in `mllib.tree.configuration.Strategy`
+* `build` in `mllib.tree.Node`
+* libsvm loaders for multiclass and load/save labeledData methods in `mllib.util.MLUtils`
+
+A full list of breaking changes can be found at [SPARK-14810](https://issues.apache.org/jira/browse/SPARK-14810).
+
+### Deprecations and changes of behavior
+
+**Deprecations**
+
+Deprecations in the `spark.mllib` and `spark.ml` packages include:
* [SPARK-14984](https://issues.apache.org/jira/browse/SPARK-14984):
In `spark.ml.regression.LinearRegressionSummary`, the `model` field has been deprecated.
@@ -125,7 +221,9 @@ Deprecations:
In `spark.ml.util.MLReader` and `spark.ml.util.MLWriter`, the `context` method has been deprecated in favor of `session`.
* In `spark.ml.feature.ChiSqSelectorModel`, the `setLabelCol` method has been deprecated since it was not used by `ChiSqSelectorModel`.
-Changes of behavior:
+**Changes of behavior**
+
+Changes of behavior in the `spark.mllib` and `spark.ml` packages include:
* [SPARK-7780](https://issues.apache.org/jira/browse/SPARK-7780):
`spark.mllib.classification.LogisticRegressionWithLBFGS` directly calls `spark.ml.classification.LogisticRegresson` for binary classification now.
diff --git a/docs/running-on-mesos.md b/docs/running-on-mesos.md
index 4a0ab623c108..5219e99fee73 100644
--- a/docs/running-on-mesos.md
+++ b/docs/running-on-mesos.md
@@ -180,30 +180,58 @@ Note that jars or python files that are passed to spark-submit should be URIs re
# Mesos Run Modes
-Spark can run over Mesos in two modes: "coarse-grained" (default) and "fine-grained".
-
-The "coarse-grained" mode will launch only *one* long-running Spark task on each Mesos
-machine, and dynamically schedule its own "mini-tasks" within it. The benefit is much lower startup
-overhead, but at the cost of reserving the Mesos resources for the complete duration of the
-application.
-
-Coarse-grained is the default mode. You can also set `spark.mesos.coarse` property to true
-to turn it on explicitly in [SparkConf](configuration.html#spark-properties):
-
-{% highlight scala %}
-conf.set("spark.mesos.coarse", "true")
-{% endhighlight %}
-
-In addition, for coarse-grained mode, you can control the maximum number of resources Spark will
-acquire. By default, it will acquire *all* cores in the cluster (that get offered by Mesos), which
-only makes sense if you run just one application at a time. You can cap the maximum number of cores
-using `conf.set("spark.cores.max", "10")` (for example).
-
-In "fine-grained" mode, each Spark task runs as a separate Mesos task. This allows
-multiple instances of Spark (and other frameworks) to share machines at a very fine granularity,
-where each application gets more or fewer machines as it ramps up and down, but it comes with an
-additional overhead in launching each task. This mode may be inappropriate for low-latency
-requirements like interactive queries or serving web requests.
+Spark can run over Mesos in two modes: "coarse-grained" (default) and
+"fine-grained" (deprecated).
+
+## Coarse-Grained
+
+In "coarse-grained" mode, each Spark executor runs as a single Mesos
+task. Spark executors are sized according to the following
+configuration variables:
+
+* Executor memory: `spark.executor.memory`
+* Executor cores: `spark.executor.cores`
+* Number of executors: `spark.cores.max`/`spark.executor.cores`
+
+Please see the [Spark Configuration](configuration.html) page for
+details and default values.
+
+Executors are brought up eagerly when the application starts, until
+`spark.cores.max` is reached. If you don't set `spark.cores.max`, the
+Spark application will reserve all resources offered to it by Mesos,
+so we of course urge you to set this variable in any sort of
+multi-tenant cluster, including one which runs multiple concurrent
+Spark applications.
+
+The scheduler will start executors round-robin on the offers Mesos
+gives it, but there are no spread guarantees, as Mesos does not
+provide such guarantees on the offer stream.
+
+The benefit of coarse-grained mode is much lower startup overhead, but
+at the cost of reserving Mesos resources for the complete duration of
+the application. To configure your job to dynamically adjust to its
+resource requirements, look into
+[Dynamic Allocation](#dynamic-resource-allocation-with-mesos).
+
+## Fine-Grained (deprecated)
+
+**NOTE:** Fine-grained mode is deprecated as of Spark 2.0.0. Consider
+ using [Dynamic Allocation](#dynamic-resource-allocation-with-mesos)
+ for some of the benefits. For a full explanation see
+ [SPARK-11857](https://issues.apache.org/jira/browse/SPARK-11857)
+
+In "fine-grained" mode, each Spark task inside the Spark executor runs
+as a separate Mesos task. This allows multiple instances of Spark (and
+other frameworks) to share cores at a very fine granularity, where
+each application gets more or fewer cores as it ramps up and down, but
+it comes with an additional overhead in launching each task. This mode
+may be inappropriate for low-latency requirements like interactive
+queries or serving web requests.
+
+Note that while Spark tasks in fine-grained will relinquish cores as
+they terminate, they will not relinquish memory, as the JVM does not
+give memory back to the Operating System. Neither will executors
+terminate when they're idle.
To run in fine-grained mode, set the `spark.mesos.coarse` property to false in your
[SparkConf](configuration.html#spark-properties):
@@ -212,7 +240,9 @@ To run in fine-grained mode, set the `spark.mesos.coarse` property to false in y
conf.set("spark.mesos.coarse", "false")
{% endhighlight %}
-You may also make use of `spark.mesos.constraints` to set attribute based constraints on mesos resource offers. By default, all resource offers will be accepted.
+You may also make use of `spark.mesos.constraints` to set
+attribute-based constraints on Mesos resource offers. By default, all
+resource offers will be accepted.
{% highlight scala %}
conf.set("spark.mesos.constraints", "os:centos7;us-east-1:false")
@@ -246,7 +276,7 @@ In either case, HDFS runs separately from Hadoop MapReduce, without being schedu
# Dynamic Resource Allocation with Mesos
-Mesos supports dynamic allocation only with coarse-grain mode, which can resize the number of
+Mesos supports dynamic allocation only with coarse-grained mode, which can resize the number of
executors based on statistics of the application. For general information,
see [Dynamic Resource Allocation](job-scheduling.html#dynamic-resource-allocation).
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md
index 6c6bc8db6a1f..68419e133159 100644
--- a/docs/sql-programming-guide.md
+++ b/docs/sql-programming-guide.md
@@ -63,52 +63,23 @@ Throughout this document, we will often refer to Scala/Java Datasets of `Row`s a
-The entry point into all functionality in Spark is the [`SparkSession`](api/scala/index.html#org.apache.spark.sql.SparkSession) class. To create a basic `SparkSession`, just use `SparkSession.build()`:
-
-{% highlight scala %}
-import org.apache.spark.sql.SparkSession
-
-val spark = SparkSession.build()
- .master("local")
- .appName("Word Count")
- .config("spark.some.config.option", "some-value")
- .getOrCreate()
-
-// this is used to implicitly convert an RDD to a DataFrame.
-import spark.implicits._
-{% endhighlight %}
+The entry point into all functionality in Spark is the [`SparkSession`](api/scala/index.html#org.apache.spark.sql.SparkSession) class. To create a basic `SparkSession`, just use `SparkSession.builder()`:
+{% include_example init_session scala/org/apache/spark/examples/sql/RDDRelation.scala %}
-The entry point into all functionality in Spark is the [`SparkSession`](api/java/index.html#org.apache.spark.sql.SparkSession) class. To create a basic `SparkSession`, just use `SparkSession.build()`:
+The entry point into all functionality in Spark is the [`SparkSession`](api/java/index.html#org.apache.spark.sql.SparkSession) class. To create a basic `SparkSession`, just use `SparkSession.builder()`:
-{% highlight java %}
-import org.apache.spark.sql.SparkSession
-
-SparkSession spark = SparkSession.build()
- .master("local")
- .appName("Word Count")
- .config("spark.some.config.option", "some-value")
- .getOrCreate();
-{% endhighlight %}
+{% include_example init_session java/org/apache/spark/examples/sql/JavaSparkSQL.java %}
-The entry point into all functionality in Spark is the [`SparkSession`](api/python/pyspark.sql.html#pyspark.sql.SparkSession) class. To create a basic `SparkSession`, just use `SparkSession.build`:
-
-{% highlight python %}
-from pyspark.sql import SparkSession
-
-spark = SparkSession.build \
- .master("local") \
- .appName("Word Count") \
- .config("spark.some.config.option", "some-value") \
- .getOrCreate()
-{% endhighlight %}
+The entry point into all functionality in Spark is the [`SparkSession`](api/python/pyspark.sql.html#pyspark.sql.SparkSession) class. To create a basic `SparkSession`, just use `SparkSession.builder`:
+{% include_example init_session python/sql.py %}
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index db06a65b994b..2ee3b80185c2 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -1534,7 +1534,7 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma
***
## DataFrame and SQL Operations
-You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SQLContext using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SQLContext. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL.
+You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SparkSession using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SparkSession. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL.
@@ -1546,9 +1546,9 @@ val words: DStream[String] = ...
words.foreachRDD { rdd =>
- // Get the singleton instance of SQLContext
- val sqlContext = SQLContext.getOrCreate(rdd.sparkContext)
- import sqlContext.implicits._
+ // Get the singleton instance of SparkSession
+ val spark = SparkSession.builder.config(rdd.sparkContext.getConf).getOrCreate()
+ import spark.implicits._
// Convert RDD[String] to DataFrame
val wordsDataFrame = rdd.toDF("word")
@@ -1558,7 +1558,7 @@ words.foreachRDD { rdd =>
// Do word count on DataFrame using SQL and print it
val wordCountsDataFrame =
- sqlContext.sql("select word, count(*) as total from words group by word")
+ spark.sql("select word, count(*) as total from words group by word")
wordCountsDataFrame.show()
}
@@ -1593,8 +1593,8 @@ words.foreachRDD(
@Override
public Void call(JavaRDD
rdd, Time time) {
- // Get the singleton instance of SQLContext
- SQLContext sqlContext = SQLContext.getOrCreate(rdd.context());
+ // Get the singleton instance of SparkSession
+ SparkSession spark = SparkSession.builder().config(rdd.sparkContext().getConf()).getOrCreate();
// Convert RDD[String] to RDD[case class] to DataFrame
JavaRDD rowRDD = rdd.map(new Function() {
@@ -1604,14 +1604,14 @@ words.foreachRDD(
return record;
}
});
- DataFrame wordsDataFrame = sqlContext.createDataFrame(rowRDD, JavaRow.class);
+ DataFrame wordsDataFrame = spark.createDataFrame(rowRDD, JavaRow.class);
// Creates a temporary view using the DataFrame
wordsDataFrame.createOrReplaceTempView("words");
// Do word count on table using SQL and print it
DataFrame wordCountsDataFrame =
- sqlContext.sql("select word, count(*) as total from words group by word");
+ spark.sql("select word, count(*) as total from words group by word");
wordCountsDataFrame.show();
return null;
}
@@ -1624,11 +1624,14 @@ See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/ma
{% highlight python %}
-# Lazily instantiated global instance of SQLContext
-def getSqlContextInstance(sparkContext):
- if ('sqlContextSingletonInstance' not in globals()):
- globals()['sqlContextSingletonInstance'] = SQLContext(sparkContext)
- return globals()['sqlContextSingletonInstance']
+# Lazily instantiated global instance of SparkSession
+def getSparkSessionInstance(sparkConf):
+ if ('sparkSessionSingletonInstance' not in globals()):
+ globals()['sparkSessionSingletonInstance'] = SparkSession\
+ .builder\
+ .config(conf=sparkConf)\
+ .getOrCreate()
+ return globals()['sparkSessionSingletonInstance']
...
@@ -1639,18 +1642,18 @@ words = ... # DStream of strings
def process(time, rdd):
print("========= %s =========" % str(time))
try:
- # Get the singleton instance of SQLContext
- sqlContext = getSqlContextInstance(rdd.context)
+ # Get the singleton instance of SparkSession
+ spark = getSparkSessionInstance(rdd.context.getConf())
# Convert RDD[String] to RDD[Row] to DataFrame
rowRdd = rdd.map(lambda w: Row(word=w))
- wordsDataFrame = sqlContext.createDataFrame(rowRdd)
+ wordsDataFrame = spark.createDataFrame(rowRdd)
# Creates a temporary view using the DataFrame
wordsDataFrame.createOrReplaceTempView("words")
# Do word count on table using SQL and print it
- wordCountsDataFrame = sqlContext.sql("select word, count(*) as total from words group by word")
+ wordCountsDataFrame = spark.sql("select word, count(*) as total from words group by word")
wordCountsDataFrame.show()
except:
pass
diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md
new file mode 100644
index 000000000000..79493968db27
--- /dev/null
+++ b/docs/structured-streaming-programming-guide.md
@@ -0,0 +1,1158 @@
+---
+layout: global
+displayTitle: Structured Streaming Programming Guide [Alpha]
+title: Structured Streaming Programming Guide
+---
+
+* This will become a table of contents (this text will be scraped).
+{:toc}
+
+# Overview
+Structured Streaming is a scalable and fault-tolerant stream processing engine built on the Spark SQL engine. You can express your streaming computation the same way you would express a batch computation on static data.The Spark SQL engine will take care of running it incrementally and continuously and updating the final result as streaming data continues to arrive. You can use the [Dataset/DataFrame API](sql-programming-guide.html) in Scala, Java or Python to express streaming aggregations, event-time windows, stream-to-batch joins, etc. The computation is executed on the same optimized Spark SQL engine. Finally, the system ensures end-to-end exactly-once fault-tolerance guarantees through checkpointing and Write Ahead Logs. In short, *Structured Streaming provides fast, scalable, fault-tolerant, end-to-end exactly-once stream processing without the user having to reason about streaming.*
+
+**Spark 2.0 is the ALPHA RELEASE of Structured Streaming** and the APIs are still experimental. In this guide, we are going to walk you through the programming model and the APIs. First, let's start with a simple example - a streaming word count.
+
+# Quick Example
+Let’s say you want to maintain a running word count of text data received from a data server listening on a TCP socket. Let’s see how you can express this using Structured Streaming. You can see the full code in
+[Scala]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala)/
+[Java]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java)/
+[Python]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/sql/streaming/structured_network_wordcount.py). And if you
+[download Spark](http://spark.apache.org/downloads.html), you can directly run the example. In any case, let’s walk through the example step-by-step and understand how it works. First, we have to import the necessary classes and create a local SparkSession, the starting point of all functionalities related to Spark.
+
+
+
+Next, let’s create a streaming DataFrame that represents text data received from a server listening on localhost:9999, and transform the DataFrame to calculate word counts.
+
+
+
+
+{% highlight scala %}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.SparkSession
+
+val spark = SparkSession
+ .builder
+ .appName("StructuredNetworkWordCount")
+ .getOrCreate()
+{% endhighlight %}
+
+Next, let’s create a streaming DataFrame that represents text data received from a server listening on localhost:9999, and transform the DataFrame to calculate word counts.
+
+{% highlight scala %}
+// Create DataFrame representing the stream of input lines from connection to localhost:9999
+val lines = spark.readStream
+ .format("socket")
+ .option("host", "localhost")
+ .option("port", 9999)
+ .load()
+
+// Split the lines into words
+val words = lines.as[String].flatMap(_.split(" "))
+
+// Generate running word count
+val wordCounts = words.groupBy("value").count()
+{% endhighlight %}
+
+This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named “value”, and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as(Encoders.STRING())`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
+
+
+
+
+{% highlight java %}
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.streaming.StreamingQuery;
+
+import java.util.Arrays;
+import java.util.Iterator;
+
+SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaStructuredNetworkWordCount")
+ .getOrCreate();
+
+import spark.implicits._
+{% endhighlight %}
+
+Next, let’s create a streaming DataFrame that represents text data received from a server listening on localhost:9999, and transform the DataFrame to calculate word counts.
+
+{% highlight java %}
+// Create DataFrame representing the stream of input lines from connection to localhost:9999
+Dataset lines = spark
+ .readStream()
+ .format("socket")
+ .option("host", "localhost")
+ .option("port", 9999)
+ .load();
+
+// Split the lines into words
+Dataset words = lines
+ .as(Encoders.STRING())
+ .flatMap(
+ new FlatMapFunction() {
+ @Override
+ public Iterator call(String x) {
+ return Arrays.asList(x.split(" ")).iterator();
+ }
+ }, Encoders.STRING());
+
+// Generate running word count
+Dataset wordCounts = words.groupBy("value").count();
+{% endhighlight %}
+
+This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named “value”, and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have converted the DataFrame to a Dataset of String using `.as(Encoders.STRING())`, so that we can apply the `flatMap` operation to split each line into multiple words. The resultant `words` Dataset contains all the words. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
+
+
+
+
+{% highlight python %}
+from pyspark.sql import SparkSession
+from pyspark.sql.functions import explode
+from pyspark.sql.functions import split
+
+spark = SparkSession\
+ .builder()\
+ .appName("StructuredNetworkWordCount")\
+ .getOrCreate()
+{% endhighlight %}
+
+Next, let’s create a streaming DataFrame that represents text data received from a server listening on localhost:9999, and transform the DataFrame to calculate word counts.
+
+{% highlight python %}
+# Create DataFrame representing the stream of input lines from connection to localhost:9999
+lines = spark\
+ .readStream\
+ .format('socket')\
+ .option('host', 'localhost')\
+ .option('port', 9999)\
+ .load()
+
+# Split the lines into words
+words = lines.select(
+ explode(
+ split(lines.value, ' ')
+ ).alias('word')
+)
+
+# Generate running word count
+wordCounts = words.groupBy('word').count()
+{% endhighlight %}
+
+This `lines` DataFrame represents an unbounded table containing the streaming text data. This table contains one column of strings named “value”, and each line in the streaming text data becomes a row in the table. Note, that this is not currently receiving any data as we are just setting up the transformation, and have not yet started it. Next, we have used two built-in SQL functions - split and explode, to split each line into multiple rows with a word each. In addition, we use the function `alias` to name the new column as “word”. Finally, we have defined the `wordCounts` DataFrame by grouping by the unique values in the Dataset and counting them. Note that this is a streaming DataFrame which represents the running word counts of the stream.
+
+
+
+
+We have now set up the query on the streaming data. All that is left is to actually start receiving data and computing the counts. To do this, we set it up to print the complete set of counts (specified by `outputMode(“complete”)`) to the console every time they are updated. And then start the streaming computation using `start()`.
+
+
+
+
+{% highlight scala %}
+// Start running the query that prints the running counts to the console
+val query = wordCounts.writeStream
+ .outputMode("complete")
+ .format("console")
+ .start()
+
+query.awaitTermination()
+{% endhighlight %}
+
+
+
+
+{% highlight java %}
+// Start running the query that prints the running counts to the console
+StreamingQuery query = wordCounts.writeStream()
+ .outputMode("complete")
+ .format("console")
+ .start();
+
+query.awaitTermination();
+{% endhighlight %}
+
+
+
+
+{% highlight python %}
+ # Start running the query that prints the running counts to the console
+query = wordCounts\
+ .writeStream\
+ .outputMode('complete')\
+ .format('console')\
+ .start()
+
+query.awaitTermination()
+{% endhighlight %}
+
+
+
+
+After this code is executed, the streaming computation will have started in the background. The `query` object is a handle to that active streaming query, and we have decided to wait for the termination of the query using `query.awaitTermination()` to prevent the process from exiting while the query is active.
+
+To actually execute this example code, you can either compile the code in your own
+[Spark application](quick-start.html#self-contained-applications), or simply
+[run the example](index.html#running-the-examples-and-shell) once you have downloaded Spark. We are showing the latter. You will first need to run Netcat (a small utility found in most Unix-like systems) as a data server by using
+
+
+ $ nc -lk 9999
+
+Then, in a different terminal, you can start the example by using
+
+
+
+{% highlight bash %}
+$ ./bin/run-example org.apache.spark.examples.sql.streaming.StructuredNetworkWordCount localhost 9999
+{% endhighlight %}
+
+
+{% highlight bash %}
+$ ./bin/run-example org.apache.spark.examples.sql.streaming.JavaStructuredNetworkWordCount localhost 9999
+{% endhighlight %}
+
+
+ {% highlight bash %}
+$ ./bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py localhost 9999
+{% endhighlight %}
+
+
+
+Then, any lines typed in the terminal running the netcat server will be counted and printed on screen every second. It will look something like the following.
+
+
+ |
+{% highlight bash %}
+# TERMINAL 1:
+# Running Netcat
+
+$ nc -lk 9999
+apache spark
+apache hadoop
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+...
+{% endhighlight %}
+ |
+ |
+
+
+
+
+{% highlight bash %}
+# TERMINAL 2: RUNNING StructuredNetworkWordCount
+
+$ ./bin/run-example org.apache.spark.examples.sql.streaming.StructuredNetworkWordCount localhost 9999
+
+-------------------------------------------
+Batch: 0
+-------------------------------------------
++------+-----+
+| value|count|
++------+-----+
+|apache| 1|
+| spark| 1|
++------+-----+
+
+-------------------------------------------
+Batch: 1
+-------------------------------------------
++------+-----+
+| value|count|
++------+-----+
+|apache| 2|
+| spark| 1|
+|hadoop| 1|
++------+-----+
+...
+{% endhighlight %}
+
+
+
+{% highlight bash %}
+# TERMINAL 2: RUNNING JavaStructuredNetworkWordCount
+
+$ ./bin/run-example org.apache.spark.examples.sql.streaming.JavaStructuredNetworkWordCount localhost 9999
+
+-------------------------------------------
+Batch: 0
+-------------------------------------------
++------+-----+
+| value|count|
++------+-----+
+|apache| 1|
+| spark| 1|
++------+-----+
+
+-------------------------------------------
+Batch: 1
+-------------------------------------------
++------+-----+
+| value|count|
++------+-----+
+|apache| 2|
+| spark| 1|
+|hadoop| 1|
++------+-----+
+...
+{% endhighlight %}
+
+
+{% highlight bash %}
+# TERMINAL 2: RUNNING structured_network_wordcount.py
+
+$ ./bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py localhost 9999
+
+-------------------------------------------
+Batch: 0
+-------------------------------------------
++------+-----+
+| value|count|
++------+-----+
+|apache| 1|
+| spark| 1|
++------+-----+
+
+-------------------------------------------
+Batch: 1
+-------------------------------------------
++------+-----+
+| value|count|
++------+-----+
+|apache| 2|
+| spark| 1|
+|hadoop| 1|
++------+-----+
+...
+{% endhighlight %}
+
+
+ |
+
+
+
+# Programming Model
+
+The key idea in Structured Streaming is to treat a live data stream as a
+table that is being continuously appended. This leads to a new stream
+processing model that is very similar to a batch processing model. You will
+express your streaming computation as standard batch-like query as on a static
+table, and Spark runs it as an *incremental* query on the *unbounded* input
+table. Let’s understand this model in more detail.
+
+## Basic Concepts
+Consider the input data stream as the “Input Table”. Every data item that is
+arriving on the stream is like a new row being appended to the Input Table.
+
+
+
+A query on the input will generate the “Result Table”. Every trigger interval (say, every 1 second), new rows get appended to the Input Table, which eventually updates the Result Table. Whenever the result table gets updated, we would want to write the changed result rows to an external sink.
+
+
+
+The “Output” is defined as what gets written out to the external storage. The output can be defined in different modes
+
+ - *Complete Mode* - The entire updated Result Table will be written to the external storage. It is up to the storage connector to decide how to handle writing of the entire table.
+
+ - *Append Mode* - Only the new rows appended in the Result Table since the last trigger will be written to the external storage. This is applicable only on the queries where existing rows in the Result Table are not expected to change.
+
+ - *Update Mode* - Only the rows that were updated in the Result Table since the last trigger will be written to the external storage (not available yet in Spark 2.0). Note that this is different from the Complete Mode in that this mode does not output the rows that are not changed.
+
+Note that each mode is applicable on certain types of queries. This is discussed in detail [later](#output-modes).
+
+To illustrate the use of this model, let’s understand the model in context of
+the Quick Example above. The first `lines` DataFrame is the input table, and
+the final `wordCounts` DataFrame is the result table. Note that the query on
+streaming `lines` DataFrame to generate `wordCounts` is *exactly the same* as
+it would be a static DataFrame. However, when this query is started, Spark
+will continuously check for new data from the socket connection. If there is
+new data, Spark will run an “incremental” query that combines the previous
+running counts with the new data to compute updated counts, as shown below.
+
+
+
+This model is significantly different from many other stream processing
+engines. Many streaming systems require the user to maintain running
+aggregations themselves, thus having to reason about fault-tolerance, and
+data consistency (at-least-once, or at-most-once, or exactly-once). In this
+model, Spark is responsible for updating the Result Table when there is new
+data, thus relieving the users from reasoning about it. As an example, let’s
+see how this model handles event-time based processing and late arriving data.
+
+## Handling Event-time and Late Data
+Event-time is the time embedded in the data itself. For many applications, you may want to operate on this event-time. For example, if you want to get the number of events generated by IoT devices every minute, then you probably want to use the time when the data was generated (that is, event-time in the data), rather than the time Spark receives them. This event-time is very naturally expressed in this model -- each event from the devices is a row in the table, and event-time is a column value in the row. This allows window-based aggregations (e.g. number of event every minute) to be just a special type of grouping and aggregation on the even-time column -- each time window is a group and each row can belong to multiple windows/groups. Therefore, such event-time-window-based aggregation queries can be defined consistently on both a static dataset (e.g. from collected device events logs) as well as on a data stream, making the life of the user much easier.
+
+Furthermore this model naturally handles data that has arrived later than expected based on its event-time. Since Spark is updating the Result Table, it has full control over updating/cleaning up the aggregates when there is late data. While not yet implemented in Spark 2.0, event-time watermarking will be used to manage this data. These are explained later in more details in the [Window Operations](#window-operations-on-event-time) section.
+
+## Fault Tolerance Semantics
+Delivering end-to-end exactly-once semantics was one of key goals behind the design of Structured Streaming. To achieve that, we have designed the Structured Streaming sources, the sinks and the execution engine to reliably track the exact progress of the processing so that it can handle any kind of failure by restarting and/or reprocessing. Every streaming source is assumed to have offsets (similar to Kafka offsets, or Kinesis sequence numbers)
+to track the read position in the stream. The engine uses checkpointing and write ahead logs to record the offset range of the data being processed in each trigger. The streaming sinks are designed to be idempotent for handling reprocessing. Together, using replayable sources and idempotant sinks, Structured Streaming can ensure **end-to-end exactly-once semantics** under any failure.
+
+# API using Datasets and DataFrames
+Since Spark 2.0, DataFrames and Datasets can represent static, bounded data, as well as streaming, unbounded data. Similar to static Datasets/DataFrames, you can use the common entry point `SparkSession` (
+[Scala](api/scala/index.html#org.apache.spark.sql.SparkSession)/
+[Java](api/java/org/apache/spark/sql/SparkSession.html)/
+[Python](api/python/pyspark.sql.html#pyspark.sql.SparkSession) docs) to create streaming DataFrames/Datasets from streaming sources, and apply the same operations on them as static DataFrames/Datasets. If you are not familiar with Datasets/DataFrames, you are strongly advised to familiarize yourself with them using the
+[DataFrame/Dataset Programming Guide](sql-programming-guide.html).
+
+## Creating streaming DataFrames and streaming Datasets
+Streaming DataFrames can be created through the `DataStreamReader` interface
+([Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamReader)/
+[Java](api/java/org/apache/spark/sql/streaming/DataStreamReader.html)/
+[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamReader) docs) returned by `SparkSession.readStream()`. Similar to the read interface for creating static DataFrame, you can specify the details of the source - data format, schema, options, etc. In Spark 2.0, there are a few built-in sources.
+
+ - **File sources** - Reads files written in a directory as a stream of data. Supported file formats are text, csv, json, parquet. See the docs of the DataStreamReader interface for a more up-to-date list, and supported options for each file format. Note that the files must be atomically placed in the given directory, which in most file systems, can be achieved by file move operations.
+
+ - **Socket source (for testing)** - Reads UTF8 text data from a socket connection. The listening server socket is at the driver. Note that this should be used only for testing as this does not provide end-to-end fault-tolerance guarantees.
+
+Here are some examples.
+
+
+
+
+{% highlight scala %}
+val spark: SparkSession = …
+
+// Read text from socket
+val socketDF = spark
+ .readStream
+ .format("socket")
+ .option("host", "localhost")
+ .option("port", 9999)
+ .load()
+
+socketDF.isStreaming // Returns True for DataFrames that have streaming sources
+
+socketDF.printSchema
+
+// Read all the csv files written atomically in a directory
+val userSchema = new StructType().add("name", "string").add("age", "integer")
+val csvDF = spark
+ .readStream
+ .option("sep", ";")
+ .schema(userSchema) // Specify schema of the parquet files
+ .csv("/path/to/directory") // Equivalent to format("csv").load("/path/to/directory")
+{% endhighlight %}
+
+
+
+
+{% highlight java %}
+SparkSession spark = ...
+
+// Read text from socket
+Dataset[Row] socketDF = spark
+ .readStream()
+ .format("socket")
+ .option("host", "localhost")
+ .option("port", 9999)
+ .load();
+
+socketDF.isStreaming(); // Returns True for DataFrames that have streaming sources
+
+socketDF.printSchema();
+
+// Read all the csv files written atomically in a directory
+StructType userSchema = new StructType().add("name", "string").add("age", "integer");
+Dataset[Row] csvDF = spark
+ .readStream()
+ .option("sep", ";")
+ .schema(userSchema) // Specify schema of the parquet files
+ .csv("/path/to/directory"); // Equivalent to format("csv").load("/path/to/directory")
+{% endhighlight %}
+
+
+
+
+{% highlight python %}
+spark = SparkSession. ….
+
+# Read text from socket
+socketDF = spark \
+ .readStream() \
+ .format("socket") \
+ .option("host", "localhost") \
+ .option("port", 9999) \
+ .load()
+
+socketDF.isStreaming() # Returns True for DataFrames that have streaming sources
+
+socketDF.printSchema()
+
+# Read all the csv files written atomically in a directory
+userSchema = StructType().add("name", "string").add("age", "integer")
+csvDF = spark \
+ .readStream() \
+ .option("sep", ";") \
+ .schema(userSchema) \
+ .csv("/path/to/directory") # Equivalent to format("csv").load("/path/to/directory")
+{% endhighlight %}
+
+
+
+
+These examples generate streaming DataFrames that are untyped, meaning that the schema of the DataFrame is not checked at compile time, only checked at runtime when the query is submitted. Some operations like `map`, `flatMap`, etc. need the type to be known at compile time. To do those, you can convert these untyped streaming DataFrames to typed streaming Datasets using the same methods as static DataFrame. See the SQL Programming Guide for more details. Additionally, more details on the supported streaming sources are discussed later in the document.
+
+## Operations on streaming DataFrames/Datasets
+You can apply all kinds of operations on streaming DataFrames/Datasets - ranging from untyped, SQL-like operations (e.g. select, where, groupBy), to typed RDD-like operations (e.g. map, filter, flatMap). See the [SQL programming guide](sql-programming-guide.html) for more details. Let’s take a look at a few example operations that you can use.
+
+### Basic Operations - Selection, Projection, Aggregation
+Most of the common operations on DataFrame/Dataset are supported for streaming. The few operations that are not supported are [discussed later](#unsupported-operations) in this section.
+
+
+
+
+{% highlight scala %}
+case class DeviceData(device: String, type: String, signal: Double, time: DateTime)
+
+val df: DataFrame = ... // streaming DataFrame with IOT device data with schema { device: string, type: string, signal: double, time: string }
+val ds: Dataset[DeviceData] = df.as[DeviceData] // streaming Dataset with IOT device data
+
+// Select the devices which have signal more than 10
+df.select("device").where("signal > 10") // using untyped APIs
+ds.filter(_.signal > 10).map(_.device) // using typed APIs
+
+// Running count of the number of updates for each device type
+df.groupBy("type").count() // using untyped API
+
+// Running average signal for each device type
+Import org.apache.spark.sql.expressions.scalalang.typed._
+ds.groupByKey(_.type).agg(typed.avg(_.signal)) // using typed API
+{% endhighlight %}
+
+
+
+
+{% highlight java %}
+import org.apache.spark.api.java.function.*;
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.expressions.javalang.typed;
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder;
+
+public class DeviceData {
+ private String device;
+ private String type;
+ private Double signal;
+ private java.sql.Date time;
+ ...
+ // Getter and setter methods for each field
+}
+
+Dataset df = ...; // streaming DataFrame with IOT device data with schema { device: string, type: string, signal: double, time: DateType }
+Dataset ds = df.as(ExpressionEncoder.javaBean(DeviceData.class)); // streaming Dataset with IOT device data
+
+// Select the devices which have signal more than 10
+df.select("device").where("signal > 10"); // using untyped APIs
+ds.filter(new FilterFunction() { // using typed APIs
+ @Override
+ public boolean call(DeviceData value) throws Exception {
+ return value.getSignal() > 10;
+ }
+}).map(new MapFunction() {
+ @Override
+ public String call(DeviceData value) throws Exception {
+ return value.getDevice();
+ }
+}, Encoders.STRING());
+
+// Running count of the number of updates for each device type
+df.groupBy("type").count(); // using untyped API
+
+// Running average signal for each device type
+ds.groupByKey(new MapFunction() { // using typed API
+ @Override
+ public String call(DeviceData value) throws Exception {
+ return value.getType();
+ }
+}, Encoders.STRING()).agg(typed.avg(new MapFunction() {
+ @Override
+ public Double call(DeviceData value) throws Exception {
+ return value.getSignal();
+ }
+}));
+{% endhighlight %}
+
+
+
+
+
+{% highlight python %}
+
+df = ... # streaming DataFrame with IOT device data with schema { device: string, type: string, signal: double, time: DateType }
+
+# Select the devices which have signal more than 11
+df.select("device").where("signal > 10")
+
+# Running count of the number of updates for each device type
+df.groupBy("type").count()
+{% endhighlight %}
+
+
+
+### Window Operations on Event Time
+Aggregations over a sliding event-time window are straightforward with Structured Streaming. The key idea to understand about window-based aggregations are very similar to grouped aggregations. In a grouped aggregation, aggregate values (e.g. counts) are maintained for each unique value in the user-specified grouping column. In case of window-based aggregations, aggregate values are maintained for each window the event-time of a row falls into. Let's understand this with an illustration.
+
+Imagine our quick example is modified and the stream now contains lines along with the time when the line was generated. Instead of running word counts, we want to count words within 10 minute windows, updating every 5 minutes. That is, word counts in words received between 10 minute windows 12:00 - 12:10, 12:05 - 12:15, 12:10 - 12:20, etc. Note that 12:00 - 12:10 means data that arrived after 12:00 but before 12:10. Now, consider a word that was received at 12:07. This word should increment the counts corresponding to two windows 12:00 - 12:10 and 12:05 - 12:15. So the counts will be indexed by both, the grouping key (i.e. the word) and the window (can be calculated from the event-time).
+
+The result tables would look something like the following.
+
+
+
+Since this windowing is similar to grouping, in code, you can use `groupBy()` and `window()` operations to express windowed aggregations.
+
+
+
+
+{% highlight scala %}
+// Number of events in every 1 minute time windows
+df.groupBy(window(df.col("time"), "1 minute"))
+ .count()
+
+
+// Average number of events for each device type in every 1 minute time windows
+df.groupBy(
+ df.col("type"),
+ window(df.col("time"), "1 minute"))
+ .avg("signal")
+{% endhighlight %}
+
+
+
+
+{% highlight java %}
+import static org.apache.spark.sql.functions.window;
+
+// Number of events in every 1 minute time windows
+df.groupBy(window(df.col("time"), "1 minute"))
+ .count();
+
+// Average number of events for each device type in every 1 minute time windows
+df.groupBy(
+ df.col("type"),
+ window(df.col("time"), "1 minute"))
+ .avg("signal");
+
+{% endhighlight %}
+
+
+
+{% highlight python %}
+from pyspark.sql.functions import window
+
+# Number of events in every 1 minute time windows
+df.groupBy(window("time", "1 minute")).count()
+
+# Average number of events for each device type in every 1 minute time windows
+df.groupBy("type", window("time", "1 minute")).avg("signal")
+{% endhighlight %}
+
+
+
+
+
+Now consider what happens if one of the events arrives late to the application.
+For example, a word that was generated at 12:04 but it was received at 12:11.
+Since this windowing is based on the time in the data, the time 12:04 should be considered for windowing. This occurs naturally in our window-based grouping - the late data is automatically placed in the proper windows and the correct aggregates updated as illustrated below.
+
+
+
+### Join Operations
+Streaming DataFrames can be joined with static DataFrames to create new streaming DataFrames. Here are a few examples.
+
+
+
+
+{% highlight scala %}
+val staticDf = spark.read. ...
+val streamingDf = spark.readStream. ...
+
+streamingDf.join(staticDf, “type”) // inner equi-join with a static DF
+streamingDf.join(staticDf, “type”, “right_join”) // right outer join with a static DF
+
+{% endhighlight %}
+
+
+
+
+{% highlight java %}
+Dataset staticDf = spark.read. ...;
+Dataset streamingDf = spark.readStream. ...;
+streamingDf.join(staticDf, "type"); // inner equi-join with a static DF
+streamingDf.join(staticDf, "type", "right_join"); // right outer join with a static DF
+{% endhighlight %}
+
+
+
+
+
+{% highlight python %}
+staticDf = spark.read. …
+streamingDf = spark.readStream. …
+streamingDf.join(staticDf, "type") # inner equi-join with a static DF
+streamingDf.join(staticDf, "type", "right_join") # right outer join with a static DF
+{% endhighlight %}
+
+
+
+
+### Unsupported Operations
+However, note that all of the operations applicable on static DataFrames/Datasets are not supported in streaming DataFrames/Datasets yet. While some of these unsupported operations will be supported in future releases of Spark, there are others which are fundamentally hard to implement on streaming data efficiently. For example, sorting is not supported on the input streaming Dataset, as it requires keeping track of all the data received in the stream. This is therefore fundamentally hard to execute efficiently. As of Spark 2.0, some of the unsupported operations are as follows
+
+- Multiple streaming aggregations (i.e. a chain of aggregations on a streaming DF) are not yet supported on streaming Datasets.
+
+- Limit and take first N rows are not supported on streaming Datasets.
+
+- Distinct operations on streaming Datasets are not supported.
+
+- Sorting operations are supported on streaming Datasets only after an aggregation and in Complete Output Mode.
+
+- Outer joins between a streaming and a static Datasets are conditionally supported.
+
+ + Full outer join with a streaming Dataset is not supported
+
+ + Left outer join with a streaming Dataset on the left is not supported
+
+ + Right outer join with a streaming Dataset on the right is not supported
+
+- Any kind of joins between two streaming Datasets are not yet supported.
+
+In addition, there are some Dataset methods that will not work on streaming Datasets. They are actions that will immediately run queries and return results, which does not makes sense on a streaming Dataset. Rather those functionalities can be done by explicitly starting a streaming query (see the next section regarding that).
+
+- `count()` - Cannot return a single count from a streaming Dataset. Instead, use `ds.groupBy.count()` which returns a streaming Dataset containing a running count.
+
+- `foreach()` - Instead use `ds.writeStream.foreach(...)` (see next section).
+
+- `show()` - Instead use the console sink (see next section).
+
+If you try any of these operations, you will see an AnalysisException like "operation XYZ is not supported with streaming DataFrames/Datasets".
+
+## Starting Streaming Queries
+Once you have defined the final result DataFrame/Dataset, all that is left is for you start the streaming computation. To do that, you have to use the
+`DataStreamWriter` (
+[Scala](api/scala/index.html#org.apache.spark.sql.streaming.DataStreamWriter)/
+[Java](api/java/org/apache/spark/sql/streaming/DataStreamWriter.html)/
+[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.DataStreamWriter) docs) returned through `Dataset.writeSteram()`. You will have to specify one or more of the following in this interface.
+
+- *Details of the output sink:* Data format, location, etc.
+
+- *Output mode:* Specify what gets written to the output sink.
+
+- *Query name:* Optionally, specify a unique name of the query for identification.
+
+- *Trigger interval:* Optionally, specify the trigger interval. If it is not specified, the system will check for availability of new data as soon as the previous processing has completed. If a trigger time is missed because the previous processing has not completed, then the system will attempt to trigger at the next trigger point, not immediately after the processing has completed.
+
+- *Checkpoint location:* For some output sinks where the end-to-end fault-tolerance can be guaranteed, specify the location where the system will write all the checkpoint information. This should be a directory in a HDFS-compatible fault-tolerant file system. The semantics of checkpointing is discussed in more detail in the next section.
+
+#### Output Modes
+There are two types of output mode currently implemented.
+
+- **Append mode (default)** - This is the default mode, where only the new rows added to the result table since the last trigger will be outputted to the sink. This is only applicable to queries that *do not have any aggregations* (e.g. queries with only select, where, map, flatMap, filter, join, etc.).
+
+- **Complete mode** - The whole result table will be outputted to the sink.This is only applicable to queries that *have aggregations*.
+
+#### Output Sinks
+There are a few types of built-in output sinks.
+
+- **File sink** - Stores the output to a directory. As of Spark 2.0, this only supports Parquet file format, and Append output mode.
+
+- **Foreach sink** - Runs arbitrary computation on the records in the output. See later in the section for more details.
+
+- **Console sink (for debugging)** - Prints the output to the console/stdout every time there is a trigger. Both, Append and Complete output modes, are supported. This should be used for debugging purposes on low data volumes as the entire output is collected and stored in the driver's memory after every trigger.
+
+- **Memory sink (for debugging)** - The output is stored in memory as an in-memory table. Both, Append and Complete output modes, are supported. This should be used for debugging purposes on low data volumes as the entire output is collected and stored in the driver's memory after every trigger.
+
+Here is a table of all the sinks, and the corresponding settings.
+
+
+
+ | Sink |
+ Supported Output Modes |
+ Usage |
+ Fault-tolerant |
+ Notes |
+
+
+ File Sink (only parquet in Spark 2.0) |
+ Append |
+ writeStream .format(“parquet”) .start() |
+ Yes |
+ Supports writes to partitioned tables. Partitioning by time may be useful. |
+
+
+ | Foreach Sink |
+ All modes |
+ writeStream .foreach(...) .start() |
+ Depends on ForeachWriter implementation |
+ More details in the next section |
+
+
+ | Console Sink |
+ Append, Complete |
+ writeStream .format(“console”) .start() |
+ No |
+ |
+
+
+ | Memory Sink |
+ Append, Complete |
+ writeStream .format(“memory”) .queryName(“table”) .start() |
+ No |
+ Saves the output data as a table, for interactive querying. Table name is the query name. |
+
+
+
+Finally, you have to call `start()` to actually to start the execution of the query. This returns a StreamingQuery object which is a handle to the continuously running execution. You can use this object to manage the query, which we will discuss in the next subsection. For now, let’s understand all this with a few examples.
+
+
+
+
+
+{% highlight scala %}
+// ========== DF with no aggregations ==========
+val noAggDF = deviceDataDf.select("device").where("signal > 10")
+
+// Print new data to console
+noAggDF
+ .writeStream
+ .format("console")
+ .start()
+
+// Write new data to Parquet files
+noAggDF
+ .writeStream
+ .parquet("path/to/destination/directory")
+ .start()
+
+// ========== DF with aggregation ==========
+val aggDF = df.groupBy(“device”).count()
+
+// Print updated aggregations to console
+aggDF
+ .writeStream
+ .outputMode("complete")
+ .format("console")
+ .start()
+
+// Have all the aggregates in an in memory table
+aggDF
+ .writeStream
+ .queryName("aggregates") // this query name will be the table name
+ .outputMode("complete")
+ .format("memory")
+ .start()
+
+spark.sql("select * from aggregates").show() // interactively query in-memory table
+{% endhighlight %}
+
+
+
+
+{% highlight java %}
+// ========== DF with no aggregations ==========
+Dataset noAggDF = deviceDataDf.select("device").where("signal > 10")
+
+// Print new data to console
+noAggDF
+ .writeStream()
+ .format("console")
+ .start();
+
+// Write new data to Parquet files
+noAggDF
+ .writeStream()
+ .parquet("path/to/destination/directory")
+ .start();
+
+// ========== DF with aggregation ==========
+Dataset aggDF = df.groupBy(“device”).count();
+
+// Print updated aggregations to console
+aggDF
+ .writeStream()
+ .outputMode("complete")
+ .format("console")
+ .start();
+
+// Have all the aggregates in an in memory table
+aggDF
+ .writeStream()
+ .queryName("aggregates") // this query name will be the table name
+ .outputMode("complete")
+ .format("memory")
+ .start();
+
+spark.sql("select * from aggregates").show(); // interactively query in-memory table
+{% endhighlight %}
+
+
+
+
+{% highlight python %}
+# ========== DF with no aggregations ==========
+noAggDF = deviceDataDf.select("device").where("signal > 10")
+
+# Print new data to console
+noAggDF\
+ .writeStream()\
+ .format("console")\
+ .start()
+
+# Write new data to Parquet files
+noAggDF\
+ .writeStream()\
+ .parquet("path/to/destination/directory")\
+ .start()
+
+# ========== DF with aggregation ==========
+aggDF = df.groupBy(“device”).count()
+
+# Print updated aggregations to console
+aggDF\
+ .writeStream()\
+ .outputMode("complete")\
+ .format("console")\
+ .start()
+
+# Have all the aggregates in an in memory table. The query name will be the table name
+aggDF\
+ .writeStream()\
+ .queryName("aggregates")\
+ .outputMode("complete")\
+ .format("memory")\
+ .start()
+
+spark.sql("select * from aggregates").show() # interactively query in-memory table
+{% endhighlight %}
+
+
+
+
+#### Using Foreach
+The `foreach` operation allows arbitrary operations to be computed on the output data. As of Spark 2.0, this is available only for Scala and Java. To use this, you will have to implement the interface `ForeachWriter` ([Scala](api/scala/index.html#org.apache.spark.sql.ForeachWriter)/
+[Java](api/java/org/apache/spark/sql/ForeachWriter.html) docs), which has methods that gets called whenever there is a sequence of rows generated as output after a trigger. Note the following important points.
+
+- The writer must be serializable, as it will be serialized and sent to the executors for execution.
+
+- All the three methods, `open`, `process` and `close` will be called on the executors.
+
+- The writer must do all the initialization (e.g. opening connections, starting a transaction, etc.) only when the `open` method is called. Be aware that, if there is any initialization in the class as soon as the object is created, then that initialization will happen in the driver (because that is where the instance is being created), which may not be what you intend.
+
+- `version` and `partition` are two parameters in `open` that uniquely represent a set of rows that needs to be pushed out. `version` is a monotonically increasing id that increases with every trigger. `partition` is an id that represents a partition of the output, since the output is distributed and will be processed on multiple executors.
+
+- `open` can use the `version` and `partition` to choose whether it needs to write the sequence of rows. Accordingly, it can return `true` (proceed with writing), or `false` (no need to write). If `false` is returned, then `process` will not be called on any row. For example, after a partial failure, some of the output partitions of the failed trigger may have already been committed to a database. Based on metadata stored in the database, the writer can identify partitions that have already been committed and accordingly return false to skip committing them again.
+
+- Whenever `open` is called, `close` will also be called (unless the JVM exits due to some error). This is true even if `open` returns false. If there is any error in processing and writing the data, `close` will be called with the error. It is your responsibility to clean up state (e.g. connections, transactions, etc.) that have been created in `open` such that there are no resource leaks.
+
+## Managing Streaming Queries
+The `StreamingQuery` object created when a query is started can be used to monitor and manage the query.
+
+
+
+
+{% highlight scala %}
+val query = df.writeStream.format("console").start() // get the query object
+
+query.id // get the unique identifier of the running query
+
+query.name // get the name of the auto-generated or user-specified name
+
+query.explain() // print detailed explanations of the query
+
+query.stop() // stop the query
+
+query.awaitTermination() // block until query is terminated, with stop() or with error
+
+query.exception() // the exception if the query has been terminated with error
+
+query.souceStatus() // progress information about data has been read from the input sources
+
+query.sinkStatus() // progress information about data written to the output sink
+{% endhighlight %}
+
+
+
+
+
+{% highlight java %}
+StreamingQuery query = df.writeStream().format("console").start(); // get the query object
+
+query.id(); // get the unique identifier of the running query
+
+query.name(); // get the name of the auto-generated or user-specified name
+
+query.explain(); // print detailed explanations of the query
+
+query.stop(); // stop the query
+
+query.awaitTermination(); // block until query is terminated, with stop() or with error
+
+query.exception(); // the exception if the query has been terminated with error
+
+query.souceStatus(); // progress information about data has been read from the input sources
+
+query.sinkStatus(); // progress information about data written to the output sink
+
+{% endhighlight %}
+
+
+
+
+{% highlight python %}
+query = df.writeStream().format("console").start() # get the query object
+
+query.id() # get the unique identifier of the running query
+
+query.name() # get the name of the auto-generated or user-specified name
+
+query.explain() # print detailed explanations of the query
+
+query.stop() # stop the query
+
+query.awaitTermination() # block until query is terminated, with stop() or with error
+
+query.exception() # the exception if the query has been terminated with error
+
+query.souceStatus() # progress information about data has been read from the input sources
+
+query.sinkStatus() # progress information about data written to the output sink
+
+{% endhighlight %}
+
+
+
+
+You can start any number of queries in a single SparkSession. They will all be running concurrently sharing the cluster resources. You can use `sparkSession.streams()` to get the `StreamingQueryManager` (
+[Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryManager)/
+[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryManager.html)/
+[Python](api/python/pyspark.sql.html#pyspark.sql.streaming.StreamingQueryManager) docs) that can be used to manage the currently active queries.
+
+
+
+
+{% highlight scala %}
+val spark: SparkSession = …
+
+spark.streams.active // get the list of currently active streaming queries
+
+spark.streams.get(id) // get a query object by its unique id
+
+spark.streams.awaitAnyTermination() // block until any one of them terminates
+{% endhighlight %}
+
+
+
+
+{% highlight java %}
+SparkSession spark = ...
+
+spark.streams().active() // get the list of currently active streaming queries
+
+spark.streams().get(id) // get a query object by its unique id
+
+spark.streams().awaitAnyTermination() // block until any one of them terminates
+{% endhighlight %}
+
+
+
+
+{% highlight python %}
+spark = ... # spark session
+
+spark.streams().active # get the list of currently active streaming queries
+
+spark.streams().get(id) # get a query object by its unique id
+
+spark.streams().awaitAnyTermination() # block until any one of them terminates
+{% endhighlight %}
+
+
+
+
+Finally, for asynchronous monitoring of streaming queries, you can create and attach a `StreamingQueryListener` (
+[Scala](api/scala/index.html#org.apache.spark.sql.streaming.StreamingQueryListener)/
+[Java](api/java/org/apache/spark/sql/streaming/StreamingQueryListener.html) docs), which will give you regular callback-based updates when queries are started and terminated.
+
+## Recovering from Failures with Checkpointing
+In case of a failure or intentional shutdown, you can recover the previous progress and state of a previous query, and continue where it left off. This is done using checkpointing and write ahead logs. You can configure a query with a checkpoint location, and the query will save all the progress information (i.e. range of offsets processed in each trigger), and the running aggregates (e.g. word counts in the quick example) will be saved the checkpoint location. As of Spark 2.0, this checkpoint location has to be a path in a HDFS compatible file system, and can be set as an option in the DataStreamWriter when [starting a query](#starting-streaming-queries).
+
+
+
+
+{% highlight scala %}
+aggDF
+ .writeStream
+ .outputMode("complete")
+ .option(“checkpointLocation”, “path/to/HDFS/dir”)
+ .format("memory")
+ .start()
+{% endhighlight %}
+
+
+
+
+{% highlight java %}
+aggDF
+ .writeStream()
+ .outputMode("complete")
+ .option(“checkpointLocation”, “path/to/HDFS/dir”)
+ .format("memory")
+ .start();
+{% endhighlight %}
+
+
+
+
+{% highlight python %}
+aggDF\
+ .writeStream()\
+ .outputMode("complete")\
+ .option(“checkpointLocation”, “path/to/HDFS/dir”)\
+ .format("memory")\
+ .start()
+{% endhighlight %}
+
+
+
+
+# Where to go from here
+- Examples: See and run the
+[Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/sql/streaming)/[Java]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/sql/streaming)/[Python]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/python/sql/streaming)
+examples.
+- Spark Summit 2016 Talk - [A Deep Dive into Structured Streaming](https://spark-summit.org/2016/events/a-deep-dive-into-structured-streaming/)
+
+
+
+
+
+
+
+
+
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
index e512979ac71b..7fc6c007b684 100644
--- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
+++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSparkSQL.java
@@ -26,7 +26,9 @@
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+// $example on:init_session$
import org.apache.spark.sql.SparkSession;
+// $example off:init_session$
public class JavaSparkSQL {
public static class Person implements Serializable {
@@ -51,10 +53,13 @@ public void setAge(int age) {
}
public static void main(String[] args) throws Exception {
+ // $example on:init_session$
SparkSession spark = SparkSession
.builder()
.appName("JavaSparkSQL")
+ .config("spark.some.config.option", "some-value")
.getOrCreate();
+ // $example off:init_session$
System.out.println("=== Data source: RDD ===");
// Load a text file and convert each line to a Java Bean.
diff --git a/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java
new file mode 100644
index 000000000000..a2cf9389543e
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/sql/streaming/JavaStructuredNetworkWordCount.java
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.examples.sql.streaming;
+
+import org.apache.spark.api.java.function.FlatMapFunction;
+import org.apache.spark.sql.*;
+import org.apache.spark.sql.streaming.StreamingQuery;
+
+import java.util.Arrays;
+import java.util.Iterator;
+
+/**
+ * Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ *
+ * Usage: JavaStructuredNetworkWordCount
+ * and describe the TCP server that Structured Streaming
+ * would connect to receive data.
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ * `$ nc -lk 9999`
+ * and then run the example
+ * `$ bin/run-example sql.streaming.JavaStructuredNetworkWordCount
+ * localhost 9999`
+ */
+public final class JavaStructuredNetworkWordCount {
+
+ public static void main(String[] args) throws Exception {
+ if (args.length < 2) {
+ System.err.println("Usage: JavaNetworkWordCount ");
+ System.exit(1);
+ }
+
+ String host = args[0];
+ int port = Integer.parseInt(args[1]);
+
+ SparkSession spark = SparkSession
+ .builder()
+ .appName("JavaStructuredNetworkWordCount")
+ .getOrCreate();
+
+ // Create DataFrame representing the stream of input lines from connection to host:port
+ Dataset lines = spark
+ .readStream()
+ .format("socket")
+ .option("host", host)
+ .option("port", port)
+ .load().as(Encoders.STRING());
+
+ // Split the lines into words
+ Dataset words = lines.flatMap(new FlatMapFunction() {
+ @Override
+ public Iterator call(String x) {
+ return Arrays.asList(x.split(" ")).iterator();
+ }
+ }, Encoders.STRING());
+
+ // Generate running word count
+ Dataset wordCounts = words.groupBy("value").count();
+
+ // Start running the query that prints the running counts to the console
+ StreamingQuery query = wordCounts.writeStream()
+ .outputMode("complete")
+ .format("console")
+ .start();
+
+ query.awaitTermination();
+ }
+}
diff --git a/examples/src/main/python/ml/decision_tree_regression_example.py b/examples/src/main/python/ml/decision_tree_regression_example.py
index b734d4974a4f..58d7ad921d8e 100644
--- a/examples/src/main/python/ml/decision_tree_regression_example.py
+++ b/examples/src/main/python/ml/decision_tree_regression_example.py
@@ -31,7 +31,7 @@
if __name__ == "__main__":
spark = SparkSession\
.builder\
- .appName("decision_tree_classification_example")\
+ .appName("DecisionTreeRegressionExample")\
.getOrCreate()
# $example on$
diff --git a/examples/src/main/python/ml/elementwise_product_example.py b/examples/src/main/python/ml/elementwise_product_example.py
index 598deae886ee..590053998bcc 100644
--- a/examples/src/main/python/ml/elementwise_product_example.py
+++ b/examples/src/main/python/ml/elementwise_product_example.py
@@ -30,10 +30,12 @@
.getOrCreate()
# $example on$
+ # Create some vector data; also works for sparse vectors
data = [(Vectors.dense([1.0, 2.0, 3.0]),), (Vectors.dense([4.0, 5.0, 6.0]),)]
df = spark.createDataFrame(data, ["vector"])
transformer = ElementwiseProduct(scalingVec=Vectors.dense([0.0, 1.0, 2.0]),
inputCol="vector", outputCol="transformedVector")
+ # Batch transform the vectors to create new column:
transformer.transform(df).show()
# $example off$
diff --git a/examples/src/main/python/ml/lda_example.py b/examples/src/main/python/ml/lda_example.py
index 6ca56adf3cb1..5ce810fccc6f 100644
--- a/examples/src/main/python/ml/lda_example.py
+++ b/examples/src/main/python/ml/lda_example.py
@@ -35,7 +35,7 @@
# Creates a SparkSession
spark = SparkSession \
.builder \
- .appName("PythonKMeansExample") \
+ .appName("LDAExample") \
.getOrCreate()
# $example on$
diff --git a/examples/src/main/python/ml/polynomial_expansion_example.py b/examples/src/main/python/ml/polynomial_expansion_example.py
index 9475e33218cf..b46c1ba2f439 100644
--- a/examples/src/main/python/ml/polynomial_expansion_example.py
+++ b/examples/src/main/python/ml/polynomial_expansion_example.py
@@ -35,7 +35,7 @@
(Vectors.dense([0.0, 0.0]),),
(Vectors.dense([0.6, -1.1]),)],
["features"])
- px = PolynomialExpansion(degree=2, inputCol="features", outputCol="polyFeatures")
+ px = PolynomialExpansion(degree=3, inputCol="features", outputCol="polyFeatures")
polyDF = px.transform(df)
for expanded in polyDF.select("polyFeatures").take(3):
print(expanded)
diff --git a/examples/src/main/python/ml/quantile_discretizer_example.py b/examples/src/main/python/ml/quantile_discretizer_example.py
index 5444cacd957f..6f422f840ad2 100644
--- a/examples/src/main/python/ml/quantile_discretizer_example.py
+++ b/examples/src/main/python/ml/quantile_discretizer_example.py
@@ -24,7 +24,7 @@
if __name__ == "__main__":
- spark = SparkSession.builder.appName("PythonQuantileDiscretizerExample").getOrCreate()
+ spark = SparkSession.builder.appName("QuantileDiscretizerExample").getOrCreate()
# $example on$
data = [(0, 18.0,), (1, 19.0,), (2, 8.0,), (3, 5.0,), (4, 2.2,)]
diff --git a/examples/src/main/python/ml/random_forest_classifier_example.py b/examples/src/main/python/ml/random_forest_classifier_example.py
index a7fc765318b9..eb9ded9af555 100644
--- a/examples/src/main/python/ml/random_forest_classifier_example.py
+++ b/examples/src/main/python/ml/random_forest_classifier_example.py
@@ -50,7 +50,7 @@
(trainingData, testData) = data.randomSplit([0.7, 0.3])
# Train a RandomForest model.
- rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures")
+ rf = RandomForestClassifier(labelCol="indexedLabel", featuresCol="indexedFeatures", numTrees=10)
# Chain indexers and forest in a Pipeline
pipeline = Pipeline(stages=[labelIndexer, featureIndexer, rf])
diff --git a/examples/src/main/python/ml/simple_params_example.py b/examples/src/main/python/ml/simple_params_example.py
index 54fbc2c9d05d..2f1eaa6f947f 100644
--- a/examples/src/main/python/ml/simple_params_example.py
+++ b/examples/src/main/python/ml/simple_params_example.py
@@ -33,7 +33,7 @@
if __name__ == "__main__":
spark = SparkSession \
.builder \
- .appName("SimpleTextClassificationPipeline") \
+ .appName("SimpleParamsExample") \
.getOrCreate()
# prepare training data.
diff --git a/examples/src/main/python/ml/simple_text_classification_pipeline.py b/examples/src/main/python/ml/simple_text_classification_pipeline.py
index 886f43c0b08e..b528b59be962 100644
--- a/examples/src/main/python/ml/simple_text_classification_pipeline.py
+++ b/examples/src/main/python/ml/simple_text_classification_pipeline.py
@@ -48,7 +48,7 @@
# Configure an ML pipeline, which consists of tree stages: tokenizer, hashingTF, and lr.
tokenizer = Tokenizer(inputCol="text", outputCol="words")
- hashingTF = HashingTF(inputCol=tokenizer.getOutputCol(), outputCol="features")
+ hashingTF = HashingTF(numFeatures=1000, inputCol=tokenizer.getOutputCol(), outputCol="features")
lr = LogisticRegression(maxIter=10, regParam=0.001)
pipeline = Pipeline(stages=[tokenizer, hashingTF, lr])
diff --git a/examples/src/main/python/sql.py b/examples/src/main/python/sql.py
index ac7246938d3b..ea11d2c4c7b3 100644
--- a/examples/src/main/python/sql.py
+++ b/examples/src/main/python/sql.py
@@ -20,15 +20,20 @@
import os
import sys
+# $example on:init_session$
from pyspark.sql import SparkSession
+# $example off:init_session$
from pyspark.sql.types import Row, StructField, StructType, StringType, IntegerType
if __name__ == "__main__":
+ # $example on:init_session$
spark = SparkSession\
.builder\
.appName("PythonSQL")\
+ .config("spark.some.config.option", "some-value")\
.getOrCreate()
+ # $example off:init_session$
# A list of Rows. Infer schema from the first row, create a DataFrame and print the schema
rows = [Row(name="John", age=19), Row(name="Smith", age=23), Row(name="Sarah", age=18)]
diff --git a/examples/src/main/python/sql/streaming/structured_network_wordcount.py b/examples/src/main/python/sql/streaming/structured_network_wordcount.py
new file mode 100644
index 000000000000..32d63c52c919
--- /dev/null
+++ b/examples/src/main/python/sql/streaming/structured_network_wordcount.py
@@ -0,0 +1,76 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+ Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ Usage: structured_network_wordcount.py
+ and describe the TCP server that Structured Streaming
+ would connect to receive data.
+
+ To run this on your local machine, you need to first run a Netcat server
+ `$ nc -lk 9999`
+ and then run the example
+ `$ bin/spark-submit examples/src/main/python/sql/streaming/structured_network_wordcount.py
+ localhost 9999`
+"""
+from __future__ import print_function
+
+import sys
+
+from pyspark.sql import SparkSession
+from pyspark.sql.functions import explode
+from pyspark.sql.functions import split
+
+if __name__ == "__main__":
+ if len(sys.argv) != 3:
+ print("Usage: structured_network_wordcount.py ", file=sys.stderr)
+ exit(-1)
+
+ host = sys.argv[1]
+ port = int(sys.argv[2])
+
+ spark = SparkSession\
+ .builder\
+ .appName("StructuredNetworkWordCount")\
+ .getOrCreate()
+
+ # Create DataFrame representing the stream of input lines from connection to host:port
+ lines = spark\
+ .readStream\
+ .format('socket')\
+ .option('host', host)\
+ .option('port', port)\
+ .load()
+
+ # Split the lines into words
+ words = lines.select(
+ explode(
+ split(lines.value, ' ')
+ ).alias('word')
+ )
+
+ # Generate running word count
+ wordCounts = words.groupBy('word').count()
+
+ # Start running the query that prints the running counts to the console
+ query = wordCounts\
+ .writeStream\
+ .outputMode('complete')\
+ .format('console')\
+ .start()
+
+ query.awaitTermination()
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/AggregateMessagesExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/AggregateMessagesExample.scala
new file mode 100644
index 000000000000..8f8262db374b
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/AggregateMessagesExample.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.graphx
+
+// $example on$
+import org.apache.spark.graphx.{Graph, VertexRDD}
+import org.apache.spark.graphx.util.GraphGenerators
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+/**
+ * An example use the [`aggregateMessages`][Graph.aggregateMessages] operator to
+ * compute the average age of the more senior followers of each user
+ * Run with
+ * {{{
+ * bin/run-example graphx.AggregateMessagesExample
+ * }}}
+ */
+object AggregateMessagesExample {
+
+ def main(args: Array[String]): Unit = {
+ // Creates a SparkSession.
+ val spark = SparkSession
+ .builder
+ .appName(s"${this.getClass.getSimpleName}")
+ .getOrCreate()
+ val sc = spark.sparkContext
+
+ // $example on$
+ // Create a graph with "age" as the vertex property.
+ // Here we use a random graph for simplicity.
+ val graph: Graph[Double, Int] =
+ GraphGenerators.logNormalGraph(sc, numVertices = 100).mapVertices( (id, _) => id.toDouble )
+ // Compute the number of older followers and their total age
+ val olderFollowers: VertexRDD[(Int, Double)] = graph.aggregateMessages[(Int, Double)](
+ triplet => { // Map Function
+ if (triplet.srcAttr > triplet.dstAttr) {
+ // Send message to destination vertex containing counter and age
+ triplet.sendToDst(1, triplet.srcAttr)
+ }
+ },
+ // Add counter and age
+ (a, b) => (a._1 + b._1, a._2 + b._2) // Reduce Function
+ )
+ // Divide total age by number of older followers to get average age of older followers
+ val avgAgeOfOlderFollowers: VertexRDD[Double] =
+ olderFollowers.mapValues( (id, value) =>
+ value match { case (count, totalAge) => totalAge / count } )
+ // Display the results
+ avgAgeOfOlderFollowers.collect.foreach(println(_))
+ // $example off$
+
+ spark.stop()
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/ComprehensiveExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/ComprehensiveExample.scala
new file mode 100644
index 000000000000..6598863bd2ea
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/ComprehensiveExample.scala
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.graphx
+
+// $example on$
+import org.apache.spark.graphx.GraphLoader
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Suppose I want to build a graph from some text files, restrict the graph
+ * to important relationships and users, run page-rank on the sub-graph, and
+ * then finally return attributes associated with the top users.
+ * This example do all of this in just a few lines with GraphX.
+ *
+ * Run with
+ * {{{
+ * bin/run-example graphx.ComprehensiveExample
+ * }}}
+ */
+object ComprehensiveExample {
+
+ def main(args: Array[String]): Unit = {
+ // Creates a SparkSession.
+ val spark = SparkSession
+ .builder
+ .appName(s"${this.getClass.getSimpleName}")
+ .getOrCreate()
+ val sc = spark.sparkContext
+
+ // $example on$
+ // Load my user data and parse into tuples of user id and attribute list
+ val users = (sc.textFile("data/graphx/users.txt")
+ .map(line => line.split(",")).map( parts => (parts.head.toLong, parts.tail) ))
+
+ // Parse the edge data which is already in userId -> userId format
+ val followerGraph = GraphLoader.edgeListFile(sc, "data/graphx/followers.txt")
+
+ // Attach the user attributes
+ val graph = followerGraph.outerJoinVertices(users) {
+ case (uid, deg, Some(attrList)) => attrList
+ // Some users may not have attributes so we set them as empty
+ case (uid, deg, None) => Array.empty[String]
+ }
+
+ // Restrict the graph to users with usernames and names
+ val subgraph = graph.subgraph(vpred = (vid, attr) => attr.size == 2)
+
+ // Compute the PageRank
+ val pagerankGraph = subgraph.pageRank(0.001)
+
+ // Get the attributes of the top pagerank users
+ val userInfoWithPageRank = subgraph.outerJoinVertices(pagerankGraph.vertices) {
+ case (uid, attrList, Some(pr)) => (pr, attrList.toList)
+ case (uid, attrList, None) => (0.0, attrList.toList)
+ }
+
+ println(userInfoWithPageRank.vertices.top(5)(Ordering.by(_._2._1)).mkString("\n"))
+ // $example off$
+
+ spark.stop()
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/ConnectedComponentsExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/ConnectedComponentsExample.scala
new file mode 100644
index 000000000000..5377ddb3594b
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/ConnectedComponentsExample.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.graphx
+
+// $example on$
+import org.apache.spark.graphx.GraphLoader
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+/**
+ * A connected components algorithm example.
+ * The connected components algorithm labels each connected component of the graph
+ * with the ID of its lowest-numbered vertex.
+ * For example, in a social network, connected components can approximate clusters.
+ * GraphX contains an implementation of the algorithm in the
+ * [`ConnectedComponents` object][ConnectedComponents],
+ * and we compute the connected components of the example social network dataset.
+ *
+ * Run with
+ * {{{
+ * bin/run-example graphx.ConnectedComponentsExample
+ * }}}
+ */
+object ConnectedComponentsExample {
+ def main(args: Array[String]): Unit = {
+ // Creates a SparkSession.
+ val spark = SparkSession
+ .builder
+ .appName(s"${this.getClass.getSimpleName}")
+ .getOrCreate()
+ val sc = spark.sparkContext
+
+ // $example on$
+ // Load the graph as in the PageRank example
+ val graph = GraphLoader.edgeListFile(sc, "data/graphx/followers.txt")
+ // Find the connected components
+ val cc = graph.connectedComponents().vertices
+ // Join the connected components with the usernames
+ val users = sc.textFile("data/graphx/users.txt").map { line =>
+ val fields = line.split(",")
+ (fields(0).toLong, fields(1))
+ }
+ val ccByUsername = users.join(cc).map {
+ case (id, (username, cc)) => (username, cc)
+ }
+ // Print the result
+ println(ccByUsername.collect().mkString("\n"))
+ // $example off$
+ spark.stop()
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/PageRankExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/PageRankExample.scala
new file mode 100644
index 000000000000..9e9affca07a1
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/PageRankExample.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.graphx
+
+// $example on$
+import org.apache.spark.graphx.GraphLoader
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+/**
+ * A PageRank example on social network dataset
+ * Run with
+ * {{{
+ * bin/run-example graphx.PageRankExample
+ * }}}
+ */
+object PageRankExample {
+ def main(args: Array[String]): Unit = {
+ // Creates a SparkSession.
+ val spark = SparkSession
+ .builder
+ .appName(s"${this.getClass.getSimpleName}")
+ .getOrCreate()
+ val sc = spark.sparkContext
+
+ // $example on$
+ // Load the edges as a graph
+ val graph = GraphLoader.edgeListFile(sc, "data/graphx/followers.txt")
+ // Run PageRank
+ val ranks = graph.pageRank(0.0001).vertices
+ // Join the ranks with the usernames
+ val users = sc.textFile("data/graphx/users.txt").map { line =>
+ val fields = line.split(",")
+ (fields(0).toLong, fields(1))
+ }
+ val ranksByUsername = users.join(ranks).map {
+ case (id, (username, rank)) => (username, rank)
+ }
+ // Print the result
+ println(ranksByUsername.collect().mkString("\n"))
+ // $example off$
+ spark.stop()
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SSSPExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SSSPExample.scala
new file mode 100644
index 000000000000..5e8b19671de7
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SSSPExample.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.graphx
+
+// $example on$
+import org.apache.spark.graphx.{Graph, VertexId}
+import org.apache.spark.graphx.util.GraphGenerators
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+/**
+ * An example use the Pregel operator to express computation
+ * such as single source shortest path
+ * Run with
+ * {{{
+ * bin/run-example graphx.SSSPExample
+ * }}}
+ */
+object SSSPExample {
+ def main(args: Array[String]): Unit = {
+ // Creates a SparkSession.
+ val spark = SparkSession
+ .builder
+ .appName(s"${this.getClass.getSimpleName}")
+ .getOrCreate()
+ val sc = spark.sparkContext
+
+ // $example on$
+ // A graph with edge attributes containing distances
+ val graph: Graph[Long, Double] =
+ GraphGenerators.logNormalGraph(sc, numVertices = 100).mapEdges(e => e.attr.toDouble)
+ val sourceId: VertexId = 42 // The ultimate source
+ // Initialize the graph such that all vertices except the root have distance infinity.
+ val initialGraph = graph.mapVertices((id, _) =>
+ if (id == sourceId) 0.0 else Double.PositiveInfinity)
+ val sssp = initialGraph.pregel(Double.PositiveInfinity)(
+ (id, dist, newDist) => math.min(dist, newDist), // Vertex Program
+ triplet => { // Send Message
+ if (triplet.srcAttr + triplet.attr < triplet.dstAttr) {
+ Iterator((triplet.dstId, triplet.srcAttr + triplet.attr))
+ } else {
+ Iterator.empty
+ }
+ },
+ (a, b) => math.min(a, b) // Merge Message
+ )
+ println(sssp.vertices.collect.mkString("\n"))
+ // $example off$
+
+ spark.stop()
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/TriangleCountingExample.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/TriangleCountingExample.scala
new file mode 100644
index 000000000000..b9bff69086cc
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/graphx/TriangleCountingExample.scala
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.graphx
+
+// $example on$
+import org.apache.spark.graphx.{GraphLoader, PartitionStrategy}
+// $example off$
+import org.apache.spark.sql.SparkSession
+
+/**
+ * A vertex is part of a triangle when it has two adjacent vertices with an edge between them.
+ * GraphX implements a triangle counting algorithm in the [`TriangleCount` object][TriangleCount]
+ * that determines the number of triangles passing through each vertex,
+ * providing a measure of clustering.
+ * We compute the triangle count of the social network dataset.
+ *
+ * Note that `TriangleCount` requires the edges to be in canonical orientation (`srcId < dstId`)
+ * and the graph to be partitioned using [`Graph.partitionBy`][Graph.partitionBy].
+ *
+ * Run with
+ * {{{
+ * bin/run-example graphx.TriangleCountingExample
+ * }}}
+ */
+object TriangleCountingExample {
+ def main(args: Array[String]): Unit = {
+ // Creates a SparkSession.
+ val spark = SparkSession
+ .builder
+ .appName(s"${this.getClass.getSimpleName}")
+ .getOrCreate()
+ val sc = spark.sparkContext
+
+ // $example on$
+ // Load the edges in canonical order and partition the graph for triangle count
+ val graph = GraphLoader.edgeListFile(sc, "data/graphx/followers.txt", true)
+ .partitionBy(PartitionStrategy.RandomVertexCut)
+ // Find the triangle count for each vertex
+ val triCounts = graph.triangleCount().vertices
+ // Join the triangle counts with the usernames
+ val users = sc.textFile("data/graphx/users.txt").map { line =>
+ val fields = line.split(",")
+ (fields(0).toLong, fields(1))
+ }
+ val triCountByUsername = users.join(triCounts).map { case (id, (username, tc)) =>
+ (username, tc)
+ }
+ // Print the result
+ println(triCountByUsername.collect().mkString("\n"))
+ // $example off$
+ spark.stop()
+ }
+}
+// scalastyle:on println
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala
index 51aa5179fa4a..988d8941a4ce 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala
@@ -27,7 +27,7 @@ object CountVectorizerExample {
def main(args: Array[String]) {
val spark = SparkSession
.builder
- .appName("CounterVectorizerExample")
+ .appName("CountVectorizerExample")
.getOrCreate()
// $example on$
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
index 11faa6192b3f..38c1c1c1865b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala
@@ -20,7 +20,6 @@ package org.apache.spark.examples.ml
import java.io.File
-import com.google.common.io.Files
import scopt.OptionParser
import org.apache.spark.examples.mllib.AbstractParams
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala
index c484ee55569b..2c2bf421bc5d 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GaussianMixtureExample.scala
@@ -21,8 +21,8 @@ package org.apache.spark.examples.ml
// $example on$
import org.apache.spark.ml.clustering.GaussianMixture
-import org.apache.spark.sql.SparkSession
// $example off$
+import org.apache.spark.sql.SparkSession
/**
* An example demonstrating Gaussian Mixture Model (GMM).
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala
index a59ba182fc20..7089a4bc87aa 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala
@@ -35,7 +35,7 @@ object NaiveBayesExample {
val data = spark.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt")
// Split the data into training and test sets (30% held out for testing)
- val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3))
+ val Array(trainingData, testData) = data.randomSplit(Array(0.7, 0.3), seed = 1234L)
// Train a NaiveBayes model.
val model = new NaiveBayes()
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
index 1b019fbb5177..deaa9f252b9b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/RDDRelation.scala
@@ -18,7 +18,10 @@
// scalastyle:off println
package org.apache.spark.examples.sql
-import org.apache.spark.sql.{SaveMode, SparkSession}
+import org.apache.spark.sql.SaveMode
+// $example on:init_session$
+import org.apache.spark.sql.SparkSession
+// $example off:init_session$
// One method for defining the schema of an RDD is to make a case class with the desired column
// names and types.
@@ -26,13 +29,16 @@ case class Record(key: Int, value: String)
object RDDRelation {
def main(args: Array[String]) {
+ // $example on:init_session$
val spark = SparkSession
.builder
- .appName("RDDRelation")
+ .appName("Spark Examples")
+ .config("spark.some.config.option", "some-value")
.getOrCreate()
// Importing the SparkSession gives access to all the SQL functions and implicit conversions.
import spark.implicits._
+ // $example off:init_session$
val df = spark.createDataFrame((1 to 100).map(i => Record(i, s"val_$i")))
// Any RDD containing case classes can be used to create a temporary view. The schema of the
diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala
new file mode 100644
index 000000000000..433f7a181bbf
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/sql/streaming/StructuredNetworkWordCount.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.sql.streaming
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.SparkSession
+
+/**
+ * Counts words in UTF8 encoded, '\n' delimited text received from the network every second.
+ *
+ * Usage: StructuredNetworkWordCount
+ * and describe the TCP server that Structured Streaming
+ * would connect to receive data.
+ *
+ * To run this on your local machine, you need to first run a Netcat server
+ * `$ nc -lk 9999`
+ * and then run the example
+ * `$ bin/run-example sql.streaming.StructuredNetworkWordCount
+ * localhost 9999`
+ */
+object StructuredNetworkWordCount {
+ def main(args: Array[String]) {
+ if (args.length < 2) {
+ System.err.println("Usage: StructuredNetworkWordCount ")
+ System.exit(1)
+ }
+
+ val host = args(0)
+ val port = args(1).toInt
+
+ val spark = SparkSession
+ .builder
+ .appName("StructuredNetworkWordCount")
+ .getOrCreate()
+
+ import spark.implicits._
+
+ // Create DataFrame representing the stream of input lines from connection to host:port
+ val lines = spark.readStream
+ .format("socket")
+ .option("host", host)
+ .option("port", port)
+ .load().as[String]
+
+ // Split the lines into words
+ val words = lines.flatMap(_.split(" "))
+
+ // Generate running word count
+ val wordCounts = words.groupBy("value").count()
+
+ // Start running the query that prints the running counts to the console
+ val query = wordCounts.writeStream
+ .outputMode("complete")
+ .format("console")
+ .start()
+
+ query.awaitTermination()
+ }
+}
+// scalastyle:on println
diff --git a/external/kafka-0-10-assembly/pom.xml b/external/kafka-0-10-assembly/pom.xml
new file mode 100644
index 000000000000..59f41f1e17f3
--- /dev/null
+++ b/external/kafka-0-10-assembly/pom.xml
@@ -0,0 +1,176 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.11
+ 2.0.1-SNAPSHOT
+ ../../pom.xml
+
+
+ org.apache.spark
+ spark-streaming-kafka-0-10-assembly_2.11
+ jar
+ Spark Integration for Kafka 0.10 Assembly
+ http://spark.apache.org/
+
+
+ streaming-kafka-0-10-assembly
+
+
+
+
+ org.apache.spark
+ spark-streaming-kafka-0-10_${scala.binary.version}
+ ${project.version}
+
+
+ org.apache.spark
+ spark-streaming_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+
+ commons-codec
+ commons-codec
+ provided
+
+
+ commons-lang
+ commons-lang
+ provided
+
+
+ com.google.protobuf
+ protobuf-java
+ provided
+
+
+ net.jpountz.lz4
+ lz4
+ provided
+
+
+ org.apache.hadoop
+ hadoop-client
+ provided
+
+
+ org.apache.avro
+ avro-mapred
+ ${avro.mapred.classifier}
+ provided
+
+
+ org.apache.curator
+ curator-recipes
+ provided
+
+
+ org.apache.zookeeper
+ zookeeper
+ provided
+
+
+ log4j
+ log4j
+ provided
+
+
+ net.java.dev.jets3t
+ jets3t
+ provided
+
+
+ org.scala-lang
+ scala-library
+ provided
+
+
+ org.slf4j
+ slf4j-api
+ provided
+
+
+ org.slf4j
+ slf4j-log4j12
+ provided
+
+
+ org.xerial.snappy
+ snappy-java
+ provided
+
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
+ org.apache.maven.plugins
+ maven-shade-plugin
+
+ false
+
+
+ *:*
+
+
+
+
+ *:*
+
+ META-INF/*.SF
+ META-INF/*.DSA
+ META-INF/*.RSA
+
+
+
+
+
+
+ package
+
+ shade
+
+
+
+
+
+ reference.conf
+
+
+ log4j.properties
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/external/kafka-0-10/pom.xml b/external/kafka-0-10/pom.xml
new file mode 100644
index 000000000000..26965612cc0a
--- /dev/null
+++ b/external/kafka-0-10/pom.xml
@@ -0,0 +1,98 @@
+
+
+
+
+ 4.0.0
+
+ org.apache.spark
+ spark-parent_2.11
+ 2.0.1-SNAPSHOT
+ ../../pom.xml
+
+
+ org.apache.spark
+ spark-streaming-kafka-0-10_2.11
+
+ streaming-kafka-0-10
+
+ jar
+ Spark Integration for Kafka 0.10
+ http://spark.apache.org/
+
+
+
+ org.apache.spark
+ spark-streaming_${scala.binary.version}
+ ${project.version}
+ provided
+
+
+ org.apache.spark
+ spark-core_${scala.binary.version}
+ ${project.version}
+ test-jar
+ test
+
+
+ org.apache.kafka
+ kafka_${scala.binary.version}
+ 0.10.0.0
+
+
+ com.sun.jmx
+ jmxri
+
+
+ com.sun.jdmk
+ jmxtools
+
+
+ net.sf.jopt-simple
+ jopt-simple
+
+
+ org.slf4j
+ slf4j-simple
+
+
+ org.apache.zookeeper
+ zookeeper
+
+
+
+
+ net.sf.jopt-simple
+ jopt-simple
+ 3.2
+ test
+
+
+ org.scalacheck
+ scalacheck_${scala.binary.version}
+ test
+
+
+ org.apache.spark
+ spark-tags_${scala.binary.version}
+
+
+
+ target/scala-${scala.binary.version}/classes
+ target/scala-${scala.binary.version}/test-classes
+
+
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
new file mode 100644
index 000000000000..fa3ea6131a50
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/CachedKafkaConsumer.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{ util => ju }
+
+import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord, KafkaConsumer }
+import org.apache.kafka.common.{ KafkaException, TopicPartition }
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+
+
+/**
+ * Consumer of single topicpartition, intended for cached reuse.
+ * Underlying consumer is not threadsafe, so neither is this,
+ * but processing the same topicpartition and group id in multiple threads is usually bad anyway.
+ */
+private[kafka010]
+class CachedKafkaConsumer[K, V] private(
+ val groupId: String,
+ val topic: String,
+ val partition: Int,
+ val kafkaParams: ju.Map[String, Object]) extends Logging {
+
+ assert(groupId == kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG),
+ "groupId used for cache key must match the groupId in kafkaParams")
+
+ val topicPartition = new TopicPartition(topic, partition)
+
+ protected val consumer = {
+ val c = new KafkaConsumer[K, V](kafkaParams)
+ val tps = new ju.ArrayList[TopicPartition]()
+ tps.add(topicPartition)
+ c.assign(tps)
+ c
+ }
+
+ // TODO if the buffer was kept around as a random-access structure,
+ // could possibly optimize re-calculating of an RDD in the same batch
+ protected var buffer = ju.Collections.emptyList[ConsumerRecord[K, V]]().iterator
+ protected var nextOffset = -2L
+
+ def close(): Unit = consumer.close()
+
+ /**
+ * Get the record for the given offset, waiting up to timeout ms if IO is necessary.
+ * Sequential forward access will use buffers, but random access will be horribly inefficient.
+ */
+ def get(offset: Long, timeout: Long): ConsumerRecord[K, V] = {
+ logDebug(s"Get $groupId $topic $partition nextOffset $nextOffset requested $offset")
+ if (offset != nextOffset) {
+ logInfo(s"Initial fetch for $groupId $topic $partition $offset")
+ seek(offset)
+ poll(timeout)
+ }
+
+ if (!buffer.hasNext()) { poll(timeout) }
+ assert(buffer.hasNext(),
+ s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout")
+ var record = buffer.next()
+
+ if (record.offset != offset) {
+ logInfo(s"Buffer miss for $groupId $topic $partition $offset")
+ seek(offset)
+ poll(timeout)
+ assert(buffer.hasNext(),
+ s"Failed to get records for $groupId $topic $partition $offset after polling for $timeout")
+ record = buffer.next()
+ assert(record.offset == offset,
+ s"Got wrong record for $groupId $topic $partition even after seeking to offset $offset")
+ }
+
+ nextOffset = offset + 1
+ record
+ }
+
+ private def seek(offset: Long): Unit = {
+ logDebug(s"Seeking to $topicPartition $offset")
+ consumer.seek(topicPartition, offset)
+ }
+
+ private def poll(timeout: Long): Unit = {
+ val p = consumer.poll(timeout)
+ val r = p.records(topicPartition)
+ logDebug(s"Polled ${p.partitions()} ${r.size}")
+ buffer = r.iterator
+ }
+
+}
+
+private[kafka010]
+object CachedKafkaConsumer extends Logging {
+
+ private case class CacheKey(groupId: String, topic: String, partition: Int)
+
+ // Don't want to depend on guava, don't want a cleanup thread, use a simple LinkedHashMap
+ private var cache: ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]] = null
+
+ /** Must be called before get, once per JVM, to configure the cache. Further calls are ignored */
+ def init(
+ initialCapacity: Int,
+ maxCapacity: Int,
+ loadFactor: Float): Unit = CachedKafkaConsumer.synchronized {
+ if (null == cache) {
+ logInfo(s"Initializing cache $initialCapacity $maxCapacity $loadFactor")
+ cache = new ju.LinkedHashMap[CacheKey, CachedKafkaConsumer[_, _]](
+ initialCapacity, loadFactor, true) {
+ override def removeEldestEntry(
+ entry: ju.Map.Entry[CacheKey, CachedKafkaConsumer[_, _]]): Boolean = {
+ if (this.size > maxCapacity) {
+ try {
+ entry.getValue.consumer.close()
+ } catch {
+ case x: KafkaException =>
+ logError("Error closing oldest Kafka consumer", x)
+ }
+ true
+ } else {
+ false
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Get a cached consumer for groupId, assigned to topic and partition.
+ * If matching consumer doesn't already exist, will be created using kafkaParams.
+ */
+ def get[K, V](
+ groupId: String,
+ topic: String,
+ partition: Int,
+ kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] =
+ CachedKafkaConsumer.synchronized {
+ val k = CacheKey(groupId, topic, partition)
+ val v = cache.get(k)
+ if (null == v) {
+ logInfo(s"Cache miss for $k")
+ logDebug(cache.keySet.toString)
+ val c = new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams)
+ cache.put(k, c)
+ c
+ } else {
+ // any given topicpartition should have a consistent key and value type
+ v.asInstanceOf[CachedKafkaConsumer[K, V]]
+ }
+ }
+
+ /**
+ * Get a fresh new instance, unassociated with the global cache.
+ * Caller is responsible for closing
+ */
+ def getUncached[K, V](
+ groupId: String,
+ topic: String,
+ partition: Int,
+ kafkaParams: ju.Map[String, Object]): CachedKafkaConsumer[K, V] =
+ new CachedKafkaConsumer[K, V](groupId, topic, partition, kafkaParams)
+
+ /** remove consumer for given groupId, topic, and partition, if it exists */
+ def remove(groupId: String, topic: String, partition: Int): Unit = {
+ val k = CacheKey(groupId, topic, partition)
+ logInfo(s"Removing $k from cache")
+ val v = CachedKafkaConsumer.synchronized {
+ cache.remove(k)
+ }
+ if (null != v) {
+ v.close()
+ logInfo(s"Removed $k from cache")
+ }
+ }
+}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala
new file mode 100644
index 000000000000..60255fc655e5
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/ConsumerStrategy.scala
@@ -0,0 +1,473 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{ lang => jl, util => ju }
+
+import scala.collection.JavaConverters._
+
+import org.apache.kafka.clients.consumer._
+import org.apache.kafka.clients.consumer.internals.NoOpConsumerRebalanceListener
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.internal.Logging
+
+/**
+ * :: Experimental ::
+ * Choice of how to create and configure underlying Kafka Consumers on driver and executors.
+ * See [[ConsumerStrategies]] to obtain instances.
+ * Kafka 0.10 consumers can require additional, sometimes complex, setup after object
+ * instantiation. This interface encapsulates that process, and allows it to be checkpointed.
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ */
+@Experimental
+abstract class ConsumerStrategy[K, V] {
+ /**
+ * Kafka
+ * configuration parameters to be used on executors. Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ */
+ def executorKafkaParams: ju.Map[String, Object]
+
+ /**
+ * Must return a fully configured Kafka Consumer, including subscribed or assigned topics.
+ * See Kafka docs.
+ * This consumer will be used on the driver to query for offsets only, not messages.
+ * The consumer must be returned in a state that it is safe to call poll(0) on.
+ * @param currentOffsets A map from TopicPartition to offset, indicating how far the driver
+ * has successfully read. Will be empty on initial start, possibly non-empty on restart from
+ * checkpoint.
+ */
+ def onStart(currentOffsets: ju.Map[TopicPartition, jl.Long]): Consumer[K, V]
+}
+
+/**
+ * Subscribe to a collection of topics.
+ * @param topics collection of topics to subscribe
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsets: offsets to begin at on initial startup. If no offset is given for a
+ * TopicPartition, the committed offset (if applicable) or kafka param
+ * auto.offset.reset will be used.
+ */
+private case class Subscribe[K, V](
+ topics: ju.Collection[jl.String],
+ kafkaParams: ju.Map[String, Object],
+ offsets: ju.Map[TopicPartition, jl.Long]
+ ) extends ConsumerStrategy[K, V] with Logging {
+
+ def executorKafkaParams: ju.Map[String, Object] = kafkaParams
+
+ def onStart(currentOffsets: ju.Map[TopicPartition, jl.Long]): Consumer[K, V] = {
+ val consumer = new KafkaConsumer[K, V](kafkaParams)
+ consumer.subscribe(topics)
+ val toSeek = if (currentOffsets.isEmpty) {
+ offsets
+ } else {
+ currentOffsets
+ }
+ if (!toSeek.isEmpty) {
+ // work around KAFKA-3370 when reset is none
+ // poll will throw if no position, i.e. auto offset reset none and no explicit position
+ // but cant seek to a position before poll, because poll is what gets subscription partitions
+ // So, poll, suppress the first exception, then seek
+ val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)
+ val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE"
+ try {
+ consumer.poll(0)
+ } catch {
+ case x: NoOffsetForPartitionException if shouldSuppress =>
+ logWarning("Catching NoOffsetForPartitionException since " +
+ ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + " is none. See KAFKA-3370")
+ }
+ toSeek.asScala.foreach { case (topicPartition, offset) =>
+ consumer.seek(topicPartition, offset)
+ }
+ }
+
+ consumer
+ }
+}
+
+/**
+ * Subscribe to all topics matching specified pattern to get dynamically assigned partitions.
+ * The pattern matching will be done periodically against topics existing at the time of check.
+ * @param pattern pattern to subscribe to
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsets: offsets to begin at on initial startup. If no offset is given for a
+ * TopicPartition, the committed offset (if applicable) or kafka param
+ * auto.offset.reset will be used.
+ */
+private case class SubscribePattern[K, V](
+ pattern: ju.regex.Pattern,
+ kafkaParams: ju.Map[String, Object],
+ offsets: ju.Map[TopicPartition, jl.Long]
+ ) extends ConsumerStrategy[K, V] with Logging {
+
+ def executorKafkaParams: ju.Map[String, Object] = kafkaParams
+
+ def onStart(currentOffsets: ju.Map[TopicPartition, jl.Long]): Consumer[K, V] = {
+ val consumer = new KafkaConsumer[K, V](kafkaParams)
+ consumer.subscribe(pattern, new NoOpConsumerRebalanceListener())
+ val toSeek = if (currentOffsets.isEmpty) {
+ offsets
+ } else {
+ currentOffsets
+ }
+ if (!toSeek.isEmpty) {
+ // work around KAFKA-3370 when reset is none, see explanation in Subscribe above
+ val aor = kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG)
+ val shouldSuppress = aor != null && aor.asInstanceOf[String].toUpperCase == "NONE"
+ try {
+ consumer.poll(0)
+ } catch {
+ case x: NoOffsetForPartitionException if shouldSuppress =>
+ logWarning("Catching NoOffsetForPartitionException since " +
+ ConsumerConfig.AUTO_OFFSET_RESET_CONFIG + " is none. See KAFKA-3370")
+ }
+ toSeek.asScala.foreach { case (topicPartition, offset) =>
+ consumer.seek(topicPartition, offset)
+ }
+ }
+
+ consumer
+ }
+}
+
+/**
+ * Assign a fixed collection of TopicPartitions
+ * @param topicPartitions collection of TopicPartitions to assign
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsets: offsets to begin at on initial startup. If no offset is given for a
+ * TopicPartition, the committed offset (if applicable) or kafka param
+ * auto.offset.reset will be used.
+ */
+private case class Assign[K, V](
+ topicPartitions: ju.Collection[TopicPartition],
+ kafkaParams: ju.Map[String, Object],
+ offsets: ju.Map[TopicPartition, jl.Long]
+ ) extends ConsumerStrategy[K, V] {
+
+ def executorKafkaParams: ju.Map[String, Object] = kafkaParams
+
+ def onStart(currentOffsets: ju.Map[TopicPartition, jl.Long]): Consumer[K, V] = {
+ val consumer = new KafkaConsumer[K, V](kafkaParams)
+ consumer.assign(topicPartitions)
+ val toSeek = if (currentOffsets.isEmpty) {
+ offsets
+ } else {
+ currentOffsets
+ }
+ if (!toSeek.isEmpty) {
+ // this doesn't need a KAFKA-3370 workaround, because partitions are known, no poll needed
+ toSeek.asScala.foreach { case (topicPartition, offset) =>
+ consumer.seek(topicPartition, offset)
+ }
+ }
+
+ consumer
+ }
+}
+
+/**
+ * :: Experimental ::
+ * object for obtaining instances of [[ConsumerStrategy]]
+ */
+@Experimental
+object ConsumerStrategies {
+ /**
+ * :: Experimental ::
+ * Subscribe to a collection of topics.
+ * @param topics collection of topics to subscribe
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsets: offsets to begin at on initial startup. If no offset is given for a
+ * TopicPartition, the committed offset (if applicable) or kafka param
+ * auto.offset.reset will be used.
+ */
+ @Experimental
+ def Subscribe[K, V](
+ topics: Iterable[jl.String],
+ kafkaParams: collection.Map[String, Object],
+ offsets: collection.Map[TopicPartition, Long]): ConsumerStrategy[K, V] = {
+ new Subscribe[K, V](
+ new ju.ArrayList(topics.asJavaCollection),
+ new ju.HashMap[String, Object](kafkaParams.asJava),
+ new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava))
+ }
+
+ /**
+ * :: Experimental ::
+ * Subscribe to a collection of topics.
+ * @param topics collection of topics to subscribe
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ */
+ @Experimental
+ def Subscribe[K, V](
+ topics: Iterable[jl.String],
+ kafkaParams: collection.Map[String, Object]): ConsumerStrategy[K, V] = {
+ new Subscribe[K, V](
+ new ju.ArrayList(topics.asJavaCollection),
+ new ju.HashMap[String, Object](kafkaParams.asJava),
+ ju.Collections.emptyMap[TopicPartition, jl.Long]())
+ }
+
+ /**
+ * :: Experimental ::
+ * Subscribe to a collection of topics.
+ * @param topics collection of topics to subscribe
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsets: offsets to begin at on initial startup. If no offset is given for a
+ * TopicPartition, the committed offset (if applicable) or kafka param
+ * auto.offset.reset will be used.
+ */
+ @Experimental
+ def Subscribe[K, V](
+ topics: ju.Collection[jl.String],
+ kafkaParams: ju.Map[String, Object],
+ offsets: ju.Map[TopicPartition, jl.Long]): ConsumerStrategy[K, V] = {
+ new Subscribe[K, V](topics, kafkaParams, offsets)
+ }
+
+ /**
+ * :: Experimental ::
+ * Subscribe to a collection of topics.
+ * @param topics collection of topics to subscribe
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ */
+ @Experimental
+ def Subscribe[K, V](
+ topics: ju.Collection[jl.String],
+ kafkaParams: ju.Map[String, Object]): ConsumerStrategy[K, V] = {
+ new Subscribe[K, V](topics, kafkaParams, ju.Collections.emptyMap[TopicPartition, jl.Long]())
+ }
+
+ /** :: Experimental ::
+ * Subscribe to all topics matching specified pattern to get dynamically assigned partitions.
+ * The pattern matching will be done periodically against topics existing at the time of check.
+ * @param pattern pattern to subscribe to
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsets: offsets to begin at on initial startup. If no offset is given for a
+ * TopicPartition, the committed offset (if applicable) or kafka param
+ * auto.offset.reset will be used.
+ */
+ @Experimental
+ def SubscribePattern[K, V](
+ pattern: ju.regex.Pattern,
+ kafkaParams: collection.Map[String, Object],
+ offsets: collection.Map[TopicPartition, Long]): ConsumerStrategy[K, V] = {
+ new SubscribePattern[K, V](
+ pattern,
+ new ju.HashMap[String, Object](kafkaParams.asJava),
+ new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava))
+ }
+
+ /** :: Experimental ::
+ * Subscribe to all topics matching specified pattern to get dynamically assigned partitions.
+ * The pattern matching will be done periodically against topics existing at the time of check.
+ * @param pattern pattern to subscribe to
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ */
+ @Experimental
+ def SubscribePattern[K, V](
+ pattern: ju.regex.Pattern,
+ kafkaParams: collection.Map[String, Object]): ConsumerStrategy[K, V] = {
+ new SubscribePattern[K, V](
+ pattern,
+ new ju.HashMap[String, Object](kafkaParams.asJava),
+ ju.Collections.emptyMap[TopicPartition, jl.Long]())
+ }
+
+ /** :: Experimental ::
+ * Subscribe to all topics matching specified pattern to get dynamically assigned partitions.
+ * The pattern matching will be done periodically against topics existing at the time of check.
+ * @param pattern pattern to subscribe to
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsets: offsets to begin at on initial startup. If no offset is given for a
+ * TopicPartition, the committed offset (if applicable) or kafka param
+ * auto.offset.reset will be used.
+ */
+ @Experimental
+ def SubscribePattern[K, V](
+ pattern: ju.regex.Pattern,
+ kafkaParams: ju.Map[String, Object],
+ offsets: ju.Map[TopicPartition, jl.Long]): ConsumerStrategy[K, V] = {
+ new SubscribePattern[K, V](pattern, kafkaParams, offsets)
+ }
+
+ /** :: Experimental ::
+ * Subscribe to all topics matching specified pattern to get dynamically assigned partitions.
+ * The pattern matching will be done periodically against topics existing at the time of check.
+ * @param pattern pattern to subscribe to
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ */
+ @Experimental
+ def SubscribePattern[K, V](
+ pattern: ju.regex.Pattern,
+ kafkaParams: ju.Map[String, Object]): ConsumerStrategy[K, V] = {
+ new SubscribePattern[K, V](
+ pattern,
+ kafkaParams,
+ ju.Collections.emptyMap[TopicPartition, jl.Long]())
+ }
+
+ /**
+ * :: Experimental ::
+ * Assign a fixed collection of TopicPartitions
+ * @param topicPartitions collection of TopicPartitions to assign
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsets: offsets to begin at on initial startup. If no offset is given for a
+ * TopicPartition, the committed offset (if applicable) or kafka param
+ * auto.offset.reset will be used.
+ */
+ @Experimental
+ def Assign[K, V](
+ topicPartitions: Iterable[TopicPartition],
+ kafkaParams: collection.Map[String, Object],
+ offsets: collection.Map[TopicPartition, Long]): ConsumerStrategy[K, V] = {
+ new Assign[K, V](
+ new ju.ArrayList(topicPartitions.asJavaCollection),
+ new ju.HashMap[String, Object](kafkaParams.asJava),
+ new ju.HashMap[TopicPartition, jl.Long](offsets.mapValues(l => new jl.Long(l)).asJava))
+ }
+
+ /**
+ * :: Experimental ::
+ * Assign a fixed collection of TopicPartitions
+ * @param topicPartitions collection of TopicPartitions to assign
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ */
+ @Experimental
+ def Assign[K, V](
+ topicPartitions: Iterable[TopicPartition],
+ kafkaParams: collection.Map[String, Object]): ConsumerStrategy[K, V] = {
+ new Assign[K, V](
+ new ju.ArrayList(topicPartitions.asJavaCollection),
+ new ju.HashMap[String, Object](kafkaParams.asJava),
+ ju.Collections.emptyMap[TopicPartition, jl.Long]())
+ }
+
+ /**
+ * :: Experimental ::
+ * Assign a fixed collection of TopicPartitions
+ * @param topicPartitions collection of TopicPartitions to assign
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsets: offsets to begin at on initial startup. If no offset is given for a
+ * TopicPartition, the committed offset (if applicable) or kafka param
+ * auto.offset.reset will be used.
+ */
+ @Experimental
+ def Assign[K, V](
+ topicPartitions: ju.Collection[TopicPartition],
+ kafkaParams: ju.Map[String, Object],
+ offsets: ju.Map[TopicPartition, jl.Long]): ConsumerStrategy[K, V] = {
+ new Assign[K, V](topicPartitions, kafkaParams, offsets)
+ }
+
+ /**
+ * :: Experimental ::
+ * Assign a fixed collection of TopicPartitions
+ * @param topicPartitions collection of TopicPartitions to assign
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters to be used on driver. The same params will be used on executors,
+ * with minor automatic modifications applied.
+ * Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ */
+ @Experimental
+ def Assign[K, V](
+ topicPartitions: ju.Collection[TopicPartition],
+ kafkaParams: ju.Map[String, Object]): ConsumerStrategy[K, V] = {
+ new Assign[K, V](
+ topicPartitions,
+ kafkaParams,
+ ju.Collections.emptyMap[TopicPartition, jl.Long]())
+ }
+
+}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
new file mode 100644
index 000000000000..13827f68f2cb
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/DirectKafkaInputDStream.scala
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{ util => ju }
+import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.concurrent.atomic.AtomicReference
+
+import scala.annotation.tailrec
+import scala.collection.JavaConverters._
+import scala.collection.mutable
+
+import org.apache.kafka.clients.consumer._
+import org.apache.kafka.common.{ PartitionInfo, TopicPartition }
+
+import org.apache.spark.SparkException
+import org.apache.spark.internal.Logging
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.streaming.{StreamingContext, Time}
+import org.apache.spark.streaming.dstream._
+import org.apache.spark.streaming.scheduler.{RateController, StreamInputInfo}
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+
+/**
+ * A DStream where
+ * each given Kafka topic/partition corresponds to an RDD partition.
+ * The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number
+ * of messages
+ * per second that each '''partition''' will accept.
+ * @param locationStrategy In most cases, pass in [[PreferConsistent]],
+ * see [[LocationStrategy]] for more details.
+ * @param executorKafkaParams Kafka
+ *
+ * configuration parameters.
+ * Requires "bootstrap.servers" to be set with Kafka broker(s),
+ * NOT zookeeper servers, specified in host1:port1,host2:port2 form.
+ * @param consumerStrategy In most cases, pass in [[Subscribe]],
+ * see [[ConsumerStrategy]] for more details
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ */
+private[spark] class DirectKafkaInputDStream[K, V](
+ _ssc: StreamingContext,
+ locationStrategy: LocationStrategy,
+ consumerStrategy: ConsumerStrategy[K, V]
+ ) extends InputDStream[ConsumerRecord[K, V]](_ssc) with Logging with CanCommitOffsets {
+
+ val executorKafkaParams = {
+ val ekp = new ju.HashMap[String, Object](consumerStrategy.executorKafkaParams)
+ KafkaUtils.fixKafkaParams(ekp)
+ ekp
+ }
+
+ protected var currentOffsets = Map[TopicPartition, Long]()
+
+ @transient private var kc: Consumer[K, V] = null
+ def consumer(): Consumer[K, V] = this.synchronized {
+ if (null == kc) {
+ kc = consumerStrategy.onStart(currentOffsets.mapValues(l => new java.lang.Long(l)).asJava)
+ }
+ kc
+ }
+
+ override def persist(newLevel: StorageLevel): DStream[ConsumerRecord[K, V]] = {
+ logError("Kafka ConsumerRecord is not serializable. " +
+ "Use .map to extract fields before calling .persist or .window")
+ super.persist(newLevel)
+ }
+
+ protected def getBrokers = {
+ val c = consumer
+ val result = new ju.HashMap[TopicPartition, String]()
+ val hosts = new ju.HashMap[TopicPartition, String]()
+ val assignments = c.assignment().iterator()
+ while (assignments.hasNext()) {
+ val tp: TopicPartition = assignments.next()
+ if (null == hosts.get(tp)) {
+ val infos = c.partitionsFor(tp.topic).iterator()
+ while (infos.hasNext()) {
+ val i = infos.next()
+ hosts.put(new TopicPartition(i.topic(), i.partition()), i.leader.host())
+ }
+ }
+ result.put(tp, hosts.get(tp))
+ }
+ result
+ }
+
+ protected def getPreferredHosts: ju.Map[TopicPartition, String] = {
+ locationStrategy match {
+ case PreferBrokers => getBrokers
+ case PreferConsistent => ju.Collections.emptyMap[TopicPartition, String]()
+ case PreferFixed(hostMap) => hostMap
+ }
+ }
+
+ // Keep this consistent with how other streams are named (e.g. "Flume polling stream [2]")
+ private[streaming] override def name: String = s"Kafka 0.10 direct stream [$id]"
+
+ protected[streaming] override val checkpointData =
+ new DirectKafkaInputDStreamCheckpointData
+
+
+ /**
+ * Asynchronously maintains & sends new rate limits to the receiver through the receiver tracker.
+ */
+ override protected[streaming] val rateController: Option[RateController] = {
+ if (RateController.isBackPressureEnabled(ssc.conf)) {
+ Some(new DirectKafkaRateController(id,
+ RateEstimator.create(ssc.conf, context.graph.batchDuration)))
+ } else {
+ None
+ }
+ }
+
+ private val maxRateLimitPerPartition: Int = context.sparkContext.getConf.getInt(
+ "spark.streaming.kafka.maxRatePerPartition", 0)
+
+ protected[streaming] def maxMessagesPerPartition(
+ offsets: Map[TopicPartition, Long]): Option[Map[TopicPartition, Long]] = {
+ val estimatedRateLimit = rateController.map(_.getLatestRate().toInt)
+
+ // calculate a per-partition rate limit based on current lag
+ val effectiveRateLimitPerPartition = estimatedRateLimit.filter(_ > 0) match {
+ case Some(rate) =>
+ val lagPerPartition = offsets.map { case (tp, offset) =>
+ tp -> Math.max(offset - currentOffsets(tp), 0)
+ }
+ val totalLag = lagPerPartition.values.sum
+
+ lagPerPartition.map { case (tp, lag) =>
+ val backpressureRate = Math.round(lag / totalLag.toFloat * rate)
+ tp -> (if (maxRateLimitPerPartition > 0) {
+ Math.min(backpressureRate, maxRateLimitPerPartition)} else backpressureRate)
+ }
+ case None => offsets.map { case (tp, offset) => tp -> maxRateLimitPerPartition }
+ }
+
+ if (effectiveRateLimitPerPartition.values.sum > 0) {
+ val secsPerBatch = context.graph.batchDuration.milliseconds.toDouble / 1000
+ Some(effectiveRateLimitPerPartition.map {
+ case (tp, limit) => tp -> (secsPerBatch * limit).toLong
+ })
+ } else {
+ None
+ }
+ }
+
+ /**
+ * Returns the latest (highest) available offsets, taking new partitions into account.
+ */
+ protected def latestOffsets(): Map[TopicPartition, Long] = {
+ val c = consumer
+ c.poll(0)
+ val parts = c.assignment().asScala
+
+ // make sure new partitions are reflected in currentOffsets
+ val newPartitions = parts.diff(currentOffsets.keySet)
+ // position for new partitions determined by auto.offset.reset if no commit
+ currentOffsets = currentOffsets ++ newPartitions.map(tp => tp -> c.position(tp)).toMap
+ // don't want to consume messages, so pause
+ c.pause(newPartitions.asJava)
+ // find latest available offsets
+ c.seekToEnd(currentOffsets.keySet.asJava)
+ parts.map(tp => tp -> c.position(tp)).toMap
+ }
+
+ // limits the maximum number of messages per partition
+ protected def clamp(
+ offsets: Map[TopicPartition, Long]): Map[TopicPartition, Long] = {
+
+ maxMessagesPerPartition(offsets).map { mmp =>
+ mmp.map { case (tp, messages) =>
+ val uo = offsets(tp)
+ tp -> Math.min(currentOffsets(tp) + messages, uo)
+ }
+ }.getOrElse(offsets)
+ }
+
+ override def compute(validTime: Time): Option[KafkaRDD[K, V]] = {
+ val untilOffsets = clamp(latestOffsets())
+ val offsetRanges = untilOffsets.map { case (tp, uo) =>
+ val fo = currentOffsets(tp)
+ OffsetRange(tp.topic, tp.partition, fo, uo)
+ }
+ val rdd = new KafkaRDD[K, V](
+ context.sparkContext, executorKafkaParams, offsetRanges.toArray, getPreferredHosts, true)
+
+ // Report the record number and metadata of this batch interval to InputInfoTracker.
+ val description = offsetRanges.filter { offsetRange =>
+ // Don't display empty ranges.
+ offsetRange.fromOffset != offsetRange.untilOffset
+ }.map { offsetRange =>
+ s"topic: ${offsetRange.topic}\tpartition: ${offsetRange.partition}\t" +
+ s"offsets: ${offsetRange.fromOffset} to ${offsetRange.untilOffset}"
+ }.mkString("\n")
+ // Copy offsetRanges to immutable.List to prevent from being modified by the user
+ val metadata = Map(
+ "offsets" -> offsetRanges.toList,
+ StreamInputInfo.METADATA_KEY_DESCRIPTION -> description)
+ val inputInfo = StreamInputInfo(id, rdd.count, metadata)
+ ssc.scheduler.inputInfoTracker.reportInfo(validTime, inputInfo)
+
+ currentOffsets = untilOffsets
+ commitAll()
+ Some(rdd)
+ }
+
+ override def start(): Unit = {
+ val c = consumer
+ c.poll(0)
+ if (currentOffsets.isEmpty) {
+ currentOffsets = c.assignment().asScala.map { tp =>
+ tp -> c.position(tp)
+ }.toMap
+ }
+
+ // don't actually want to consume any messages, so pause all partitions
+ c.pause(currentOffsets.keySet.asJava)
+ }
+
+ override def stop(): Unit = this.synchronized {
+ if (kc != null) {
+ kc.close()
+ }
+ }
+
+ protected val commitQueue = new ConcurrentLinkedQueue[OffsetRange]
+ protected val commitCallback = new AtomicReference[OffsetCommitCallback]
+
+ /**
+ * Queue up offset ranges for commit to Kafka at a future time. Threadsafe.
+ * @param offsetRanges The maximum untilOffset for a given partition will be used at commit.
+ */
+ def commitAsync(offsetRanges: Array[OffsetRange]): Unit = {
+ commitAsync(offsetRanges, null)
+ }
+
+ /**
+ * Queue up offset ranges for commit to Kafka at a future time. Threadsafe.
+ * @param offsetRanges The maximum untilOffset for a given partition will be used at commit.
+ * @param callback Only the most recently provided callback will be used at commit.
+ */
+ def commitAsync(offsetRanges: Array[OffsetRange], callback: OffsetCommitCallback): Unit = {
+ commitCallback.set(callback)
+ commitQueue.addAll(ju.Arrays.asList(offsetRanges: _*))
+ }
+
+ protected def commitAll(): Unit = {
+ val m = new ju.HashMap[TopicPartition, OffsetAndMetadata]()
+ val it = commitQueue.iterator()
+ while (it.hasNext) {
+ val osr = it.next
+ val tp = osr.topicPartition
+ val x = m.get(tp)
+ val offset = if (null == x) { osr.untilOffset } else { Math.max(x.offset, osr.untilOffset) }
+ m.put(tp, new OffsetAndMetadata(offset))
+ }
+ if (!m.isEmpty) {
+ consumer.commitAsync(m, commitCallback.get)
+ }
+ }
+
+ private[streaming]
+ class DirectKafkaInputDStreamCheckpointData extends DStreamCheckpointData(this) {
+ def batchForTime: mutable.HashMap[Time, Array[(String, Int, Long, Long)]] = {
+ data.asInstanceOf[mutable.HashMap[Time, Array[OffsetRange.OffsetRangeTuple]]]
+ }
+
+ override def update(time: Time): Unit = {
+ batchForTime.clear()
+ generatedRDDs.foreach { kv =>
+ val a = kv._2.asInstanceOf[KafkaRDD[K, V]].offsetRanges.map(_.toTuple).toArray
+ batchForTime += kv._1 -> a
+ }
+ }
+
+ override def cleanup(time: Time): Unit = { }
+
+ override def restore(): Unit = {
+ batchForTime.toSeq.sortBy(_._1)(Time.ordering).foreach { case (t, b) =>
+ logInfo(s"Restoring KafkaRDD for time $t ${b.mkString("[", ", ", "]")}")
+ generatedRDDs += t -> new KafkaRDD[K, V](
+ context.sparkContext,
+ executorKafkaParams,
+ b.map(OffsetRange(_)),
+ getPreferredHosts,
+ // during restore, it's possible same partition will be consumed from multiple
+ // threads, so dont use cache
+ false
+ )
+ }
+ }
+ }
+
+ /**
+ * A RateController to retrieve the rate from RateEstimator.
+ */
+ private[streaming] class DirectKafkaRateController(id: Int, estimator: RateEstimator)
+ extends RateController(id, estimator) {
+ override def publish(rate: Long): Unit = ()
+ }
+}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
new file mode 100644
index 000000000000..5b5a9ac48c7c
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDD.scala
@@ -0,0 +1,232 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{ util => ju }
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.kafka.clients.consumer.{ ConsumerConfig, ConsumerRecord }
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.{Partition, SparkContext, SparkException, TaskContext}
+import org.apache.spark.internal.Logging
+import org.apache.spark.partial.{BoundedDouble, PartialResult}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * A batch-oriented interface for consuming from Kafka.
+ * Starting and ending offsets are specified in advance,
+ * so that you can control exactly-once semantics.
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters. Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD
+ * @param preferredHosts map from TopicPartition to preferred host for processing that partition.
+ * In most cases, use [[DirectKafkaInputDStream.preferConsistent]]
+ * Use [[DirectKafkaInputDStream.preferBrokers]] if your executors are on same nodes as brokers.
+ * @param useConsumerCache whether to use a consumer from a per-jvm cache
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ */
+private[spark] class KafkaRDD[K, V](
+ sc: SparkContext,
+ val kafkaParams: ju.Map[String, Object],
+ val offsetRanges: Array[OffsetRange],
+ val preferredHosts: ju.Map[TopicPartition, String],
+ useConsumerCache: Boolean
+) extends RDD[ConsumerRecord[K, V]](sc, Nil) with Logging with HasOffsetRanges {
+
+ assert("none" ==
+ kafkaParams.get(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG).asInstanceOf[String],
+ ConsumerConfig.AUTO_OFFSET_RESET_CONFIG +
+ " must be set to none for executor kafka params, else messages may not match offsetRange")
+
+ assert(false ==
+ kafkaParams.get(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG).asInstanceOf[Boolean],
+ ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG +
+ " must be set to false for executor kafka params, else offsets may commit before processing")
+
+ // TODO is it necessary to have separate configs for initial poll time vs ongoing poll time?
+ private val pollTimeout = conf.getLong("spark.streaming.kafka.consumer.poll.ms", 512)
+ private val cacheInitialCapacity =
+ conf.getInt("spark.streaming.kafka.consumer.cache.initialCapacity", 16)
+ private val cacheMaxCapacity =
+ conf.getInt("spark.streaming.kafka.consumer.cache.maxCapacity", 64)
+ private val cacheLoadFactor =
+ conf.getDouble("spark.streaming.kafka.consumer.cache.loadFactor", 0.75).toFloat
+
+ override def persist(newLevel: StorageLevel): this.type = {
+ logError("Kafka ConsumerRecord is not serializable. " +
+ "Use .map to extract fields before calling .persist or .window")
+ super.persist(newLevel)
+ }
+
+ override def getPartitions: Array[Partition] = {
+ offsetRanges.zipWithIndex.map { case (o, i) =>
+ new KafkaRDDPartition(i, o.topic, o.partition, o.fromOffset, o.untilOffset)
+ }.toArray
+ }
+
+ override def count(): Long = offsetRanges.map(_.count).sum
+
+ override def countApprox(
+ timeout: Long,
+ confidence: Double = 0.95
+ ): PartialResult[BoundedDouble] = {
+ val c = count
+ new PartialResult(new BoundedDouble(c, 1.0, c, c), true)
+ }
+
+ override def isEmpty(): Boolean = count == 0L
+
+ override def take(num: Int): Array[ConsumerRecord[K, V]] = {
+ val nonEmptyPartitions = this.partitions
+ .map(_.asInstanceOf[KafkaRDDPartition])
+ .filter(_.count > 0)
+
+ if (num < 1 || nonEmptyPartitions.isEmpty) {
+ return new Array[ConsumerRecord[K, V]](0)
+ }
+
+ // Determine in advance how many messages need to be taken from each partition
+ val parts = nonEmptyPartitions.foldLeft(Map[Int, Int]()) { (result, part) =>
+ val remain = num - result.values.sum
+ if (remain > 0) {
+ val taken = Math.min(remain, part.count)
+ result + (part.index -> taken.toInt)
+ } else {
+ result
+ }
+ }
+
+ val buf = new ArrayBuffer[ConsumerRecord[K, V]]
+ val res = context.runJob(
+ this,
+ (tc: TaskContext, it: Iterator[ConsumerRecord[K, V]]) =>
+ it.take(parts(tc.partitionId)).toArray, parts.keys.toArray
+ )
+ res.foreach(buf ++= _)
+ buf.toArray
+ }
+
+ private def executors(): Array[ExecutorCacheTaskLocation] = {
+ val bm = sparkContext.env.blockManager
+ bm.master.getPeers(bm.blockManagerId).toArray
+ .map(x => ExecutorCacheTaskLocation(x.host, x.executorId))
+ .sortWith(compareExecutors)
+ }
+
+ protected[kafka010] def compareExecutors(
+ a: ExecutorCacheTaskLocation,
+ b: ExecutorCacheTaskLocation): Boolean =
+ if (a.host == b.host) {
+ a.executorId > b.executorId
+ } else {
+ a.host > b.host
+ }
+
+ /**
+ * Non-negative modulus, from java 8 math
+ */
+ private def floorMod(a: Int, b: Int): Int = ((a % b) + b) % b
+
+ override def getPreferredLocations(thePart: Partition): Seq[String] = {
+ // The intention is best-effort consistent executor for a given topicpartition,
+ // so that caching consumers can be effective.
+ // TODO what about hosts specified by ip vs name
+ val part = thePart.asInstanceOf[KafkaRDDPartition]
+ val allExecs = executors()
+ val tp = part.topicPartition
+ val prefHost = preferredHosts.get(tp)
+ val prefExecs = if (null == prefHost) allExecs else allExecs.filter(_.host == prefHost)
+ val execs = if (prefExecs.isEmpty) allExecs else prefExecs
+ if (execs.isEmpty) {
+ Seq()
+ } else {
+ // execs is sorted, tp.hashCode depends only on topic and partition, so consistent index
+ val index = this.floorMod(tp.hashCode, execs.length)
+ val chosen = execs(index)
+ Seq(chosen.toString)
+ }
+ }
+
+ private def errBeginAfterEnd(part: KafkaRDDPartition): String =
+ s"Beginning offset ${part.fromOffset} is after the ending offset ${part.untilOffset} " +
+ s"for topic ${part.topic} partition ${part.partition}. " +
+ "You either provided an invalid fromOffset, or the Kafka topic has been damaged"
+
+ override def compute(thePart: Partition, context: TaskContext): Iterator[ConsumerRecord[K, V]] = {
+ val part = thePart.asInstanceOf[KafkaRDDPartition]
+ assert(part.fromOffset <= part.untilOffset, errBeginAfterEnd(part))
+ if (part.fromOffset == part.untilOffset) {
+ logInfo(s"Beginning offset ${part.fromOffset} is the same as ending offset " +
+ s"skipping ${part.topic} ${part.partition}")
+ Iterator.empty
+ } else {
+ new KafkaRDDIterator(part, context)
+ }
+ }
+
+ /**
+ * An iterator that fetches messages directly from Kafka for the offsets in partition.
+ * Uses a cached consumer where possible to take advantage of prefetching
+ */
+ private class KafkaRDDIterator(
+ part: KafkaRDDPartition,
+ context: TaskContext) extends Iterator[ConsumerRecord[K, V]] {
+
+ logInfo(s"Computing topic ${part.topic}, partition ${part.partition} " +
+ s"offsets ${part.fromOffset} -> ${part.untilOffset}")
+
+ val groupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG).asInstanceOf[String]
+
+ context.addTaskCompletionListener{ context => closeIfNeeded() }
+
+ val consumer = if (useConsumerCache) {
+ CachedKafkaConsumer.init(cacheInitialCapacity, cacheMaxCapacity, cacheLoadFactor)
+ if (context.attemptNumber > 1) {
+ // just in case the prior attempt failures were cache related
+ CachedKafkaConsumer.remove(groupId, part.topic, part.partition)
+ }
+ CachedKafkaConsumer.get[K, V](groupId, part.topic, part.partition, kafkaParams)
+ } else {
+ CachedKafkaConsumer.getUncached[K, V](groupId, part.topic, part.partition, kafkaParams)
+ }
+
+ var requestOffset = part.fromOffset
+
+ def closeIfNeeded(): Unit = {
+ if (!useConsumerCache && consumer != null) {
+ consumer.close
+ }
+ }
+
+ override def hasNext(): Boolean = requestOffset < part.untilOffset
+
+ override def next(): ConsumerRecord[K, V] = {
+ assert(hasNext(), "Can't call getNext() once untilOffset has been reached")
+ val r = consumer.get(requestOffset, pollTimeout)
+ requestOffset += 1
+ r
+ }
+ }
+}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDDPartition.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDDPartition.scala
new file mode 100644
index 000000000000..95569b109f30
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaRDDPartition.scala
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.Partition
+
+
+/**
+ * @param topic kafka topic name
+ * @param partition kafka partition id
+ * @param fromOffset inclusive starting offset
+ * @param untilOffset exclusive ending offset
+ */
+private[kafka010]
+class KafkaRDDPartition(
+ val index: Int,
+ val topic: String,
+ val partition: Int,
+ val fromOffset: Long,
+ val untilOffset: Long
+) extends Partition {
+ /** Number of messages this partition refers to */
+ def count(): Long = untilOffset - fromOffset
+
+ /** Kafka TopicPartition object, for convenience */
+ def topicPartition(): TopicPartition = new TopicPartition(topic, partition)
+
+}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala
new file mode 100644
index 000000000000..19192e4b9594
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaTestUtils.scala
@@ -0,0 +1,278 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.io.File
+import java.lang.{Integer => JInt}
+import java.net.InetSocketAddress
+import java.util.{Map => JMap, Properties}
+import java.util.concurrent.TimeoutException
+
+import scala.annotation.tailrec
+import scala.collection.JavaConverters._
+import scala.language.postfixOps
+import scala.util.control.NonFatal
+
+import kafka.admin.AdminUtils
+import kafka.api.Request
+import kafka.producer.{KeyedMessage, Producer, ProducerConfig}
+import kafka.serializer.StringEncoder
+import kafka.server.{KafkaConfig, KafkaServer}
+import kafka.utils.ZkUtils
+import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer}
+
+import org.apache.spark.SparkConf
+import org.apache.spark.internal.Logging
+import org.apache.spark.streaming.Time
+import org.apache.spark.util.Utils
+
+/**
+ * This is a helper class for Kafka test suites. This has the functionality to set up
+ * and tear down local Kafka servers, and to push data using Kafka producers.
+ *
+ * The reason to put Kafka test utility class in src is to test Python related Kafka APIs.
+ */
+private[kafka010] class KafkaTestUtils extends Logging {
+
+ // Zookeeper related configurations
+ private val zkHost = "localhost"
+ private var zkPort: Int = 0
+ private val zkConnectionTimeout = 60000
+ private val zkSessionTimeout = 6000
+
+ private var zookeeper: EmbeddedZookeeper = _
+
+ private var zkUtils: ZkUtils = _
+
+ // Kafka broker related configurations
+ private val brokerHost = "localhost"
+ private var brokerPort = 0
+ private var brokerConf: KafkaConfig = _
+
+ // Kafka broker server
+ private var server: KafkaServer = _
+
+ // Kafka producer
+ private var producer: Producer[String, String] = _
+
+ // Flag to test whether the system is correctly started
+ private var zkReady = false
+ private var brokerReady = false
+
+ def zkAddress: String = {
+ assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper address")
+ s"$zkHost:$zkPort"
+ }
+
+ def brokerAddress: String = {
+ assert(brokerReady, "Kafka not setup yet or already torn down, cannot get broker address")
+ s"$brokerHost:$brokerPort"
+ }
+
+ def zookeeperClient: ZkUtils = {
+ assert(zkReady, "Zookeeper not setup yet or already torn down, cannot get zookeeper client")
+ Option(zkUtils).getOrElse(
+ throw new IllegalStateException("Zookeeper client is not yet initialized"))
+ }
+
+ // Set up the Embedded Zookeeper server and get the proper Zookeeper port
+ private def setupEmbeddedZookeeper(): Unit = {
+ // Zookeeper server startup
+ zookeeper = new EmbeddedZookeeper(s"$zkHost:$zkPort")
+ // Get the actual zookeeper binding port
+ zkPort = zookeeper.actualPort
+ zkUtils = ZkUtils(s"$zkHost:$zkPort", zkSessionTimeout, zkConnectionTimeout, false)
+ zkReady = true
+ }
+
+ // Set up the Embedded Kafka server
+ private def setupEmbeddedKafkaServer(): Unit = {
+ assert(zkReady, "Zookeeper should be set up beforehand")
+
+ // Kafka broker startup
+ Utils.startServiceOnPort(brokerPort, port => {
+ brokerPort = port
+ brokerConf = new KafkaConfig(brokerConfiguration, doLog = false)
+ server = new KafkaServer(brokerConf)
+ server.startup()
+ brokerPort = server.boundPort()
+ (server, brokerPort)
+ }, new SparkConf(), "KafkaBroker")
+
+ brokerReady = true
+ }
+
+ /** setup the whole embedded servers, including Zookeeper and Kafka brokers */
+ def setup(): Unit = {
+ setupEmbeddedZookeeper()
+ setupEmbeddedKafkaServer()
+ }
+
+ /** Teardown the whole servers, including Kafka broker and Zookeeper */
+ def teardown(): Unit = {
+ brokerReady = false
+ zkReady = false
+
+ if (producer != null) {
+ producer.close()
+ producer = null
+ }
+
+ if (server != null) {
+ server.shutdown()
+ server = null
+ }
+
+ brokerConf.logDirs.foreach { f => Utils.deleteRecursively(new File(f)) }
+
+ if (zkUtils != null) {
+ zkUtils.close()
+ zkUtils = null
+ }
+
+ if (zookeeper != null) {
+ zookeeper.shutdown()
+ zookeeper = null
+ }
+ }
+
+ /** Create a Kafka topic and wait until it is propagated to the whole cluster */
+ def createTopic(topic: String, partitions: Int): Unit = {
+ AdminUtils.createTopic(zkUtils, topic, partitions, 1)
+ // wait until metadata is propagated
+ (0 until partitions).foreach { p =>
+ waitUntilMetadataIsPropagated(topic, p)
+ }
+ }
+
+ /** Create a Kafka topic and wait until it is propagated to the whole cluster */
+ def createTopic(topic: String): Unit = {
+ createTopic(topic, 1)
+ }
+
+ /** Java-friendly function for sending messages to the Kafka broker */
+ def sendMessages(topic: String, messageToFreq: JMap[String, JInt]): Unit = {
+ sendMessages(topic, Map(messageToFreq.asScala.mapValues(_.intValue()).toSeq: _*))
+ }
+
+ /** Send the messages to the Kafka broker */
+ def sendMessages(topic: String, messageToFreq: Map[String, Int]): Unit = {
+ val messages = messageToFreq.flatMap { case (s, freq) => Seq.fill(freq)(s) }.toArray
+ sendMessages(topic, messages)
+ }
+
+ /** Send the array of messages to the Kafka broker */
+ def sendMessages(topic: String, messages: Array[String]): Unit = {
+ producer = new Producer[String, String](new ProducerConfig(producerConfiguration))
+ producer.send(messages.map { new KeyedMessage[String, String](topic, _ ) }: _*)
+ producer.close()
+ producer = null
+ }
+
+ private def brokerConfiguration: Properties = {
+ val props = new Properties()
+ props.put("broker.id", "0")
+ props.put("host.name", "localhost")
+ props.put("port", brokerPort.toString)
+ props.put("log.dir", Utils.createTempDir().getAbsolutePath)
+ props.put("zookeeper.connect", zkAddress)
+ props.put("log.flush.interval.messages", "1")
+ props.put("replica.socket.timeout.ms", "1500")
+ props
+ }
+
+ private def producerConfiguration: Properties = {
+ val props = new Properties()
+ props.put("metadata.broker.list", brokerAddress)
+ props.put("serializer.class", classOf[StringEncoder].getName)
+ // wait for all in-sync replicas to ack sends
+ props.put("request.required.acks", "-1")
+ props
+ }
+
+ // A simplified version of scalatest eventually, rewritten here to avoid adding extra test
+ // dependency
+ def eventually[T](timeout: Time, interval: Time)(func: => T): T = {
+ def makeAttempt(): Either[Throwable, T] = {
+ try {
+ Right(func)
+ } catch {
+ case e if NonFatal(e) => Left(e)
+ }
+ }
+
+ val startTime = System.currentTimeMillis()
+ @tailrec
+ def tryAgain(attempt: Int): T = {
+ makeAttempt() match {
+ case Right(result) => result
+ case Left(e) =>
+ val duration = System.currentTimeMillis() - startTime
+ if (duration < timeout.milliseconds) {
+ Thread.sleep(interval.milliseconds)
+ } else {
+ throw new TimeoutException(e.getMessage)
+ }
+
+ tryAgain(attempt + 1)
+ }
+ }
+
+ tryAgain(1)
+ }
+
+ private def waitUntilMetadataIsPropagated(topic: String, partition: Int): Unit = {
+ def isPropagated = server.apis.metadataCache.getPartitionInfo(topic, partition) match {
+ case Some(partitionState) =>
+ val leaderAndInSyncReplicas = partitionState.leaderIsrAndControllerEpoch.leaderAndIsr
+
+ zkUtils.getLeaderForPartition(topic, partition).isDefined &&
+ Request.isValidBrokerId(leaderAndInSyncReplicas.leader) &&
+ leaderAndInSyncReplicas.isr.size >= 1
+
+ case _ =>
+ false
+ }
+ eventually(Time(10000), Time(100)) {
+ assert(isPropagated, s"Partition [$topic, $partition] metadata not propagated after timeout")
+ }
+ }
+
+ private class EmbeddedZookeeper(val zkConnect: String) {
+ val snapshotDir = Utils.createTempDir()
+ val logDir = Utils.createTempDir()
+
+ val zookeeper = new ZooKeeperServer(snapshotDir, logDir, 500)
+ val (ip, port) = {
+ val splits = zkConnect.split(":")
+ (splits(0), splits(1).toInt)
+ }
+ val factory = new NIOServerCnxnFactory()
+ factory.configure(new InetSocketAddress(ip, port), 16)
+ factory.startup(zookeeper)
+
+ val actualPort = factory.getLocalPort
+
+ def shutdown() {
+ factory.shutdown()
+ Utils.deleteRecursively(snapshotDir)
+ Utils.deleteRecursively(logDir)
+ }
+ }
+}
+
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala
new file mode 100644
index 000000000000..b2190bfa05a3
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/KafkaUtils.scala
@@ -0,0 +1,179 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{ util => ju }
+
+import org.apache.kafka.clients.consumer._
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.{ JavaRDD, JavaSparkContext }
+import org.apache.spark.api.java.function.{ Function0 => JFunction0 }
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.StreamingContext
+import org.apache.spark.streaming.api.java.{ JavaInputDStream, JavaStreamingContext }
+import org.apache.spark.streaming.dstream._
+
+/**
+ * :: Experimental ::
+ * object for constructing Kafka streams and RDDs
+ */
+@Experimental
+object KafkaUtils extends Logging {
+ /**
+ * :: Experimental ::
+ * Scala constructor for a batch-oriented interface for consuming from Kafka.
+ * Starting and ending offsets are specified in advance,
+ * so that you can control exactly-once semantics.
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters. Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD
+ * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent,
+ * see [[LocationStrategies]] for more details.
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ */
+ @Experimental
+ def createRDD[K, V](
+ sc: SparkContext,
+ kafkaParams: ju.Map[String, Object],
+ offsetRanges: Array[OffsetRange],
+ locationStrategy: LocationStrategy
+ ): RDD[ConsumerRecord[K, V]] = {
+ val preferredHosts = locationStrategy match {
+ case PreferBrokers =>
+ throw new AssertionError(
+ "If you want to prefer brokers, you must provide a mapping using PreferFixed " +
+ "A single KafkaRDD does not have a driver consumer and cannot look up brokers for you.")
+ case PreferConsistent => ju.Collections.emptyMap[TopicPartition, String]()
+ case PreferFixed(hostMap) => hostMap
+ }
+ val kp = new ju.HashMap[String, Object](kafkaParams)
+ fixKafkaParams(kp)
+ val osr = offsetRanges.clone()
+
+ new KafkaRDD[K, V](sc, kp, osr, preferredHosts, true)
+ }
+
+ /**
+ * :: Experimental ::
+ * Java constructor for a batch-oriented interface for consuming from Kafka.
+ * Starting and ending offsets are specified in advance,
+ * so that you can control exactly-once semantics.
+ * @param keyClass Class of the keys in the Kafka records
+ * @param valueClass Class of the values in the Kafka records
+ * @param kafkaParams Kafka
+ *
+ * configuration parameters. Requires "bootstrap.servers" to be set
+ * with Kafka broker(s) specified in host1:port1,host2:port2 form.
+ * @param offsetRanges offset ranges that define the Kafka data belonging to this RDD
+ * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent,
+ * see [[LocationStrategies]] for more details.
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ */
+ @Experimental
+ def createRDD[K, V](
+ jsc: JavaSparkContext,
+ kafkaParams: ju.Map[String, Object],
+ offsetRanges: Array[OffsetRange],
+ locationStrategy: LocationStrategy
+ ): JavaRDD[ConsumerRecord[K, V]] = {
+
+ new JavaRDD(createRDD[K, V](jsc.sc, kafkaParams, offsetRanges, locationStrategy))
+ }
+
+ /**
+ * :: Experimental ::
+ * Scala constructor for a DStream where
+ * each given Kafka topic/partition corresponds to an RDD partition.
+ * The spark configuration spark.streaming.kafka.maxRatePerPartition gives the maximum number
+ * of messages
+ * per second that each '''partition''' will accept.
+ * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent,
+ * see [[LocationStrategies]] for more details.
+ * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe,
+ * see [[ConsumerStrategies]] for more details
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ */
+ @Experimental
+ def createDirectStream[K, V](
+ ssc: StreamingContext,
+ locationStrategy: LocationStrategy,
+ consumerStrategy: ConsumerStrategy[K, V]
+ ): InputDStream[ConsumerRecord[K, V]] = {
+ new DirectKafkaInputDStream[K, V](ssc, locationStrategy, consumerStrategy)
+ }
+
+ /**
+ * :: Experimental ::
+ * Java constructor for a DStream where
+ * each given Kafka topic/partition corresponds to an RDD partition.
+ * @param keyClass Class of the keys in the Kafka records
+ * @param valueClass Class of the values in the Kafka records
+ * @param locationStrategy In most cases, pass in LocationStrategies.preferConsistent,
+ * see [[LocationStrategies]] for more details.
+ * @param consumerStrategy In most cases, pass in ConsumerStrategies.subscribe,
+ * see [[ConsumerStrategies]] for more details
+ * @tparam K type of Kafka message key
+ * @tparam V type of Kafka message value
+ */
+ @Experimental
+ def createDirectStream[K, V](
+ jssc: JavaStreamingContext,
+ locationStrategy: LocationStrategy,
+ consumerStrategy: ConsumerStrategy[K, V]
+ ): JavaInputDStream[ConsumerRecord[K, V]] = {
+ new JavaInputDStream(
+ createDirectStream[K, V](
+ jssc.ssc, locationStrategy, consumerStrategy))
+ }
+
+ /**
+ * Tweak kafka params to prevent issues on executors
+ */
+ private[kafka010] def fixKafkaParams(kafkaParams: ju.HashMap[String, Object]): Unit = {
+ logWarning(s"overriding ${ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG} to false for executor")
+ kafkaParams.put(ConsumerConfig.ENABLE_AUTO_COMMIT_CONFIG, false: java.lang.Boolean)
+
+ logWarning(s"overriding ${ConsumerConfig.AUTO_OFFSET_RESET_CONFIG} to none for executor")
+ kafkaParams.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, "none")
+
+ // driver and executor should be in different consumer groups
+ val originalGroupId = kafkaParams.get(ConsumerConfig.GROUP_ID_CONFIG)
+ if (null == originalGroupId) {
+ logError(s"${ConsumerConfig.GROUP_ID_CONFIG} is null, you should probably set it")
+ }
+ val groupId = "spark-executor-" + originalGroupId
+ logWarning(s"overriding executor ${ConsumerConfig.GROUP_ID_CONFIG} to ${groupId}")
+ kafkaParams.put(ConsumerConfig.GROUP_ID_CONFIG, groupId)
+
+ // possible workaround for KAFKA-3135
+ val rbb = kafkaParams.get(ConsumerConfig.RECEIVE_BUFFER_CONFIG)
+ if (null == rbb || rbb.asInstanceOf[java.lang.Integer] < 65536) {
+ logWarning(s"overriding ${ConsumerConfig.RECEIVE_BUFFER_CONFIG} to 65536 see KAFKA-3135")
+ kafkaParams.put(ConsumerConfig.RECEIVE_BUFFER_CONFIG, 65536: java.lang.Integer)
+ }
+ }
+}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/LocationStrategy.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/LocationStrategy.scala
new file mode 100644
index 000000000000..c9a8a13f51c3
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/LocationStrategy.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{ util => ju }
+
+import scala.collection.JavaConverters._
+
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.annotation.Experimental
+
+
+/**
+ * :: Experimental ::
+ * Choice of how to schedule consumers for a given TopicPartition on an executor.
+ * See [[LocationStrategies]] to obtain instances.
+ * Kafka 0.10 consumers prefetch messages, so it's important for performance
+ * to keep cached consumers on appropriate executors, not recreate them for every partition.
+ * Choice of location is only a preference, not an absolute; partitions may be scheduled elsewhere.
+ */
+@Experimental
+sealed abstract class LocationStrategy
+
+private case object PreferBrokers extends LocationStrategy
+
+private case object PreferConsistent extends LocationStrategy
+
+private case class PreferFixed(hostMap: ju.Map[TopicPartition, String]) extends LocationStrategy
+
+/**
+ * :: Experimental :: object to obtain instances of [[LocationStrategy]]
+ *
+ */
+@Experimental
+object LocationStrategies {
+ /**
+ * :: Experimental ::
+ * Use this only if your executors are on the same nodes as your Kafka brokers.
+ */
+ @Experimental
+ def PreferBrokers: LocationStrategy =
+ org.apache.spark.streaming.kafka010.PreferBrokers
+
+ /**
+ * :: Experimental ::
+ * Use this in most cases, it will consistently distribute partitions across all executors.
+ */
+ @Experimental
+ def PreferConsistent: LocationStrategy =
+ org.apache.spark.streaming.kafka010.PreferConsistent
+
+ /**
+ * :: Experimental ::
+ * Use this to place particular TopicPartitions on particular hosts if your load is uneven.
+ * Any TopicPartition not specified in the map will use a consistent location.
+ */
+ @Experimental
+ def PreferFixed(hostMap: collection.Map[TopicPartition, String]): LocationStrategy =
+ new PreferFixed(new ju.HashMap[TopicPartition, String](hostMap.asJava))
+
+ /**
+ * :: Experimental ::
+ * Use this to place particular TopicPartitions on particular hosts if your load is uneven.
+ * Any TopicPartition not specified in the map will use a consistent location.
+ */
+ @Experimental
+ def PreferFixed(hostMap: ju.Map[TopicPartition, String]): LocationStrategy =
+ new PreferFixed(hostMap)
+}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/OffsetRange.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/OffsetRange.scala
new file mode 100644
index 000000000000..c66d3c9b8d22
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/OffsetRange.scala
@@ -0,0 +1,153 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import org.apache.kafka.clients.consumer.OffsetCommitCallback
+import org.apache.kafka.common.TopicPartition
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * Represents any object that has a collection of [[OffsetRange]]s. This can be used to access the
+ * offset ranges in RDDs generated by the direct Kafka DStream (see
+ * [[KafkaUtils.createDirectStream]]).
+ * {{{
+ * KafkaUtils.createDirectStream(...).foreachRDD { rdd =>
+ * val offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+ * ...
+ * }
+ * }}}
+ */
+trait HasOffsetRanges {
+ def offsetRanges: Array[OffsetRange]
+}
+
+/**
+ * :: Experimental ::
+ * Represents any object that can commit a collection of [[OffsetRange]]s.
+ * The direct Kafka DStream implements this interface (see
+ * [[KafkaUtils.createDirectStream]]).
+ * {{{
+ * val stream = KafkaUtils.createDirectStream(...)
+ * ...
+ * stream.asInstanceOf[CanCommitOffsets].commitAsync(offsets, new OffsetCommitCallback() {
+ * def onComplete(m: java.util.Map[TopicPartition, OffsetAndMetadata], e: Exception) {
+ * if (null != e) {
+ * // error
+ * } else {
+ * // success
+ * }
+ * }
+ * })
+ * }}}
+ */
+@Experimental
+trait CanCommitOffsets {
+ /**
+ * :: Experimental ::
+ * Queue up offset ranges for commit to Kafka at a future time. Threadsafe.
+ * This is only needed if you intend to store offsets in Kafka, instead of your own store.
+ * @param offsetRanges The maximum untilOffset for a given partition will be used at commit.
+ */
+ @Experimental
+ def commitAsync(offsetRanges: Array[OffsetRange]): Unit
+
+ /**
+ * :: Experimental ::
+ * Queue up offset ranges for commit to Kafka at a future time. Threadsafe.
+ * This is only needed if you intend to store offsets in Kafka, instead of your own store.
+ * @param offsetRanges The maximum untilOffset for a given partition will be used at commit.
+ * @param callback Only the most recently provided callback will be used at commit.
+ */
+ @Experimental
+ def commitAsync(offsetRanges: Array[OffsetRange], callback: OffsetCommitCallback): Unit
+}
+
+/**
+ * Represents a range of offsets from a single Kafka TopicPartition. Instances of this class
+ * can be created with `OffsetRange.create()`.
+ * @param topic Kafka topic name
+ * @param partition Kafka partition id
+ * @param fromOffset Inclusive starting offset
+ * @param untilOffset Exclusive ending offset
+ */
+final class OffsetRange private(
+ val topic: String,
+ val partition: Int,
+ val fromOffset: Long,
+ val untilOffset: Long) extends Serializable {
+ import OffsetRange.OffsetRangeTuple
+
+ /** Kafka TopicPartition object, for convenience */
+ def topicPartition(): TopicPartition = new TopicPartition(topic, partition)
+
+ /** Number of messages this OffsetRange refers to */
+ def count(): Long = untilOffset - fromOffset
+
+ override def equals(obj: Any): Boolean = obj match {
+ case that: OffsetRange =>
+ this.topic == that.topic &&
+ this.partition == that.partition &&
+ this.fromOffset == that.fromOffset &&
+ this.untilOffset == that.untilOffset
+ case _ => false
+ }
+
+ override def hashCode(): Int = {
+ toTuple.hashCode()
+ }
+
+ override def toString(): String = {
+ s"OffsetRange(topic: '$topic', partition: $partition, range: [$fromOffset -> $untilOffset])"
+ }
+
+ /** this is to avoid ClassNotFoundException during checkpoint restore */
+ private[streaming]
+ def toTuple: OffsetRangeTuple = (topic, partition, fromOffset, untilOffset)
+}
+
+/**
+ * Companion object the provides methods to create instances of [[OffsetRange]].
+ */
+object OffsetRange {
+ def create(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange =
+ new OffsetRange(topic, partition, fromOffset, untilOffset)
+
+ def create(
+ topicPartition: TopicPartition,
+ fromOffset: Long,
+ untilOffset: Long): OffsetRange =
+ new OffsetRange(topicPartition.topic, topicPartition.partition, fromOffset, untilOffset)
+
+ def apply(topic: String, partition: Int, fromOffset: Long, untilOffset: Long): OffsetRange =
+ new OffsetRange(topic, partition, fromOffset, untilOffset)
+
+ def apply(
+ topicPartition: TopicPartition,
+ fromOffset: Long,
+ untilOffset: Long): OffsetRange =
+ new OffsetRange(topicPartition.topic, topicPartition.partition, fromOffset, untilOffset)
+
+ /** this is to avoid ClassNotFoundException during checkpoint restore */
+ private[kafka010]
+ type OffsetRangeTuple = (String, Int, Long, Long)
+
+ private[kafka010]
+ def apply(t: OffsetRangeTuple) =
+ new OffsetRange(t._1, t._2, t._3, t._4)
+}
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package-info.java b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package-info.java
new file mode 100644
index 000000000000..ebfcf8764a32
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package-info.java
@@ -0,0 +1,21 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Spark Integration for Kafka 0.10
+ */
+package org.apache.spark.streaming.kafka010;
diff --git a/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package.scala b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package.scala
new file mode 100644
index 000000000000..09db6d6062d8
--- /dev/null
+++ b/external/kafka-0-10/src/main/scala/org/apache/spark/streaming/kafka010/package.scala
@@ -0,0 +1,23 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming
+
+/**
+ * Spark Integration for Kafka 0.10
+ */
+package object kafka010 //scalastyle:ignore
diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java
new file mode 100644
index 000000000000..ba57b6beb247
--- /dev/null
+++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaConsumerStrategySuite.java
@@ -0,0 +1,101 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010;
+
+import java.io.Serializable;
+import java.util.*;
+import java.util.regex.Pattern;
+
+import scala.collection.JavaConverters;
+
+import org.apache.kafka.common.TopicPartition;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class JavaConsumerStrategySuite implements Serializable {
+
+ @Test
+ public void testConsumerStrategyConstructors() {
+ final String topic1 = "topic1";
+ final Pattern pat = Pattern.compile("top.*");
+ final Collection topics = Arrays.asList(topic1);
+ final scala.collection.Iterable sTopics =
+ JavaConverters.collectionAsScalaIterableConverter(topics).asScala();
+ final TopicPartition tp1 = new TopicPartition(topic1, 0);
+ final TopicPartition tp2 = new TopicPartition(topic1, 1);
+ final Collection parts = Arrays.asList(tp1, tp2);
+ final scala.collection.Iterable sParts =
+ JavaConverters.collectionAsScalaIterableConverter(parts).asScala();
+ final Map kafkaParams = new HashMap();
+ kafkaParams.put("bootstrap.servers", "not used");
+ final scala.collection.Map sKafkaParams =
+ JavaConverters.mapAsScalaMapConverter(kafkaParams).asScala();
+ final Map offsets = new HashMap<>();
+ offsets.put(tp1, 23L);
+ final scala.collection.Map sOffsets =
+ JavaConverters.mapAsScalaMapConverter(offsets).asScala().mapValues(
+ new scala.runtime.AbstractFunction1() {
+ @Override
+ public Object apply(Long x) {
+ return (Object) x;
+ }
+ }
+ );
+
+ final ConsumerStrategy sub1 =
+ ConsumerStrategies.Subscribe(sTopics, sKafkaParams, sOffsets);
+ final ConsumerStrategy sub2 =
+ ConsumerStrategies.Subscribe(sTopics, sKafkaParams);
+ final ConsumerStrategy sub3 =
+ ConsumerStrategies.Subscribe(topics, kafkaParams, offsets);
+ final ConsumerStrategy sub4 =
+ ConsumerStrategies.Subscribe(topics, kafkaParams);
+
+ Assert.assertEquals(
+ sub1.executorKafkaParams().get("bootstrap.servers"),
+ sub3.executorKafkaParams().get("bootstrap.servers"));
+
+ final ConsumerStrategy psub1 =
+ ConsumerStrategies.SubscribePattern(pat, sKafkaParams, sOffsets);
+ final ConsumerStrategy psub2 =
+ ConsumerStrategies.SubscribePattern(pat, sKafkaParams);
+ final ConsumerStrategy psub3 =
+ ConsumerStrategies.SubscribePattern(pat, kafkaParams, offsets);
+ final ConsumerStrategy psub4 =
+ ConsumerStrategies.SubscribePattern(pat, kafkaParams);
+
+ Assert.assertEquals(
+ psub1.executorKafkaParams().get("bootstrap.servers"),
+ psub3.executorKafkaParams().get("bootstrap.servers"));
+
+ final ConsumerStrategy asn1 =
+ ConsumerStrategies.Assign(sParts, sKafkaParams, sOffsets);
+ final ConsumerStrategy asn2 =
+ ConsumerStrategies.Assign(sParts, sKafkaParams);
+ final ConsumerStrategy asn3 =
+ ConsumerStrategies.Assign(parts, kafkaParams, offsets);
+ final ConsumerStrategy asn4 =
+ ConsumerStrategies.Assign(parts, kafkaParams);
+
+ Assert.assertEquals(
+ asn1.executorKafkaParams().get("bootstrap.servers"),
+ asn3.executorKafkaParams().get("bootstrap.servers"));
+ }
+
+}
diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaDirectKafkaStreamSuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaDirectKafkaStreamSuite.java
new file mode 100644
index 000000000000..dc9c13ba863f
--- /dev/null
+++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaDirectKafkaStreamSuite.java
@@ -0,0 +1,180 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010;
+
+import java.io.Serializable;
+import java.util.*;
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.apache.kafka.common.serialization.StringDeserializer;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.VoidFunction;
+import org.apache.spark.streaming.Durations;
+import org.apache.spark.streaming.api.java.JavaDStream;
+import org.apache.spark.streaming.api.java.JavaInputDStream;
+import org.apache.spark.streaming.api.java.JavaStreamingContext;
+
+public class JavaDirectKafkaStreamSuite implements Serializable {
+ private transient JavaStreamingContext ssc = null;
+ private transient KafkaTestUtils kafkaTestUtils = null;
+
+ @Before
+ public void setUp() {
+ kafkaTestUtils = new KafkaTestUtils();
+ kafkaTestUtils.setup();
+ SparkConf sparkConf = new SparkConf()
+ .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
+ ssc = new JavaStreamingContext(sparkConf, Durations.milliseconds(200));
+ }
+
+ @After
+ public void tearDown() {
+ if (ssc != null) {
+ ssc.stop();
+ ssc = null;
+ }
+
+ if (kafkaTestUtils != null) {
+ kafkaTestUtils.teardown();
+ kafkaTestUtils = null;
+ }
+ }
+
+ @Test
+ public void testKafkaStream() throws InterruptedException {
+ final String topic1 = "topic1";
+ final String topic2 = "topic2";
+ // hold a reference to the current offset ranges, so it can be used downstream
+ final AtomicReference offsetRanges = new AtomicReference<>();
+
+ String[] topic1data = createTopicAndSendData(topic1);
+ String[] topic2data = createTopicAndSendData(topic2);
+
+ Set sent = new HashSet<>();
+ sent.addAll(Arrays.asList(topic1data));
+ sent.addAll(Arrays.asList(topic2data));
+
+ Random random = new Random();
+
+ final Map kafkaParams = new HashMap<>();
+ kafkaParams.put("bootstrap.servers", kafkaTestUtils.brokerAddress());
+ kafkaParams.put("key.deserializer", StringDeserializer.class);
+ kafkaParams.put("value.deserializer", StringDeserializer.class);
+ kafkaParams.put("auto.offset.reset", "earliest");
+ kafkaParams.put("group.id", "java-test-consumer-" + random.nextInt() +
+ "-" + System.currentTimeMillis());
+
+ JavaInputDStream> istream1 = KafkaUtils.createDirectStream(
+ ssc,
+ LocationStrategies.PreferConsistent(),
+ ConsumerStrategies.Subscribe(Arrays.asList(topic1), kafkaParams)
+ );
+
+ JavaDStream stream1 = istream1.transform(
+ // Make sure you can get offset ranges from the rdd
+ new Function>,
+ JavaRDD>>() {
+ @Override
+ public JavaRDD> call(
+ JavaRDD> rdd
+ ) {
+ OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges();
+ offsetRanges.set(offsets);
+ Assert.assertEquals(topic1, offsets[0].topic());
+ return rdd;
+ }
+ }
+ ).map(
+ new Function, String>() {
+ @Override
+ public String call(ConsumerRecord r) {
+ return r.value();
+ }
+ }
+ );
+
+ final Map kafkaParams2 = new HashMap<>(kafkaParams);
+ kafkaParams2.put("group.id", "java-test-consumer-" + random.nextInt() +
+ "-" + System.currentTimeMillis());
+
+ JavaInputDStream> istream2 = KafkaUtils.createDirectStream(
+ ssc,
+ LocationStrategies.PreferConsistent(),
+ ConsumerStrategies.Subscribe(Arrays.asList(topic2), kafkaParams2)
+ );
+
+ JavaDStream stream2 = istream2.transform(
+ // Make sure you can get offset ranges from the rdd
+ new Function>,
+ JavaRDD>>() {
+ @Override
+ public JavaRDD> call(
+ JavaRDD> rdd
+ ) {
+ OffsetRange[] offsets = ((HasOffsetRanges) rdd.rdd()).offsetRanges();
+ offsetRanges.set(offsets);
+ Assert.assertEquals(topic2, offsets[0].topic());
+ return rdd;
+ }
+ }
+ ).map(
+ new Function, String>() {
+ @Override
+ public String call(ConsumerRecord r) {
+ return r.value();
+ }
+ }
+ );
+
+ JavaDStream unifiedStream = stream1.union(stream2);
+
+ final Set result = Collections.synchronizedSet(new HashSet());
+ unifiedStream.foreachRDD(new VoidFunction>() {
+ @Override
+ public void call(JavaRDD rdd) {
+ result.addAll(rdd.collect());
+ }
+ }
+ );
+ ssc.start();
+ long startTime = System.currentTimeMillis();
+ boolean matches = false;
+ while (!matches && System.currentTimeMillis() - startTime < 20000) {
+ matches = sent.size() == result.size();
+ Thread.sleep(50);
+ }
+ Assert.assertEquals(sent, result);
+ ssc.stop();
+ }
+
+ private String[] createTopicAndSendData(String topic) {
+ String[] data = { topic + "-1", topic + "-2", topic + "-3"};
+ kafkaTestUtils.createTopic(topic);
+ kafkaTestUtils.sendMessages(topic, data);
+ return data;
+ }
+}
diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java
new file mode 100644
index 000000000000..87bfe1514e33
--- /dev/null
+++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaKafkaRDDSuite.java
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+
+import org.apache.kafka.common.serialization.StringDeserializer;
+import org.apache.kafka.common.TopicPartition;
+import org.apache.kafka.clients.consumer.ConsumerRecord;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+
+public class JavaKafkaRDDSuite implements Serializable {
+ private transient JavaSparkContext sc = null;
+ private transient KafkaTestUtils kafkaTestUtils = null;
+
+ @Before
+ public void setUp() {
+ kafkaTestUtils = new KafkaTestUtils();
+ kafkaTestUtils.setup();
+ SparkConf sparkConf = new SparkConf()
+ .setMaster("local[4]").setAppName(this.getClass().getSimpleName());
+ sc = new JavaSparkContext(sparkConf);
+ }
+
+ @After
+ public void tearDown() {
+ if (sc != null) {
+ sc.stop();
+ sc = null;
+ }
+
+ if (kafkaTestUtils != null) {
+ kafkaTestUtils.teardown();
+ kafkaTestUtils = null;
+ }
+ }
+
+ @Test
+ public void testKafkaRDD() throws InterruptedException {
+ String topic1 = "topic1";
+ String topic2 = "topic2";
+
+ Random random = new Random();
+
+ createTopicAndSendData(topic1);
+ createTopicAndSendData(topic2);
+
+ Map kafkaParams = new HashMap<>();
+ kafkaParams.put("bootstrap.servers", kafkaTestUtils.brokerAddress());
+ kafkaParams.put("key.deserializer", StringDeserializer.class);
+ kafkaParams.put("value.deserializer", StringDeserializer.class);
+ kafkaParams.put("group.id", "java-test-consumer-" + random.nextInt() +
+ "-" + System.currentTimeMillis());
+
+ OffsetRange[] offsetRanges = {
+ OffsetRange.create(topic1, 0, 0, 1),
+ OffsetRange.create(topic2, 0, 0, 1)
+ };
+
+ Map leaders = new HashMap<>();
+ String[] hostAndPort = kafkaTestUtils.brokerAddress().split(":");
+ String broker = hostAndPort[0];
+ leaders.put(offsetRanges[0].topicPartition(), broker);
+ leaders.put(offsetRanges[1].topicPartition(), broker);
+
+ Function, String> handler =
+ new Function, String>() {
+ @Override
+ public String call(ConsumerRecord r) {
+ return r.value();
+ }
+ };
+
+ JavaRDD rdd1 = KafkaUtils.createRDD(
+ sc,
+ kafkaParams,
+ offsetRanges,
+ LocationStrategies.PreferFixed(leaders)
+ ).map(handler);
+
+ JavaRDD rdd2 = KafkaUtils.createRDD(
+ sc,
+ kafkaParams,
+ offsetRanges,
+ LocationStrategies.PreferConsistent()
+ ).map(handler);
+
+ // just making sure the java user apis work; the scala tests handle logic corner cases
+ long count1 = rdd1.count();
+ long count2 = rdd2.count();
+ Assert.assertTrue(count1 > 0);
+ Assert.assertEquals(count1, count2);
+ }
+
+ private String[] createTopicAndSendData(String topic) {
+ String[] data = { topic + "-1", topic + "-2", topic + "-3"};
+ kafkaTestUtils.createTopic(topic);
+ kafkaTestUtils.sendMessages(topic, data);
+ return data;
+ }
+}
diff --git a/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaLocationStrategySuite.java b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaLocationStrategySuite.java
new file mode 100644
index 000000000000..41ccb0ebe7bf
--- /dev/null
+++ b/external/kafka-0-10/src/test/java/org/apache/spark/streaming/kafka010/JavaLocationStrategySuite.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010;
+
+import java.io.Serializable;
+import java.util.*;
+
+import scala.collection.JavaConverters;
+
+import org.apache.kafka.common.TopicPartition;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+public class JavaLocationStrategySuite implements Serializable {
+
+ @Test
+ public void testLocationStrategyConstructors() {
+ final String topic1 = "topic1";
+ final TopicPartition tp1 = new TopicPartition(topic1, 0);
+ final TopicPartition tp2 = new TopicPartition(topic1, 1);
+ final Map hosts = new HashMap<>();
+ hosts.put(tp1, "node1");
+ hosts.put(tp2, "node2");
+ final scala.collection.Map sHosts =
+ JavaConverters.mapAsScalaMapConverter(hosts).asScala();
+
+ // make sure constructors can be called from java
+ final LocationStrategy c1 = LocationStrategies.PreferConsistent();
+ final LocationStrategy c2 = LocationStrategies.PreferConsistent();
+ Assert.assertSame(c1, c2);
+
+ final LocationStrategy c3 = LocationStrategies.PreferBrokers();
+ final LocationStrategy c4 = LocationStrategies.PreferBrokers();
+ Assert.assertSame(c3, c4);
+
+ Assert.assertNotSame(c1, c3);
+
+ final LocationStrategy c5 = LocationStrategies.PreferFixed(hosts);
+ final LocationStrategy c6 = LocationStrategies.PreferFixed(sHosts);
+ Assert.assertEquals(c5, c6);
+ }
+
+}
diff --git a/external/kafka-0-10/src/test/resources/log4j.properties b/external/kafka-0-10/src/test/resources/log4j.properties
new file mode 100644
index 000000000000..75e3b53a093f
--- /dev/null
+++ b/external/kafka-0-10/src/test/resources/log4j.properties
@@ -0,0 +1,28 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the file target/unit-tests.log
+log4j.rootCategory=INFO, file
+log4j.appender.file=org.apache.log4j.FileAppender
+log4j.appender.file.append=true
+log4j.appender.file.file=target/unit-tests.log
+log4j.appender.file.layout=org.apache.log4j.PatternLayout
+log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n
+
+# Ignore messages below warning level from Jetty, because it's a bit verbose
+log4j.logger.org.spark-project.jetty=WARN
+
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
new file mode 100644
index 000000000000..c9e15bcba0a9
--- /dev/null
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/DirectKafkaStreamSuite.scala
@@ -0,0 +1,695 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.io.File
+import java.lang.{ Long => JLong }
+import java.util.{ Arrays, HashMap => JHashMap, Map => JMap }
+import java.util.concurrent.atomic.AtomicLong
+import java.util.concurrent.ConcurrentLinkedQueue
+
+import scala.collection.JavaConverters._
+import scala.concurrent.duration._
+import scala.language.postfixOps
+import scala.util.Random
+
+import org.apache.kafka.clients.consumer._
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.serialization.StringDeserializer
+import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
+import org.scalatest.concurrent.Eventually
+
+import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.internal.Logging
+import org.apache.spark.rdd.RDD
+import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time}
+import org.apache.spark.streaming.dstream.DStream
+import org.apache.spark.streaming.scheduler._
+import org.apache.spark.streaming.scheduler.rate.RateEstimator
+import org.apache.spark.util.Utils
+
+class DirectKafkaStreamSuite
+ extends SparkFunSuite
+ with BeforeAndAfter
+ with BeforeAndAfterAll
+ with Eventually
+ with Logging {
+ val sparkConf = new SparkConf()
+ .setMaster("local[4]")
+ .setAppName(this.getClass.getSimpleName)
+
+ private var sc: SparkContext = _
+ private var ssc: StreamingContext = _
+ private var testDir: File = _
+
+ private var kafkaTestUtils: KafkaTestUtils = _
+
+ override def beforeAll {
+ kafkaTestUtils = new KafkaTestUtils
+ kafkaTestUtils.setup()
+ }
+
+ override def afterAll {
+ if (kafkaTestUtils != null) {
+ kafkaTestUtils.teardown()
+ kafkaTestUtils = null
+ }
+ }
+
+ after {
+ if (ssc != null) {
+ ssc.stop()
+ sc = null
+ }
+ if (sc != null) {
+ sc.stop()
+ }
+ if (testDir != null) {
+ Utils.deleteRecursively(testDir)
+ }
+ }
+
+ def getKafkaParams(extra: (String, Object)*): JHashMap[String, Object] = {
+ val kp = new JHashMap[String, Object]()
+ kp.put("bootstrap.servers", kafkaTestUtils.brokerAddress)
+ kp.put("key.deserializer", classOf[StringDeserializer])
+ kp.put("value.deserializer", classOf[StringDeserializer])
+ kp.put("group.id", s"test-consumer-${Random.nextInt}-${System.currentTimeMillis}")
+ extra.foreach(e => kp.put(e._1, e._2))
+ kp
+ }
+
+ val preferredHosts = LocationStrategies.PreferConsistent
+
+ test("basic stream receiving with multiple topics and smallest starting offset") {
+ val topics = List("basic1", "basic2", "basic3")
+ val data = Map("a" -> 7, "b" -> 9)
+ topics.foreach { t =>
+ kafkaTestUtils.createTopic(t)
+ kafkaTestUtils.sendMessages(t, data)
+ }
+ val offsets = Map(new TopicPartition("basic3", 0) -> 2L)
+ // one topic is starting 2 messages later
+ val expectedTotal = (data.values.sum * topics.size) - 2
+ val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val stream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String](
+ ssc,
+ preferredHosts,
+ ConsumerStrategies.Subscribe[String, String](topics, kafkaParams.asScala, offsets))
+ }
+ val allReceived = new ConcurrentLinkedQueue[(String, String)]()
+
+ // hold a reference to the current offset ranges, so it can be used downstream
+ var offsetRanges = Array[OffsetRange]()
+ val tf = stream.transform { rdd =>
+ // Get the offset ranges in the RDD
+ offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+ rdd.map(r => (r.key, r.value))
+ }
+
+ tf.foreachRDD { rdd =>
+ for (o <- offsetRanges) {
+ logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}")
+ }
+ val collected = rdd.mapPartitionsWithIndex { (i, iter) =>
+ // For each partition, get size of the range in the partition,
+ // and the number of items in the partition
+ val off = offsetRanges(i)
+ val all = iter.toSeq
+ val partSize = all.size
+ val rangeSize = off.untilOffset - off.fromOffset
+ Iterator((partSize, rangeSize))
+ }.collect
+
+ // Verify whether number of elements in each partition
+ // matches with the corresponding offset range
+ collected.foreach { case (partSize, rangeSize) =>
+ assert(partSize === rangeSize, "offset ranges are wrong")
+ }
+ }
+
+ stream.foreachRDD { rdd =>
+ allReceived.addAll(Arrays.asList(rdd.map(r => (r.key, r.value)).collect(): _*))
+ }
+ ssc.start()
+ eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
+ assert(allReceived.size === expectedTotal,
+ "didn't get expected number of messages, messages:\n" +
+ allReceived.asScala.mkString("\n"))
+ }
+ ssc.stop()
+ }
+
+ test("pattern based subscription") {
+ val topics = List("pat1", "pat2", "advanced3")
+ // Should match 2 out of 3 topics
+ val pat = """pat\d""".r.pattern
+ val data = Map("a" -> 7, "b" -> 9)
+ topics.foreach { t =>
+ kafkaTestUtils.createTopic(t)
+ kafkaTestUtils.sendMessages(t, data)
+ }
+ val offsets = Map(new TopicPartition("pat2", 0) -> 3L)
+ // 2 matching topics, one of which starts 3 messages later
+ val expectedTotal = (data.values.sum * 2) - 3
+ val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val stream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String](
+ ssc,
+ preferredHosts,
+ ConsumerStrategies.SubscribePattern[String, String](pat, kafkaParams.asScala, offsets))
+ }
+ val allReceived = new ConcurrentLinkedQueue[(String, String)]()
+
+ // hold a reference to the current offset ranges, so it can be used downstream
+ var offsetRanges = Array[OffsetRange]()
+ val tf = stream.transform { rdd =>
+ // Get the offset ranges in the RDD
+ offsetRanges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+ rdd.map(r => (r.key, r.value))
+ }
+
+ tf.foreachRDD { rdd =>
+ for (o <- offsetRanges) {
+ logInfo(s"${o.topic} ${o.partition} ${o.fromOffset} ${o.untilOffset}")
+ }
+ val collected = rdd.mapPartitionsWithIndex { (i, iter) =>
+ // For each partition, get size of the range in the partition,
+ // and the number of items in the partition
+ val off = offsetRanges(i)
+ val all = iter.toSeq
+ val partSize = all.size
+ val rangeSize = off.untilOffset - off.fromOffset
+ Iterator((partSize, rangeSize))
+ }.collect
+
+ // Verify whether number of elements in each partition
+ // matches with the corresponding offset range
+ collected.foreach { case (partSize, rangeSize) =>
+ assert(partSize === rangeSize, "offset ranges are wrong")
+ }
+ }
+
+ stream.foreachRDD { rdd =>
+ allReceived.addAll(Arrays.asList(rdd.map(r => (r.key, r.value)).collect(): _*))
+ }
+ ssc.start()
+ eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
+ assert(allReceived.size === expectedTotal,
+ "didn't get expected number of messages, messages:\n" +
+ allReceived.asScala.mkString("\n"))
+ }
+ ssc.stop()
+ }
+
+
+ test("receiving from largest starting offset") {
+ val topic = "latest"
+ val topicPartition = new TopicPartition(topic, 0)
+ val data = Map("a" -> 10)
+ kafkaTestUtils.createTopic(topic)
+ val kafkaParams = getKafkaParams("auto.offset.reset" -> "latest")
+ val kc = new KafkaConsumer(kafkaParams)
+ kc.assign(Arrays.asList(topicPartition))
+ def getLatestOffset(): Long = {
+ kc.seekToEnd(Arrays.asList(topicPartition))
+ kc.position(topicPartition)
+ }
+
+ // Send some initial messages before starting context
+ kafkaTestUtils.sendMessages(topic, data)
+ eventually(timeout(10 seconds), interval(20 milliseconds)) {
+ assert(getLatestOffset() > 3)
+ }
+ val offsetBeforeStart = getLatestOffset()
+ kc.close()
+
+ // Setup context and kafka stream with largest offset
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val stream = withClue("Error creating direct stream") {
+ val s = new DirectKafkaInputDStream[String, String](
+ ssc,
+ preferredHosts,
+ ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala))
+ s.consumer.poll(0)
+ assert(
+ s.consumer.position(topicPartition) >= offsetBeforeStart,
+ "Start offset not from latest"
+ )
+ s
+ }
+
+ val collectedData = new ConcurrentLinkedQueue[String]()
+ stream.map { _.value }.foreachRDD { rdd =>
+ collectedData.addAll(Arrays.asList(rdd.collect(): _*))
+ }
+ ssc.start()
+ val newData = Map("b" -> 10)
+ kafkaTestUtils.sendMessages(topic, newData)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ collectedData.contains("b")
+ }
+ assert(!collectedData.contains("a"))
+ }
+
+
+ test("creating stream by offset") {
+ val topic = "offset"
+ val topicPartition = new TopicPartition(topic, 0)
+ val data = Map("a" -> 10)
+ kafkaTestUtils.createTopic(topic)
+ val kafkaParams = getKafkaParams("auto.offset.reset" -> "latest")
+ val kc = new KafkaConsumer(kafkaParams)
+ kc.assign(Arrays.asList(topicPartition))
+ def getLatestOffset(): Long = {
+ kc.seekToEnd(Arrays.asList(topicPartition))
+ kc.position(topicPartition)
+ }
+
+ // Send some initial messages before starting context
+ kafkaTestUtils.sendMessages(topic, data)
+ eventually(timeout(10 seconds), interval(20 milliseconds)) {
+ assert(getLatestOffset() >= 10)
+ }
+ val offsetBeforeStart = getLatestOffset()
+ kc.close()
+
+ // Setup context and kafka stream with largest offset
+ kafkaParams.put("auto.offset.reset", "none")
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val stream = withClue("Error creating direct stream") {
+ val s = new DirectKafkaInputDStream[String, String](
+ ssc,
+ preferredHosts,
+ ConsumerStrategies.Assign[String, String](
+ List(topicPartition),
+ kafkaParams.asScala,
+ Map(topicPartition -> 11L)))
+ s.consumer.poll(0)
+ assert(
+ s.consumer.position(topicPartition) >= offsetBeforeStart,
+ "Start offset not from latest"
+ )
+ s
+ }
+
+ val collectedData = new ConcurrentLinkedQueue[String]()
+ stream.map(_.value).foreachRDD { rdd => collectedData.addAll(Arrays.asList(rdd.collect(): _*)) }
+ ssc.start()
+ val newData = Map("b" -> 10)
+ kafkaTestUtils.sendMessages(topic, newData)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ collectedData.contains("b")
+ }
+ assert(!collectedData.contains("a"))
+ }
+
+ // Test to verify the offset ranges can be recovered from the checkpoints
+ test("offset recovery") {
+ val topic = "recovery"
+ kafkaTestUtils.createTopic(topic)
+ testDir = Utils.createTempDir()
+
+ val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+
+ // Send data to Kafka
+ def sendData(data: Seq[Int]) {
+ val strings = data.map { _.toString}
+ kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap)
+ }
+
+ // Setup the streaming context
+ ssc = new StreamingContext(sparkConf, Milliseconds(100))
+ val kafkaStream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String](
+ ssc,
+ preferredHosts,
+ ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala))
+ }
+ val keyedStream = kafkaStream.map { r => "key" -> r.value.toInt }
+ val stateStream = keyedStream.updateStateByKey { (values: Seq[Int], state: Option[Int]) =>
+ Some(values.sum + state.getOrElse(0))
+ }
+ ssc.checkpoint(testDir.getAbsolutePath)
+
+ // This is ensure all the data is eventually receiving only once
+ stateStream.foreachRDD { (rdd: RDD[(String, Int)]) =>
+ rdd.collect().headOption.foreach { x =>
+ DirectKafkaStreamSuite.total.set(x._2)
+ }
+ }
+
+ ssc.start()
+
+ // Send some data
+ for (i <- (1 to 10).grouped(4)) {
+ sendData(i)
+ }
+
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ assert(DirectKafkaStreamSuite.total.get === (1 to 10).sum)
+ }
+
+ ssc.stop()
+
+ // Verify that offset ranges were generated
+ val offsetRangesBeforeStop = getOffsetRanges(kafkaStream)
+ assert(offsetRangesBeforeStop.size >= 1, "No offset ranges generated")
+ assert(
+ offsetRangesBeforeStop.head._2.forall { _.fromOffset === 0 },
+ "starting offset not zero"
+ )
+
+ logInfo("====== RESTARTING ========")
+
+ // Recover context from checkpoints
+ ssc = new StreamingContext(testDir.getAbsolutePath)
+ val recoveredStream =
+ ssc.graph.getInputStreams().head.asInstanceOf[DStream[ConsumerRecord[String, String]]]
+
+ // Verify offset ranges have been recovered
+ val recoveredOffsetRanges = getOffsetRanges(recoveredStream).map { x => (x._1, x._2.toSet) }
+ assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered")
+ val earlierOffsetRanges = offsetRangesBeforeStop.map { x => (x._1, x._2.toSet) }
+ assert(
+ recoveredOffsetRanges.forall { or =>
+ earlierOffsetRanges.contains((or._1, or._2))
+ },
+ "Recovered ranges are not the same as the ones generated\n" +
+ earlierOffsetRanges + "\n" + recoveredOffsetRanges
+ )
+ // Restart context, give more data and verify the total at the end
+ // If the total is write that means each records has been received only once
+ ssc.start()
+ for (i <- (11 to 20).grouped(4)) {
+ sendData(i)
+ }
+
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ assert(DirectKafkaStreamSuite.total.get === (1 to 20).sum)
+ }
+ ssc.stop()
+ }
+
+ // Test to verify the offsets can be recovered from Kafka
+ test("offset recovery from kafka") {
+ val topic = "recoveryfromkafka"
+ kafkaTestUtils.createTopic(topic)
+
+ val kafkaParams = getKafkaParams(
+ "auto.offset.reset" -> "earliest",
+ ("enable.auto.commit", false: java.lang.Boolean)
+ )
+
+ val collectedData = new ConcurrentLinkedQueue[String]()
+ val committed = new JHashMap[TopicPartition, OffsetAndMetadata]()
+
+ // Send data to Kafka and wait for it to be received
+ def sendDataAndWaitForReceive(data: Seq[Int]) {
+ val strings = data.map { _.toString}
+ kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap)
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ assert(strings.forall { collectedData.contains })
+ }
+ }
+
+ // Setup the streaming context
+ ssc = new StreamingContext(sparkConf, Milliseconds(100))
+ withClue("Error creating direct stream") {
+ val kafkaStream = KafkaUtils.createDirectStream[String, String](
+ ssc,
+ preferredHosts,
+ ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala))
+ kafkaStream.foreachRDD { (rdd: RDD[ConsumerRecord[String, String]], time: Time) =>
+ val offsets = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+ val data = rdd.map(_.value).collect()
+ collectedData.addAll(Arrays.asList(data: _*))
+ kafkaStream.asInstanceOf[CanCommitOffsets]
+ .commitAsync(offsets, new OffsetCommitCallback() {
+ def onComplete(m: JMap[TopicPartition, OffsetAndMetadata], e: Exception) {
+ if (null != e) {
+ logError("commit failed", e)
+ } else {
+ committed.putAll(m)
+ }
+ }
+ })
+ }
+ }
+ ssc.start()
+ // Send some data and wait for them to be received
+ for (i <- (1 to 10).grouped(4)) {
+ sendDataAndWaitForReceive(i)
+ }
+ ssc.stop()
+ assert(! committed.isEmpty)
+ val consumer = new KafkaConsumer[String, String](kafkaParams)
+ consumer.subscribe(Arrays.asList(topic))
+ consumer.poll(0)
+ committed.asScala.foreach {
+ case (k, v) =>
+ // commits are async, not exactly once
+ assert(v.offset > 0)
+ assert(consumer.position(k) >= v.offset)
+ }
+ }
+
+
+ test("Direct Kafka stream report input information") {
+ val topic = "report-test"
+ val data = Map("a" -> 7, "b" -> 9)
+ kafkaTestUtils.createTopic(topic)
+ kafkaTestUtils.sendMessages(topic, data)
+
+ val totalSent = data.values.sum
+ val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+
+ import DirectKafkaStreamSuite._
+ ssc = new StreamingContext(sparkConf, Milliseconds(200))
+ val collector = new InputInfoCollector
+ ssc.addStreamingListener(collector)
+
+ val stream = withClue("Error creating direct stream") {
+ KafkaUtils.createDirectStream[String, String](
+ ssc,
+ preferredHosts,
+ ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala))
+ }
+
+ val allReceived = new ConcurrentLinkedQueue[(String, String)]
+
+ stream.map(r => (r.key, r.value))
+ .foreachRDD { rdd => allReceived.addAll(Arrays.asList(rdd.collect(): _*)) }
+ ssc.start()
+ eventually(timeout(20000.milliseconds), interval(200.milliseconds)) {
+ assert(allReceived.size === totalSent,
+ "didn't get expected number of messages, messages:\n" +
+ allReceived.asScala.mkString("\n"))
+
+ // Calculate all the record number collected in the StreamingListener.
+ assert(collector.numRecordsSubmitted.get() === totalSent)
+ assert(collector.numRecordsStarted.get() === totalSent)
+ assert(collector.numRecordsCompleted.get() === totalSent)
+ }
+ ssc.stop()
+ }
+
+ test("maxMessagesPerPartition with backpressure disabled") {
+ val topic = "maxMessagesPerPartition"
+ val kafkaStream = getDirectKafkaStream(topic, None)
+
+ val input = Map(new TopicPartition(topic, 0) -> 50L, new TopicPartition(topic, 1) -> 50L)
+ assert(kafkaStream.maxMessagesPerPartition(input).get ==
+ Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 10L))
+ }
+
+ test("maxMessagesPerPartition with no lag") {
+ val topic = "maxMessagesPerPartition"
+ val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 100))
+ val kafkaStream = getDirectKafkaStream(topic, rateController)
+
+ val input = Map(new TopicPartition(topic, 0) -> 0L, new TopicPartition(topic, 1) -> 0L)
+ assert(kafkaStream.maxMessagesPerPartition(input).isEmpty)
+ }
+
+ test("maxMessagesPerPartition respects max rate") {
+ val topic = "maxMessagesPerPartition"
+ val rateController = Some(new ConstantRateController(0, new ConstantEstimator(100), 1000))
+ val kafkaStream = getDirectKafkaStream(topic, rateController)
+
+ val input = Map(new TopicPartition(topic, 0) -> 1000L, new TopicPartition(topic, 1) -> 1000L)
+ assert(kafkaStream.maxMessagesPerPartition(input).get ==
+ Map(new TopicPartition(topic, 0) -> 10L, new TopicPartition(topic, 1) -> 10L))
+ }
+
+ test("using rate controller") {
+ val topic = "backpressure"
+ val topicPartitions = Set(new TopicPartition(topic, 0), new TopicPartition(topic, 1))
+ kafkaTestUtils.createTopic(topic, 2)
+ val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+ val executorKafkaParams = new JHashMap[String, Object](kafkaParams)
+ KafkaUtils.fixKafkaParams(executorKafkaParams)
+
+ val batchIntervalMilliseconds = 100
+ val estimator = new ConstantEstimator(100)
+ val messages = Map("foo" -> 200)
+ kafkaTestUtils.sendMessages(topic, messages)
+
+ val sparkConf = new SparkConf()
+ // Safe, even with streaming, because we're using the direct API.
+ // Using 1 core is useful to make the test more predictable.
+ .setMaster("local[1]")
+ .setAppName(this.getClass.getSimpleName)
+ .set("spark.streaming.kafka.maxRatePerPartition", "100")
+
+ // Setup the streaming context
+ ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds))
+
+ val kafkaStream = withClue("Error creating direct stream") {
+ new DirectKafkaInputDStream[String, String](
+ ssc,
+ preferredHosts,
+ ConsumerStrategies.Subscribe[String, String](List(topic), kafkaParams.asScala)) {
+ override protected[streaming] val rateController =
+ Some(new DirectKafkaRateController(id, estimator))
+ }.map(r => (r.key, r.value))
+ }
+
+ val collectedData = new ConcurrentLinkedQueue[Array[String]]()
+
+ // Used for assertion failure messages.
+ def dataToString: String =
+ collectedData.asScala.map(_.mkString("[", ",", "]")).mkString("{", ", ", "}")
+
+ // This is to collect the raw data received from Kafka
+ kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) =>
+ val data = rdd.map { _._2 }.collect()
+ collectedData.add(data)
+ }
+
+ ssc.start()
+
+ // Try different rate limits.
+ // Wait for arrays of data to appear matching the rate.
+ Seq(100, 50, 20).foreach { rate =>
+ collectedData.clear() // Empty this buffer on each pass.
+ estimator.updateRate(rate) // Set a new rate.
+ // Expect blocks of data equal to "rate", scaled by the interval length in secs.
+ val expectedSize = Math.round(rate * batchIntervalMilliseconds * 0.001)
+ eventually(timeout(5.seconds), interval(batchIntervalMilliseconds.milliseconds)) {
+ // Assert that rate estimator values are used to determine maxMessagesPerPartition.
+ // Funky "-" in message makes the complete assertion message read better.
+ assert(collectedData.asScala.exists(_.size == expectedSize),
+ s" - No arrays of size $expectedSize for rate $rate found in $dataToString")
+ }
+ }
+
+ ssc.stop()
+ }
+
+ /** Get the generated offset ranges from the DirectKafkaStream */
+ private def getOffsetRanges[K, V](
+ kafkaStream: DStream[ConsumerRecord[K, V]]): Seq[(Time, Array[OffsetRange])] = {
+ kafkaStream.generatedRDDs.mapValues { rdd =>
+ rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+ }.toSeq.sortBy { _._1 }
+ }
+
+ private def getDirectKafkaStream(topic: String, mockRateController: Option[RateController]) = {
+ val batchIntervalMilliseconds = 100
+
+ val sparkConf = new SparkConf()
+ .setMaster("local[1]")
+ .setAppName(this.getClass.getSimpleName)
+ .set("spark.streaming.kafka.maxRatePerPartition", "100")
+
+ // Setup the streaming context
+ ssc = new StreamingContext(sparkConf, Milliseconds(batchIntervalMilliseconds))
+
+ val kafkaParams = getKafkaParams("auto.offset.reset" -> "earliest")
+ val ekp = new JHashMap[String, Object](kafkaParams)
+ KafkaUtils.fixKafkaParams(ekp)
+
+ val s = new DirectKafkaInputDStream[String, String](
+ ssc,
+ preferredHosts,
+ new ConsumerStrategy[String, String] {
+ def executorKafkaParams = ekp
+ def onStart(currentOffsets: JMap[TopicPartition, JLong]): Consumer[String, String] = {
+ val consumer = new KafkaConsumer[String, String](kafkaParams)
+ val tps = List(new TopicPartition(topic, 0), new TopicPartition(topic, 1))
+ consumer.assign(Arrays.asList(tps: _*))
+ tps.foreach(tp => consumer.seek(tp, 0))
+ consumer
+ }
+ }
+ ) {
+ override protected[streaming] val rateController = mockRateController
+ }
+ // manual start necessary because we arent consuming the stream, just checking its state
+ s.start()
+ s
+ }
+}
+
+object DirectKafkaStreamSuite {
+ val total = new AtomicLong(-1L)
+
+ class InputInfoCollector extends StreamingListener {
+ val numRecordsSubmitted = new AtomicLong(0L)
+ val numRecordsStarted = new AtomicLong(0L)
+ val numRecordsCompleted = new AtomicLong(0L)
+
+ override def onBatchSubmitted(batchSubmitted: StreamingListenerBatchSubmitted): Unit = {
+ numRecordsSubmitted.addAndGet(batchSubmitted.batchInfo.numRecords)
+ }
+
+ override def onBatchStarted(batchStarted: StreamingListenerBatchStarted): Unit = {
+ numRecordsStarted.addAndGet(batchStarted.batchInfo.numRecords)
+ }
+
+ override def onBatchCompleted(batchCompleted: StreamingListenerBatchCompleted): Unit = {
+ numRecordsCompleted.addAndGet(batchCompleted.batchInfo.numRecords)
+ }
+ }
+}
+
+private[streaming] class ConstantEstimator(@volatile private var rate: Long)
+ extends RateEstimator {
+
+ def updateRate(newRate: Long): Unit = {
+ rate = newRate
+ }
+
+ def compute(
+ time: Long,
+ elements: Long,
+ processingDelay: Long,
+ schedulingDelay: Long): Option[Double] = Some(rate)
+}
+
+private[streaming] class ConstantRateController(id: Int, estimator: RateEstimator, rate: Long)
+ extends RateController(id, estimator) {
+ override def publish(rate: Long): Unit = ()
+ override def getLatestRate(): Long = rate
+}
diff --git a/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
new file mode 100644
index 000000000000..be373af0599c
--- /dev/null
+++ b/external/kafka-0-10/src/test/scala/org/apache/spark/streaming/kafka010/KafkaRDDSuite.scala
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.streaming.kafka010
+
+import java.{ util => ju }
+
+import scala.collection.JavaConverters._
+import scala.util.Random
+
+import org.apache.kafka.common.TopicPartition
+import org.apache.kafka.common.serialization.StringDeserializer
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark._
+import org.apache.spark.scheduler.ExecutorCacheTaskLocation
+
+class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
+
+ private var kafkaTestUtils: KafkaTestUtils = _
+
+ private val sparkConf = new SparkConf().setMaster("local[4]")
+ .setAppName(this.getClass.getSimpleName)
+ private var sc: SparkContext = _
+
+ override def beforeAll {
+ sc = new SparkContext(sparkConf)
+ kafkaTestUtils = new KafkaTestUtils
+ kafkaTestUtils.setup()
+ }
+
+ override def afterAll {
+ if (sc != null) {
+ sc.stop
+ sc = null
+ }
+
+ if (kafkaTestUtils != null) {
+ kafkaTestUtils.teardown()
+ kafkaTestUtils = null
+ }
+ }
+
+ private def getKafkaParams() = Map[String, Object](
+ "bootstrap.servers" -> kafkaTestUtils.brokerAddress,
+ "key.deserializer" -> classOf[StringDeserializer],
+ "value.deserializer" -> classOf[StringDeserializer],
+ "group.id" -> s"test-consumer-${Random.nextInt}-${System.currentTimeMillis}"
+ ).asJava
+
+ private val preferredHosts = LocationStrategies.PreferConsistent
+
+ test("basic usage") {
+ val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}"
+ kafkaTestUtils.createTopic(topic)
+ val messages = Array("the", "quick", "brown", "fox")
+ kafkaTestUtils.sendMessages(topic, messages)
+
+ val kafkaParams = getKafkaParams()
+
+ val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))
+
+ val rdd = KafkaUtils.createRDD[String, String](sc, kafkaParams, offsetRanges, preferredHosts)
+ .map(_.value)
+
+ val received = rdd.collect.toSet
+ assert(received === messages.toSet)
+
+ // size-related method optimizations return sane results
+ assert(rdd.count === messages.size)
+ assert(rdd.countApprox(0).getFinalValue.mean === messages.size)
+ assert(!rdd.isEmpty)
+ assert(rdd.take(1).size === 1)
+ assert(rdd.take(1).head === messages.head)
+ assert(rdd.take(messages.size + 10).size === messages.size)
+
+ val emptyRdd = KafkaUtils.createRDD[String, String](
+ sc, kafkaParams, Array(OffsetRange(topic, 0, 0, 0)), preferredHosts)
+
+ assert(emptyRdd.isEmpty)
+
+ // invalid offset ranges throw exceptions
+ val badRanges = Array(OffsetRange(topic, 0, 0, messages.size + 1))
+ intercept[SparkException] {
+ val result = KafkaUtils.createRDD[String, String](sc, kafkaParams, badRanges, preferredHosts)
+ .map(_.value)
+ .collect()
+ }
+ }
+
+ test("iterator boundary conditions") {
+ // the idea is to find e.g. off-by-one errors between what kafka has available and the rdd
+ val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}"
+ val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
+ kafkaTestUtils.createTopic(topic)
+
+ val kafkaParams = getKafkaParams()
+
+ // this is the "lots of messages" case
+ kafkaTestUtils.sendMessages(topic, sent)
+ var sentCount = sent.values.sum
+
+ val rdd = KafkaUtils.createRDD[String, String](sc, kafkaParams,
+ Array(OffsetRange(topic, 0, 0, sentCount)), preferredHosts)
+
+ val ranges = rdd.asInstanceOf[HasOffsetRanges].offsetRanges
+ val rangeCount = ranges.map(o => o.untilOffset - o.fromOffset).sum
+
+ assert(rangeCount === sentCount, "offset range didn't include all sent messages")
+ assert(rdd.map(_.offset).collect.sorted === (0 until sentCount).toArray,
+ "didn't get all sent messages")
+
+ // this is the "0 messages" case
+ val rdd2 = KafkaUtils.createRDD[String, String](sc, kafkaParams,
+ Array(OffsetRange(topic, 0, sentCount, sentCount)), preferredHosts)
+
+ // shouldn't get anything, since message is sent after rdd was defined
+ val sentOnlyOne = Map("d" -> 1)
+
+ kafkaTestUtils.sendMessages(topic, sentOnlyOne)
+
+ assert(rdd2.map(_.value).collect.size === 0, "got messages when there shouldn't be any")
+
+ // this is the "exactly 1 message" case, namely the single message from sentOnlyOne above
+ val rdd3 = KafkaUtils.createRDD[String, String](sc, kafkaParams,
+ Array(OffsetRange(topic, 0, sentCount, sentCount + 1)), preferredHosts)
+
+ // send lots of messages after rdd was defined, they shouldn't show up
+ kafkaTestUtils.sendMessages(topic, Map("extra" -> 22))
+
+ assert(rdd3.map(_.value).collect.head === sentOnlyOne.keys.head,
+ "didn't get exactly one message")
+ }
+
+ test("executor sorting") {
+ val kafkaParams = new ju.HashMap[String, Object](getKafkaParams())
+ kafkaParams.put("auto.offset.reset", "none")
+ val rdd = new KafkaRDD[String, String](
+ sc,
+ kafkaParams,
+ Array(OffsetRange("unused", 0, 1, 2)),
+ ju.Collections.emptyMap[TopicPartition, String](),
+ true)
+ val a3 = ExecutorCacheTaskLocation("a", "3")
+ val a4 = ExecutorCacheTaskLocation("a", "4")
+ val b1 = ExecutorCacheTaskLocation("b", "1")
+ val b2 = ExecutorCacheTaskLocation("b", "2")
+
+ val correct = Array(b2, b1, a4, a3)
+
+ correct.permutations.foreach { p =>
+ assert(p.sortWith(rdd.compareExecutors) === correct)
+ }
+ }
+}
diff --git a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
index d9d4240c056a..abfd7aad4c5c 100644
--- a/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
+++ b/external/kafka-0-8/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala
@@ -35,6 +35,7 @@ import kafka.serializer.StringEncoder
import kafka.server.{KafkaConfig, KafkaServer}
import kafka.utils.{ZKStringSerializer, ZkUtils}
import org.I0Itec.zkclient.ZkClient
+import org.apache.commons.lang3.RandomUtils
import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer}
import org.apache.spark.SparkConf
@@ -62,7 +63,8 @@ private[kafka] class KafkaTestUtils extends Logging {
// Kafka broker related configurations
private val brokerHost = "localhost"
- private var brokerPort = 9092
+ // 0.8.2 server doesn't have a boundPort method, so can't use 0 for a random port
+ private var brokerPort = RandomUtils.nextInt(1024, 65536)
private var brokerConf: KafkaConfig = _
// Kafka broker server
@@ -112,7 +114,7 @@ private[kafka] class KafkaTestUtils extends Logging {
brokerConf = new KafkaConfig(brokerConfiguration)
server = new KafkaServer(brokerConf)
server.startup()
- (server, port)
+ (server, brokerPort)
}, new SparkConf(), "KafkaBroker")
brokerReady = true
diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
index cb782d27fe22..ab1c5055a253 100644
--- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
+++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala
@@ -244,12 +244,9 @@ class DirectKafkaStreamSuite
)
// Send data to Kafka and wait for it to be received
- def sendDataAndWaitForReceive(data: Seq[Int]) {
+ def sendData(data: Seq[Int]) {
val strings = data.map { _.toString}
kafkaTestUtils.sendMessages(topic, strings.map { _ -> 1}.toMap)
- eventually(timeout(10 seconds), interval(50 milliseconds)) {
- assert(strings.forall { DirectKafkaStreamSuite.collectedData.contains })
- }
}
// Setup the streaming context
@@ -264,21 +261,21 @@ class DirectKafkaStreamSuite
}
ssc.checkpoint(testDir.getAbsolutePath)
- // This is to collect the raw data received from Kafka
- kafkaStream.foreachRDD { (rdd: RDD[(String, String)], time: Time) =>
- val data = rdd.map { _._2 }.collect()
- DirectKafkaStreamSuite.collectedData.addAll(Arrays.asList(data: _*))
- }
-
// This is ensure all the data is eventually receiving only once
stateStream.foreachRDD { (rdd: RDD[(String, Int)]) =>
- rdd.collect().headOption.foreach { x => DirectKafkaStreamSuite.total = x._2 }
+ rdd.collect().headOption.foreach { x =>
+ DirectKafkaStreamSuite.total.set(x._2)
+ }
}
ssc.start()
- // Send some data and wait for them to be received
+ // Send some data
for (i <- (1 to 10).grouped(4)) {
- sendDataAndWaitForReceive(i)
+ sendData(i)
+ }
+
+ eventually(timeout(10 seconds), interval(50 milliseconds)) {
+ assert(DirectKafkaStreamSuite.total.get === (1 to 10).sum)
}
ssc.stop()
@@ -302,23 +299,26 @@ class DirectKafkaStreamSuite
val recoveredStream = ssc.graph.getInputStreams().head.asInstanceOf[DStream[(String, String)]]
// Verify offset ranges have been recovered
- val recoveredOffsetRanges = getOffsetRanges(recoveredStream)
+ val recoveredOffsetRanges = getOffsetRanges(recoveredStream).map { x => (x._1, x._2.toSet) }
assert(recoveredOffsetRanges.size > 0, "No offset ranges recovered")
- val earlierOffsetRangesAsSets = offsetRangesAfterStop.map { x => (x._1, x._2.toSet) }
+ val earlierOffsetRanges = offsetRangesAfterStop.map { x => (x._1, x._2.toSet) }
assert(
recoveredOffsetRanges.forall { or =>
- earlierOffsetRangesAsSets.contains((or._1, or._2.toSet))
+ earlierOffsetRanges.contains((or._1, or._2))
},
"Recovered ranges are not the same as the ones generated\n" +
s"recoveredOffsetRanges: $recoveredOffsetRanges\n" +
- s"earlierOffsetRangesAsSets: $earlierOffsetRangesAsSets"
+ s"earlierOffsetRanges: $earlierOffsetRanges"
)
// Restart context, give more data and verify the total at the end
// If the total is write that means each records has been received only once
ssc.start()
- sendDataAndWaitForReceive(11 to 20)
+ for (i <- (11 to 20).grouped(4)) {
+ sendData(i)
+ }
+
eventually(timeout(10 seconds), interval(50 milliseconds)) {
- assert(DirectKafkaStreamSuite.total === (1 to 20).sum)
+ assert(DirectKafkaStreamSuite.total.get === (1 to 20).sum)
}
ssc.stop()
}
@@ -488,8 +488,7 @@ class DirectKafkaStreamSuite
}
object DirectKafkaStreamSuite {
- val collectedData = new ConcurrentLinkedQueue[String]()
- @volatile var total = -1L
+ val total = new AtomicLong(-1L)
class InputInfoCollector extends StreamingListener {
val numRecordsSubmitted = new AtomicLong(0L)
diff --git a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
index 5e539c1d790c..809699a73996 100644
--- a/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
+++ b/external/kafka-0-8/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala
@@ -53,13 +53,13 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
}
test("basic usage") {
- val topic = s"topicbasic-${Random.nextInt}"
+ val topic = s"topicbasic-${Random.nextInt}-${System.currentTimeMillis}"
kafkaTestUtils.createTopic(topic)
val messages = Array("the", "quick", "brown", "fox")
kafkaTestUtils.sendMessages(topic, messages)
val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
- "group.id" -> s"test-consumer-${Random.nextInt}")
+ "group.id" -> s"test-consumer-${Random.nextInt}-${System.currentTimeMillis}")
val offsetRanges = Array(OffsetRange(topic, 0, 0, messages.size))
@@ -92,12 +92,12 @@ class KafkaRDDSuite extends SparkFunSuite with BeforeAndAfterAll {
test("iterator boundary conditions") {
// the idea is to find e.g. off-by-one errors between what kafka has available and the rdd
- val topic = s"topicboundary-${Random.nextInt}"
+ val topic = s"topicboundary-${Random.nextInt}-${System.currentTimeMillis}"
val sent = Map("a" -> 5, "b" -> 3, "c" -> 10)
kafkaTestUtils.createTopic(topic)
val kafkaParams = Map("metadata.broker.list" -> kafkaTestUtils.brokerAddress,
- "group.id" -> s"test-consumer-${Random.nextInt}")
+ "group.id" -> s"test-consumer-${Random.nextInt}-${System.currentTimeMillis}")
val kc = new KafkaCluster(kafkaParams)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 25e56d70c233..a1d08b3a6e78 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -44,7 +44,10 @@ abstract class PipelineStage extends Params with Logging {
/**
* :: DeveloperApi ::
*
- * Derives the output schema from the input schema.
+ * Check transform validity and derive the output schema from the input schema.
+ *
+ * Typical implementation should first conduct verification on schema change and parameter
+ * validity, including complex parameter interaction checks.
*/
@DeveloperApi
def transformSchema(schema: StructType): StructType
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 7c340312df3e..c99ae30155e3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -28,8 +28,9 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.classification.{NaiveBayes => OldNaiveBayes}
import org.apache.spark.mllib.classification.{NaiveBayesModel => OldNaiveBayesModel}
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
+import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.{Dataset, Row}
/**
* Params for Naive Bayes Classifiers.
@@ -275,9 +276,11 @@ object NaiveBayesModel extends MLReadable[NaiveBayesModel] {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
val dataPath = new Path(path, "data").toString
- val data = sparkSession.read.parquet(dataPath).select("pi", "theta").head()
- val pi = data.getAs[Vector](0)
- val theta = data.getAs[Matrix](1)
+ val data = sparkSession.read.parquet(dataPath)
+ val vecConverted = MLUtils.convertVectorColumnsToML(data, "pi")
+ val Row(pi: Vector, theta: Matrix) = MLUtils.convertMatrixColumnsToML(vecConverted, "theta")
+ .select("pi", "theta")
+ .head()
val model = new NaiveBayesModel(metadata.uid, pi, theta)
DefaultParamsReader.getAndSetParams(model, metadata)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index b333d5925823..778cd0fee71c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -880,11 +880,13 @@ class LDA @Since("1.6.0") (
}
}
-
-private[clustering] object LDA extends DefaultParamsReadable[LDA] {
+@Since("2.0.0")
+object LDA extends DefaultParamsReadable[LDA] {
/** Get dataset for spark.mllib LDA */
- def getOldDataset(dataset: Dataset[_], featuresCol: String): RDD[(Long, OldVector)] = {
+ private[clustering] def getOldDataset(
+ dataset: Dataset[_],
+ featuresCol: String): RDD[(Long, OldVector)] = {
dataset
.withColumn("docId", monotonicallyIncreasingId())
.select("docId", featuresCol)
@@ -894,6 +896,6 @@ private[clustering] object LDA extends DefaultParamsReadable[LDA] {
}
}
- @Since("1.6.0")
+ @Since("2.0.0")
override def load(path: String): LDA = super.load(path)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 72167b50e384..ef8b08545db2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -206,24 +206,22 @@ object PCAModel extends MLReadable[PCAModel] {
override def load(path: String): PCAModel = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
- // explainedVariance field is not present in Spark <= 1.6
- val versionRegex = "([0-9]+)\\.([0-9]+).*".r
- val hasExplainedVariance = metadata.sparkVersion match {
- case versionRegex(major, minor) =>
- major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)
- case _ => false
- }
+ val versionRegex = "([0-9]+)\\.(.+)".r
+ val versionRegex(major, _) = metadata.sparkVersion
val dataPath = new Path(path, "data").toString
- val model = if (hasExplainedVariance) {
+ val model = if (major.toInt >= 2) {
val Row(pc: DenseMatrix, explainedVariance: DenseVector) =
sparkSession.read.parquet(dataPath)
.select("pc", "explainedVariance")
.head()
new PCAModel(metadata.uid, pc, explainedVariance)
} else {
- val Row(pc: DenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head()
- new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
+ // pc field is the old matrix format in Spark <= 1.6
+ // explainedVariance field is not present in Spark <= 1.6
+ val Row(pc: OldDenseMatrix) = sparkSession.read.parquet(dataPath).select("pc").head()
+ new PCAModel(metadata.uid, pc.asML,
+ Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector])
}
DefaultParamsReader.getAndSetParams(model, metadata)
model
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index f4819f77ebdb..a80cca70f4b2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -1127,7 +1127,7 @@ private[python] class PythonMLLibAPI extends Serializable {
* Wrapper around RowMatrix constructor.
*/
def createRowMatrix(rows: JavaRDD[Vector], numRows: Long, numCols: Int): RowMatrix = {
- new RowMatrix(rows.rdd.retag(classOf[Vector]), numRows, numCols)
+ new RowMatrix(rows.rdd, numRows, numCols)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 2f52825c6cb0..f2211df3f943 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -629,14 +629,16 @@ object Word2VecModel extends Loader[Word2VecModel] {
("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
- // We want to partition the model in partitions of size 32MB
- val partitionSize = (1L << 25)
+ // We want to partition the model in partitions smaller than
+ // spark.kryoserializer.buffer.max
+ val bufferSize = Utils.byteStringAsBytes(
+ spark.conf.get("spark.kryoserializer.buffer.max", "64m"))
// We calculate the approximate size of the model
- // We only calculate the array size, not considering
- // the string size, the formula is:
- // floatSize * numWords * vectorSize
- val approxSize = 4L * numWords * vectorSize
- val nPartitions = ((approxSize / partitionSize) + 1).toInt
+ // We only calculate the array size, considering an
+ // average string size of 15 bytes, the formula is:
+ // (floatSize * vectorSize + 15) * numWords
+ val approxSize = (4L * vectorSize + 15) * numWords
+ val nPartitions = ((approxSize / bufferSize) + 1).toInt
val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
spark.createDataFrame(dataArray).repartition(nPartitions).write.parquet(Loader.dataPath(path))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index cd5209d0ebe2..ec32e37afb79 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -537,7 +537,7 @@ class RowMatrix @Since("1.0.0") (
def tallSkinnyQR(computeQ: Boolean = false): QRDecomposition[RowMatrix, Matrix] = {
val col = numCols().toInt
// split rows horizontally into smaller matrices, and compute QR for each of them
- val blockQRs = rows.glom().map { partRows =>
+ val blockQRs = rows.retag(classOf[Vector]).glom().filter(_.length != 0).map { partRows =>
val bdm = BDM.zeros[Double](partRows.length, col)
var i = 0
partRows.foreach { row =>
@@ -548,10 +548,11 @@ class RowMatrix @Since("1.0.0") (
}
// combine the R part from previous results vertically into a tall matrix
- val combinedR = blockQRs.treeReduce{ (r1, r2) =>
+ val combinedR = blockQRs.treeReduce { (r1, r2) =>
val stackedR = BDM.vertcat(r1, r2)
breeze.linalg.qr.reduced(stackedR).r
}
+
val finalR = Matrices.fromBreeze(combinedR.toDenseMatrix)
val finalQ = if (computeQ) {
try {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaRowMatrixSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaRowMatrixSuite.java
new file mode 100644
index 000000000000..c01af405491b
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/distributed/JavaRowMatrixSuite.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.linalg.distributed;
+
+import java.util.Arrays;
+
+import org.junit.Test;
+
+import org.apache.spark.SharedSparkSession;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.mllib.linalg.Matrix;
+import org.apache.spark.mllib.linalg.QRDecomposition;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+public class JavaRowMatrixSuite extends SharedSparkSession {
+
+ @Test
+ public void rowMatrixQRDecomposition() {
+ Vector v1 = Vectors.dense(1.0, 10.0, 100.0);
+ Vector v2 = Vectors.dense(2.0, 20.0, 200.0);
+ Vector v3 = Vectors.dense(3.0, 30.0, 300.0);
+
+ JavaRDD rows = jsc.parallelize(Arrays.asList(v1, v2, v3), 1);
+ RowMatrix mat = new RowMatrix(rows.rdd());
+
+ QRDecomposition result = mat.tallSkinnyQR(true);
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index c9fb9768c1b4..22de4c4ac40e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -91,11 +91,23 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
}
- ignore("big model load / save") {
- // create a model bigger than 32MB since 9000 * 1000 * 4 > 2^25
- val word2VecMap = Map((0 to 9000).map(i => s"$i" -> Array.fill(1000)(0.1f)): _*)
+ test("big model load / save") {
+ // backupping old values
+ val oldBufferConfValue = spark.conf.get("spark.kryoserializer.buffer.max", "64m")
+ val oldBufferMaxConfValue = spark.conf.get("spark.kryoserializer.buffer", "64k")
+
+ // setting test values to trigger partitioning
+ spark.conf.set("spark.kryoserializer.buffer", "50b")
+ spark.conf.set("spark.kryoserializer.buffer.max", "50b")
+
+ // create a model bigger than 50 Bytes
+ val word2VecMap = Map((0 to 10).map(i => s"$i" -> Array.fill(10)(0.1f)): _*)
val model = new Word2VecModel(word2VecMap)
+ // est. size of this model, given the formula:
+ // (floatSize * vectorSize + 15) * numWords
+ // (4 * 10 + 15) * 10 = 550
+ // therefore it should generate multiple partitions
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
@@ -103,9 +115,16 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
model.save(sc, path)
val sameModel = Word2VecModel.load(sc, path)
assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq))
+ }
+ catch {
+ case t: Throwable => fail("exception thrown persisting a model " +
+ "that spans over multiple partitions", t)
} finally {
Utils.deleteRecursively(tempDir)
+ spark.conf.set("spark.kryoserializer.buffer", oldBufferConfValue)
+ spark.conf.set("spark.kryoserializer.buffer.max", oldBufferMaxConfValue)
}
+
}
test("test similarity for word vectors with large values is not Infinity or NaN") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
index 7c4c6d8409c6..7c9e14f8cee7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors}
import org.apache.spark.mllib.random.RandomRDDs
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.mllib.util.TestingUtils._
class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -281,6 +282,22 @@ class RowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(cov(i, j) === cov(j, i))
}
}
+
+ test("QR decomposition should aware of empty partition (SPARK-16369)") {
+ val mat: RowMatrix = new RowMatrix(sc.parallelize(denseData, 1))
+ val qrResult = mat.tallSkinnyQR(true)
+
+ val matWithEmptyPartition = new RowMatrix(sc.parallelize(denseData, 8))
+ val qrResult2 = matWithEmptyPartition.tallSkinnyQR(true)
+
+ assert(qrResult.Q.numCols() === qrResult2.Q.numCols(), "Q matrix ncol not match")
+ assert(qrResult.Q.numRows() === qrResult2.Q.numRows(), "Q matrix nrow not match")
+ qrResult.Q.rows.collect().zip(qrResult2.Q.rows.collect())
+ .foreach(x => assert(x._1 ~== x._2 relTol 1E-8, "Q matrix not match"))
+
+ qrResult.R.toArray.zip(qrResult2.R.toArray)
+ .foreach(x => assert(x._1 ~== x._2 relTol 1E-8, "R matrix not match"))
+ }
}
class RowMatrixClusterSuite extends SparkFunSuite with LocalClusterSparkContext {
diff --git a/pom.xml b/pom.xml
index e2730ee1c74c..9f3d7f003584 100644
--- a/pom.xml
+++ b/pom.xml
@@ -109,6 +109,8 @@
launcher
external/kafka-0-8
external/kafka-0-8-assembly
+ external/kafka-0-10
+ external/kafka-0-10-assembly
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 4c01ad3c3371..b1a9f393423b 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -44,9 +44,9 @@ object BuildCommons {
).map(ProjectRef(buildLocation, _))
val streamingProjects@Seq(
- streaming, streamingFlumeSink, streamingFlume, streamingKafka
+ streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingKafka010
) = Seq(
- "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-8"
+ "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-8", "streaming-kafka-0-10"
).map(ProjectRef(buildLocation, _))
val allProjects@Seq(
@@ -61,8 +61,8 @@ object BuildCommons {
Seq("yarn", "java8-tests", "ganglia-lgpl", "streaming-kinesis-asl",
"docker-integration-tests").map(ProjectRef(buildLocation, _))
- val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKinesisAslAssembly) =
- Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kinesis-asl-assembly")
+ val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) =
+ Seq("network-yarn", "streaming-flume-assembly", "streaming-kafka-0-8-assembly", "streaming-kafka-0-10-assembly", "streaming-kinesis-asl-assembly")
.map(ProjectRef(buildLocation, _))
val copyJarsProjects@Seq(assembly, examples) = Seq("assembly", "examples")
@@ -352,7 +352,7 @@ object SparkBuild extends PomBuild {
val mimaProjects = allProjects.filterNot { x =>
Seq(
spark, hive, hiveThriftServer, catalyst, repl, networkCommon, networkShuffle, networkYarn,
- unsafe, tags, sketch, mllibLocal
+ unsafe, tags, sketch, mllibLocal, streamingKafka010
).contains(x)
}
@@ -608,7 +608,7 @@ object Assembly {
.getOrElse(SbtPomKeys.effectivePom.value.getProperties.get("hadoop.version").asInstanceOf[String])
},
jarName in assembly <<= (version, moduleName, hadoopVersion) map { (v, mName, hv) =>
- if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-0-8-assembly") || mName.contains("streaming-kinesis-asl-assembly")) {
+ if (mName.contains("streaming-flume-assembly") || mName.contains("streaming-kafka-0-8-assembly") || mName.contains("streaming-kafka-0-10-assembly") || mName.contains("streaming-kinesis-asl-assembly")) {
// This must match the same name used in maven (see external/kafka-0-8-assembly/pom.xml)
s"${mName}-${v}.jar"
} else {
@@ -701,15 +701,29 @@ object Unidoc {
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/sql/hive/test")))
}
+ private def ignoreClasspaths(classpaths: Seq[Classpath]): Seq[Classpath] = {
+ classpaths
+ .map(_.filterNot(_.data.getCanonicalPath.matches(""".*kafka-clients-0\.10.*""")))
+ .map(_.filterNot(_.data.getCanonicalPath.matches(""".*kafka_2\..*-0\.10.*""")))
+ }
+
val unidocSourceBase = settingKey[String]("Base URL of source links in Scaladoc.")
lazy val settings = scalaJavaUnidocSettings ++ Seq (
publish := {},
unidocProjectFilter in(ScalaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags),
+ inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010),
unidocProjectFilter in(JavaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags),
+ inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, tags, streamingKafka010),
+
+ unidocAllClasspaths in (ScalaUnidoc, unidoc) := {
+ ignoreClasspaths((unidocAllClasspaths in (ScalaUnidoc, unidoc)).value)
+ },
+
+ unidocAllClasspaths in (JavaUnidoc, unidoc) := {
+ ignoreClasspaths((unidocAllClasspaths in (JavaUnidoc, unidoc)).value)
+ },
// Skip actual catalyst, but include the subproject.
// Catalyst is not public API and contains quasiquotes which break scaladoc.
@@ -723,8 +737,7 @@ object Unidoc {
.map(_.filterNot(_.getCanonicalPath.contains("org/apache/hadoop")))
},
- // Javadoc options: create a window title, and group key packages on index page
- javacOptions in doc := Seq(
+ javacOptions in (JavaUnidoc, unidoc) := Seq(
"-windowtitle", "Spark " + version.value.replaceAll("-SNAPSHOT", "") + " JavaDoc",
"-public",
"-noqualifier", "java.lang"
diff --git a/python/docs/pyspark.ml.rst b/python/docs/pyspark.ml.rst
index 26f7415e1a42..2658b8463222 100644
--- a/python/docs/pyspark.ml.rst
+++ b/python/docs/pyspark.ml.rst
@@ -80,3 +80,11 @@ pyspark.ml.evaluation module
:members:
:undoc-members:
:inherited-members:
+
+pyspark.ml.stat module
+----------------------------
+
+.. automodule:: pyspark.ml.stat
+ :members:
+ :undoc-members:
+ :inherited-members:
diff --git a/python/docs/pyspark.sql.rst b/python/docs/pyspark.sql.rst
index 6259379ed05b..3be9533c126d 100644
--- a/python/docs/pyspark.sql.rst
+++ b/python/docs/pyspark.sql.rst
@@ -21,3 +21,9 @@ pyspark.sql.functions module
.. automodule:: pyspark.sql.functions
:members:
:undoc-members:
+
+pyspark.sql.streaming module
+----------------------------
+.. automodule:: pyspark.sql.streaming
+ :members:
+ :undoc-members:
diff --git a/python/pyspark/ml/common.py b/python/pyspark/ml/common.py
index 256e91e14165..7d449aaccb44 100644
--- a/python/pyspark/ml/common.py
+++ b/python/pyspark/ml/common.py
@@ -63,7 +63,7 @@ def _to_java_object_rdd(rdd):
RDD is serialized in batch or not.
"""
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
- return rdd.ctx._jvm.MLSerDe.pythonToJava(rdd._jrdd, True)
+ return rdd.ctx._jvm.org.apache.spark.ml.python.MLSerDe.pythonToJava(rdd._jrdd, True)
def _py2java(sc, obj):
@@ -82,7 +82,7 @@ def _py2java(sc, obj):
pass
else:
data = bytearray(PickleSerializer().dumps(obj))
- obj = sc._jvm.MLSerDe.loads(data)
+ obj = sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(data)
return obj
@@ -95,17 +95,17 @@ def _java2py(sc, r, encoding="bytes"):
clsName = 'JavaRDD'
if clsName == 'JavaRDD':
- jrdd = sc._jvm.MLSerDe.javaToPython(r)
+ jrdd = sc._jvm.org.apache.spark.ml.python.MLSerDe.javaToPython(r)
return RDD(jrdd, sc)
if clsName == 'Dataset':
return DataFrame(r, SQLContext.getOrCreate(sc))
if clsName in _picklable_classes:
- r = sc._jvm.MLSerDe.dumps(r)
+ r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
elif isinstance(r, (JavaArray, JavaList)):
try:
- r = sc._jvm.MLSerDe.dumps(r)
+ r = sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(r)
except Py4JJavaError:
pass # not pickable
diff --git a/python/pyspark/ml/stat/__init__.py b/python/pyspark/ml/stat/__init__.py
new file mode 100644
index 000000000000..6d2f6bea46b7
--- /dev/null
+++ b/python/pyspark/ml/stat/__init__.py
@@ -0,0 +1,27 @@
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Python package for statistical functions in ML.
+"""
+
+
+from pyspark.ml.stat.distribution import MultivariateGaussian
+
+
+__all__ = ["MultivariateGaussian"]
diff --git a/python/pyspark/ml/stat/distribution.py b/python/pyspark/ml/stat/distribution.py
new file mode 100644
index 000000000000..6f5c3282af0f
--- /dev/null
+++ b/python/pyspark/ml/stat/distribution.py
@@ -0,0 +1,173 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+from pyspark.ml.linalg import DenseVector, DenseMatrix, Vector, Vectors
+
+__all__ = ['MultivariateGaussian']
+
+
+class MultivariateGaussian():
+ """
+ This class provides basic functionality for a Multivariate Gaussian (Normal) Distribution. In
+ the event that the covariance matrix is singular, the density will be computed in a
+ reduced dimensional subspace under which the distribution is supported.
+ (see `http://en.wikipedia.org/wiki/Multivariate_normal_distribution#Degenerate_case` )
+
+ mu The mean vector of the distribution
+ sigma The covariance matrix of the distribution
+
+ >>> mu = Vectors.dense([0.0, 0.0])
+ >>> sigma= DenseMatrix(2, 2, [1.0, 1.0, 1.0, 1.0])
+ >>> x = Vectors.dense([1.0, 1.0])
+ >>> m = MultivariateGaussian(mu, sigma)
+ >>> m.pdf(x)
+ 0.06825868114860217
+
+ """
+
+ def __init__(self, mu, sigma):
+ """
+ __init__(self, mu, sigma)
+
+ mu The mean vector of the distribution
+ sigma The covariance matrix of the distribution
+
+ mu and sigma must be instances of DenseVector and DenseMatrix respectively.
+
+ """
+ assert (isinstance(mu, DenseVector)), "mu must be a DenseVector Object"
+ assert (isinstance(sigma, DenseMatrix)), "sigma must be a DenseMatrix Object"
+ assert (sigma.numRows == sigma.numCols), "Covariance matrix must be square"
+ assert (sigma.numRows == mu.size), "Mean vector length must match covariance matrix size"
+
+ # initialize eagerly precomputed attributes
+ self.mu = mu
+
+ # storing sigma as numpy.ndarray
+ # further calculations are done on ndarray only
+ self.sigma = sigma.toArray()
+
+ # initialize attributes to be computed later
+ self.prec_U = None
+ self.log_det_cov = None
+
+ # compute distribution dependent constants
+ self.__calculateCovarianceConstants()
+
+ def pdf(self, x):
+ """
+ Returns density of this multivariate Gaussian at a point given by Vector x
+ """
+ assert (isinstance(x, Vector)), "x must be of Type Vector"
+ assert (self.mu.size == x.size), "Length of vector x must match that of Mean"
+ return float(self.__pdf(x))
+
+ def logpdf(self, x):
+ """
+ Returns the log-density of this multivariate Gaussian at a point given by Vector x
+ """
+ assert (isinstance(x, Vector)), "x must be of Vector Type"
+ assert (self.mu.size == x.size), "Length of vector x must match that of Mean"
+ return float(self.__logpdf(x))
+
+ def __calculateCovarianceConstants(self):
+ """
+ Calculates distribution dependent components used for the density function
+ based on scipy multivariate library.
+ For further understanding on covariance constants and calculations,
+ see `https://github.com/scipy/scipy/blob/master/scipy/stats/_multivariate.py`
+ """
+ # calculating the eigenvalues and eigenvectors of covariance matrix
+ # s = eigen values
+ # u = eigen vectors
+ s, u = np.linalg.eigh(self.sigma)
+
+ # Singular values are considered to be non-zero only if
+ # they exceed a tolerance based on machine precision, matrix size, and
+ # relation to the maximum singular value (same tolerance used by, e.g., Octave).
+
+ # calculation for machine precision
+ t = u.dtype.char.lower()
+ factor = {'f': 1E3, 'd': 1E6}
+ cond = factor[t] * np.finfo(t).eps
+
+ eps = cond * np.max(abs(s))
+
+ # checking whether covariance matrix has any non-zero singular values
+ if np.min(s) < -eps:
+ raise ValueError("Covariance matrix has no non-zero singular values")
+
+ # computing the pseudoinverse of s (creates a copy)
+ # elements of vector s smaller than eps are considered negligible
+ # while remaining elements are inverted
+ s_pinv = np.array([0 if abs(x) < eps else 1/x for x in s], dtype=float)
+
+ # prec_U ndarray
+ # A decomposition such that np.dot(prec_U, prec_U.T)
+ # is the precision matrix, i.e. inverse of the covariance matrix.
+ self.prec_U = u * np.sqrt(s_pinv)
+
+ # log_det_cov : float
+ # Logarithm of the determinant of the covariance matrix
+ self.log_det_cov = np.sum(np.log(s[s > eps]))
+
+ def __pdf(self, x):
+ """
+ Calculates density at point x using precomputed Constants
+ x Points at which to evaluate the probability density function
+ """
+ return np.exp(self.__logpdf(x))
+
+ def __logpdf(self, x):
+ """
+ Calculates log-density at point x using precomputed Constants
+ x Points at which to evaluate the log of the probability
+ density function
+ """
+ dim = x.size
+ delta = x - self.mu
+ maha = np.sum(np.square(np.dot(delta, self.prec_U)), axis=-1)
+ return -0.5 * (dim * np.log(2 * np.pi) + self.log_det_cov + maha)
+
+if __name__ == '__main__':
+ import doctest
+ import pyspark.ml.stat.distribution
+ from pyspark.sql import SparkSession
+ globs = pyspark.ml.stat.distribution.__dict__.copy()
+ # The small batch size here ensures that we see multiple batches,
+ # even in these small test examples:
+ spark = SparkSession.builder\
+ .master("local[2]")\
+ .appName("ml.stat.distribution tests")\
+ .getOrCreate()
+ sc = spark.sparkContext
+ globs['sc'] = sc
+ globs['spark'] = spark
+ import tempfile
+ temp_path = tempfile.mkdtemp()
+ globs['temp_path'] = temp_path
+ try:
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ spark.stop()
+ finally:
+ from shutil import rmtree
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
+ if failure_count:
+ exit(-1)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 981ed9dda042..d8f619df12dc 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -59,6 +59,7 @@
from pyspark.ml.recommendation import ALS
from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor, \
GeneralizedLinearRegression
+from pyspark.ml.stat.distribution import MultivariateGaussian
from pyspark.ml.tuning import *
from pyspark.ml.wrapper import JavaParams
from pyspark.ml.common import _java2py
@@ -69,6 +70,7 @@
from pyspark.storagelevel import *
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
+
ser = PickleSerializer()
@@ -1195,12 +1197,12 @@ class VectorTests(MLlibTestCase):
def _test_serialize(self, v):
self.assertEqual(v, ser.loads(ser.dumps(v)))
- jvec = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(v)))
- nv = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvec)))
+ jvec = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(v)))
+ nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvec)))
self.assertEqual(v, nv)
vs = [v] * 100
- jvecs = self.sc._jvm.MLSerDe.loads(bytearray(ser.dumps(vs)))
- nvs = ser.loads(bytes(self.sc._jvm.MLSerDe.dumps(jvecs)))
+ jvecs = self.sc._jvm.org.apache.spark.ml.python.MLSerDe.loads(bytearray(ser.dumps(vs)))
+ nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.ml.python.MLSerDe.dumps(jvecs)))
self.assertEqual(vs, nvs)
def test_serialize(self):
@@ -1518,6 +1520,72 @@ def test_infer_schema(self):
raise ValueError("Expected a matrix but got type %r" % type(m))
+class MultiVariateGaussianTests(PySparkTestCase):
+ def test_univariate(self):
+ x1 = Vectors.dense([0.0])
+ x2 = Vectors.dense([1.5])
+
+ mu = Vectors.dense([0.0])
+ sigma1 = DenseMatrix(1, 1, [1.0])
+ dist1 = MultivariateGaussian(mu, sigma1)
+
+ self.assertAlmostEqual(dist1.pdf(x1), 0.39894, 5)
+ self.assertAlmostEqual(dist1.pdf(x2), 0.12952, 5)
+
+ sigma2 = DenseMatrix(1, 1, [4.0])
+ dist2 = MultivariateGaussian(mu, sigma2)
+
+ self.assertAlmostEqual(dist2.pdf(x1), 0.19947, 5)
+ self.assertAlmostEqual(dist2.pdf(x2), 0.15057, 5)
+
+ def test_multivariate(self):
+ x1 = Vectors.dense([0.0, 0.0])
+ x2 = Vectors.dense([1.0, 1.0])
+
+ mu = Vectors.dense([0.0, 0.0])
+ sigma1 = DenseMatrix(2, 2, [1.0, 0.0, 0.0, 1.0])
+ dist1 = MultivariateGaussian(mu, sigma1)
+
+ self.assertAlmostEqual(dist1.pdf(x1), 0.159154, 5)
+ self.assertAlmostEqual(dist1.pdf(x2), 0.05855, 5)
+
+ sigma2 = DenseMatrix(2, 2, [4.0, -1.0, -1.0, 2.0])
+ dist2 = MultivariateGaussian(mu, sigma2)
+
+ self.assertAlmostEqual(dist2.pdf(x1), 0.060155, 5)
+ self.assertAlmostEqual(dist2.pdf(x2), 0.0339717, 5)
+
+ def test_multivariate_degenerate(self):
+ x1 = Vectors.dense([0.0, 0.0])
+ x2 = Vectors.dense([1.0, 1.0])
+
+ mu = Vectors.dense([0.0, 0.0])
+ sigma1 = DenseMatrix(2, 2, [1.0, 1.0, 1.0, 1.0])
+ dist1 = MultivariateGaussian(mu, sigma1)
+
+ self.assertAlmostEqual(dist1.pdf(x1), 0.11254, 5)
+ self.assertAlmostEqual(dist1.pdf(x2), 0.068259, 5)
+
+ def test_SPARK_11302(self):
+ x = Vectors.dense([629, 640, 1.7188, 618.19])
+
+ mu = Vectors.dense(
+ [
+ 1055.3910505836575, 1070.489299610895,
+ 1.39020554474708, 1040.5907503867697])
+ sigma = DenseMatrix(
+ 4, 4, [
+ 166769.00466698944, 169336.6705268059,
+ 12.820670788921873, 164243.93314092053,
+ 169336.6705268059, 172041.5670061245,
+ 21.62590020524533, 166678.01075856484,
+ 12.820670788921873, 21.62590020524533,
+ 0.872524191943962, 4.283255814732373,
+ 164243.93314092053, 166678.01075856484,
+ 4.283255814732373, 161848.9196719207])
+ dist = MultivariateGaussian(mu, sigma)
+ self.assertAlmostEqual(dist.pdf(x), 0.00007154782, 9)
+
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner:
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 95f7278dc64c..c38c543972d1 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -507,7 +507,7 @@ def load(cls, sc, path):
Path to where the model is stored.
"""
model = cls._load_java(sc, path)
- wrapper = sc._jvm.GaussianMixtureModelWrapper(model)
+ wrapper = sc._jvm.org.apache.spark.mllib.api.python.GaussianMixtureModelWrapper(model)
return cls(wrapper)
@@ -571,14 +571,14 @@ class PowerIterationClusteringModel(JavaModelWrapper, JavaSaveable, JavaLoader):
>>> import math
>>> def genCircle(r, n):
- ... points = []
- ... for i in range(0, n):
- ... theta = 2.0 * math.pi * i / n
- ... points.append((r * math.cos(theta), r * math.sin(theta)))
- ... return points
+ ... points = []
+ ... for i in range(0, n):
+ ... theta = 2.0 * math.pi * i / n
+ ... points.append((r * math.cos(theta), r * math.sin(theta)))
+ ... return points
>>> def sim(x, y):
- ... dist2 = (x[0] - y[0]) * (x[0] - y[0]) + (x[1] - y[1]) * (x[1] - y[1])
- ... return math.exp(-dist2 / 2.0)
+ ... dist2 = (x[0] - y[0]) * (x[0] - y[0]) + (x[1] - y[1]) * (x[1] - y[1])
+ ... return math.exp(-dist2 / 2.0)
>>> r1 = 1.0
>>> n1 = 10
>>> r2 = 4.0
@@ -638,7 +638,8 @@ def load(cls, sc, path):
Load a model from the given path.
"""
model = cls._load_java(sc, path)
- wrapper = sc._jvm.PowerIterationClusteringModelWrapper(model)
+ wrapper =\
+ sc._jvm.org.apache.spark.mllib.api.python.PowerIterationClusteringModelWrapper(model)
return PowerIterationClusteringModel(wrapper)
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index 31afdf576b67..21f0e09ea774 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -66,7 +66,7 @@ def _to_java_object_rdd(rdd):
RDD is serialized in batch or not.
"""
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
- return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
+ return rdd.ctx._jvm.org.apache.spark.mllib.api.python.SerDe.pythonToJava(rdd._jrdd, True)
def _py2java(sc, obj):
@@ -85,7 +85,7 @@ def _py2java(sc, obj):
pass
else:
data = bytearray(PickleSerializer().dumps(obj))
- obj = sc._jvm.SerDe.loads(data)
+ obj = sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(data)
return obj
@@ -98,17 +98,17 @@ def _java2py(sc, r, encoding="bytes"):
clsName = 'JavaRDD'
if clsName == 'JavaRDD':
- jrdd = sc._jvm.SerDe.javaToPython(r)
+ jrdd = sc._jvm.org.apache.spark.mllib.api.python.SerDe.javaToPython(r)
return RDD(jrdd, sc)
if clsName == 'Dataset':
return DataFrame(r, SQLContext.getOrCreate(sc))
if clsName in _picklable_classes:
- r = sc._jvm.SerDe.dumps(r)
+ r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
elif isinstance(r, (JavaArray, JavaList)):
try:
- r = sc._jvm.SerDe.dumps(r)
+ r = sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(r)
except Py4JJavaError:
pass # not pickable
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index e31c75c1e867..aef91a8ddc1f 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -553,7 +553,7 @@ def load(cls, sc, path):
"""
jmodel = sc._jvm.org.apache.spark.mllib.feature \
.Word2VecModel.load(sc._jsc.sc(), path)
- model = sc._jvm.Word2VecModelWrapper(jmodel)
+ model = sc._jvm.org.apache.spark.mllib.api.python.Word2VecModelWrapper(jmodel)
return Word2VecModel(model)
diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
index ab4066f7d68b..fb226e84e5d5 100644
--- a/python/pyspark/mllib/fpm.py
+++ b/python/pyspark/mllib/fpm.py
@@ -64,7 +64,7 @@ def load(cls, sc, path):
Load a model from the given path.
"""
model = cls._load_java(sc, path)
- wrapper = sc._jvm.FPGrowthModelWrapper(model)
+ wrapper = sc._jvm.org.apache.spark.mllib.api.python.FPGrowthModelWrapper(model)
return FPGrowthModel(wrapper)
diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py
index 3a345b2b5638..15dc53a959d6 100644
--- a/python/pyspark/mllib/linalg/__init__.py
+++ b/python/pyspark/mllib/linalg/__init__.py
@@ -39,6 +39,7 @@
import numpy as np
from pyspark import since
+from pyspark.ml import linalg as newlinalg
from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
IntegerType, ByteType, BooleanType
@@ -247,6 +248,15 @@ def toArray(self):
"""
raise NotImplementedError
+ def asML(self):
+ """
+ Convert this vector to the new mllib-local representation.
+ This does NOT copy the data; it copies references.
+
+ :return: :py:class:`pyspark.ml.linalg.Vector`
+ """
+ raise NotImplementedError
+
class DenseVector(Vector):
"""
@@ -408,6 +418,17 @@ def toArray(self):
"""
return self.array
+ def asML(self):
+ """
+ Convert this vector to the new mllib-local representation.
+ This does NOT copy the data; it copies references.
+
+ :return: :py:class:`pyspark.ml.linalg.DenseVector`
+
+ .. versionadded:: 2.0.0
+ """
+ return newlinalg.DenseVector(self.array)
+
@property
def values(self):
"""
@@ -737,6 +758,17 @@ def toArray(self):
arr[self.indices] = self.values
return arr
+ def asML(self):
+ """
+ Convert this vector to the new mllib-local representation.
+ This does NOT copy the data; it copies references.
+
+ :return: :py:class:`pyspark.ml.linalg.SparseVector`
+
+ .. versionadded:: 2.0.0
+ """
+ return newlinalg.SparseVector(self.size, self.indices, self.values)
+
def __len__(self):
return self.size
@@ -845,6 +877,24 @@ def dense(*elements):
elements = elements[0]
return DenseVector(elements)
+ @staticmethod
+ def fromML(vec):
+ """
+ Convert a vector from the new mllib-local representation.
+ This does NOT copy the data; it copies references.
+
+ :param vec: a :py:class:`pyspark.ml.linalg.Vector`
+ :return: a :py:class:`pyspark.mllib.linalg.Vector`
+
+ .. versionadded:: 2.0.0
+ """
+ if isinstance(vec, newlinalg.DenseVector):
+ return DenseVector(vec.array)
+ elif isinstance(vec, newlinalg.SparseVector):
+ return SparseVector(vec.size, vec.indices, vec.values)
+ else:
+ raise TypeError("Unsupported vector type %s" % type(vec))
+
@staticmethod
def stringify(vector):
"""
@@ -945,6 +995,13 @@ def toArray(self):
"""
raise NotImplementedError
+ def asML(self):
+ """
+ Convert this matrix to the new mllib-local representation.
+ This does NOT copy the data; it copies references.
+ """
+ raise NotImplementedError
+
@staticmethod
def _convert_to_array(array_like, dtype):
"""
@@ -1044,6 +1101,17 @@ def toSparse(self):
return SparseMatrix(self.numRows, self.numCols, colPtrs, rowIndices, values)
+ def asML(self):
+ """
+ Convert this matrix to the new mllib-local representation.
+ This does NOT copy the data; it copies references.
+
+ :return: :py:class:`pyspark.ml.linalg.DenseMatrix`
+
+ .. versionadded:: 2.0.0
+ """
+ return newlinalg.DenseMatrix(self.numRows, self.numCols, self.values, self.isTransposed)
+
def __getitem__(self, indices):
i, j = indices
if i < 0 or i >= self.numRows:
@@ -1216,6 +1284,18 @@ def toDense(self):
densevals = np.ravel(self.toArray(), order='F')
return DenseMatrix(self.numRows, self.numCols, densevals)
+ def asML(self):
+ """
+ Convert this matrix to the new mllib-local representation.
+ This does NOT copy the data; it copies references.
+
+ :return: :py:class:`pyspark.ml.linalg.SparseMatrix`
+
+ .. versionadded:: 2.0.0
+ """
+ return newlinalg.SparseMatrix(self.numRows, self.numCols, self.colPtrs, self.rowIndices,
+ self.values, self.isTransposed)
+
# TODO: More efficient implementation:
def __eq__(self, other):
return np.all(self.toArray() == other.toArray())
@@ -1236,6 +1316,25 @@ def sparse(numRows, numCols, colPtrs, rowIndices, values):
"""
return SparseMatrix(numRows, numCols, colPtrs, rowIndices, values)
+ @staticmethod
+ def fromML(mat):
+ """
+ Convert a matrix from the new mllib-local representation.
+ This does NOT copy the data; it copies references.
+
+ :param mat: a :py:class:`pyspark.ml.linalg.Matrix`
+ :return: a :py:class:`pyspark.mllib.linalg.Matrix`
+
+ .. versionadded:: 2.0.0
+ """
+ if isinstance(mat, newlinalg.DenseMatrix):
+ return DenseMatrix(mat.numRows, mat.numCols, mat.values, mat.isTransposed)
+ elif isinstance(mat, newlinalg.SparseMatrix):
+ return SparseMatrix(mat.numRows, mat.numCols, mat.colPtrs, mat.rowIndices,
+ mat.values, mat.isTransposed)
+ else:
+ raise TypeError("Unsupported matrix type %s" % type(mat))
+
class QRDecomposition(object):
"""
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 7e60255d43ea..732300ee9c2c 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -207,7 +207,7 @@ def rank(self):
def load(cls, sc, path):
"""Load a model from the given path"""
model = cls._load_java(sc, path)
- wrapper = sc._jvm.MatrixFactorizationModelWrapper(model)
+ wrapper = sc._jvm.org.apache.spark.mllib.api.python.MatrixFactorizationModelWrapper(model)
return MatrixFactorizationModel(wrapper)
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 74cf7bb8eaf9..99bf50b5a164 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -49,6 +49,7 @@
import unittest
from pyspark import SparkContext
+import pyspark.ml.linalg as newlinalg
from pyspark.mllib.common import _to_java_object_rdd
from pyspark.mllib.clustering import StreamingKMeans, StreamingKMeansModel
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
@@ -149,12 +150,12 @@ class VectorTests(MLlibTestCase):
def _test_serialize(self, v):
self.assertEqual(v, ser.loads(ser.dumps(v)))
- jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
- nv = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvec)))
+ jvec = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(v)))
+ nv = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvec)))
self.assertEqual(v, nv)
vs = [v] * 100
- jvecs = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(vs)))
- nvs = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jvecs)))
+ jvecs = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(vs)))
+ nvs = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jvecs)))
self.assertEqual(vs, nvs)
def test_serialize(self):
@@ -423,6 +424,74 @@ def test_norms(self):
tmp = SparseVector(4, [0, 2], [3, 0])
self.assertEqual(tmp.numNonzeros(), 1)
+ def test_ml_mllib_vector_conversion(self):
+ # to ml
+ # dense
+ mllibDV = Vectors.dense([1, 2, 3])
+ mlDV1 = newlinalg.Vectors.dense([1, 2, 3])
+ mlDV2 = mllibDV.asML()
+ self.assertEqual(mlDV2, mlDV1)
+ # sparse
+ mllibSV = Vectors.sparse(4, {1: 1.0, 3: 5.5})
+ mlSV1 = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5})
+ mlSV2 = mllibSV.asML()
+ self.assertEqual(mlSV2, mlSV1)
+ # from ml
+ # dense
+ mllibDV1 = Vectors.dense([1, 2, 3])
+ mlDV = newlinalg.Vectors.dense([1, 2, 3])
+ mllibDV2 = Vectors.fromML(mlDV)
+ self.assertEqual(mllibDV1, mllibDV2)
+ # sparse
+ mllibSV1 = Vectors.sparse(4, {1: 1.0, 3: 5.5})
+ mlSV = newlinalg.Vectors.sparse(4, {1: 1.0, 3: 5.5})
+ mllibSV2 = Vectors.fromML(mlSV)
+ self.assertEqual(mllibSV1, mllibSV2)
+
+ def test_ml_mllib_matrix_conversion(self):
+ # to ml
+ # dense
+ mllibDM = Matrices.dense(2, 2, [0, 1, 2, 3])
+ mlDM1 = newlinalg.Matrices.dense(2, 2, [0, 1, 2, 3])
+ mlDM2 = mllibDM.asML()
+ self.assertEqual(mlDM2, mlDM1)
+ # transposed
+ mllibDMt = DenseMatrix(2, 2, [0, 1, 2, 3], True)
+ mlDMt1 = newlinalg.DenseMatrix(2, 2, [0, 1, 2, 3], True)
+ mlDMt2 = mllibDMt.asML()
+ self.assertEqual(mlDMt2, mlDMt1)
+ # sparse
+ mllibSM = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ mlSM1 = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ mlSM2 = mllibSM.asML()
+ self.assertEqual(mlSM2, mlSM1)
+ # transposed
+ mllibSMt = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+ mlSMt1 = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+ mlSMt2 = mllibSMt.asML()
+ self.assertEqual(mlSMt2, mlSMt1)
+ # from ml
+ # dense
+ mllibDM1 = Matrices.dense(2, 2, [1, 2, 3, 4])
+ mlDM = newlinalg.Matrices.dense(2, 2, [1, 2, 3, 4])
+ mllibDM2 = Matrices.fromML(mlDM)
+ self.assertEqual(mllibDM1, mllibDM2)
+ # transposed
+ mllibDMt1 = DenseMatrix(2, 2, [1, 2, 3, 4], True)
+ mlDMt = newlinalg.DenseMatrix(2, 2, [1, 2, 3, 4], True)
+ mllibDMt2 = Matrices.fromML(mlDMt)
+ self.assertEqual(mllibDMt1, mllibDMt2)
+ # sparse
+ mllibSM1 = Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ mlSM = newlinalg.Matrices.sparse(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4])
+ mllibSM2 = Matrices.fromML(mlSM)
+ self.assertEqual(mllibSM1, mllibSM2)
+ # transposed
+ mllibSMt1 = SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+ mlSMt = newlinalg.SparseMatrix(2, 2, [0, 2, 3], [0, 1, 1], [2, 3, 4], True)
+ mllibSMt2 = Matrices.fromML(mlSMt)
+ self.assertEqual(mllibSMt1, mllibSMt2)
+
class ListTests(MLlibTestCase):
@@ -1581,8 +1650,8 @@ class ALSTests(MLlibTestCase):
def test_als_ratings_serialize(self):
r = Rating(7, 1123, 3.14)
- jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r)))
- nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr)))
+ jr = self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads(bytearray(ser.dumps(r)))
+ nr = ser.loads(bytes(self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.dumps(jr)))
self.assertEqual(r.user, nr.user)
self.assertEqual(r.product, nr.product)
self.assertAlmostEqual(r.rating, nr.rating, 2)
@@ -1590,7 +1659,8 @@ def test_als_ratings_serialize(self):
def test_als_ratings_id_long_error(self):
r = Rating(1205640308657491975, 50233468418, 1.0)
# rating user id exceeds max int value, should fail when pickled
- self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
+ self.assertRaises(Py4JJavaError, self.sc._jvm.org.apache.spark.mllib.api.python.SerDe.loads,
+ bytearray(ser.dumps(r)))
class HashingTFTest(MLlibTestCase):
diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py
index 3033f147bc96..4af930a3cd56 100644
--- a/python/pyspark/sql/catalog.py
+++ b/python/pyspark/sql/catalog.py
@@ -232,6 +232,11 @@ def clearCache(self):
"""Removes all cached tables from the in-memory cache."""
self._jcatalog.clearCache()
+ @since(2.0)
+ def refreshTable(self, tableName):
+ """Invalidate and refresh all the cached metadata of the given table."""
+ self._jcatalog.refreshTable(tableName)
+
def _reset(self):
"""(Internal use only) Drop all existing databases (except "default"), tables,
partitions and functions, and set the current database to "default".
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index 8a1a874884e2..4cfdf799f6f4 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -27,6 +27,7 @@
from pyspark.sql.session import _monkey_patch_RDD, SparkSession
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.readwriter import DataFrameReader
+from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import Row, StringType
from pyspark.sql.utils import install_exception_handler
@@ -438,8 +439,12 @@ def readStream(self):
.. note:: Experimental.
:return: :class:`DataStreamReader`
+
+ >>> text_sdf = sqlContext.readStream.text(tempfile.mkdtemp())
+ >>> text_sdf.isStreaming
+ True
"""
- return DataStreamReader(self._wrapped)
+ return DataStreamReader(self)
@property
@since(2.0)
@@ -487,7 +492,7 @@ def _createForTesting(cls, sparkContext):
confusing error messages.
"""
jsc = sparkContext._jsc.sc()
- jtestHive = sparkContext._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc)
+ jtestHive = sparkContext._jvm.org.apache.spark.sql.hive.test.TestHiveContext(jsc, False)
return cls(sparkContext, jtestHive)
def refreshTable(self, tableName):
@@ -515,6 +520,7 @@ def register(self, name, f, returnType=StringType()):
def _test():
import os
import doctest
+ import tempfile
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
import pyspark.sql.context
@@ -523,6 +529,8 @@ def _test():
globs = pyspark.sql.context.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
+ globs['tempfile'] = tempfile
+ globs['os'] = os
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
globs['rdd'] = rdd = sc.parallelize(
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index acf9d08b23a2..c7d704a18ada 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -33,7 +33,8 @@
from pyspark.traceback_utils import SCCallSiteSync
from pyspark.sql.types import _parse_datatype_json_string
from pyspark.sql.column import Column, _to_seq, _to_list, _to_java_column
-from pyspark.sql.readwriter import DataFrameWriter, DataStreamWriter
+from pyspark.sql.readwriter import DataFrameWriter
+from pyspark.sql.streaming import DataStreamWriter
from pyspark.sql.types import *
__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"]
@@ -257,8 +258,8 @@ def isLocal(self):
def isStreaming(self):
"""Returns true if this :class:`Dataset` contains one or more sources that continuously
return data as it arrives. A :class:`Dataset` that reads data from a streaming source
- must be executed as a :class:`StreamingQuery` using the :func:`startStream` method in
- :class:`DataFrameWriter`. Methods that return a single answer, (e.g., :func:`count` or
+ must be executed as a :class:`StreamingQuery` using the :func:`start` method in
+ :class:`DataStreamWriter`. Methods that return a single answer, (e.g., :func:`count` or
:func:`collect`) will throw an :class:`AnalysisException` when there is a streaming
source present.
@@ -1032,10 +1033,10 @@ def dropDuplicates(self, subset=None):
:func:`drop_duplicates` is an alias for :func:`dropDuplicates`.
>>> from pyspark.sql import Row
- >>> df = sc.parallelize([ \
- Row(name='Alice', age=5, height=80), \
- Row(name='Alice', age=5, height=80), \
- Row(name='Alice', age=10, height=80)]).toDF()
+ >>> df = sc.parallelize([ \\
+ ... Row(name='Alice', age=5, height=80), \\
+ ... Row(name='Alice', age=5, height=80), \\
+ ... Row(name='Alice', age=10, height=80)]).toDF()
>>> df.dropDuplicates().show()
+---+------+-----+
|age|height| name|
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 15cefc8cf112..92d709ee40e1 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1550,8 +1550,8 @@ def translate(srcCol, matching, replace):
The translate will happen when any character in the string matching with the character
in the `matching`.
- >>> spark.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123")\
- .alias('r')).collect()
+ >>> spark.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123") \\
+ ... .alias('r')).collect()
[Row(r=u'1a2s3ae')]
"""
sc = SparkContext._active_spark_context
@@ -1637,6 +1637,27 @@ def explode(col):
return Column(jc)
+@since(2.1)
+def posexplode(col):
+ """Returns a new row for each element with position in the given array or map.
+
+ >>> from pyspark.sql import Row
+ >>> eDF = spark.createDataFrame([Row(a=1, intlist=[1,2,3], mapfield={"a": "b"})])
+ >>> eDF.select(posexplode(eDF.intlist)).collect()
+ [Row(pos=0, col=1), Row(pos=1, col=2), Row(pos=2, col=3)]
+
+ >>> eDF.select(posexplode(eDF.mapfield)).show()
+ +---+---+-----+
+ |pos|key|value|
+ +---+---+-----+
+ | 0| a| b|
+ +---+---+-----+
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.posexplode(_to_java_column(col))
+ return Column(jc)
+
+
@ignore_unicode_prefix
@since(1.6)
def get_json_object(col, path):
@@ -1649,8 +1670,8 @@ def get_json_object(col, path):
>>> data = [("1", '''{"f1": "value1", "f2": "value2"}'''), ("2", '''{"f1": "value12"}''')]
>>> df = spark.createDataFrame(data, ("key", "jstring"))
- >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \
- get_json_object(df.jstring, '$.f2').alias("c1") ).collect()
+ >>> df.select(df.key, get_json_object(df.jstring, '$.f1').alias("c0"), \\
+ ... get_json_object(df.jstring, '$.f2').alias("c1") ).collect()
[Row(key=u'1', c0=u'value1', c1=u'value2'), Row(key=u'2', c0=u'value12', c1=None)]
"""
sc = SparkContext._active_spark_context
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index a4232065540e..f2092f9c6305 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -179,10 +179,12 @@ def pivot(self, pivot_col, values=None):
:param values: List of values that will be translated to columns in the output DataFrame.
# Compute the sum of earnings for each year by course with each course as a separate column
+
>>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect()
[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
# Or without specifying column values (less efficient)
+
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py
index ccbf895c2d88..f7c354f51330 100644
--- a/python/pyspark/sql/readwriter.py
+++ b/python/pyspark/sql/readwriter.py
@@ -28,7 +28,7 @@
from pyspark.sql.types import *
from pyspark.sql import utils
-__all__ = ["DataFrameReader", "DataFrameWriter", "DataStreamReader", "DataStreamWriter"]
+__all__ = ["DataFrameReader", "DataFrameWriter"]
def to_str(value):
@@ -44,84 +44,20 @@ def to_str(value):
return str(value)
-class ReaderUtils(object):
+class OptionUtils(object):
- def _set_json_opts(self, schema, primitivesAsString, prefersDecimal,
- allowComments, allowUnquotedFieldNames, allowSingleQuotes,
- allowNumericLeadingZero, allowBackslashEscapingAnyCharacter,
- mode, columnNameOfCorruptRecord):
+ def _set_opts(self, schema=None, **options):
"""
- Set options based on the Json optional parameters
+ Set named options (filter out those the value is None)
"""
if schema is not None:
self.schema(schema)
- if primitivesAsString is not None:
- self.option("primitivesAsString", primitivesAsString)
- if prefersDecimal is not None:
- self.option("prefersDecimal", prefersDecimal)
- if allowComments is not None:
- self.option("allowComments", allowComments)
- if allowUnquotedFieldNames is not None:
- self.option("allowUnquotedFieldNames", allowUnquotedFieldNames)
- if allowSingleQuotes is not None:
- self.option("allowSingleQuotes", allowSingleQuotes)
- if allowNumericLeadingZero is not None:
- self.option("allowNumericLeadingZero", allowNumericLeadingZero)
- if allowBackslashEscapingAnyCharacter is not None:
- self.option("allowBackslashEscapingAnyCharacter", allowBackslashEscapingAnyCharacter)
- if mode is not None:
- self.option("mode", mode)
- if columnNameOfCorruptRecord is not None:
- self.option("columnNameOfCorruptRecord", columnNameOfCorruptRecord)
-
- def _set_csv_opts(self, schema, sep, encoding, quote, escape,
- comment, header, inferSchema, ignoreLeadingWhiteSpace,
- ignoreTrailingWhiteSpace, nullValue, nanValue, positiveInf, negativeInf,
- dateFormat, maxColumns, maxCharsPerColumn, maxMalformedLogPerPartition, mode):
- """
- Set options based on the CSV optional parameters
- """
- if schema is not None:
- self.schema(schema)
- if sep is not None:
- self.option("sep", sep)
- if encoding is not None:
- self.option("encoding", encoding)
- if quote is not None:
- self.option("quote", quote)
- if escape is not None:
- self.option("escape", escape)
- if comment is not None:
- self.option("comment", comment)
- if header is not None:
- self.option("header", header)
- if inferSchema is not None:
- self.option("inferSchema", inferSchema)
- if ignoreLeadingWhiteSpace is not None:
- self.option("ignoreLeadingWhiteSpace", ignoreLeadingWhiteSpace)
- if ignoreTrailingWhiteSpace is not None:
- self.option("ignoreTrailingWhiteSpace", ignoreTrailingWhiteSpace)
- if nullValue is not None:
- self.option("nullValue", nullValue)
- if nanValue is not None:
- self.option("nanValue", nanValue)
- if positiveInf is not None:
- self.option("positiveInf", positiveInf)
- if negativeInf is not None:
- self.option("negativeInf", negativeInf)
- if dateFormat is not None:
- self.option("dateFormat", dateFormat)
- if maxColumns is not None:
- self.option("maxColumns", maxColumns)
- if maxCharsPerColumn is not None:
- self.option("maxCharsPerColumn", maxCharsPerColumn)
- if maxMalformedLogPerPartition is not None:
- self.option("maxMalformedLogPerPartition", maxMalformedLogPerPartition)
- if mode is not None:
- self.option("mode", mode)
-
-
-class DataFrameReader(ReaderUtils):
+ for k, v in options.items():
+ if v is not None:
+ self.option(k, v)
+
+
+class DataFrameReader(OptionUtils):
"""
Interface used to load a :class:`DataFrame` from external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`spark.read`
@@ -207,7 +143,9 @@ def load(self, path=None, format=None, schema=None, **options):
if schema is not None:
self.schema(schema)
self.options(**options)
- if path is not None:
+ if isinstance(path, basestring):
+ return self._df(self._jreader.load(path))
+ elif path is not None:
if type(path) != list:
path = [path]
return self._df(self._jreader.load(self._spark._sc._jvm.PythonUtils.toSeq(path)))
@@ -270,7 +208,7 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
[('age', 'bigint'), ('name', 'string')]
"""
- self._set_json_opts(
+ self._set_opts(
schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal,
allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames,
allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
@@ -413,7 +351,7 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non
>>> df.dtypes
[('_c0', 'string'), ('_c1', 'string')]
"""
- self._set_csv_opts(
+ self._set_opts(
schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
@@ -484,7 +422,7 @@ def jdbc(self, url, table, column=None, lowerBound=None, upperBound=None, numPar
return self._df(self._jreader.jdbc(url, table, jprop))
-class DataFrameWriter(object):
+class DataFrameWriter(OptionUtils):
"""
Interface used to write a :class:`DataFrame` to external storage systems
(e.g. file systems, key-value stores, etc). Use :func:`DataFrame.write`
@@ -649,8 +587,7 @@ def json(self, path, mode=None, compression=None):
>>> df.write.json(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
- if compression is not None:
- self.option("compression", compression)
+ self._set_opts(compression=compression)
self._jwrite.json(path)
@since(1.4)
@@ -676,8 +613,7 @@ def parquet(self, path, mode=None, partitionBy=None, compression=None):
self.mode(mode)
if partitionBy is not None:
self.partitionBy(partitionBy)
- if compression is not None:
- self.option("compression", compression)
+ self._set_opts(compression=compression)
self._jwrite.parquet(path)
@since(1.6)
@@ -692,13 +628,12 @@ def text(self, path, compression=None):
The DataFrame must have only one column that is of string type.
Each row becomes a new line in the output file.
"""
- if compression is not None:
- self.option("compression", compression)
+ self._set_opts(compression=compression)
self._jwrite.text(path)
@since(2.0)
def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=None,
- header=None, nullValue=None, escapeQuotes=None):
+ header=None, nullValue=None, escapeQuotes=None, quoteAll=None):
"""Saves the content of the :class:`DataFrame` in CSV format at the specified path.
:param path: the path in any Hadoop supported file system
@@ -723,6 +658,9 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
:param escapeQuotes: A flag indicating whether values containing quotes should always
be enclosed in quotes. If None is set, it uses the default value
``true``, escaping all values containing a quote character.
+ :param quoteAll: A flag indicating whether all values should always be enclosed in
+ quotes. If None is set, it uses the default value ``false``,
+ only escaping values containing a quote character.
:param header: writes the names of columns as the first line. If None is set, it uses
the default value, ``false``.
:param nullValue: sets the string representation of a null value. If None is set, it uses
@@ -731,20 +669,8 @@ def csv(self, path, mode=None, compression=None, sep=None, quote=None, escape=No
>>> df.write.csv(os.path.join(tempfile.mkdtemp(), 'data'))
"""
self.mode(mode)
- if compression is not None:
- self.option("compression", compression)
- if sep is not None:
- self.option("sep", sep)
- if quote is not None:
- self.option("quote", quote)
- if escape is not None:
- self.option("escape", escape)
- if header is not None:
- self.option("header", header)
- if nullValue is not None:
- self.option("nullValue", nullValue)
- if escapeQuotes is not None:
- self.option("escapeQuotes", nullValue)
+ self._set_opts(compression=compression, sep=sep, quote=quote, escape=escape, header=header,
+ nullValue=nullValue, escapeQuotes=escapeQuotes, quoteAll=quoteAll)
self._jwrite.csv(path)
@since(1.5)
@@ -772,8 +698,7 @@ def orc(self, path, mode=None, partitionBy=None, compression=None):
self.mode(mode)
if partitionBy is not None:
self.partitionBy(partitionBy)
- if compression is not None:
- self.option("compression", compression)
+ self._set_opts(compression=compression)
self._jwrite.orc(path)
@since(1.4)
@@ -803,494 +728,6 @@ def jdbc(self, url, table, mode=None, properties=None):
self._jwrite.mode(mode).jdbc(url, table, jprop)
-class DataStreamReader(ReaderUtils):
- """
- Interface used to load a streaming :class:`DataFrame` from external storage systems
- (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream`
- to access this.
-
- .. note:: Experimental.
-
- .. versionadded:: 2.0
- """
-
- def __init__(self, spark):
- self._jreader = spark._ssql_ctx.readStream()
- self._spark = spark
-
- def _df(self, jdf):
- from pyspark.sql.dataframe import DataFrame
- return DataFrame(jdf, self._spark)
-
- @since(2.0)
- def format(self, source):
- """Specifies the input data source format.
-
- .. note:: Experimental.
-
- :param source: string, name of the data source, e.g. 'json', 'parquet'.
-
- >>> s = spark.readStream.format("text")
- """
- self._jreader = self._jreader.format(source)
- return self
-
- @since(2.0)
- def schema(self, schema):
- """Specifies the input schema.
-
- Some data sources (e.g. JSON) can infer the input schema automatically from data.
- By specifying the schema here, the underlying data source can skip the schema
- inference step, and thus speed up data loading.
-
- .. note:: Experimental.
-
- :param schema: a StructType object
-
- >>> s = spark.readStream.schema(sdf_schema)
- """
- if not isinstance(schema, StructType):
- raise TypeError("schema should be StructType")
- jschema = self._spark._ssql_ctx.parseDataType(schema.json())
- self._jreader = self._jreader.schema(jschema)
- return self
-
- @since(2.0)
- def option(self, key, value):
- """Adds an input option for the underlying data source.
-
- .. note:: Experimental.
-
- >>> s = spark.readStream.option("x", 1)
- """
- self._jreader = self._jreader.option(key, to_str(value))
- return self
-
- @since(2.0)
- def options(self, **options):
- """Adds input options for the underlying data source.
-
- .. note:: Experimental.
-
- >>> s = spark.readStream.options(x="1", y=2)
- """
- for k in options:
- self._jreader = self._jreader.option(k, to_str(options[k]))
- return self
-
- @since(2.0)
- def load(self, path=None, format=None, schema=None, **options):
- """Loads a data stream from a data source and returns it as a :class`DataFrame`.
-
- .. note:: Experimental.
-
- :param path: optional string for file-system backed data sources.
- :param format: optional string for format of the data source. Default to 'parquet'.
- :param schema: optional :class:`StructType` for the input schema.
- :param options: all other string options
-
- >>> json_sdf = spark.readStream.format("json")\
- .schema(sdf_schema)\
- .load(os.path.join(tempfile.mkdtemp(),'data'))
- >>> json_sdf.isStreaming
- True
- >>> json_sdf.schema == sdf_schema
- True
- """
- if format is not None:
- self.format(format)
- if schema is not None:
- self.schema(schema)
- self.options(**options)
- if path is not None:
- if type(path) != str or len(path.strip()) == 0:
- raise ValueError("If the path is provided for stream, it needs to be a " +
- "non-empty string. List of paths are not supported.")
- return self._df(self._jreader.load(path))
- else:
- return self._df(self._jreader.load())
-
- @since(2.0)
- def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
- allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
- allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
- mode=None, columnNameOfCorruptRecord=None):
- """
- Loads a JSON file stream (one object per line) and returns a :class`DataFrame`.
-
- If the ``schema`` parameter is not specified, this function goes
- through the input once to determine the input schema.
-
- .. note:: Experimental.
-
- :param path: string represents path to the JSON dataset,
- or RDD of Strings storing JSON objects.
- :param schema: an optional :class:`StructType` for the input schema.
- :param primitivesAsString: infers all primitive values as a string type. If None is set,
- it uses the default value, ``false``.
- :param prefersDecimal: infers all floating-point values as a decimal type. If the values
- do not fit in decimal, then it infers them as doubles. If None is
- set, it uses the default value, ``false``.
- :param allowComments: ignores Java/C++ style comment in JSON records. If None is set,
- it uses the default value, ``false``.
- :param allowUnquotedFieldNames: allows unquoted JSON field names. If None is set,
- it uses the default value, ``false``.
- :param allowSingleQuotes: allows single quotes in addition to double quotes. If None is
- set, it uses the default value, ``true``.
- :param allowNumericLeadingZero: allows leading zeros in numbers (e.g. 00012). If None is
- set, it uses the default value, ``false``.
- :param allowBackslashEscapingAnyCharacter: allows accepting quoting of all character
- using backslash quoting mechanism. If None is
- set, it uses the default value, ``false``.
- :param mode: allows a mode for dealing with corrupt records during parsing. If None is
- set, it uses the default value, ``PERMISSIVE``.
-
- * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
- record and puts the malformed string into a new field configured by \
- ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \
- ``null`` for extra fields.
- * ``DROPMALFORMED`` : ignores the whole corrupted records.
- * ``FAILFAST`` : throws an exception when it meets corrupted records.
-
- :param columnNameOfCorruptRecord: allows renaming the new field having malformed string
- created by ``PERMISSIVE`` mode. This overrides
- ``spark.sql.columnNameOfCorruptRecord``. If None is set,
- it uses the value specified in
- ``spark.sql.columnNameOfCorruptRecord``.
-
- >>> json_sdf = spark.readStream.json(os.path.join(tempfile.mkdtemp(), 'data'), \
- schema = sdf_schema)
- >>> json_sdf.isStreaming
- True
- >>> json_sdf.schema == sdf_schema
- True
- """
- self._set_json_opts(
- schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal,
- allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames,
- allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
- allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
- mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord)
- if isinstance(path, basestring):
- return self._df(self._jreader.json(path))
- else:
- raise TypeError("path can be only a single string")
-
- @since(2.0)
- def parquet(self, path):
- """Loads a Parquet file stream, returning the result as a :class:`DataFrame`.
-
- You can set the following Parquet-specific option(s) for reading Parquet files:
- * ``mergeSchema``: sets whether we should merge schemas collected from all \
- Parquet part-files. This will override ``spark.sql.parquet.mergeSchema``. \
- The default value is specified in ``spark.sql.parquet.mergeSchema``.
-
- .. note:: Experimental.
-
- >>> parquet_sdf = spark.readStream.schema(sdf_schema)\
- .parquet(os.path.join(tempfile.mkdtemp()))
- >>> parquet_sdf.isStreaming
- True
- >>> parquet_sdf.schema == sdf_schema
- True
- """
- if isinstance(path, basestring):
- return self._df(self._jreader.parquet(path))
- else:
- raise TypeError("path can be only a single string")
-
- @ignore_unicode_prefix
- @since(2.0)
- def text(self, path):
- """
- Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a
- string column named "value", and followed by partitioned columns if there
- are any.
-
- Each line in the text file is a new row in the resulting DataFrame.
-
- .. note:: Experimental.
-
- :param paths: string, or list of strings, for input path(s).
-
- >>> text_sdf = spark.readStream.text(os.path.join(tempfile.mkdtemp(), 'data'))
- >>> text_sdf.isStreaming
- True
- >>> "value" in str(text_sdf.schema)
- True
- """
- if isinstance(path, basestring):
- return self._df(self._jreader.text(path))
- else:
- raise TypeError("path can be only a single string")
-
- @since(2.0)
- def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
- comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
- ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
- negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None,
- maxMalformedLogPerPartition=None, mode=None):
- """Loads a CSV file stream and returns the result as a :class:`DataFrame`.
-
- This function will go through the input once to determine the input schema if
- ``inferSchema`` is enabled. To avoid going through the entire data once, disable
- ``inferSchema`` option or specify the schema explicitly using ``schema``.
-
- .. note:: Experimental.
-
- :param path: string, or list of strings, for input path(s).
- :param schema: an optional :class:`StructType` for the input schema.
- :param sep: sets the single character as a separator for each field and value.
- If None is set, it uses the default value, ``,``.
- :param encoding: decodes the CSV files by the given encoding type. If None is set,
- it uses the default value, ``UTF-8``.
- :param quote: sets the single character used for escaping quoted values where the
- separator can be part of the value. If None is set, it uses the default
- value, ``"``. If you would like to turn off quotations, you need to set an
- empty string.
- :param escape: sets the single character used for escaping quotes inside an already
- quoted value. If None is set, it uses the default value, ``\``.
- :param comment: sets the single character used for skipping lines beginning with this
- character. By default (None), it is disabled.
- :param header: uses the first line as names of columns. If None is set, it uses the
- default value, ``false``.
- :param inferSchema: infers the input schema automatically from data. It requires one extra
- pass over the data. If None is set, it uses the default value, ``false``.
- :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values
- being read should be skipped. If None is set, it uses
- the default value, ``false``.
- :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values
- being read should be skipped. If None is set, it uses
- the default value, ``false``.
- :param nullValue: sets the string representation of a null value. If None is set, it uses
- the default value, empty string.
- :param nanValue: sets the string representation of a non-number value. If None is set, it
- uses the default value, ``NaN``.
- :param positiveInf: sets the string representation of a positive infinity value. If None
- is set, it uses the default value, ``Inf``.
- :param negativeInf: sets the string representation of a negative infinity value. If None
- is set, it uses the default value, ``Inf``.
- :param dateFormat: sets the string that indicates a date format. Custom date formats
- follow the formats at ``java.text.SimpleDateFormat``. This
- applies to both date type and timestamp type. By default, it is None
- which means trying to parse times and date by
- ``java.sql.Timestamp.valueOf()`` and ``java.sql.Date.valueOf()``.
- :param maxColumns: defines a hard limit of how many columns a record can have. If None is
- set, it uses the default value, ``20480``.
- :param maxCharsPerColumn: defines the maximum number of characters allowed for any given
- value being read. If None is set, it uses the default value,
- ``1000000``.
- :param mode: allows a mode for dealing with corrupt records during parsing. If None is
- set, it uses the default value, ``PERMISSIVE``.
-
- * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record.
- When a schema is set by user, it sets ``null`` for extra fields.
- * ``DROPMALFORMED`` : ignores the whole corrupted records.
- * ``FAILFAST`` : throws an exception when it meets corrupted records.
-
- >>> csv_sdf = spark.readStream.csv(os.path.join(tempfile.mkdtemp(), 'data'), \
- schema = sdf_schema)
- >>> csv_sdf.isStreaming
- True
- >>> csv_sdf.schema == sdf_schema
- True
- """
- self._set_csv_opts(
- schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
- header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
- ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
- nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf,
- dateFormat=dateFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn,
- maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode)
- if isinstance(path, basestring):
- return self._df(self._jreader.csv(path))
- else:
- raise TypeError("path can be only a single string")
-
-
-class DataStreamWriter(object):
- """
- Interface used to write a streaming :class:`DataFrame` to external storage systems
- (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.writeStream`
- to access this.
-
- .. note:: Experimental.
-
- .. versionadded:: 2.0
- """
-
- def __init__(self, df):
- self._df = df
- self._spark = df.sql_ctx
- self._jwrite = df._jdf.writeStream()
-
- def _sq(self, jsq):
- from pyspark.sql.streaming import StreamingQuery
- return StreamingQuery(jsq)
-
- @since(2.0)
- def outputMode(self, outputMode):
- """Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
-
- Options include:
-
- * `append`:Only the new rows in the streaming DataFrame/Dataset will be written to
- the sink
- * `complete`:All the rows in the streaming DataFrame/Dataset will be written to the sink
- every time these is some updates
-
- .. note:: Experimental.
-
- >>> writer = sdf.writeStream.outputMode('append')
- """
- if not outputMode or type(outputMode) != str or len(outputMode.strip()) == 0:
- raise ValueError('The output mode must be a non-empty string. Got: %s' % outputMode)
- self._jwrite = self._jwrite.outputMode(outputMode)
- return self
-
- @since(2.0)
- def format(self, source):
- """Specifies the underlying output data source.
-
- .. note:: Experimental.
-
- :param source: string, name of the data source, e.g. 'json', 'parquet'.
-
- >>> writer = sdf.writeStream.format('json')
- """
- self._jwrite = self._jwrite.format(source)
- return self
-
- @since(2.0)
- def option(self, key, value):
- """Adds an output option for the underlying data source.
-
- .. note:: Experimental.
- """
- self._jwrite = self._jwrite.option(key, to_str(value))
- return self
-
- @since(2.0)
- def options(self, **options):
- """Adds output options for the underlying data source.
-
- .. note:: Experimental.
- """
- for k in options:
- self._jwrite = self._jwrite.option(k, to_str(options[k]))
- return self
-
- @since(2.0)
- def partitionBy(self, *cols):
- """Partitions the output by the given columns on the file system.
-
- If specified, the output is laid out on the file system similar
- to Hive's partitioning scheme.
-
- .. note:: Experimental.
-
- :param cols: name of columns
-
- """
- if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
- cols = cols[0]
- self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
- return self
-
- @since(2.0)
- def queryName(self, queryName):
- """Specifies the name of the :class:`StreamingQuery` that can be started with
- :func:`start`. This name must be unique among all the currently active queries
- in the associated SparkSession.
-
- .. note:: Experimental.
-
- :param queryName: unique name for the query
-
- >>> writer = sdf.writeStream.queryName('streaming_query')
- """
- if not queryName or type(queryName) != str or len(queryName.strip()) == 0:
- raise ValueError('The queryName must be a non-empty string. Got: %s' % queryName)
- self._jwrite = self._jwrite.queryName(queryName)
- return self
-
- @keyword_only
- @since(2.0)
- def trigger(self, processingTime=None):
- """Set the trigger for the stream query. If this is not set it will run the query as fast
- as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``.
-
- .. note:: Experimental.
-
- :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'.
-
- >>> # trigger the query for execution every 5 seconds
- >>> writer = sdf.writeStream.trigger(processingTime='5 seconds')
- """
- from pyspark.sql.streaming import ProcessingTime
- trigger = None
- if processingTime is not None:
- if type(processingTime) != str or len(processingTime.strip()) == 0:
- raise ValueError('The processing time must be a non empty string. Got: %s' %
- processingTime)
- trigger = ProcessingTime(processingTime)
- if trigger is None:
- raise ValueError('A trigger was not provided. Supported triggers: processingTime.')
- self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark))
- return self
-
- @ignore_unicode_prefix
- @since(2.0)
- def start(self, path=None, format=None, partitionBy=None, queryName=None, **options):
- """Streams the contents of the :class:`DataFrame` to a data source.
-
- The data source is specified by the ``format`` and a set of ``options``.
- If ``format`` is not specified, the default data source configured by
- ``spark.sql.sources.default`` will be used.
-
- .. note:: Experimental.
-
- :param path: the path in a Hadoop supported file system
- :param format: the format used to save
-
- * ``append``: Append contents of this :class:`DataFrame` to existing data.
- * ``overwrite``: Overwrite existing data.
- * ``ignore``: Silently ignore this operation if data already exists.
- * ``error`` (default case): Throw an exception if data already exists.
- :param partitionBy: names of partitioning columns
- :param queryName: unique name for the query
- :param options: All other string options. You may want to provide a `checkpointLocation`
- for most streams, however it is not required for a `memory` stream.
-
- >>> sq = sdf.writeStream.format('memory').queryName('this_query').start()
- >>> sq.isActive
- True
- >>> sq.name
- u'this_query'
- >>> sq.stop()
- >>> sq.isActive
- False
- >>> sq = sdf.writeStream.trigger(processingTime='5 seconds').start(
- ... queryName='that_query', format='memory')
- >>> sq.name
- u'that_query'
- >>> sq.isActive
- True
- >>> sq.stop()
- """
- self.options(**options)
- if partitionBy is not None:
- self.partitionBy(partitionBy)
- if format is not None:
- self.format(format)
- if queryName is not None:
- self.queryName(queryName)
- if path is None:
- return self._sq(self._jwrite.start())
- else:
- return self._sq(self._jwrite.start(path))
-
-
def _test():
import doctest
import os
@@ -1314,9 +751,6 @@ def _test():
globs['sc'] = sc
globs['spark'] = spark
globs['df'] = spark.read.parquet('python/test_support/sql/parquet_partitioned')
- globs['sdf'] = \
- spark.readStream.format('text').load('python/test_support/sql/streaming')
- globs['sdf_schema'] = StructType([StructField("data", StringType(), False)])
(failure_count, test_count) = doctest.testmod(
pyspark.sql.readwriter, globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE | doctest.REPORT_NDIFF)
diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py
index b4152a34ad97..a360fbefa492 100644
--- a/python/pyspark/sql/session.py
+++ b/python/pyspark/sql/session.py
@@ -31,7 +31,8 @@
from pyspark.sql.catalog import Catalog
from pyspark.sql.conf import RuntimeConfig
from pyspark.sql.dataframe import DataFrame
-from pyspark.sql.readwriter import DataFrameReader, DataStreamReader
+from pyspark.sql.readwriter import DataFrameReader
+from pyspark.sql.streaming import DataStreamReader
from pyspark.sql.types import Row, DataType, StringType, StructType, _verify_type, \
_infer_schema, _has_nulltype, _merge_type, _create_converter, _parse_datatype_string
from pyspark.sql.utils import install_exception_handler
@@ -65,12 +66,11 @@ class SparkSession(object):
tables, execute SQL over tables, cache tables, and read parquet files.
To create a SparkSession, use the following builder pattern:
- >>> spark = SparkSession.builder \
- .master("local") \
- .appName("Word Count") \
- .config("spark.some.config.option", "some-value") \
- .getOrCreate()
-
+ >>> spark = SparkSession.builder \\
+ ... .master("local") \\
+ ... .appName("Word Count") \\
+ ... .config("spark.some.config.option", "some-value") \\
+ ... .getOrCreate()
"""
class Builder(object):
@@ -86,11 +86,13 @@ def config(self, key=None, value=None, conf=None):
both :class:`SparkConf` and :class:`SparkSession`'s own configuration.
For an existing SparkConf, use `conf` parameter.
+
>>> from pyspark.conf import SparkConf
>>> SparkSession.builder.config(conf=SparkConf())
>> SparkSession.builder.config("spark.some.config.option", "some-value")
= '3':
intlike = int
+ basestring = unicode = str
else:
intlike = (int, long)
from abc import ABCMeta, abstractmethod
-from pyspark import since
+from pyspark import since, keyword_only
from pyspark.rdd import ignore_unicode_prefix
+from pyspark.sql.readwriter import OptionUtils, to_str
+from pyspark.sql.types import *
-__all__ = ["StreamingQuery"]
+__all__ = ["StreamingQuery", "StreamingQueryManager", "DataStreamReader", "DataStreamWriter"]
class StreamingQuery(object):
@@ -118,7 +121,7 @@ def __init__(self, jsqm):
def active(self):
"""Returns a list of active queries associated with this SQLContext
- >>> sq = df.writeStream.format('memory').queryName('this_query').start()
+ >>> sq = sdf.writeStream.format('memory').queryName('this_query').start()
>>> sqm = spark.streams
>>> # get the list of active streaming queries
>>> [q.name for q in sqm.active]
@@ -133,7 +136,7 @@ def get(self, id):
"""Returns an active query from this SQLContext or throws exception if an active query
with this name doesn't exist.
- >>> sq = df.writeStream.format('memory').queryName('this_query').start()
+ >>> sq = sdf.writeStream.format('memory').queryName('this_query').start()
>>> sq.name
u'this_query'
>>> sq = spark.streams.get(sq.id)
@@ -224,6 +227,491 @@ def _to_java_trigger(self, sqlContext):
self.interval)
+class DataStreamReader(OptionUtils):
+ """
+ Interface used to load a streaming :class:`DataFrame` from external storage systems
+ (e.g. file systems, key-value stores, etc). Use :func:`spark.readStream`
+ to access this.
+
+ .. note:: Experimental.
+
+ .. versionadded:: 2.0
+ """
+
+ def __init__(self, spark):
+ self._jreader = spark._ssql_ctx.readStream()
+ self._spark = spark
+
+ def _df(self, jdf):
+ from pyspark.sql.dataframe import DataFrame
+ return DataFrame(jdf, self._spark)
+
+ @since(2.0)
+ def format(self, source):
+ """Specifies the input data source format.
+
+ .. note:: Experimental.
+
+ :param source: string, name of the data source, e.g. 'json', 'parquet'.
+
+ >>> s = spark.readStream.format("text")
+ """
+ self._jreader = self._jreader.format(source)
+ return self
+
+ @since(2.0)
+ def schema(self, schema):
+ """Specifies the input schema.
+
+ Some data sources (e.g. JSON) can infer the input schema automatically from data.
+ By specifying the schema here, the underlying data source can skip the schema
+ inference step, and thus speed up data loading.
+
+ .. note:: Experimental.
+
+ :param schema: a StructType object
+
+ >>> s = spark.readStream.schema(sdf_schema)
+ """
+ if not isinstance(schema, StructType):
+ raise TypeError("schema should be StructType")
+ jschema = self._spark._ssql_ctx.parseDataType(schema.json())
+ self._jreader = self._jreader.schema(jschema)
+ return self
+
+ @since(2.0)
+ def option(self, key, value):
+ """Adds an input option for the underlying data source.
+
+ .. note:: Experimental.
+
+ >>> s = spark.readStream.option("x", 1)
+ """
+ self._jreader = self._jreader.option(key, to_str(value))
+ return self
+
+ @since(2.0)
+ def options(self, **options):
+ """Adds input options for the underlying data source.
+
+ .. note:: Experimental.
+
+ >>> s = spark.readStream.options(x="1", y=2)
+ """
+ for k in options:
+ self._jreader = self._jreader.option(k, to_str(options[k]))
+ return self
+
+ @since(2.0)
+ def load(self, path=None, format=None, schema=None, **options):
+ """Loads a data stream from a data source and returns it as a :class`DataFrame`.
+
+ .. note:: Experimental.
+
+ :param path: optional string for file-system backed data sources.
+ :param format: optional string for format of the data source. Default to 'parquet'.
+ :param schema: optional :class:`StructType` for the input schema.
+ :param options: all other string options
+
+ >>> json_sdf = spark.readStream.format("json")\
+ .schema(sdf_schema)\
+ .load(tempfile.mkdtemp())
+ >>> json_sdf.isStreaming
+ True
+ >>> json_sdf.schema == sdf_schema
+ True
+ """
+ if format is not None:
+ self.format(format)
+ if schema is not None:
+ self.schema(schema)
+ self.options(**options)
+ if path is not None:
+ if type(path) != str or len(path.strip()) == 0:
+ raise ValueError("If the path is provided for stream, it needs to be a " +
+ "non-empty string. List of paths are not supported.")
+ return self._df(self._jreader.load(path))
+ else:
+ return self._df(self._jreader.load())
+
+ @since(2.0)
+ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None,
+ allowComments=None, allowUnquotedFieldNames=None, allowSingleQuotes=None,
+ allowNumericLeadingZero=None, allowBackslashEscapingAnyCharacter=None,
+ mode=None, columnNameOfCorruptRecord=None):
+ """
+ Loads a JSON file stream (one object per line) and returns a :class`DataFrame`.
+
+ If the ``schema`` parameter is not specified, this function goes
+ through the input once to determine the input schema.
+
+ .. note:: Experimental.
+
+ :param path: string represents path to the JSON dataset,
+ or RDD of Strings storing JSON objects.
+ :param schema: an optional :class:`StructType` for the input schema.
+ :param primitivesAsString: infers all primitive values as a string type. If None is set,
+ it uses the default value, ``false``.
+ :param prefersDecimal: infers all floating-point values as a decimal type. If the values
+ do not fit in decimal, then it infers them as doubles. If None is
+ set, it uses the default value, ``false``.
+ :param allowComments: ignores Java/C++ style comment in JSON records. If None is set,
+ it uses the default value, ``false``.
+ :param allowUnquotedFieldNames: allows unquoted JSON field names. If None is set,
+ it uses the default value, ``false``.
+ :param allowSingleQuotes: allows single quotes in addition to double quotes. If None is
+ set, it uses the default value, ``true``.
+ :param allowNumericLeadingZero: allows leading zeros in numbers (e.g. 00012). If None is
+ set, it uses the default value, ``false``.
+ :param allowBackslashEscapingAnyCharacter: allows accepting quoting of all character
+ using backslash quoting mechanism. If None is
+ set, it uses the default value, ``false``.
+ :param mode: allows a mode for dealing with corrupt records during parsing. If None is
+ set, it uses the default value, ``PERMISSIVE``.
+
+ * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted \
+ record and puts the malformed string into a new field configured by \
+ ``columnNameOfCorruptRecord``. When a schema is set by user, it sets \
+ ``null`` for extra fields.
+ * ``DROPMALFORMED`` : ignores the whole corrupted records.
+ * ``FAILFAST`` : throws an exception when it meets corrupted records.
+
+ :param columnNameOfCorruptRecord: allows renaming the new field having malformed string
+ created by ``PERMISSIVE`` mode. This overrides
+ ``spark.sql.columnNameOfCorruptRecord``. If None is set,
+ it uses the value specified in
+ ``spark.sql.columnNameOfCorruptRecord``.
+
+ >>> json_sdf = spark.readStream.json(tempfile.mkdtemp(), schema = sdf_schema)
+ >>> json_sdf.isStreaming
+ True
+ >>> json_sdf.schema == sdf_schema
+ True
+ """
+ self._set_opts(
+ schema=schema, primitivesAsString=primitivesAsString, prefersDecimal=prefersDecimal,
+ allowComments=allowComments, allowUnquotedFieldNames=allowUnquotedFieldNames,
+ allowSingleQuotes=allowSingleQuotes, allowNumericLeadingZero=allowNumericLeadingZero,
+ allowBackslashEscapingAnyCharacter=allowBackslashEscapingAnyCharacter,
+ mode=mode, columnNameOfCorruptRecord=columnNameOfCorruptRecord)
+ if isinstance(path, basestring):
+ return self._df(self._jreader.json(path))
+ else:
+ raise TypeError("path can be only a single string")
+
+ @since(2.0)
+ def parquet(self, path):
+ """Loads a Parquet file stream, returning the result as a :class:`DataFrame`.
+
+ You can set the following Parquet-specific option(s) for reading Parquet files:
+ * ``mergeSchema``: sets whether we should merge schemas collected from all \
+ Parquet part-files. This will override ``spark.sql.parquet.mergeSchema``. \
+ The default value is specified in ``spark.sql.parquet.mergeSchema``.
+
+ .. note:: Experimental.
+
+ >>> parquet_sdf = spark.readStream.schema(sdf_schema).parquet(tempfile.mkdtemp())
+ >>> parquet_sdf.isStreaming
+ True
+ >>> parquet_sdf.schema == sdf_schema
+ True
+ """
+ if isinstance(path, basestring):
+ return self._df(self._jreader.parquet(path))
+ else:
+ raise TypeError("path can be only a single string")
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def text(self, path):
+ """
+ Loads a text file stream and returns a :class:`DataFrame` whose schema starts with a
+ string column named "value", and followed by partitioned columns if there
+ are any.
+
+ Each line in the text file is a new row in the resulting DataFrame.
+
+ .. note:: Experimental.
+
+ :param paths: string, or list of strings, for input path(s).
+
+ >>> text_sdf = spark.readStream.text(tempfile.mkdtemp())
+ >>> text_sdf.isStreaming
+ True
+ >>> "value" in str(text_sdf.schema)
+ True
+ """
+ if isinstance(path, basestring):
+ return self._df(self._jreader.text(path))
+ else:
+ raise TypeError("path can be only a single string")
+
+ @since(2.0)
+ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=None,
+ comment=None, header=None, inferSchema=None, ignoreLeadingWhiteSpace=None,
+ ignoreTrailingWhiteSpace=None, nullValue=None, nanValue=None, positiveInf=None,
+ negativeInf=None, dateFormat=None, maxColumns=None, maxCharsPerColumn=None,
+ maxMalformedLogPerPartition=None, mode=None):
+ """Loads a CSV file stream and returns the result as a :class:`DataFrame`.
+
+ This function will go through the input once to determine the input schema if
+ ``inferSchema`` is enabled. To avoid going through the entire data once, disable
+ ``inferSchema`` option or specify the schema explicitly using ``schema``.
+
+ .. note:: Experimental.
+
+ :param path: string, or list of strings, for input path(s).
+ :param schema: an optional :class:`StructType` for the input schema.
+ :param sep: sets the single character as a separator for each field and value.
+ If None is set, it uses the default value, ``,``.
+ :param encoding: decodes the CSV files by the given encoding type. If None is set,
+ it uses the default value, ``UTF-8``.
+ :param quote: sets the single character used for escaping quoted values where the
+ separator can be part of the value. If None is set, it uses the default
+ value, ``"``. If you would like to turn off quotations, you need to set an
+ empty string.
+ :param escape: sets the single character used for escaping quotes inside an already
+ quoted value. If None is set, it uses the default value, ``\``.
+ :param comment: sets the single character used for skipping lines beginning with this
+ character. By default (None), it is disabled.
+ :param header: uses the first line as names of columns. If None is set, it uses the
+ default value, ``false``.
+ :param inferSchema: infers the input schema automatically from data. It requires one extra
+ pass over the data. If None is set, it uses the default value, ``false``.
+ :param ignoreLeadingWhiteSpace: defines whether or not leading whitespaces from values
+ being read should be skipped. If None is set, it uses
+ the default value, ``false``.
+ :param ignoreTrailingWhiteSpace: defines whether or not trailing whitespaces from values
+ being read should be skipped. If None is set, it uses
+ the default value, ``false``.
+ :param nullValue: sets the string representation of a null value. If None is set, it uses
+ the default value, empty string.
+ :param nanValue: sets the string representation of a non-number value. If None is set, it
+ uses the default value, ``NaN``.
+ :param positiveInf: sets the string representation of a positive infinity value. If None
+ is set, it uses the default value, ``Inf``.
+ :param negativeInf: sets the string representation of a negative infinity value. If None
+ is set, it uses the default value, ``Inf``.
+ :param dateFormat: sets the string that indicates a date format. Custom date formats
+ follow the formats at ``java.text.SimpleDateFormat``. This
+ applies to both date type and timestamp type. By default, it is None
+ which means trying to parse times and date by
+ ``java.sql.Timestamp.valueOf()`` and ``java.sql.Date.valueOf()``.
+ :param maxColumns: defines a hard limit of how many columns a record can have. If None is
+ set, it uses the default value, ``20480``.
+ :param maxCharsPerColumn: defines the maximum number of characters allowed for any given
+ value being read. If None is set, it uses the default value,
+ ``1000000``.
+ :param mode: allows a mode for dealing with corrupt records during parsing. If None is
+ set, it uses the default value, ``PERMISSIVE``.
+
+ * ``PERMISSIVE`` : sets other fields to ``null`` when it meets a corrupted record.
+ When a schema is set by user, it sets ``null`` for extra fields.
+ * ``DROPMALFORMED`` : ignores the whole corrupted records.
+ * ``FAILFAST`` : throws an exception when it meets corrupted records.
+
+ >>> csv_sdf = spark.readStream.csv(tempfile.mkdtemp(), schema = sdf_schema)
+ >>> csv_sdf.isStreaming
+ True
+ >>> csv_sdf.schema == sdf_schema
+ True
+ """
+ self._set_opts(
+ schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment,
+ header=header, inferSchema=inferSchema, ignoreLeadingWhiteSpace=ignoreLeadingWhiteSpace,
+ ignoreTrailingWhiteSpace=ignoreTrailingWhiteSpace, nullValue=nullValue,
+ nanValue=nanValue, positiveInf=positiveInf, negativeInf=negativeInf,
+ dateFormat=dateFormat, maxColumns=maxColumns, maxCharsPerColumn=maxCharsPerColumn,
+ maxMalformedLogPerPartition=maxMalformedLogPerPartition, mode=mode)
+ if isinstance(path, basestring):
+ return self._df(self._jreader.csv(path))
+ else:
+ raise TypeError("path can be only a single string")
+
+
+class DataStreamWriter(object):
+ """
+ Interface used to write a streaming :class:`DataFrame` to external storage systems
+ (e.g. file systems, key-value stores, etc). Use :func:`DataFrame.writeStream`
+ to access this.
+
+ .. note:: Experimental.
+
+ .. versionadded:: 2.0
+ """
+
+ def __init__(self, df):
+ self._df = df
+ self._spark = df.sql_ctx
+ self._jwrite = df._jdf.writeStream()
+
+ def _sq(self, jsq):
+ from pyspark.sql.streaming import StreamingQuery
+ return StreamingQuery(jsq)
+
+ @since(2.0)
+ def outputMode(self, outputMode):
+ """Specifies how data of a streaming DataFrame/Dataset is written to a streaming sink.
+
+ Options include:
+
+ * `append`:Only the new rows in the streaming DataFrame/Dataset will be written to
+ the sink
+ * `complete`:All the rows in the streaming DataFrame/Dataset will be written to the sink
+ every time these is some updates
+
+ .. note:: Experimental.
+
+ >>> writer = sdf.writeStream.outputMode('append')
+ """
+ if not outputMode or type(outputMode) != str or len(outputMode.strip()) == 0:
+ raise ValueError('The output mode must be a non-empty string. Got: %s' % outputMode)
+ self._jwrite = self._jwrite.outputMode(outputMode)
+ return self
+
+ @since(2.0)
+ def format(self, source):
+ """Specifies the underlying output data source.
+
+ .. note:: Experimental.
+
+ :param source: string, name of the data source, e.g. 'json', 'parquet'.
+
+ >>> writer = sdf.writeStream.format('json')
+ """
+ self._jwrite = self._jwrite.format(source)
+ return self
+
+ @since(2.0)
+ def option(self, key, value):
+ """Adds an output option for the underlying data source.
+
+ .. note:: Experimental.
+ """
+ self._jwrite = self._jwrite.option(key, to_str(value))
+ return self
+
+ @since(2.0)
+ def options(self, **options):
+ """Adds output options for the underlying data source.
+
+ .. note:: Experimental.
+ """
+ for k in options:
+ self._jwrite = self._jwrite.option(k, to_str(options[k]))
+ return self
+
+ @since(2.0)
+ def partitionBy(self, *cols):
+ """Partitions the output by the given columns on the file system.
+
+ If specified, the output is laid out on the file system similar
+ to Hive's partitioning scheme.
+
+ .. note:: Experimental.
+
+ :param cols: name of columns
+
+ """
+ if len(cols) == 1 and isinstance(cols[0], (list, tuple)):
+ cols = cols[0]
+ self._jwrite = self._jwrite.partitionBy(_to_seq(self._spark._sc, cols))
+ return self
+
+ @since(2.0)
+ def queryName(self, queryName):
+ """Specifies the name of the :class:`StreamingQuery` that can be started with
+ :func:`start`. This name must be unique among all the currently active queries
+ in the associated SparkSession.
+
+ .. note:: Experimental.
+
+ :param queryName: unique name for the query
+
+ >>> writer = sdf.writeStream.queryName('streaming_query')
+ """
+ if not queryName or type(queryName) != str or len(queryName.strip()) == 0:
+ raise ValueError('The queryName must be a non-empty string. Got: %s' % queryName)
+ self._jwrite = self._jwrite.queryName(queryName)
+ return self
+
+ @keyword_only
+ @since(2.0)
+ def trigger(self, processingTime=None):
+ """Set the trigger for the stream query. If this is not set it will run the query as fast
+ as possible, which is equivalent to setting the trigger to ``processingTime='0 seconds'``.
+
+ .. note:: Experimental.
+
+ :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'.
+
+ >>> # trigger the query for execution every 5 seconds
+ >>> writer = sdf.writeStream.trigger(processingTime='5 seconds')
+ """
+ from pyspark.sql.streaming import ProcessingTime
+ trigger = None
+ if processingTime is not None:
+ if type(processingTime) != str or len(processingTime.strip()) == 0:
+ raise ValueError('The processing time must be a non empty string. Got: %s' %
+ processingTime)
+ trigger = ProcessingTime(processingTime)
+ if trigger is None:
+ raise ValueError('A trigger was not provided. Supported triggers: processingTime.')
+ self._jwrite = self._jwrite.trigger(trigger._to_java_trigger(self._spark))
+ return self
+
+ @ignore_unicode_prefix
+ @since(2.0)
+ def start(self, path=None, format=None, partitionBy=None, queryName=None, **options):
+ """Streams the contents of the :class:`DataFrame` to a data source.
+
+ The data source is specified by the ``format`` and a set of ``options``.
+ If ``format`` is not specified, the default data source configured by
+ ``spark.sql.sources.default`` will be used.
+
+ .. note:: Experimental.
+
+ :param path: the path in a Hadoop supported file system
+ :param format: the format used to save
+
+ * ``append``: Append contents of this :class:`DataFrame` to existing data.
+ * ``overwrite``: Overwrite existing data.
+ * ``ignore``: Silently ignore this operation if data already exists.
+ * ``error`` (default case): Throw an exception if data already exists.
+ :param partitionBy: names of partitioning columns
+ :param queryName: unique name for the query
+ :param options: All other string options. You may want to provide a `checkpointLocation`
+ for most streams, however it is not required for a `memory` stream.
+
+ >>> sq = sdf.writeStream.format('memory').queryName('this_query').start()
+ >>> sq.isActive
+ True
+ >>> sq.name
+ u'this_query'
+ >>> sq.stop()
+ >>> sq.isActive
+ False
+ >>> sq = sdf.writeStream.trigger(processingTime='5 seconds').start(
+ ... queryName='that_query', format='memory')
+ >>> sq.name
+ u'that_query'
+ >>> sq.isActive
+ True
+ >>> sq.stop()
+ """
+ self.options(**options)
+ if partitionBy is not None:
+ self.partitionBy(partitionBy)
+ if format is not None:
+ self.format(format)
+ if queryName is not None:
+ self.queryName(queryName)
+ if path is None:
+ return self._sq(self._jwrite.start())
+ else:
+ return self._sq(self._jwrite.start(path))
+
+
def _test():
import doctest
import os
@@ -243,6 +731,9 @@ def _test():
globs['os'] = os
globs['spark'] = spark
globs['sqlContext'] = SQLContext.getOrCreate(spark.sparkContext)
+ globs['sdf'] = \
+ spark.readStream.format('text').load('python/test_support/sql/streaming')
+ globs['sdf_schema'] = StructType([StructField("data", StringType(), False)])
globs['df'] = \
globs['spark'].readStream.format('text').load('python/test_support/sql/streaming')
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index a3679873e1d8..eea80684e2df 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -486,8 +486,8 @@ def add(self, field, data_type=None, nullable=True, metadata=None):
DataType object.
>>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
- >>> struct2 = StructType([StructField("f1", StringType(), True),\
- StructField("f2", StringType(), True, None)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True), \\
+ ... StructField("f2", StringType(), True, None)])
>>> struct1 == struct2
True
>>> struct1 = StructType().add(StructField("f1", StringType(), True))
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
index af61e2011f40..0e4264fe8dfb 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java
@@ -45,7 +45,13 @@ public BufferHolder(UnsafeRow row) {
}
public BufferHolder(UnsafeRow row, int initialSize) {
- this.fixedSize = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()) + 8 * row.numFields();
+ int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields());
+ if (row.numFields() > (Integer.MAX_VALUE - initialSize - bitsetWidthInBytes) / 8) {
+ throw new UnsupportedOperationException(
+ "Cannot create BufferHolder for input UnsafeRow because there are " +
+ "too many fields (number of fields: " + row.numFields() + ")");
+ }
+ this.fixedSize = bitsetWidthInBytes + 8 * row.numFields();
this.buffer = new byte[fixedSize + initialSize];
this.row = row;
this.row.pointTo(buffer, buffer.length);
@@ -55,10 +61,16 @@ public BufferHolder(UnsafeRow row, int initialSize) {
* Grows the buffer by at least neededSize and points the row to the buffer.
*/
public void grow(int neededSize) {
+ if (neededSize > Integer.MAX_VALUE - totalSize()) {
+ throw new UnsupportedOperationException(
+ "Cannot grow BufferHolder by size " + neededSize + " because the size after growing " +
+ "exceeds size limitation " + Integer.MAX_VALUE);
+ }
final int length = totalSize() + neededSize;
if (buffer.length < length) {
// This will not happen frequently, because the buffer is re-used.
- final byte[] tmp = new byte[length * 2];
+ int newLength = length < Integer.MAX_VALUE / 2 ? length * 2 : Integer.MAX_VALUE;
+ final byte[] tmp = new byte[newLength];
Platform.copyMemory(
buffer,
Platform.BYTE_ARRAY_OFFSET,
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
new file mode 100644
index 000000000000..01a11f9bdca2
--- /dev/null
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtil.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.xml;
+
+import java.io.IOException;
+import java.io.Reader;
+import java.io.StringReader;
+
+import javax.xml.namespace.QName;
+import javax.xml.xpath.XPath;
+import javax.xml.xpath.XPathConstants;
+import javax.xml.xpath.XPathExpression;
+import javax.xml.xpath.XPathExpressionException;
+import javax.xml.xpath.XPathFactory;
+
+import org.w3c.dom.Node;
+import org.w3c.dom.NodeList;
+import org.xml.sax.InputSource;
+
+/**
+ * Utility class for all XPath UDFs. Each UDF instance should keep an instance of this class.
+ *
+ * This is based on Hive's UDFXPathUtil implementation.
+ */
+public class UDFXPathUtil {
+ private XPath xpath = XPathFactory.newInstance().newXPath();
+ private ReusableStringReader reader = new ReusableStringReader();
+ private InputSource inputSource = new InputSource(reader);
+ private XPathExpression expression = null;
+ private String oldPath = null;
+
+ public Object eval(String xml, String path, QName qname) {
+ if (xml == null || path == null || qname == null) {
+ return null;
+ }
+
+ if (xml.length() == 0 || path.length() == 0) {
+ return null;
+ }
+
+ if (!path.equals(oldPath)) {
+ try {
+ expression = xpath.compile(path);
+ } catch (XPathExpressionException e) {
+ expression = null;
+ }
+ oldPath = path;
+ }
+
+ if (expression == null) {
+ return null;
+ }
+
+ reader.set(xml);
+
+ try {
+ return expression.evaluate(inputSource, qname);
+ } catch (XPathExpressionException e) {
+ throw new RuntimeException ("Invalid expression '" + oldPath + "'", e);
+ }
+ }
+
+ public Boolean evalBoolean(String xml, String path) {
+ return (Boolean) eval(xml, path, XPathConstants.BOOLEAN);
+ }
+
+ public String evalString(String xml, String path) {
+ return (String) eval(xml, path, XPathConstants.STRING);
+ }
+
+ public Double evalNumber(String xml, String path) {
+ return (Double) eval(xml, path, XPathConstants.NUMBER);
+ }
+
+ public Node evalNode(String xml, String path) {
+ return (Node) eval(xml, path, XPathConstants.NODE);
+ }
+
+ public NodeList evalNodeList(String xml, String path) {
+ return (NodeList) eval(xml, path, XPathConstants.NODESET);
+ }
+
+ /**
+ * Reusable, non-threadsafe version of {@link StringReader}.
+ */
+ public static class ReusableStringReader extends Reader {
+
+ private String str = null;
+ private int length = -1;
+ private int next = 0;
+ private int mark = 0;
+
+ public ReusableStringReader() {
+ }
+
+ public void set(String s) {
+ this.str = s;
+ this.length = s.length();
+ this.mark = 0;
+ this.next = 0;
+ }
+
+ /** Check to make sure that the stream has not been closed */
+ private void ensureOpen() throws IOException {
+ if (str == null)
+ throw new IOException("Stream closed");
+ }
+
+ @Override
+ public int read() throws IOException {
+ ensureOpen();
+ if (next >= length)
+ return -1;
+ return str.charAt(next++);
+ }
+
+ @Override
+ public int read(char cbuf[], int off, int len) throws IOException {
+ ensureOpen();
+ if ((off < 0) || (off > cbuf.length) || (len < 0)
+ || ((off + len) > cbuf.length) || ((off + len) < 0)) {
+ throw new IndexOutOfBoundsException();
+ } else if (len == 0) {
+ return 0;
+ }
+ if (next >= length)
+ return -1;
+ int n = Math.min(length - next, len);
+ str.getChars(next, next + n, cbuf, off);
+ next += n;
+ return n;
+ }
+
+ @Override
+ public long skip(long ns) throws IOException {
+ ensureOpen();
+ if (next >= length)
+ return 0;
+ // Bound skip by beginning and end of the source
+ long n = Math.min(length - next, ns);
+ n = Math.max(-next, n);
+ next += n;
+ return n;
+ }
+
+ @Override
+ public boolean ready() throws IOException {
+ ensureOpen();
+ return true;
+ }
+
+ @Override
+ public boolean markSupported() {
+ return true;
+ }
+
+ @Override
+ public void mark(int readAheadLimit) throws IOException {
+ if (readAheadLimit < 0) {
+ throw new IllegalArgumentException("Read-ahead limit < 0");
+ }
+ ensureOpen();
+ mark = next;
+ }
+
+ @Override
+ public void reset() throws IOException {
+ ensureOpen();
+ next = mark;
+ }
+
+ @Override
+ public void close() {
+ str = null;
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 37fbad47c145..c9a1f2293a6c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -74,6 +74,8 @@ public UnsafeExternalRowSorter(
prefixComparator,
/* initialSize */ 4096,
pageSizeBytes,
+ SparkEnv.get().conf().getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
canUseRadixSort
);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 96f2e38946f1..d1d2c59caed9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1836,13 +1836,25 @@ class Analyzer(
}
private def commonNaturalJoinProcessing(
- left: LogicalPlan,
- right: LogicalPlan,
- joinType: JoinType,
- joinNames: Seq[String],
- condition: Option[Expression]) = {
- val leftKeys = joinNames.map(keyName => left.output.find(_.name == keyName).get)
- val rightKeys = joinNames.map(keyName => right.output.find(_.name == keyName).get)
+ left: LogicalPlan,
+ right: LogicalPlan,
+ joinType: JoinType,
+ joinNames: Seq[String],
+ condition: Option[Expression]) = {
+ val leftKeys = joinNames.map { keyName =>
+ val joinColumn = left.output.find(attr => resolver(attr.name, keyName))
+ assert(
+ joinColumn.isDefined,
+ s"$keyName should exist in ${left.output.map(_.name).mkString(",")}")
+ joinColumn.get
+ }
+ val rightKeys = joinNames.map { keyName =>
+ val joinColumn = right.output.find(attr => resolver(attr.name, keyName))
+ assert(
+ joinColumn.isDefined,
+ s"$keyName should exist in ${right.output.map(_.name).mkString(",")}")
+ joinColumn.get
+ }
val joinPairs = leftKeys.zip(rightKeys)
val newCondition = (condition ++ joinPairs.map(EqualTo.tupled)).reduceOption(And)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index ac9693e079f5..7b30fcc6c531 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -206,7 +206,6 @@ trait CheckAnalysis extends PredicateHelper {
"Add to group by or wrap in first() (or first_value) if you don't care " +
"which value you get.")
case e if groupingExprs.exists(_.semanticEquals(e)) => // OK
- case e if e.references.isEmpty => // OK
case e => e.children.foreach(checkValidAggregateExpression)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 42a8faa412a3..c8bbbf88532d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.xml._
import org.apache.spark.sql.catalyst.util.StringKeyHashMap
@@ -164,19 +165,24 @@ object FunctionRegistry {
expression[Explode]("explode"),
expression[Greatest]("greatest"),
expression[If]("if"),
+ expression[Inline]("inline"),
expression[IsNaN]("isnan"),
expression[IfNull]("ifnull"),
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
expression[CreateMap]("map"),
+ expression[MapKeys]("map_keys"),
+ expression[MapValues]("map_values"),
expression[CreateNamedStruct]("named_struct"),
expression[NaNvl]("nanvl"),
expression[NullIf]("nullif"),
expression[Nvl]("nvl"),
expression[Nvl2]("nvl2"),
+ expression[PosExplode]("posexplode"),
expression[Rand]("rand"),
expression[Randn]("randn"),
+ expression[Stack]("stack"),
expression[CreateStruct]("struct"),
expression[CaseWhen]("when"),
@@ -248,6 +254,7 @@ object FunctionRegistry {
expression[Average]("mean"),
expression[Min]("min"),
expression[Skewness]("skewness"),
+ expression[StddevSamp]("std"),
expression[StddevSamp]("stddev"),
expression[StddevPop]("stddev_pop"),
expression[StddevSamp]("stddev_samp"),
@@ -264,6 +271,7 @@ object FunctionRegistry {
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
expression[Decode]("decode"),
+ expression[Elt]("elt"),
expression[Encode]("encode"),
expression[FindInSet]("find_in_set"),
expression[FormatNumber]("format_number"),
@@ -280,6 +288,7 @@ object FunctionRegistry {
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
expression[JsonTuple]("json_tuple"),
+ expression[ParseUrl]("parse_url"),
expression[FormatString]("printf"),
expression[RegExpExtract]("regexp_extract"),
expression[RegExpReplace]("regexp_replace"),
@@ -288,6 +297,7 @@ object FunctionRegistry {
expression[RLike]("rlike"),
expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
+ expression[Sentences]("sentences"),
expression[SoundEx]("soundex"),
expression[StringSpace]("space"),
expression[StringSplit]("split"),
@@ -300,6 +310,7 @@ object FunctionRegistry {
expression[UnBase64]("unbase64"),
expression[Unhex]("unhex"),
expression[Upper]("upper"),
+ expression[XPathBoolean]("xpath_boolean"),
// datetime functions
expression[AddMonths]("add_months"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
index 689e016a5a1d..f6e32e29ebca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala
@@ -30,7 +30,7 @@ object UnsupportedOperationChecker {
def checkForBatch(plan: LogicalPlan): Unit = {
plan.foreachUp {
case p if p.isStreaming =>
- throwError("Queries with streaming sources must be executed with write.startStream()")(p)
+ throwError("Queries with streaming sources must be executed with writeStream.start()")(p)
case _ =>
}
@@ -40,7 +40,7 @@ object UnsupportedOperationChecker {
if (!plan.isStreaming) {
throwError(
- "Queries without streaming sources cannot be executed with write.startStream()")(plan)
+ "Queries without streaming sources cannot be executed with writeStream.start()")(plan)
}
// Disallow multiple streaming aggregations
@@ -154,7 +154,7 @@ object UnsupportedOperationChecker {
case ReturnAnswer(child) if child.isStreaming =>
throwError("Cannot return immediate result on streaming DataFrames/Dataset. Queries " +
- "with streaming DataFrames/Datasets must be executed with write.startStream().")
+ "with streaming DataFrames/Datasets must be executed with writeStream.start().")
case _ =>
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index b883546135f0..609089a302c8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -215,23 +215,20 @@ abstract class Star extends LeafExpression with NamedExpression {
case class UnresolvedStar(target: Option[Seq[String]]) extends Star with Unevaluable {
override def expand(input: LogicalPlan, resolver: Resolver): Seq[NamedExpression] = {
+ // If there is no table specified, use all input attributes.
+ if (target.isEmpty) return input.output
- // First try to expand assuming it is table.*.
- val expandedAttributes: Seq[Attribute] = target match {
- // If there is no table specified, use all input attributes.
- case None => input.output
- // If there is a table, pick out attributes that are part of this table.
- case Some(t) => if (t.size == 1) {
- input.output.filter(_.qualifier.exists(resolver(_, t.head)))
+ val expandedAttributes =
+ if (target.get.size == 1) {
+ // If there is a table, pick out attributes that are part of this table.
+ input.output.filter(_.qualifier.exists(resolver(_, target.get.head)))
} else {
List()
}
- }
if (expandedAttributes.nonEmpty) return expandedAttributes
// Try to resolve it as a struct expansion. If there is a conflict and both are possible,
// (i.e. [name].* is both a table and a struct), the struct path can always be qualified.
- require(target.isDefined)
val attribute = input.resolve(target.get, resolver)
if (attribute.isDefined) {
// This target resolved to an attribute in child. It must be a struct. Expand it.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
index 8c620d36e567..e1d49912c311 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala
@@ -462,17 +462,17 @@ class SessionCatalog(
}
}
- // TODO: It's strange that we have both refresh and invalidate here.
-
/**
* Refresh the cache entry for a metastore table, if any.
*/
- def refreshTable(name: TableIdentifier): Unit = { /* no-op */ }
-
- /**
- * Invalidate the cache entry for a metastore table, if any.
- */
- def invalidateTable(name: TableIdentifier): Unit = { /* no-op */ }
+ def refreshTable(name: TableIdentifier): Unit = {
+ // Go through temporary tables and invalidate them.
+ // If the database is defined, this is definitely not a temp table.
+ // If the database is not defined, there is a good chance this is a temp table.
+ if (name.database.isEmpty) {
+ tempTables.get(name.table).foreach(_.refresh())
+ }
+ }
/**
* Drop all existing temporary tables.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
index c15a2df50855..98f25a9ad759 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala
@@ -57,7 +57,8 @@ trait ExpectsInputTypes extends Expression {
/**
- * A mixin for the analyzer to perform implicit type casting using [[ImplicitTypeCasts]].
+ * A mixin for the analyzer to perform implicit type casting using
+ * [[org.apache.spark.sql.catalyst.analysis.TypeCoercion.ImplicitTypeCasts]].
*/
trait ImplicitCastInputTypes extends ExpectsInputTypes {
// No other methods
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 6392ff42d709..16fb1f683710 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -17,11 +17,16 @@
package org.apache.spark.sql.catalyst.expressions.codegen
+import java.io.ByteArrayInputStream
+import java.util.{Map => JavaMap}
+
+import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import com.google.common.cache.{CacheBuilder, CacheLoader}
-import org.codehaus.janino.ClassBodyEvaluator
+import org.codehaus.janino.{ByteArrayClassLoader, ClassBodyEvaluator, SimpleCompiler}
+import org.codehaus.janino.util.ClassFile
import scala.language.existentials
import org.apache.spark.SparkEnv
@@ -876,6 +881,7 @@ object CodeGenerator extends Logging {
try {
evaluator.cook("generated.java", code.body)
+ recordCompilationStats(evaluator)
} catch {
case e: Exception =>
val msg = s"failed to compile: $e\n$formatted"
@@ -885,6 +891,38 @@ object CodeGenerator extends Logging {
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]
}
+ /**
+ * Records the generated class and method bytecode sizes by inspecting janino private fields.
+ */
+ private def recordCompilationStats(evaluator: ClassBodyEvaluator): Unit = {
+ // First retrieve the generated classes.
+ val classes = {
+ val resultField = classOf[SimpleCompiler].getDeclaredField("result")
+ resultField.setAccessible(true)
+ val loader = resultField.get(evaluator).asInstanceOf[ByteArrayClassLoader]
+ val classesField = loader.getClass.getDeclaredField("classes")
+ classesField.setAccessible(true)
+ classesField.get(loader).asInstanceOf[JavaMap[String, Array[Byte]]].asScala
+ }
+
+ // Then walk the classes to get at the method bytecode.
+ val codeAttr = Utils.classForName("org.codehaus.janino.util.ClassFile$CodeAttribute")
+ val codeAttrField = codeAttr.getDeclaredField("code")
+ codeAttrField.setAccessible(true)
+ classes.foreach { case (_, classBytes) =>
+ CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length)
+ val cf = new ClassFile(new ByteArrayInputStream(classBytes))
+ cf.methodInfos.asScala.foreach { method =>
+ method.getAttributes().foreach { a =>
+ if (a.getClass.getName == codeAttr.getName) {
+ CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(
+ codeAttrField.get(a).asInstanceOf[Array[Byte]].length)
+ }
+ }
+ }
+ }
+ }
+
/**
* A cache of generated classes.
*
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
index c71cb73d65bf..2e8ea1107cee 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
@@ -43,6 +43,54 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType
}
}
+/**
+ * Returns an unordered array containing the keys of the map.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(map) - Returns an unordered array containing the keys of the map.",
+ extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [1,2]")
+case class MapKeys(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+ override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].keyType)
+
+ override def nullSafeEval(map: Any): Any = {
+ map.asInstanceOf[MapData].keyArray()
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).keyArray();")
+ }
+
+ override def prettyName: String = "map_keys"
+}
+
+/**
+ * Returns an unordered array containing the values of the map.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(map) - Returns an unordered array containing the values of the map.",
+ extended = " > SELECT _FUNC_(map(1, 'a', 2, 'b'));\n [\"a\",\"b\"]")
+case class MapValues(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
+
+ override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].valueType)
+
+ override def nullSafeEval(map: Any): Any = {
+ map.asInstanceOf[MapData].valueArray()
+ }
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).valueArray();")
+ }
+
+ override def prettyName: String = "map_values"
+}
+
/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 12c35644e564..9d5c856a23e2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -94,13 +94,63 @@ case class UserDefinedGenerator(
}
/**
- * Given an input array produces a sequence of rows for each value in the array.
+ * Separate v1, ..., vk into n rows. Each row will have k/n columns. n must be constant.
+ * {{{
+ * SELECT stack(2, 1, 2, 3) ->
+ * 1 2
+ * 3 NULL
+ * }}}
*/
-// scalastyle:off line.size.limit
@ExpressionDescription(
- usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.")
-// scalastyle:on line.size.limit
-case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
+ usage = "_FUNC_(n, v1, ..., vk) - Separate v1, ..., vk into n rows.",
+ extended = "> SELECT _FUNC_(2, 1, 2, 3);\n [1,2]\n [3,null]")
+case class Stack(children: Seq[Expression])
+ extends Expression with Generator with CodegenFallback {
+
+ private lazy val numRows = children.head.eval().asInstanceOf[Int]
+ private lazy val numFields = Math.ceil((children.length - 1.0) / numRows).toInt
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.length <= 1) {
+ TypeCheckResult.TypeCheckFailure(s"$prettyName requires at least 2 arguments.")
+ } else if (children.head.dataType != IntegerType || !children.head.foldable || numRows < 1) {
+ TypeCheckResult.TypeCheckFailure("The number of rows must be a positive constant integer.")
+ } else {
+ for (i <- 1 until children.length) {
+ val j = (i - 1) % numFields
+ if (children(i).dataType != elementSchema.fields(j).dataType) {
+ return TypeCheckResult.TypeCheckFailure(
+ s"Argument ${j + 1} (${elementSchema.fields(j).dataType}) != " +
+ s"Argument $i (${children(i).dataType})")
+ }
+ }
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ override def elementSchema: StructType =
+ StructType(children.tail.take(numFields).zipWithIndex.map {
+ case (e, index) => StructField(s"col$index", e.dataType)
+ })
+
+ override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
+ val values = children.tail.map(_.eval(input)).toArray
+ for (row <- 0 until numRows) yield {
+ val fields = new Array[Any](numFields)
+ for (col <- 0 until numFields) {
+ val index = row * numFields + col
+ fields.update(col, if (index < values.length) values(index) else null)
+ }
+ InternalRow(fields: _*)
+ }
+ }
+}
+
+/**
+ * A base class for Explode and PosExplode
+ */
+abstract class ExplodeBase(child: Expression, position: Boolean)
+ extends UnaryExpression with Generator with CodegenFallback with Serializable {
override def children: Seq[Expression] = child :: Nil
@@ -115,9 +165,26 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
// hive-compatible default alias for explode function ("col" for array, "key", "value" for map)
override def elementSchema: StructType = child.dataType match {
- case ArrayType(et, containsNull) => new StructType().add("col", et, containsNull)
+ case ArrayType(et, containsNull) =>
+ if (position) {
+ new StructType()
+ .add("pos", IntegerType, false)
+ .add("col", et, containsNull)
+ } else {
+ new StructType()
+ .add("col", et, containsNull)
+ }
case MapType(kt, vt, valueContainsNull) =>
- new StructType().add("key", kt, false).add("value", vt, valueContainsNull)
+ if (position) {
+ new StructType()
+ .add("pos", IntegerType, false)
+ .add("key", kt, false)
+ .add("value", vt, valueContainsNull)
+ } else {
+ new StructType()
+ .add("key", kt, false)
+ .add("value", vt, valueContainsNull)
+ }
}
override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
@@ -129,7 +196,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
} else {
val rows = new Array[InternalRow](inputArray.numElements())
inputArray.foreach(et, (i, e) => {
- rows(i) = InternalRow(e)
+ rows(i) = if (position) InternalRow(i, e) else InternalRow(e)
})
rows
}
@@ -141,7 +208,7 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
val rows = new Array[InternalRow](inputMap.numElements())
var i = 0
inputMap.foreach(kt, vt, (k, v) => {
- rows(i) = InternalRow(k, v)
+ rows(i) = if (position) InternalRow(i, k, v) else InternalRow(k, v)
i += 1
})
rows
@@ -149,3 +216,70 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit
}
}
}
+
+/**
+ * Given an input array produces a sequence of rows for each value in the array.
+ *
+ * {{{
+ * SELECT explode(array(10,20)) ->
+ * 10
+ * 20
+ * }}}
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of map a into multiple rows and columns.",
+ extended = "> SELECT _FUNC_(array(10,20));\n 10\n 20")
+// scalastyle:on line.size.limit
+case class Explode(child: Expression) extends ExplodeBase(child, position = false)
+
+/**
+ * Given an input array produces a sequence of rows for each position and value in the array.
+ *
+ * {{{
+ * SELECT posexplode(array(10,20)) ->
+ * 0 10
+ * 1 20
+ * }}}
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Separates the elements of array a into multiple rows with positions, or the elements of a map into multiple rows and columns with positions.",
+ extended = "> SELECT _FUNC_(array(10,20));\n 0\t10\n 1\t20")
+// scalastyle:on line.size.limit
+case class PosExplode(child: Expression) extends ExplodeBase(child, position = true)
+
+/**
+ * Explodes an array of structs into a table.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(a) - Explodes an array of structs into a table.",
+ extended = "> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));\n [1,a]\n [2,b]")
+case class Inline(child: Expression) extends UnaryExpression with Generator with CodegenFallback {
+
+ override def children: Seq[Expression] = child :: Nil
+
+ override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
+ case ArrayType(et, _) if et.isInstanceOf[StructType] =>
+ TypeCheckResult.TypeCheckSuccess
+ case _ =>
+ TypeCheckResult.TypeCheckFailure(
+ s"input to function $prettyName should be array of struct type, not ${child.dataType}")
+ }
+
+ override def elementSchema: StructType = child.dataType match {
+ case ArrayType(et : StructType, _) => et
+ }
+
+ private lazy val numFields = elementSchema.fields.length
+
+ override def eval(input: InternalRow): TraversableOnce[InternalRow] = {
+ val inputArray = child.eval(input).asInstanceOf[ArrayData]
+ if (inputArray == null) {
+ Nil
+ } else {
+ for (i <- 0 until inputArray.numElements())
+ yield inputArray.getStruct(i, numFields)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 44ff7fda8ef4..61549c9a2368 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -17,12 +17,17 @@
package org.apache.spark.sql.catalyst.expressions
-import java.text.{DecimalFormat, DecimalFormatSymbols}
+import java.net.{MalformedURLException, URL}
+import java.text.{BreakIterator, DecimalFormat, DecimalFormatSymbols}
import java.util.{HashMap, Locale, Map => JMap}
+import java.util.regex.Pattern
+
+import scala.collection.mutable.ArrayBuffer
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
@@ -162,6 +167,46 @@ case class ConcatWs(children: Seq[Expression])
}
}
+@ExpressionDescription(
+ usage = "_FUNC_(n, str1, str2, ...) - returns the n-th string, e.g. returns str2 when n is 2",
+ extended = "> SELECT _FUNC_(1, 'scala', 'java') FROM src LIMIT 1;\n" + "'scala'")
+case class Elt(children: Seq[Expression])
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
+
+ private lazy val indexExpr = children.head
+ private lazy val stringExprs = children.tail.toArray
+
+ /** This expression is always nullable because it returns null if index is out of range. */
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = StringType
+
+ override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.size < 2) {
+ TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments")
+ } else {
+ super[ImplicitCastInputTypes].checkInputDataTypes()
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val indexObj = indexExpr.eval(input)
+ if (indexObj == null) {
+ null
+ } else {
+ val index = indexObj.asInstanceOf[Int]
+ if (index <= 0 || index > stringExprs.length) {
+ null
+ } else {
+ stringExprs(index - 1).eval(input)
+ }
+ }
+ }
+}
+
+
trait String2StringExpression extends ImplicitCastInputTypes {
self: UnaryExpression =>
@@ -611,6 +656,154 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
override def prettyName: String = "rpad"
}
+object ParseUrl {
+ private val HOST = UTF8String.fromString("HOST")
+ private val PATH = UTF8String.fromString("PATH")
+ private val QUERY = UTF8String.fromString("QUERY")
+ private val REF = UTF8String.fromString("REF")
+ private val PROTOCOL = UTF8String.fromString("PROTOCOL")
+ private val FILE = UTF8String.fromString("FILE")
+ private val AUTHORITY = UTF8String.fromString("AUTHORITY")
+ private val USERINFO = UTF8String.fromString("USERINFO")
+ private val REGEXPREFIX = "(&|^)"
+ private val REGEXSUBFIX = "=([^&]*)"
+}
+
+/**
+ * Extracts a part from a URL
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(url, partToExtract[, key]) - extracts a part from a URL",
+ extended = """Parts: HOST, PATH, QUERY, REF, PROTOCOL, AUTHORITY, FILE, USERINFO.
+ Key specifies which query to extract.
+ Examples:
+ > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'HOST')
+ 'spark.apache.org'
+ > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY')
+ 'query=1'
+ > SELECT _FUNC_('http://spark.apache.org/path?query=1', 'QUERY', 'query')
+ '1'""")
+case class ParseUrl(children: Seq[Expression])
+ extends Expression with ExpectsInputTypes with CodegenFallback {
+
+ override def nullable: Boolean = true
+ override def inputTypes: Seq[DataType] = Seq.fill(children.size)(StringType)
+ override def dataType: DataType = StringType
+ override def prettyName: String = "parse_url"
+
+ // If the url is a constant, cache the URL object so that we don't need to convert url
+ // from UTF8String to String to URL for every row.
+ @transient private lazy val cachedUrl = children(0) match {
+ case Literal(url: UTF8String, _) if url ne null => getUrl(url)
+ case _ => null
+ }
+
+ // If the key is a constant, cache the Pattern object so that we don't need to convert key
+ // from UTF8String to String to StringBuilder to String to Pattern for every row.
+ @transient private lazy val cachedPattern = children(2) match {
+ case Literal(key: UTF8String, _) if key ne null => getPattern(key)
+ case _ => null
+ }
+
+ // If the partToExtract is a constant, cache the Extract part function so that we don't need
+ // to check the partToExtract for every row.
+ @transient private lazy val cachedExtractPartFunc = children(1) match {
+ case Literal(part: UTF8String, _) => getExtractPartFunc(part)
+ case _ => null
+ }
+
+ import ParseUrl._
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.size > 3 || children.size < 2) {
+ TypeCheckResult.TypeCheckFailure(s"$prettyName function requires two or three arguments")
+ } else {
+ super[ExpectsInputTypes].checkInputDataTypes()
+ }
+ }
+
+ private def getPattern(key: UTF8String): Pattern = {
+ Pattern.compile(REGEXPREFIX + key.toString + REGEXSUBFIX)
+ }
+
+ private def getUrl(url: UTF8String): URL = {
+ try {
+ new URL(url.toString)
+ } catch {
+ case e: MalformedURLException => null
+ }
+ }
+
+ private def getExtractPartFunc(partToExtract: UTF8String): URL => String = {
+ partToExtract match {
+ case HOST => _.getHost
+ case PATH => _.getPath
+ case QUERY => _.getQuery
+ case REF => _.getRef
+ case PROTOCOL => _.getProtocol
+ case FILE => _.getFile
+ case AUTHORITY => _.getAuthority
+ case USERINFO => _.getUserInfo
+ case _ => (url: URL) => null
+ }
+ }
+
+ private def extractValueFromQuery(query: UTF8String, pattern: Pattern): UTF8String = {
+ val m = pattern.matcher(query.toString)
+ if (m.find()) {
+ UTF8String.fromString(m.group(2))
+ } else {
+ null
+ }
+ }
+
+ private def extractFromUrl(url: URL, partToExtract: UTF8String): UTF8String = {
+ if (cachedExtractPartFunc ne null) {
+ UTF8String.fromString(cachedExtractPartFunc.apply(url))
+ } else {
+ UTF8String.fromString(getExtractPartFunc(partToExtract).apply(url))
+ }
+ }
+
+ private def parseUrlWithoutKey(url: UTF8String, partToExtract: UTF8String): UTF8String = {
+ if (cachedUrl ne null) {
+ extractFromUrl(cachedUrl, partToExtract)
+ } else {
+ val currentUrl = getUrl(url)
+ if (currentUrl ne null) {
+ extractFromUrl(currentUrl, partToExtract)
+ } else {
+ null
+ }
+ }
+ }
+
+ override def eval(input: InternalRow): Any = {
+ val evaluated = children.map{e => e.eval(input).asInstanceOf[UTF8String]}
+ if (evaluated.contains(null)) return null
+ if (evaluated.size == 2) {
+ parseUrlWithoutKey(evaluated(0), evaluated(1))
+ } else {
+ // 3-arg, i.e. QUERY with key
+ assert(evaluated.size == 3)
+ if (evaluated(1) != QUERY) {
+ return null
+ }
+
+ val query = parseUrlWithoutKey(evaluated(0), evaluated(1))
+ if (query eq null) {
+ return null
+ }
+
+ if (cachedPattern ne null) {
+ extractValueFromQuery(query, cachedPattern)
+ } else {
+ extractValueFromQuery(query, getPattern(evaluated(2)))
+ }
+ }
+ }
+}
+
/**
* Returns the input formatted according do printf-style format strings
*/
@@ -1147,3 +1340,65 @@ case class FormatNumber(x: Expression, d: Expression)
override def prettyName: String = "format_number"
}
+
+/**
+ * Splits a string into arrays of sentences, where each sentence is an array of words.
+ * The 'lang' and 'country' arguments are optional, and if omitted, the default locale is used.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(str[, lang, country]) - Splits str into an array of array of words.",
+ extended = "> SELECT _FUNC_('Hi there! Good morning.');\n [['Hi','there'], ['Good','morning']]")
+case class Sentences(
+ str: Expression,
+ language: Expression = Literal(""),
+ country: Expression = Literal(""))
+ extends Expression with ImplicitCastInputTypes with CodegenFallback {
+
+ def this(str: Expression) = this(str, Literal(""), Literal(""))
+ def this(str: Expression, language: Expression) = this(str, language, Literal(""))
+
+ override def nullable: Boolean = true
+ override def dataType: DataType =
+ ArrayType(ArrayType(StringType, containsNull = false), containsNull = false)
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
+ override def children: Seq[Expression] = str :: language :: country :: Nil
+
+ override def eval(input: InternalRow): Any = {
+ val string = str.eval(input)
+ if (string == null) {
+ null
+ } else {
+ val languageStr = language.eval(input).asInstanceOf[UTF8String]
+ val countryStr = country.eval(input).asInstanceOf[UTF8String]
+ val locale = if (languageStr != null && countryStr != null) {
+ new Locale(languageStr.toString, countryStr.toString)
+ } else {
+ Locale.getDefault
+ }
+ getSentences(string.asInstanceOf[UTF8String].toString, locale)
+ }
+ }
+
+ private def getSentences(sentences: String, locale: Locale) = {
+ val bi = BreakIterator.getSentenceInstance(locale)
+ bi.setText(sentences)
+ var idx = 0
+ val result = new ArrayBuffer[GenericArrayData]
+ while (bi.next != BreakIterator.DONE) {
+ val sentence = sentences.substring(idx, bi.current)
+ idx = bi.current
+
+ val wi = BreakIterator.getWordInstance(locale)
+ var widx = 0
+ wi.setText(sentence)
+ val words = new ArrayBuffer[UTF8String]
+ while (wi.next != BreakIterator.DONE) {
+ val word = sentence.substring(widx, wi.current)
+ widx = wi.current
+ if (Character.isLetterOrDigit(word.charAt(0))) words += UTF8String.fromString(word)
+ }
+ result += new GenericArrayData(words)
+ }
+ new GenericArrayData(result)
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala
new file mode 100644
index 000000000000..2a5256c7f56f
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathBoolean.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.xml
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.types.{AbstractDataType, BooleanType, DataType, StringType}
+import org.apache.spark.unsafe.types.UTF8String
+
+
+@ExpressionDescription(
+ usage = "_FUNC_(xml, xpath) - Evaluates a boolean xpath expression.",
+ extended = "> SELECT _FUNC_('1','a/b');\ntrue")
+case class XPathBoolean(xml: Expression, path: Expression)
+ extends BinaryExpression with ExpectsInputTypes with CodegenFallback {
+
+ @transient private lazy val xpathUtil = new UDFXPathUtil
+
+ // If the path is a constant, cache the path string so that we don't need to convert path
+ // from UTF8String to String for every row.
+ @transient lazy val pathLiteral: String = path match {
+ case Literal(str: UTF8String, _) => str.toString
+ case _ => null
+ }
+
+ override def prettyName: String = "xpath_boolean"
+
+ override def dataType: DataType = BooleanType
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
+
+ override def left: Expression = xml
+ override def right: Expression = path
+
+ override protected def nullSafeEval(xml: Any, path: Any): Any = {
+ val xmlString = xml.asInstanceOf[UTF8String].toString
+ if (pathLiteral ne null) {
+ xpathUtil.evalBoolean(xmlString, pathLiteral)
+ } else {
+ xpathUtil.evalBoolean(xmlString, path.asInstanceOf[UTF8String].toString)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 4984f235b412..d0b2b5d7b2df 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -265,6 +265,11 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
s"Reference '$name' is ambiguous, could be: $referenceNames.")
}
}
+
+ /**
+ * Refreshes (or invalidates) any metadata/data cached in the plan recursively.
+ */
+ def refresh(): Unit = children.foreach(_.refresh())
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
index 436512ff6933..55fdfbe3e046 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala
@@ -298,6 +298,12 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
Utils.truncatedString(fieldTypes, "struct<", ",", ">")
}
+ override def catalogString: String = {
+ // in catalogString, we should not truncate
+ val fieldTypes = fields.map(field => s"${field.name}:${field.dataType.catalogString}")
+ s"struct<${fieldTypes.mkString(",")}>"
+ }
+
override def sql: String = {
val fieldTypes = fields.map(f => s"${quoteIdentifier(f.name)}: ${f.dataType.sql}")
s"STRUCT<${fieldTypes.mkString(", ")}>"
@@ -378,10 +384,10 @@ object StructType extends AbstractDataType {
StructType(fields.asScala)
}
- protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
+ private[sql] def fromAttributes(attributes: Seq[Attribute]): StructType =
StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))
- def removeMetadata(key: String, dt: DataType): DataType =
+ private[sql] def removeMetadata(key: String, dt: DataType): DataType =
dt match {
case StructType(fields) =>
val newFields = fields.map { f =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index a41383fbf656..a9cde1e19efc 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count}
+import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max}
import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData}
@@ -162,6 +162,16 @@ class AnalysisErrorSuite extends AnalysisTest {
UnspecifiedFrame)).as('window)),
"Distinct window functions are not supported" :: Nil)
+ errorTest(
+ "nested aggregate functions",
+ testRelation.groupBy('a)(
+ AggregateExpression(
+ Max(AggregateExpression(Count(Literal(1)), Complete, isDistinct = false)),
+ Complete,
+ isDistinct = false)),
+ "not allowed to use an aggregate function in the argument of another aggregate function." :: Nil
+ )
+
errorTest(
"offset window function",
testRelation2.select(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
index 54436ea9a4a7..76e42d9afa4c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala
@@ -166,6 +166,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite {
assertError(new Murmur3Hash(Nil), "function hash requires at least one argument")
assertError(Explode('intField),
"input to function explode should be array or map type")
+ assertError(PosExplode('intField),
+ "input to function explode should be array or map type")
}
test("check types for CreateNamedStruct") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
index 748579df4158..100ec4d53fb8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveNaturalJoinSuite.scala
@@ -113,4 +113,34 @@ class ResolveNaturalJoinSuite extends AnalysisTest {
assert(error.message.contains(
"using columns ['d] can not be resolved given input columns: [b, a, c]"))
}
+
+ test("using join with a case sensitive analyzer") {
+ val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c)
+
+ {
+ val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None)
+ checkAnalysis(usingPlan, expected, caseSensitive = true)
+ }
+
+ {
+ val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("A"))), None)
+ assertAnalysisError(
+ usingPlan,
+ Seq("using columns ['A] can not be resolved given input columns: [b, a, c, a]"))
+ }
+ }
+
+ test("using join with a case insensitive analyzer") {
+ val expected = r1.join(r2, Inner, Some(EqualTo(a, a))).select(a, b, c)
+
+ {
+ val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("a"))), None)
+ checkAnalysis(usingPlan, expected, caseSensitive = false)
+ }
+
+ {
+ val usingPlan = r1.join(r2, UsingJoin(Inner, Seq(UnresolvedAttribute("A"))), None)
+ checkAnalysis(usingPlan, expected, caseSensitive = false)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
index c21ad5e03a48..6df47acaba85 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala
@@ -53,12 +53,12 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
assertNotSupportedInBatchPlan(
"streaming source",
streamRelation,
- Seq("with streaming source", "startStream"))
+ Seq("with streaming source", "start"))
assertNotSupportedInBatchPlan(
"select on streaming source",
streamRelation.select($"count(*)"),
- Seq("with streaming source", "startStream"))
+ Seq("with streaming source", "start"))
/*
@@ -70,7 +70,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite {
// Batch plan in streaming query
testError(
"streaming plan - no streaming source",
- Seq("without streaming source", "startStream")) {
+ Seq("without streaming source", "start")) {
UnsupportedOperationChecker.checkForStreaming(batchRelation.select($"count(*)"), Append)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 60dd03f5d0c1..8ea8f6115084 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -53,9 +53,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
test("metrics are recorded on compile") {
val startCount1 = CodegenMetrics.METRIC_COMPILATION_TIME.getCount()
val startCount2 = CodegenMetrics.METRIC_SOURCE_CODE_SIZE.getCount()
+ val startCount3 = CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.getCount()
+ val startCount4 = CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.getCount()
GenerateOrdering.generate(Add(Literal(123), Literal(1)).asc :: Nil)
assert(CodegenMetrics.METRIC_COMPILATION_TIME.getCount() == startCount1 + 1)
assert(CodegenMetrics.METRIC_SOURCE_CODE_SIZE.getCount() == startCount2 + 1)
+ assert(CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.getCount() > startCount1)
+ assert(CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.getCount() > startCount1)
}
test("SPARK-8443: split wide projections into blocks due to JVM code size limit") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
index 1aae4678d627..a5f784fdcc13 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionFunctionsSuite.scala
@@ -44,6 +44,19 @@ class CollectionFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Literal.create(null, ArrayType(StringType)), null)
}
+ test("MapKeys/MapValues") {
+ val m0 = Literal.create(Map("a" -> "1", "b" -> "2"), MapType(StringType, StringType))
+ val m1 = Literal.create(Map[String, String](), MapType(StringType, StringType))
+ val m2 = Literal.create(null, MapType(StringType, StringType))
+
+ checkEvaluation(MapKeys(m0), Seq("a", "b"))
+ checkEvaluation(MapValues(m0), Seq("1", "2"))
+ checkEvaluation(MapKeys(m1), Seq())
+ checkEvaluation(MapValues(m1), Seq())
+ checkEvaluation(MapKeys(m2), null)
+ checkEvaluation(MapValues(m2), null)
+ }
+
test("Sort Array") {
val a0 = Literal.create(Seq(2, 1, 3), ArrayType(IntegerType))
val a1 = Literal.create(Seq[Integer](), ArrayType(IntegerType))
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
new file mode 100644
index 000000000000..e29dfa41f1cc
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/GeneratorExpressionSuite.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
+
+class GeneratorExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+ private def checkTuple(actual: Expression, expected: Seq[InternalRow]): Unit = {
+ assert(actual.eval(null).asInstanceOf[TraversableOnce[InternalRow]].toSeq === expected)
+ }
+
+ private final val empty_array = CreateArray(Seq.empty)
+ private final val int_array = CreateArray(Seq(1, 2, 3).map(Literal(_)))
+ private final val str_array = CreateArray(Seq("a", "b", "c").map(Literal(_)))
+
+ test("explode") {
+ val int_correct_answer = Seq(create_row(1), create_row(2), create_row(3))
+ val str_correct_answer = Seq(create_row("a"), create_row("b"), create_row("c"))
+
+ checkTuple(Explode(empty_array), Seq.empty)
+ checkTuple(Explode(int_array), int_correct_answer)
+ checkTuple(Explode(str_array), str_correct_answer)
+ }
+
+ test("posexplode") {
+ val int_correct_answer = Seq(create_row(0, 1), create_row(1, 2), create_row(2, 3))
+ val str_correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))
+
+ checkTuple(PosExplode(CreateArray(Seq.empty)), Seq.empty)
+ checkTuple(PosExplode(int_array), int_correct_answer)
+ checkTuple(PosExplode(str_array), str_correct_answer)
+ }
+
+ test("inline") {
+ val correct_answer = Seq(create_row(0, "a"), create_row(1, "b"), create_row(2, "c"))
+
+ checkTuple(
+ Inline(Literal.create(Array(), ArrayType(new StructType().add("id", LongType)))),
+ Seq.empty)
+
+ checkTuple(
+ Inline(CreateArray(Seq(
+ CreateStruct(Seq(Literal(0), Literal("a"))),
+ CreateStruct(Seq(Literal(1), Literal("b"))),
+ CreateStruct(Seq(Literal(2), Literal("c")))
+ ))),
+ correct_answer)
+ }
+
+ test("stack") {
+ checkTuple(Stack(Seq(1, 1).map(Literal(_))), Seq(create_row(1)))
+ checkTuple(Stack(Seq(1, 1, 2).map(Literal(_))), Seq(create_row(1, 2)))
+ checkTuple(Stack(Seq(2, 1, 2).map(Literal(_))), Seq(create_row(1), create_row(2)))
+ checkTuple(Stack(Seq(2, 1, 2, 3).map(Literal(_))), Seq(create_row(1, 2), create_row(3, null)))
+ checkTuple(Stack(Seq(3, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3).map(create_row(_)))
+ checkTuple(Stack(Seq(4, 1, 2, 3).map(Literal(_))), Seq(1, 2, 3, null).map(create_row(_)))
+
+ checkTuple(
+ Stack(Seq(3, 1, 1.0, "a", 2, 2.0, "b", 3, 3.0, "c").map(Literal(_))),
+ Seq(create_row(1, 1.0, "a"), create_row(2, 2.0, "b"), create_row(3, 3.0, "c")))
+
+ assert(Stack(Seq(Literal(1))).checkInputDataTypes().isFailure)
+ assert(Stack(Seq(Literal(1.0))).checkInputDataTypes().isFailure)
+ assert(Stack(Seq(Literal(1), Literal(1), Literal(1.0))).checkInputDataTypes().isSuccess)
+ assert(Stack(Seq(Literal(2), Literal(1), Literal(1.0))).checkInputDataTypes().isFailure)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 29bf15bf524b..8f7b1041fad3 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -75,6 +75,29 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
// scalastyle:on
}
+ test("elt") {
+ def testElt(result: String, n: java.lang.Integer, args: String*): Unit = {
+ checkEvaluation(
+ Elt(Literal.create(n, IntegerType) +: args.map(Literal.create(_, StringType))),
+ result)
+ }
+
+ testElt("hello", 1, "hello", "world")
+ testElt(null, 1, null, "world")
+ testElt(null, null, "hello", "world")
+
+ // Invalid ranages
+ testElt(null, 3, "hello", "world")
+ testElt(null, 0, "hello", "world")
+ testElt(null, -1, "hello", "world")
+
+ // type checking
+ assert(Elt(Seq.empty).checkInputDataTypes().isFailure)
+ assert(Elt(Seq(Literal(1))).checkInputDataTypes().isFailure)
+ assert(Elt(Seq(Literal(1), Literal("A"))).checkInputDataTypes().isSuccess)
+ assert(Elt(Seq(Literal(1), Literal(2))).checkInputDataTypes().isFailure)
+ }
+
test("StringComparison") {
val row = create_row("abc", null)
val c1 = 'a.string.at(0)
@@ -702,4 +725,78 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(FindInSet(Literal("abf"), Literal("abc,b,ab,c,def")), 0)
checkEvaluation(FindInSet(Literal("ab,"), Literal("abc,b,ab,c,def")), 0)
}
+
+ test("ParseUrl") {
+ def checkParseUrl(expected: String, urlStr: String, partToExtract: String): Unit = {
+ checkEvaluation(
+ ParseUrl(Seq(Literal(urlStr), Literal(partToExtract))), expected)
+ }
+ def checkParseUrlWithKey(
+ expected: String,
+ urlStr: String,
+ partToExtract: String,
+ key: String): Unit = {
+ checkEvaluation(
+ ParseUrl(Seq(Literal(urlStr), Literal(partToExtract), Literal(key))), expected)
+ }
+
+ checkParseUrl("spark.apache.org", "http://spark.apache.org/path?query=1", "HOST")
+ checkParseUrl("/path", "http://spark.apache.org/path?query=1", "PATH")
+ checkParseUrl("query=1", "http://spark.apache.org/path?query=1", "QUERY")
+ checkParseUrl("Ref", "http://spark.apache.org/path?query=1#Ref", "REF")
+ checkParseUrl("http", "http://spark.apache.org/path?query=1", "PROTOCOL")
+ checkParseUrl("/path?query=1", "http://spark.apache.org/path?query=1", "FILE")
+ checkParseUrl("spark.apache.org:8080", "http://spark.apache.org:8080/path?query=1", "AUTHORITY")
+ checkParseUrl("userinfo", "http://userinfo@spark.apache.org/path?query=1", "USERINFO")
+ checkParseUrlWithKey("1", "http://spark.apache.org/path?query=1", "QUERY", "query")
+
+ // Null checking
+ checkParseUrl(null, null, "HOST")
+ checkParseUrl(null, "http://spark.apache.org/path?query=1", null)
+ checkParseUrl(null, null, null)
+ checkParseUrl(null, "test", "HOST")
+ checkParseUrl(null, "http://spark.apache.org/path?query=1", "NO")
+ checkParseUrl(null, "http://spark.apache.org/path?query=1", "USERINFO")
+ checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "HOST", "query")
+ checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "quer")
+ checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", null)
+ checkParseUrlWithKey(null, "http://spark.apache.org/path?query=1", "QUERY", "")
+
+ // exceptional cases
+ intercept[java.util.regex.PatternSyntaxException] {
+ evaluate(ParseUrl(Seq(Literal("http://spark.apache.org/path?"),
+ Literal("QUERY"), Literal("???"))))
+ }
+
+ // arguments checking
+ assert(ParseUrl(Seq(Literal("1"))).checkInputDataTypes().isFailure)
+ assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal("3"), Literal("4")))
+ .checkInputDataTypes().isFailure)
+ assert(ParseUrl(Seq(Literal("1"), Literal(2))).checkInputDataTypes().isFailure)
+ assert(ParseUrl(Seq(Literal(1), Literal("2"))).checkInputDataTypes().isFailure)
+ assert(ParseUrl(Seq(Literal("1"), Literal("2"), Literal(3))).checkInputDataTypes().isFailure)
+ }
+
+ test("Sentences") {
+ val nullString = Literal.create(null, StringType)
+ checkEvaluation(Sentences(nullString, nullString, nullString), null)
+ checkEvaluation(Sentences(nullString, nullString), null)
+ checkEvaluation(Sentences(nullString), null)
+ checkEvaluation(Sentences(Literal.create(null, NullType)), null)
+ checkEvaluation(Sentences("", nullString, nullString), Seq.empty)
+ checkEvaluation(Sentences("", nullString), Seq.empty)
+ checkEvaluation(Sentences(""), Seq.empty)
+
+ val answer = Seq(
+ Seq("Hi", "there"),
+ Seq("The", "price", "was"),
+ Seq("But", "not", "now"))
+
+ checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now."), answer)
+ checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en"), answer)
+ checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "en", "US"),
+ answer)
+ checkEvaluation(Sentences("Hi there! The price was $1,234.56.... But, not now.", "XXX", "YYY"),
+ answer)
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala
new file mode 100644
index 000000000000..c7c386b5b838
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSuite.scala
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.codegen
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+
+class BufferHolderSuite extends SparkFunSuite {
+
+ test("SPARK-16071 Check the size limit to avoid integer overflow") {
+ var e = intercept[UnsupportedOperationException] {
+ new BufferHolder(new UnsafeRow(Int.MaxValue / 8))
+ }
+ assert(e.getMessage.contains("too many fields"))
+
+ val holder = new BufferHolder(new UnsafeRow(1000))
+ holder.reset()
+ holder.grow(1000)
+ e = intercept[UnsupportedOperationException] {
+ holder.grow(Integer.MAX_VALUE)
+ }
+ assert(e.getMessage.contains("exceeds size limitation"))
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/ReusableStringReaderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/ReusableStringReaderSuite.scala
new file mode 100644
index 000000000000..e06d209c474b
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/ReusableStringReaderSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.xml
+
+import java.io.IOException
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.expressions.xml.UDFXPathUtil.ReusableStringReader
+
+/**
+ * Unit tests for [[UDFXPathUtil.ReusableStringReader]].
+ *
+ * Loosely based on Hive's TestReusableStringReader.java.
+ */
+class ReusableStringReaderSuite extends SparkFunSuite {
+
+ private val fox = "Quick brown fox jumps over the lazy dog."
+
+ test("empty reader") {
+ val reader = new ReusableStringReader
+
+ intercept[IOException] {
+ reader.read()
+ }
+
+ intercept[IOException] {
+ reader.ready()
+ }
+
+ reader.close()
+ }
+
+ test("mark reset") {
+ val reader = new ReusableStringReader
+
+ if (reader.markSupported()) {
+ reader.asInstanceOf[ReusableStringReader].set(fox)
+ assert(reader.ready())
+
+ val cc = new Array[Char](6)
+ var read = reader.read(cc)
+ assert(read == 6)
+ assert("Quick " == new String(cc))
+
+ reader.mark(100)
+
+ read = reader.read(cc)
+ assert(read == 6)
+ assert("brown " == new String(cc))
+
+ reader.reset()
+ read = reader.read(cc)
+ assert(read == 6)
+ assert("brown " == new String(cc))
+ }
+ reader.close()
+ }
+
+ test("skip") {
+ val reader = new ReusableStringReader
+ reader.asInstanceOf[ReusableStringReader].set(fox)
+
+ // skip entire the data:
+ var skipped = reader.skip(fox.length() + 1)
+ assert(fox.length() == skipped)
+ assert(-1 == reader.read())
+
+ reader.asInstanceOf[ReusableStringReader].set(fox) // reset the data
+ val cc = new Array[Char](6)
+ var read = reader.read(cc)
+ assert(read == 6)
+ assert("Quick " == new String(cc))
+
+ // skip some piece of data:
+ skipped = reader.skip(30)
+ assert(skipped == 30)
+ read = reader.read(cc)
+ assert(read == 4)
+ assert("dog." == new String(cc, 0, read))
+
+ // skip when already at EOF:
+ skipped = reader.skip(300)
+ assert(skipped == 0, skipped)
+ assert(reader.read() == -1)
+
+ reader.close()
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala
new file mode 100644
index 000000000000..a5614f83844e
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/UDFXPathUtilSuite.scala
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.xml
+
+import javax.xml.xpath.XPathConstants.STRING
+
+import org.w3c.dom.Node
+import org.w3c.dom.NodeList
+
+import org.apache.spark.SparkFunSuite
+
+/**
+ * Unit tests for [[UDFXPathUtil]]. Loosely based on Hive's TestUDFXPathUtil.java.
+ */
+class UDFXPathUtilSuite extends SparkFunSuite {
+
+ private lazy val util = new UDFXPathUtil
+
+ test("illegal arguments") {
+ // null args
+ assert(util.eval(null, "a/text()", STRING) == null)
+ assert(util.eval("b1b2b3c1c2", null, STRING) == null)
+ assert(
+ util.eval("b1b2b3c1c2", "a/text()", null) == null)
+
+ // empty String args
+ assert(util.eval("", "a/text()", STRING) == null)
+ assert(util.eval("b1b2b3c1c2", "", STRING) == null)
+
+ // wrong expression:
+ assert(
+ util.eval("b1b2b3c1c2", "a/text(", STRING) == null)
+ }
+
+ test("generic eval") {
+ val ret =
+ util.eval("b1b2b3c1c2", "a/c[2]/text()", STRING)
+ assert(ret == "c2")
+ }
+
+ test("boolean eval") {
+ var ret =
+ util.evalBoolean("truefalseb3c1c2", "a/b[1]/text()")
+ assert(ret == true)
+
+ ret = util.evalBoolean("truefalseb3c1c2", "a/b[4]")
+ assert(ret == false)
+ }
+
+ test("string eval") {
+ var ret =
+ util.evalString("truefalseb3c1c2", "a/b[3]/text()")
+ assert(ret == "b3")
+
+ ret =
+ util.evalString("truefalseb3c1c2", "a/b[4]/text()")
+ assert(ret == "")
+
+ ret = util.evalString(
+ "trueFALSEb3c1c2", "a/b[2]/@k")
+ assert(ret == "foo")
+ }
+
+ test("number eval") {
+ var ret =
+ util.evalNumber("truefalseb3c1-77", "a/c[2]")
+ assert(ret == -77.0d)
+
+ ret = util.evalNumber(
+ "trueFALSEb3c1c2", "a/b[2]/@k")
+ assert(ret.isNaN)
+ }
+
+ test("node eval") {
+ val ret = util.evalNode("truefalseb3c1-77", "a/c[2]")
+ assert(ret != null && ret.isInstanceOf[Node])
+ }
+
+ test("node list eval") {
+ val ret = util.evalNodeList("truefalseb3c1-77", "a/*")
+ assert(ret != null && ret.isInstanceOf[NodeList])
+ assert(ret.asInstanceOf[NodeList].getLength == 5)
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala
new file mode 100644
index 000000000000..f7c65c667efb
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/xml/XPathExpressionSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.xml
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, Literal}
+import org.apache.spark.sql.types.StringType
+
+/**
+ * Test suite for various xpath functions.
+ */
+class XPathExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ private def testBoolean[T](xml: String, path: String, expected: T): Unit = {
+ checkEvaluation(
+ XPathBoolean(Literal.create(xml, StringType), Literal.create(path, StringType)),
+ expected)
+ }
+
+ test("xpath_boolean") {
+ testBoolean("b", "a/b", true)
+ testBoolean("b", "a/c", false)
+ testBoolean("b", "a/b = \"b\"", true)
+ testBoolean("b", "a/b = \"c\"", false)
+ testBoolean("10", "a/b < 10", false)
+ testBoolean("10", "a/b = 10", true)
+
+ // null input
+ testBoolean(null, null, null)
+ testBoolean(null, "a", null)
+ testBoolean("10", null, null)
+
+ // exception handling for invalid input
+ intercept[Exception] {
+ testBoolean("/a>", "a", null)
+ }
+ }
+
+ test("xpath_boolean path cache invalidation") {
+ // This is a test to ensure the expression is not reusing the path for different strings
+ val expr = XPathBoolean(Literal("b"), 'path.string.at(0))
+ checkEvaluation(expr, true, create_row("a/b"))
+ checkEvaluation(expr, false, create_row("a/c"))
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
index 1f1b5389aa7d..cd521c52d1b2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java
@@ -29,6 +29,7 @@
import org.apache.spark.unsafe.KVIterator;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.map.BytesToBytesMap;
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
/**
* Unsafe-based HashMap for performing aggregations where the aggregated values are fixed-width.
@@ -246,6 +247,8 @@ public UnsafeKVExternalSorter destructAndCreateExternalSorter() throws IOExcepti
SparkEnv.get().blockManager(),
SparkEnv.get().serializerManager(),
map.getPageSizeBytes(),
+ SparkEnv.get().conf().getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
map);
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
index bb823cd07be5..82ee5b0d7771 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java
@@ -54,8 +54,10 @@ public UnsafeKVExternalSorter(
StructType valueSchema,
BlockManager blockManager,
SerializerManager serializerManager,
- long pageSizeBytes) throws IOException {
- this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes, null);
+ long pageSizeBytes,
+ long numElementsForSpillThreshold) throws IOException {
+ this(keySchema, valueSchema, blockManager, serializerManager, pageSizeBytes,
+ numElementsForSpillThreshold, null);
}
public UnsafeKVExternalSorter(
@@ -64,6 +66,7 @@ public UnsafeKVExternalSorter(
BlockManager blockManager,
SerializerManager serializerManager,
long pageSizeBytes,
+ long numElementsForSpillThreshold,
@Nullable BytesToBytesMap map) throws IOException {
this.keySchema = keySchema;
this.valueSchema = valueSchema;
@@ -88,6 +91,7 @@ public UnsafeKVExternalSorter(
prefixComparator,
/* initialSize */ 4096,
pageSizeBytes,
+ numElementsForSpillThreshold,
canUseRadixSort);
} else {
// The array will be used to do in-place sort, which require half of the space to be empty.
@@ -132,6 +136,7 @@ public UnsafeKVExternalSorter(
prefixComparator,
/* initialSize */ 4096,
pageSizeBytes,
+ numElementsForSpillThreshold,
inMemSorter);
// reset the map, so we can re-use it to insert new records. the inMemSorter will not used
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 9f35107e5bb6..a46d1949e94a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -159,6 +159,7 @@ class Column(protected[sql] val expr: Expression) extends Logging {
// Leave an unaliased generator with an empty list of names since the analyzer will generate
// the correct defaults after the nested expression's type has been resolved.
case explode: Explode => MultiAlias(explode, Nil)
+ case explode: PosExplode => MultiAlias(explode, Nil)
case jt: JsonTuple => MultiAlias(jt, Nil)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
index 35ba52278633..e8c2885d7737 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala
@@ -177,7 +177,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* clause expressions used to split the column `columnName` evenly.
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
- * should be included.
+ * should be included. "fetchsize" can be used to control the
+ * number of rows per fetch.
* @since 1.4.0
*/
def jdbc(
@@ -207,7 +208,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging {
* @param predicates Condition in the where clause for each partition.
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
- * should be included.
+ * should be included. "fetchsize" can be used to control the
+ * number of rows per fetch.
* @since 1.4.0
*/
def jdbc(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
index ca3972d62dfb..12b304623d30 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala
@@ -391,7 +391,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* @param table Name of the table in the external database.
* @param connectionProperties JDBC database connection arguments, a list of arbitrary string
* tag/value. Normally at least a "user" and "password" property
- * should be included.
+ * should be included. "batchsize" can be used to control the
+ * number of rows per insert.
* @since 1.4.0
*/
def jdbc(url: String, table: String, connectionProperties: Properties): Unit = {
@@ -536,6 +537,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) {
* `escapeQuotes` (default `true`): a flag indicating whether values containing
* quotes should always be enclosed in quotes. Default is to escape all values containing
* a quote character.
+ * `quoteAll` (default `false`): A flag indicating whether all values should always be
+ * enclosed in quotes. Default is to only escape values containing a quote character.
* `header` (default `false`): writes the names of columns as the first line.
* `nullValue` (default empty string): sets the string representation of a null value.
* `compression` (default `null`): compression codec to use when saving to file. This can be
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 85d060639c7f..067cbec4bf61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -455,8 +455,8 @@ class Dataset[T] private[sql](
/**
* Returns true if this Dataset contains one or more sources that continuously
* return data as it arrives. A Dataset that reads data from a streaming source
- * must be executed as a [[StreamingQuery]] using the `startStream()` method in
- * [[DataFrameWriter]]. Methods that return a single answer, e.g. `count()` or
+ * must be executed as a [[StreamingQuery]] using the `start()` method in
+ * [[DataStreamWriter]]. Methods that return a single answer, e.g. `count()` or
* `collect()`, will throw an [[AnalysisException]] when there is a streaming
* source present.
*
@@ -2350,14 +2350,14 @@ class Dataset[T] private[sql](
}
/**
- * Returns the content of the Dataset as a [[JavaRDD]] of [[Row]]s.
+ * Returns the content of the Dataset as a [[JavaRDD]] of [[T]]s.
* @group basic
* @since 1.6.0
*/
def toJavaRDD: JavaRDD[T] = rdd.toJavaRDD()
/**
- * Returns the content of the Dataset as a [[JavaRDD]] of [[Row]]s.
+ * Returns the content of the Dataset as a [[JavaRDD]] of [[T]]s.
* @group basic
* @since 1.6.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
index 083a63c98c43..91ed9b3258a1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalog/Catalog.scala
@@ -214,7 +214,7 @@ abstract class Catalog {
def clearCache(): Unit
/**
- * Invalidate and refresh all the cached the metadata of the given table. For performance reasons,
+ * Invalidate and refresh all the cached metadata of the given table. For performance reasons,
* Spark SQL or the external data source library it uses might cache certain metadata about a
* table, such as the location of blocks. When those change outside of Spark SQL, users should
* call this function to invalidate the cache.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
index 97bbab65af1d..e01094a7c8e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WindowExec.scala
@@ -345,6 +345,8 @@ case class WindowExec(
null,
1024,
SparkEnv.get.memoryManager.pageSizeBytes,
+ SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
false)
rows.foreach { r =>
sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
index fc00912bf9f5..226f61ef404a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala
@@ -206,7 +206,7 @@ case class DropTableCommand(
} catch {
case NonFatal(e) => log.warn(e.toString, e)
}
- catalog.invalidateTable(tableName)
+ catalog.refreshTable(tableName)
catalog.dropTable(tableName, ifExists)
}
Seq.empty[Row]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
index 30dc7e81e9ee..14836044cabe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala
@@ -172,7 +172,7 @@ case class AlterTableRenameCommand(
}
// Invalidate the table last, otherwise uncaching the table would load the logical plan
// back into the hive metastore cache
- catalog.invalidateTable(oldName)
+ catalog.refreshTable(oldName)
catalog.renameTable(oldName, newName)
if (wasCached) {
sparkSession.catalog.cacheTable(newName.unquotedString)
@@ -373,7 +373,7 @@ case class TruncateTableCommand(
}
// After deleting the data, invalidate the table to make sure we don't keep around a stale
// file relation in the metastore cache.
- spark.sessionState.invalidateTable(tableName.unquotedString)
+ spark.sessionState.refreshTable(tableName.unquotedString)
// Also try to drop the contents of the table from the columnar cache
try {
spark.sharedState.cacheManager.uncacheQuery(spark.table(tableName.quotedString))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
index 088f684365db..6533d796e806 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala
@@ -88,7 +88,11 @@ case class CreateViewCommand(
qe.assertAnalyzed()
val analyzedPlan = qe.analyzed
- require(tableDesc.schema == Nil || tableDesc.schema.length == analyzedPlan.output.length)
+ if (tableDesc.schema != Nil && tableDesc.schema.length != analyzedPlan.output.length) {
+ throw new AnalysisException(s"The number of columns produced by the SELECT clause " +
+ s"(num: `${analyzedPlan.output.length}`) does not match the number of column names " +
+ s"specified by CREATE VIEW (num: `${tableDesc.schema.length}`).")
+ }
val sessionState = sparkSession.sessionState
if (isTemporary) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
index 557445c2bc91..f572b93991e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala
@@ -203,6 +203,18 @@ case class DataSource(
val path = caseInsensitiveOptions.getOrElse("path", {
throw new IllegalArgumentException("'path' is not specified")
})
+
+ // Check whether the path exists if it is not a glob pattern.
+ // For glob pattern, we do not check it because the glob pattern might only make sense
+ // once the streaming job starts and some upstream source starts dropping data.
+ val hdfsPath = new Path(path)
+ if (!SparkHadoopUtil.get.isGlobPath(hdfsPath)) {
+ val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf())
+ if (!fs.exists(hdfsPath)) {
+ throw new AnalysisException(s"Path does not exist: $path")
+ }
+ }
+
val isSchemaInferenceEnabled = sparkSession.conf.get(SQLConf.STREAMING_SCHEMA_INFERENCE)
val isTextSource = providingClass == classOf[text.TextFileFormat]
// If the schema inference is disabled, only text sources require schema to be specified
@@ -364,7 +376,8 @@ case class DataSource(
}
val fileCatalog =
- new ListingFileCatalog(sparkSession, globbedPaths, options, partitionSchema)
+ new ListingFileCatalog(
+ sparkSession, globbedPaths, options, partitionSchema, !checkPathExist)
val dataSchema = userSpecifiedSchema.map { schema =>
val equality =
@@ -472,12 +485,11 @@ case class DataSource(
data.logicalPlan,
mode)
sparkSession.sessionState.executePlan(plan).toRdd
+ // Replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
+ copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()
case _ =>
sys.error(s"${providingClass.getCanonicalName} does not allow create table as select.")
}
-
- // We replace the schema with that of the DataFrame we just wrote out to avoid re-inferring it.
- copy(userSpecifiedSchema = Some(data.schema.asNullable)).resolveRelation()
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
index f7f68b1eb90d..1314c94d42cf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala
@@ -111,7 +111,20 @@ class FileScanRDD(
currentFile = files.next()
logInfo(s"Reading File $currentFile")
InputFileNameHolder.setInputFileName(currentFile.filePath)
- currentIterator = readFunction(currentFile)
+
+ try {
+ currentIterator = readFunction(currentFile)
+ } catch {
+ case e: java.io.FileNotFoundException =>
+ throw new java.io.FileNotFoundException(
+ e.getMessage + "\n" +
+ "It is possible the underlying files have been updated. " +
+ "You can explicitly invalidate the cache in Spark by " +
+ "running 'REFRESH TABLE tableName' command in SQL or " +
+ "by recreating the Dataset/DataFrame involved."
+ )
+ }
+
hasNext
} else {
currentFile = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala
index 675e755cb2d0..706ec6b9b36c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ListingFileCatalog.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources
+import java.io.FileNotFoundException
+
import scala.collection.mutable
import scala.util.Try
@@ -35,12 +37,16 @@ import org.apache.spark.sql.types.StructType
* @param paths a list of paths to scan
* @param partitionSchema an optional partition schema that will be use to provide types for the
* discovered partitions
+ * @param ignoreFileNotFound if true, return empty file list when encountering a
+ * [[FileNotFoundException]] in file listing. Note that this is a hack
+ * for SPARK-16313. We should get rid of this flag in the future.
*/
class ListingFileCatalog(
sparkSession: SparkSession,
override val paths: Seq[Path],
parameters: Map[String, String],
- partitionSchema: Option[StructType])
+ partitionSchema: Option[StructType],
+ ignoreFileNotFound: Boolean = false)
extends PartitioningAwareFileCatalog(sparkSession, parameters, partitionSchema) {
@volatile private var cachedLeafFiles: mutable.LinkedHashMap[Path, FileStatus] = _
@@ -77,10 +83,12 @@ class ListingFileCatalog(
* List leaf files of given paths. This method will submit a Spark job to do parallel
* listing whenever there is a path having more files than the parallel partition discovery
* discovery threshold.
+ *
+ * This is publicly visible for testing.
*/
- protected[spark] def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = {
+ def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = {
if (paths.length >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) {
- HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sparkSession)
+ HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sparkSession, ignoreFileNotFound)
} else {
// Right now, the number of paths is less than the value of
// parallelPartitionDiscoveryThreshold. So, we will list file statues at the driver.
@@ -96,8 +104,12 @@ class ListingFileCatalog(
logTrace(s"Listing $path on driver")
val childStatuses = {
- // TODO: We need to avoid of using Try at here.
- val stats = Try(fs.listStatus(path)).getOrElse(Array.empty[FileStatus])
+ val stats =
+ try {
+ fs.listStatus(path)
+ } catch {
+ case e: FileNotFoundException if ignoreFileNotFound => Array.empty[FileStatus]
+ }
if (pathFilter != null) stats.filter(f => pathFilter.accept(f.getPath)) else stats
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
index 39c8606fd14b..90711f2b1dde 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
@@ -85,5 +85,10 @@ case class LogicalRelation(
expectedOutputAttributes,
metastoreTableIdentifier).asInstanceOf[this.type]
+ override def refresh(): Unit = relation match {
+ case fs: HadoopFsRelation => fs.refresh()
+ case _ => // Do nothing.
+ }
+
override def simpleString: String = s"Relation[${Utils.truncatedString(output, ",")}] $relation"
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
index 388df7002dc3..c3561099d684 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala
@@ -351,7 +351,7 @@ private[sql] object PartitioningUtils {
}
}
- if (partitionColumns.size == schema.fields.size) {
+ if (partitionColumns.nonEmpty && partitionColumns.size == schema.fields.length) {
throw new AnalysisException(s"Cannot use all columns for partition columns")
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index f56b50a54385..9a0b46c1a4a5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -36,6 +36,7 @@ import org.apache.spark.sql.execution.command.CreateDataSourceTableUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.apache.spark.util.{SerializableConfiguration, Utils}
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/** A container for all the details required when writing to a table. */
@@ -389,7 +390,9 @@ private[sql] class DynamicPartitionWriterContainer(
StructType.fromAttributes(dataColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes)
+ TaskContext.get().taskMemoryManager().pageSizeBytes,
+ SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
while (iterator.hasNext) {
val currentRow = iterator.next()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
index 581eda7e09a3..22fb8163b1c0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVOptions.scala
@@ -115,6 +115,8 @@ private[sql] class CSVOptions(@transient private val parameters: Map[String, Str
val maxMalformedLogPerPartition = getInt("maxMalformedLogPerPartition", 10)
+ val quoteAll = getBool("quoteAll", false)
+
val inputBufferSize = 128
val isCommentSet = this.comment != '\u0000'
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
index b06f12369dd0..7929ebbd90f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVParser.scala
@@ -73,7 +73,7 @@ private[sql] class LineCsvWriter(params: CSVOptions, headers: Seq[String]) exten
writerSettings.setNullValue(params.nullValue)
writerSettings.setEmptyValue(params.nullValue)
writerSettings.setSkipEmptyLines(true)
- writerSettings.setQuoteAllFields(false)
+ writerSettings.setQuoteAllFields(params.quoteAll)
writerSettings.setHeaders(headers: _*)
writerSettings.setQuoteEscapingEnabled(params.escapeQuotes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
index 20399e190f43..0b5a19fe9384 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala
@@ -439,7 +439,8 @@ private[sql] object HadoopFsRelation extends Logging {
def listLeafFilesInParallel(
paths: Seq[Path],
hadoopConf: Configuration,
- sparkSession: SparkSession): mutable.LinkedHashSet[FileStatus] = {
+ sparkSession: SparkSession,
+ ignoreFileNotFound: Boolean): mutable.LinkedHashSet[FileStatus] = {
assert(paths.size >= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold)
logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}")
@@ -460,9 +461,11 @@ private[sql] object HadoopFsRelation extends Logging {
val pathFilter = FileInputFormat.getInputPathFilter(jobConf)
paths.map(new Path(_)).flatMap { path =>
val fs = path.getFileSystem(serializableConfiguration.value)
- // TODO: We need to avoid of using Try at here.
- Try(listLeafFiles(fs, fs.getFileStatus(path), pathFilter))
- .getOrElse(Array.empty[FileStatus])
+ try {
+ listLeafFiles(fs, fs.getFileStatus(path), pathFilter)
+ } catch {
+ case e: java.io.FileNotFoundException if ignoreFileNotFound => Array.empty[FileStatus]
+ }
}
}.map { status =>
val blockLocations = status match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
index 44cfbb9fbd81..24e2c1a5fd2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala
@@ -390,7 +390,11 @@ private[sql] class JDBCRDD(
val sqlText = s"SELECT $columnList FROM $fqTable $myWhereClause"
val stmt = conn.prepareStatement(sqlText,
ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
- val fetchSize = properties.getProperty("fetchsize", "0").toInt
+ val fetchSize = properties.getProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt
+ require(fetchSize >= 0,
+ s"Invalid value `${fetchSize.toString}` for parameter " +
+ s"`${JdbcUtils.JDBC_BATCH_FETCH_SIZE}`. The minimum value is 0. When the value is 0, " +
+ "the JDBC driver ignores the value and does the estimates.")
stmt.setFetchSize(fetchSize)
val rs = stmt.executeQuery()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
index 065c8572b06a..d3e1efc56277 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala
@@ -34,6 +34,10 @@ import org.apache.spark.sql.types._
*/
object JdbcUtils extends Logging {
+ // the property names are case sensitive
+ val JDBC_BATCH_FETCH_SIZE = "fetchsize"
+ val JDBC_BATCH_INSERT_SIZE = "batchsize"
+
/**
* Returns a factory for creating connections to the given JDBC URL.
*
@@ -96,8 +100,9 @@ object JdbcUtils extends Logging {
/**
* Returns a PreparedStatement that inserts a row into table via conn.
*/
- def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = {
- val columns = rddSchema.fields.map(_.name).mkString(",")
+ def insertStatement(conn: Connection, table: String, rddSchema: StructType, dialect: JdbcDialect)
+ : PreparedStatement = {
+ val columns = rddSchema.fields.map(x => dialect.quoteIdentifier(x.name)).mkString(",")
val placeholders = rddSchema.fields.map(_ => "?").mkString(",")
val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)"
conn.prepareStatement(sql)
@@ -154,6 +159,10 @@ object JdbcUtils extends Logging {
nullTypes: Array[Int],
batchSize: Int,
dialect: JdbcDialect): Iterator[Byte] = {
+ require(batchSize >= 1,
+ s"Invalid value `${batchSize.toString}` for parameter " +
+ s"`${JdbcUtils.JDBC_BATCH_INSERT_SIZE}`. The minimum value is 1.")
+
val conn = getConnection()
var committed = false
val supportsTransactions = try {
@@ -169,7 +178,7 @@ object JdbcUtils extends Logging {
if (supportsTransactions) {
conn.setAutoCommit(false) // Everything in the same db transaction.
}
- val stmt = insertStatement(conn, table, rddSchema)
+ val stmt = insertStatement(conn, table, rddSchema, dialect)
try {
var rowCount = 0
while (iterator.hasNext) {
@@ -252,7 +261,7 @@ object JdbcUtils extends Logging {
val sb = new StringBuilder()
val dialect = JdbcDialects.get(url)
df.schema.fields foreach { field =>
- val name = field.name
+ val name = dialect.quoteIdentifier(field.name)
val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition
val nullable = if (field.nullable) "" else "NOT NULL"
sb.append(s", $name $typ $nullable")
@@ -275,7 +284,7 @@ object JdbcUtils extends Logging {
val rddSchema = df.schema
val getConnection: () => Connection = createConnectionFactory(url, properties)
- val batchSize = properties.getProperty("batchsize", "1000").toInt
+ val batchSize = properties.getProperty(JDBC_BATCH_INSERT_SIZE, "1000").toInt
df.foreachPartition { iterator =>
savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index f38bf81e52c0..8cbdaebac179 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -436,7 +436,7 @@ private[sql] class ParquetOutputWriterFactory(
ParquetOutputFormat.setWriteSupportClass(job, classOf[ParquetWriteSupport])
// We want to clear this temporary metadata from saving into Parquet file.
- // This metadata is only useful for detecting optional columns when pushdowning filters.
+ // This metadata is only useful for detecting optional columns when pushing down filters.
val dataSchemaToWrite = StructType.removeMetadata(
StructType.metadataKeyForOptionalField,
dataSchema).asInstanceOf[StructType]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
index 95afdc789f32..70ae829219d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala
@@ -215,10 +215,13 @@ private[sql] object ParquetFilters {
*/
private def getFieldMap(dataType: DataType): Array[(String, DataType)] = dataType match {
case StructType(fields) =>
+ // Here we don't flatten the fields in the nested schema but just look up through
+ // root fields. Currently, accessing to nested fields does not push down filters
+ // and it does not support to create filters for them.
fields.filter { f =>
!f.metadata.contains(StructType.metadataKeyForOptionalField) ||
!f.metadata.getBoolean(StructType.metadataKeyForOptionalField)
- }.map(f => f.name -> f.dataType) ++ fields.flatMap { f => getFieldMap(f.dataType) }
+ }.map(f => f.name -> f.dataType)
case _ => Array.empty[(String, DataType)]
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
index 88f78a7a73bc..3a0b6efdfc91 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProductExec.scala
@@ -49,6 +49,8 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField
null,
1024,
SparkEnv.get.memoryManager.pageSizeBytes,
+ SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD),
false)
val partition = split.asInstanceOf[CartesianPartition]
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
index efb04912d76b..117d6672ee2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala
@@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.execution.datasources.{FileFormat, OutputWriter, PartitioningUtils}
import org.apache.spark.sql.types.{StringType, StructType}
import org.apache.spark.util.SerializableConfiguration
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
object FileStreamSink {
// The name of the subdirectory that is used to store metadata about which files are valid.
@@ -209,7 +210,9 @@ class FileStreamSinkWriter(
StructType.fromAttributes(writeColumns),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes)
+ TaskContext.get().taskMemoryManager().pageSizeBytes,
+ SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
while (iterator.hasNext) {
val currentRow = iterator.next()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
index 11bf3c0bd2e0..72b335a42ed3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.streaming
-import scala.collection.mutable.ArrayBuffer
+import scala.util.Try
import org.apache.hadoop.fs.Path
@@ -46,6 +46,9 @@ class FileStreamSource(
private val metadataLog = new HDFSMetadataLog[Seq[String]](sparkSession, metadataPath)
private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L)
+ /** Maximum number of new files to be considered in each batch */
+ private val maxFilesPerBatch = getMaxFilesPerBatch()
+
private val seenFiles = new OpenHashSet[String]
metadataLog.get(None, Some(maxBatchId)).foreach { case (batchId, files) =>
files.foreach(seenFiles.add)
@@ -58,19 +61,17 @@ class FileStreamSource(
* there is no race here, so the cost of `synchronized` should be rare.
*/
private def fetchMaxOffset(): LongOffset = synchronized {
- val filesPresent = fetchAllFiles()
- val newFiles = new ArrayBuffer[String]()
- filesPresent.foreach { file =>
- if (!seenFiles.contains(file)) {
- logDebug(s"new file: $file")
- newFiles.append(file)
- seenFiles.add(file)
- } else {
- logDebug(s"old file: $file")
- }
+ val newFiles = fetchAllFiles().filter(!seenFiles.contains(_))
+ val batchFiles =
+ if (maxFilesPerBatch.nonEmpty) newFiles.take(maxFilesPerBatch.get) else newFiles
+ batchFiles.foreach { file =>
+ seenFiles.add(file)
+ logDebug(s"New file: $file")
}
-
- if (newFiles.nonEmpty) {
+ logTrace(s"Number of new files = ${newFiles.size})")
+ logTrace(s"Number of files selected for batch = ${batchFiles.size}")
+ logTrace(s"Number of seen files = ${seenFiles.size}")
+ if (batchFiles.nonEmpty) {
maxBatchId += 1
metadataLog.add(maxBatchId, newFiles)
logInfo(s"Max batch id increased to $maxBatchId with ${newFiles.size} new files")
@@ -118,7 +119,7 @@ class FileStreamSource(
val startTime = System.nanoTime
val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath)
val catalog = new ListingFileCatalog(sparkSession, globbedPaths, options, Some(new StructType))
- val files = catalog.allFiles().map(_.getPath.toUri.toString)
+ val files = catalog.allFiles().sortBy(_.getModificationTime).map(_.getPath.toUri.toString)
val endTime = System.nanoTime
val listingTimeMs = (endTime.toDouble - startTime) / 1000000
if (listingTimeMs > 2000) {
@@ -131,6 +132,17 @@ class FileStreamSource(
files
}
+ private def getMaxFilesPerBatch(): Option[Int] = {
+ new CaseInsensitiveMap(options)
+ .get("maxFilesPerTrigger")
+ .map { str =>
+ Try(str.toInt).toOption.filter(_ > 0).getOrElse {
+ throw new IllegalArgumentException(
+ s"Invalid value '$str' for option 'maxFilesPerBatch', must be a positive integer")
+ }
+ }
+ }
+
override def getOffset: Option[Offset] = Some(fetchMaxOffset()).filterNot(_.offset == -1)
override def toString: String = s"FileStreamSource[$qualifiedBasePath]"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
index 14b9b1cb0931..082664aa23f0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ForeachSink.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.execution.streaming
import org.apache.spark.TaskContext
-import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Dataset, Encoder, ForeachWriter}
+import org.apache.spark.sql.catalyst.plans.logical.CatalystSerde
/**
* A [[Sink]] that forwards all data into [[ForeachWriter]] according to the contract defined by
@@ -30,7 +32,41 @@ import org.apache.spark.sql.{DataFrame, Encoder, ForeachWriter}
class ForeachSink[T : Encoder](writer: ForeachWriter[T]) extends Sink with Serializable {
override def addBatch(batchId: Long, data: DataFrame): Unit = {
- data.as[T].foreachPartition { iter =>
+ // TODO: Refine this method when SPARK-16264 is resolved; see comments below.
+
+ // This logic should've been as simple as:
+ // ```
+ // data.as[T].foreachPartition { iter => ... }
+ // ```
+ //
+ // Unfortunately, doing that would just break the incremental planing. The reason is,
+ // `Dataset.foreachPartition()` would further call `Dataset.rdd()`, but `Dataset.rdd()` just
+ // does not support `IncrementalExecution`.
+ //
+ // So as a provisional fix, below we've made a special version of `Dataset` with its `rdd()`
+ // method supporting incremental planning. But in the long run, we should generally make newly
+ // created Datasets use `IncrementalExecution` where necessary (which is SPARK-16264 tries to
+ // resolve).
+
+ val datasetWithIncrementalExecution =
+ new Dataset(data.sparkSession, data.logicalPlan, implicitly[Encoder[T]]) {
+ override lazy val rdd: RDD[T] = {
+ val objectType = exprEnc.deserializer.dataType
+ val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+
+ // was originally: sparkSession.sessionState.executePlan(deserialized) ...
+ val incrementalExecution = new IncrementalExecution(
+ this.sparkSession,
+ deserialized,
+ data.queryExecution.asInstanceOf[IncrementalExecution].outputMode,
+ data.queryExecution.asInstanceOf[IncrementalExecution].checkpointLocation,
+ data.queryExecution.asInstanceOf[IncrementalExecution].currentBatchId)
+ incrementalExecution.toRdd.mapPartitions { rows =>
+ rows.map(_.get(0, objectType))
+ }.asInstanceOf[RDD[T]]
+ }
+ }
+ datasetWithIncrementalExecution.foreachPartition { iter =>
if (writer.open(TaskContext.getPartitionId(), batchId)) {
var isFailed = false
try {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
index 0ce00552bf6c..7367c68d0a0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala
@@ -30,8 +30,8 @@ import org.apache.spark.sql.streaming.OutputMode
class IncrementalExecution private[sql](
sparkSession: SparkSession,
logicalPlan: LogicalPlan,
- outputMode: OutputMode,
- checkpointLocation: String,
+ val outputMode: OutputMode,
+ val checkpointLocation: String,
val currentBatchId: Long)
extends QueryExecution(sparkSession, logicalPlan) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index e8bd489be341..c8782df146df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -2721,6 +2721,14 @@ object functions {
*/
def explode(e: Column): Column = withExpr { Explode(e.expr) }
+ /**
+ * Creates a new row for each element with position in the given array or map column.
+ *
+ * @group collection_funcs
+ * @since 2.1.0
+ */
+ def posexplode(e: Column): Column = withExpr { PosExplode(e.expr) }
+
/**
* Extracts json object from a json string based on json path specified, and returns json string
* of the extracted json object. It will return null if the input json string is invalid.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 5f5cf5c6d30c..01cc13f9df88 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -166,8 +166,8 @@ private[sql] class SessionState(sparkSession: SparkSession) {
def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(sparkSession, plan)
- def invalidateTable(tableName: String): Unit = {
- catalog.invalidateTable(sqlParser.parseTableIdentifier(tableName))
+ def refreshTable(tableName: String): Unit = {
+ catalog.refreshTable(sqlParser.parseTableIdentifier(tableName))
}
def addJar(path: String): Unit = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
index 2d6c3974a833..6baf1b6f16cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala
@@ -89,7 +89,7 @@ private object PostgresDialect extends JdbcDialect {
//
// See: https://jdbc.postgresql.org/documentation/head/query.html#query-with-cursor
//
- if (properties.getOrElse("fetchsize", "0").toInt > 0) {
+ if (properties.getOrElse(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "0").toInt > 0) {
connection.setAutoCommit(false)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
index 248247a257d9..2e606b21bdf3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala
@@ -161,6 +161,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* schema in advance, use the version that specifies the schema to avoid the extra scan.
*
* You can set the following JSON-specific options to deal with non-standard JSON files:
+ * `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
+ * considered in every trigger.
* `primitivesAsString` (default `false`): infers all primitive values as a string type
* `prefersDecimal` (default `false`): infers all floating-point values as a decimal
* type. If the values do not fit in decimal, then it infers them as doubles.
@@ -199,6 +201,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* specify the schema explicitly using [[schema]].
*
* You can set the following CSV-specific options to deal with CSV files:
+ * `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
+ * considered in every trigger.
* `sep` (default `,`): sets the single character as a separator for each
* field and value.
* `encoding` (default `UTF-8`): decodes the CSV files by the given encoding
@@ -251,6 +255,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* Loads a Parquet file stream, returning the result as a [[DataFrame]].
*
* You can set the following Parquet-specific option(s) for reading Parquet files:
+ * `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
+ * considered in every trigger.
* `mergeSchema` (default is the value specified in `spark.sql.parquet.mergeSchema`): sets
* whether we should merge schemas collected from all Parquet part-files. This will override
* `spark.sql.parquet.mergeSchema`.
@@ -276,6 +282,10 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo
* spark.readStream().text("/path/to/directory/")
* }}}
*
+ * You can set the following text-specific options to deal with text files:
+ * `maxFilesPerTrigger` (default: no max limit): sets the maximum number of new files to be
+ * considered in every trigger.
+ *
* @since 2.0.0
*/
@Experimental
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
index d4b0a3cca240..d38e3e58125d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala
@@ -109,7 +109,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) {
/**
* :: Experimental ::
- * Specifies the name of the [[StreamingQuery]] that can be started with `startStream()`.
+ * Specifies the name of the [[StreamingQuery]] that can be started with `start()`.
* This name must be unique among all the currently active queries in the associated SQLContext.
*
* @since 2.0.0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
index 19d1ecf740d0..91f0a1e3446a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQuery.scala
@@ -31,8 +31,8 @@ trait StreamingQuery {
/**
* Returns the name of the query. This name is unique across all active queries. This can be
- * set in the[[org.apache.spark.sql.DataFrameWriter DataFrameWriter]] as
- * `dataframe.write().queryName("query").startStream()`.
+ * set in the [[org.apache.spark.sql.DataStreamWriter DataStreamWriter]] as
+ * `dataframe.writeStream.queryName("query").start()`.
* @since 2.0.0
*/
def name: String
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
index c43de58faa80..3b3cead3a66d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryListener.scala
@@ -35,9 +35,9 @@ abstract class StreamingQueryListener {
/**
* Called when a query is started.
* @note This is called synchronously with
- * [[org.apache.spark.sql.DataFrameWriter `DataFrameWriter.startStream()`]],
+ * [[org.apache.spark.sql.DataStreamWriter `DataStreamWriter.start()`]],
* that is, `onQueryStart` will be called on all listeners before
- * `DataFrameWriter.startStream()` returns the corresponding [[StreamingQuery]]. Please
+ * `DataStreamWriter.start()` returns the corresponding [[StreamingQuery]]. Please
* don't block this method as it will block your query.
* @since 2.0.0
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index a66c83dea00b..a170fae577c1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -122,66 +122,6 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
assert(newCol.expr.asInstanceOf[NamedExpression].metadata.getString("key") === "value")
}
- test("single explode") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
- checkAnswer(
- df.select(explode('intList)),
- Row(1) :: Row(2) :: Row(3) :: Nil)
- }
-
- test("explode and other columns") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
-
- checkAnswer(
- df.select($"a", explode('intList)),
- Row(1, 1) ::
- Row(1, 2) ::
- Row(1, 3) :: Nil)
-
- checkAnswer(
- df.select($"*", explode('intList)),
- Row(1, Seq(1, 2, 3), 1) ::
- Row(1, Seq(1, 2, 3), 2) ::
- Row(1, Seq(1, 2, 3), 3) :: Nil)
- }
-
- test("aliased explode") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
-
- checkAnswer(
- df.select(explode('intList).as('int)).select('int),
- Row(1) :: Row(2) :: Row(3) :: Nil)
-
- checkAnswer(
- df.select(explode('intList).as('int)).select(sum('int)),
- Row(6) :: Nil)
- }
-
- test("explode on map") {
- val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
-
- checkAnswer(
- df.select(explode('map)),
- Row("a", "b"))
- }
-
- test("explode on map with aliases") {
- val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
-
- checkAnswer(
- df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
- Row("a", "b"))
- }
-
- test("self join explode") {
- val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
- val exploded = df.select(explode('intList).as('i))
-
- checkAnswer(
- exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
- Row(3) :: Nil)
- }
-
test("collect on column produced by a binary operator") {
val df = Seq((1, 2, 3)).toDF("a", "b", "c")
checkAnswer(df.select(df("a") + df("b")), Seq(Row(3)))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 73d77651a027..0f6c49e75959 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -352,6 +352,22 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
)
}
+ test("map_keys/map_values function") {
+ val df = Seq(
+ (Map[Int, Int](1 -> 100, 2 -> 200), "x"),
+ (Map[Int, Int](), "y"),
+ (Map[Int, Int](1 -> 100, 2 -> 200, 3 -> 300), "z")
+ ).toDF("a", "b")
+ checkAnswer(
+ df.selectExpr("map_keys(a)"),
+ Seq(Row(Seq(1, 2)), Row(Seq.empty), Row(Seq(1, 2, 3)))
+ )
+ checkAnswer(
+ df.selectExpr("map_values(a)"),
+ Seq(Row(Seq(100, 200)), Row(Seq.empty), Row(Seq(100, 200, 300)))
+ )
+ }
+
test("array contains function") {
val df = Seq(
(Seq[Int](1, 2), "x"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
new file mode 100644
index 000000000000..aedc0a8d6f70
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/GeneratorFunctionSuite.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class GeneratorFunctionSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("stack") {
+ val df = spark.range(1)
+
+ // Empty DataFrame suppress the result generation
+ checkAnswer(spark.emptyDataFrame.selectExpr("stack(1, 1, 2, 3)"), Nil)
+
+ // Rows & columns
+ checkAnswer(df.selectExpr("stack(1, 1, 2, 3)"), Row(1, 2, 3) :: Nil)
+ checkAnswer(df.selectExpr("stack(2, 1, 2, 3)"), Row(1, 2) :: Row(3, null) :: Nil)
+ checkAnswer(df.selectExpr("stack(3, 1, 2, 3)"), Row(1) :: Row(2) :: Row(3) :: Nil)
+ checkAnswer(df.selectExpr("stack(4, 1, 2, 3)"), Row(1) :: Row(2) :: Row(3) :: Row(null) :: Nil)
+
+ // Various column types
+ checkAnswer(df.selectExpr("stack(3, 1, 1.1, 'a', 2, 2.2, 'b', 3, 3.3, 'c')"),
+ Row(1, 1.1, "a") :: Row(2, 2.2, "b") :: Row(3, 3.3, "c") :: Nil)
+
+ // Repeat generation at every input row
+ checkAnswer(spark.range(2).selectExpr("stack(2, 1, 2, 3)"),
+ Row(1, 2) :: Row(3, null) :: Row(1, 2) :: Row(3, null) :: Nil)
+
+ // The first argument must be a positive constant integer.
+ val m = intercept[AnalysisException] {
+ df.selectExpr("stack(1.1, 1, 2, 3)")
+ }.getMessage
+ assert(m.contains("The number of rows must be a positive constant integer."))
+ val m2 = intercept[AnalysisException] {
+ df.selectExpr("stack(-1, 1, 2, 3)")
+ }.getMessage
+ assert(m2.contains("The number of rows must be a positive constant integer."))
+
+ // The data for the same column should have the same type.
+ val m3 = intercept[AnalysisException] {
+ df.selectExpr("stack(2, 1, '2.2')")
+ }.getMessage
+ assert(m3.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (StringType)"))
+
+ // stack on column data
+ val df2 = Seq((2, 1, 2, 3)).toDF("n", "a", "b", "c")
+ checkAnswer(df2.selectExpr("stack(2, a, b, c)"), Row(1, 2) :: Row(3, null) :: Nil)
+
+ val m4 = intercept[AnalysisException] {
+ df2.selectExpr("stack(n, a, b, c)")
+ }.getMessage
+ assert(m4.contains("The number of rows must be a positive constant integer."))
+
+ val df3 = Seq((2, 1, 2.0)).toDF("n", "a", "b")
+ val m5 = intercept[AnalysisException] {
+ df3.selectExpr("stack(2, a, b)")
+ }.getMessage
+ assert(m5.contains("data type mismatch: Argument 1 (IntegerType) != Argument 2 (DoubleType)"))
+
+ }
+
+ test("single explode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ checkAnswer(
+ df.select(explode('intList)),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+ }
+
+ test("single posexplode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ checkAnswer(
+ df.select(posexplode('intList)),
+ Row(0, 1) :: Row(1, 2) :: Row(2, 3) :: Nil)
+ }
+
+ test("explode and other columns") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+
+ checkAnswer(
+ df.select($"a", explode('intList)),
+ Row(1, 1) ::
+ Row(1, 2) ::
+ Row(1, 3) :: Nil)
+
+ checkAnswer(
+ df.select($"*", explode('intList)),
+ Row(1, Seq(1, 2, 3), 1) ::
+ Row(1, Seq(1, 2, 3), 2) ::
+ Row(1, Seq(1, 2, 3), 3) :: Nil)
+ }
+
+ test("aliased explode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select('int),
+ Row(1) :: Row(2) :: Row(3) :: Nil)
+
+ checkAnswer(
+ df.select(explode('intList).as('int)).select(sum('int)),
+ Row(6) :: Nil)
+ }
+
+ test("explode on map") {
+ val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode('map)),
+ Row("a", "b"))
+ }
+
+ test("explode on map with aliases") {
+ val df = Seq((1, Map("a" -> "b"))).toDF("a", "map")
+
+ checkAnswer(
+ df.select(explode('map).as("key1" :: "value1" :: Nil)).select("key1", "value1"),
+ Row("a", "b"))
+ }
+
+ test("self join explode") {
+ val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList")
+ val exploded = df.select(explode('intList).as('i))
+
+ checkAnswer(
+ exploded.join(exploded, exploded("i") === exploded("i")).agg(count("*")),
+ Row(3) :: Nil)
+ }
+
+ test("inline raises exception on array of null type") {
+ val m = intercept[AnalysisException] {
+ spark.range(2).selectExpr("inline(array())")
+ }.getMessage
+ assert(m.contains("data type mismatch"))
+ }
+
+ test("inline with empty table") {
+ checkAnswer(
+ spark.range(0).selectExpr("inline(array(struct(10, 100)))"),
+ Nil)
+ }
+
+ test("inline on literal") {
+ checkAnswer(
+ spark.range(2).selectExpr("inline(array(struct(10, 100), struct(20, 200), struct(30, 300)))"),
+ Row(10, 100) :: Row(20, 200) :: Row(30, 300) ::
+ Row(10, 100) :: Row(20, 200) :: Row(30, 300) :: Nil)
+ }
+
+ test("inline on column") {
+ val df = Seq((1, 2)).toDF("a", "b")
+
+ checkAnswer(
+ df.selectExpr("inline(array(struct(a), struct(a)))"),
+ Row(1) :: Row(1) :: Nil)
+
+ checkAnswer(
+ df.selectExpr("inline(array(struct(a, b), struct(a, b)))"),
+ Row(1, 2) :: Row(1, 2) :: Nil)
+
+ // Spark think [struct, struct] is heterogeneous due to name difference.
+ val m = intercept[AnalysisException] {
+ df.selectExpr("inline(array(struct(a), struct(b)))")
+ }.getMessage
+ assert(m.contains("data type mismatch"))
+
+ checkAnswer(
+ df.selectExpr("inline(array(struct(a), named_struct('a', b)))"),
+ Row(1) :: Row(2) :: Nil)
+
+ // Spark think [struct, struct] is heterogeneous due to name difference.
+ val m2 = intercept[AnalysisException] {
+ df.selectExpr("inline(array(struct(a), struct(2)))")
+ }.getMessage
+ assert(m2.contains("data type mismatch"))
+
+ checkAnswer(
+ df.selectExpr("inline(array(struct(a), named_struct('a', 2)))"),
+ Row(1) :: Row(2) :: Nil)
+
+ checkAnswer(
+ df.selectExpr("struct(a)").selectExpr("inline(array(*))"),
+ Row(1) :: Nil)
+
+ checkAnswer(
+ df.selectExpr("array(struct(a), named_struct('a', b))").selectExpr("inline(*)"),
+ Row(1) :: Row(2) :: Nil)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala
new file mode 100644
index 000000000000..3f8cc8164d04
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MetadataCacheSuite.scala
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.io.File
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.test.SharedSQLContext
+
+/**
+ * Test suite to handle metadata cache related.
+ */
+class MetadataCacheSuite extends QueryTest with SharedSQLContext {
+
+ /** Removes one data file in the given directory. */
+ private def deleteOneFileInDirectory(dir: File): Unit = {
+ assert(dir.isDirectory)
+ val oneFile = dir.listFiles().find { file =>
+ !file.getName.startsWith("_") && !file.getName.startsWith(".")
+ }
+ assert(oneFile.isDefined)
+ oneFile.foreach(_.delete())
+ }
+
+ test("SPARK-16336 Suggest doing table refresh when encountering FileNotFoundException") {
+ withTempPath { (location: File) =>
+ // Create a Parquet directory
+ spark.range(start = 0, end = 100, step = 1, numPartitions = 3)
+ .write.parquet(location.getAbsolutePath)
+
+ // Read the directory in
+ val df = spark.read.parquet(location.getAbsolutePath)
+ assert(df.count() == 100)
+
+ // Delete a file
+ deleteOneFileInDirectory(location)
+
+ // Read it again and now we should see a FileNotFoundException
+ val e = intercept[SparkException] {
+ df.count()
+ }
+ assert(e.getMessage.contains("FileNotFoundException"))
+ assert(e.getMessage.contains("REFRESH"))
+ }
+ }
+
+ test("SPARK-16337 temporary view refresh") {
+ withTempTable("view_refresh") { withTempPath { (location: File) =>
+ // Create a Parquet directory
+ spark.range(start = 0, end = 100, step = 1, numPartitions = 3)
+ .write.parquet(location.getAbsolutePath)
+
+ // Read the directory in
+ spark.read.parquet(location.getAbsolutePath).createOrReplaceTempView("view_refresh")
+ assert(sql("select count(*) from view_refresh").first().getLong(0) == 100)
+
+ // Delete a file
+ deleteOneFileInDirectory(location)
+
+ // Read it again and now we should see a FileNotFoundException
+ val e = intercept[SparkException] {
+ sql("select count(*) from view_refresh").first()
+ }
+ assert(e.getMessage.contains("FileNotFoundException"))
+ assert(e.getMessage.contains("REFRESH"))
+
+ // Refresh and we should be able to read it again.
+ spark.catalog.refreshTable("view_refresh")
+ val newCount = sql("select count(*) from view_refresh").first().getLong(0)
+ assert(newCount > 0 && newCount < 100)
+ }}
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index b1dbf21d4b80..dca9e5e503c7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -21,10 +21,8 @@ import java.math.MathContext
import java.sql.Timestamp
import org.apache.spark.AccumulatorSuite
-import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
-import org.apache.spark.sql.catalyst.catalog.{CatalogTestUtils, ExternalCatalog, SessionCatalog}
-import org.apache.spark.sql.catalyst.expressions.{ExpressionDescription, SortOrder}
+import org.apache.spark.sql.catalyst.expressions.SortOrder
import org.apache.spark.sql.catalyst.plans.logical.Aggregate
import org.apache.spark.sql.catalyst.util.StringUtils
import org.apache.spark.sql.execution.aggregate
@@ -2117,6 +2115,37 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
}
}
+ test("Star Expansion - table with zero column") {
+ withTempTable("temp_table_no_cols") {
+ val rddNoCols = sparkContext.parallelize(1 to 10).map(_ => Row.empty)
+ val dfNoCols = spark.createDataFrame(rddNoCols, StructType(Seq.empty))
+ dfNoCols.createTempView("temp_table_no_cols")
+
+ // ResolvedStar
+ checkAnswer(
+ dfNoCols,
+ dfNoCols.select(dfNoCols.col("*")))
+
+ // UnresolvedStar
+ checkAnswer(
+ dfNoCols,
+ sql("SELECT * FROM temp_table_no_cols"))
+ checkAnswer(
+ dfNoCols,
+ dfNoCols.select($"*"))
+
+ var e = intercept[AnalysisException] {
+ sql("SELECT a.* FROM temp_table_no_cols a")
+ }.getMessage
+ assert(e.contains("cannot resolve 'a.*' give input columns ''"))
+
+ e = intercept[AnalysisException] {
+ dfNoCols.select($"b.*")
+ }.getMessage
+ assert(e.contains("cannot resolve 'b.*' give input columns ''"))
+ }
+ }
+
test("Common subexpression elimination") {
// TODO: support subexpression elimination in whole stage codegen
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 1de2d9b5adab..cbe480b52564 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -48,6 +48,20 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row("a||b"))
}
+ test("string elt") {
+ val df = Seq[(String, String, String, Int)](("hello", "world", null, 15))
+ .toDF("a", "b", "c", "d")
+
+ checkAnswer(
+ df.selectExpr("elt(0, a, b, c)", "elt(1, a, b, c)", "elt(4, a, b, c)"),
+ Row(null, "hello", null))
+
+ // check implicit type cast
+ checkAnswer(
+ df.selectExpr("elt(4, a, b, c, d)", "elt('2', a, b, c, d)"),
+ Row("15", "world"))
+ }
+
test("string Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
checkAnswer(df.select(levenshtein($"l", $"r")), Seq(Row(3), Row(1)))
@@ -212,6 +226,21 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
Row("???hi", "hi???", "h", "h"))
}
+ test("string parse_url function") {
+ val df = Seq[String](("http://userinfo@spark.apache.org/path?query=1#Ref"))
+ .toDF("url")
+
+ checkAnswer(
+ df.selectExpr(
+ "parse_url(url, 'HOST')", "parse_url(url, 'PATH')",
+ "parse_url(url, 'QUERY')", "parse_url(url, 'REF')",
+ "parse_url(url, 'PROTOCOL')", "parse_url(url, 'FILE')",
+ "parse_url(url, 'AUTHORITY')", "parse_url(url, 'USERINFO')",
+ "parse_url(url, 'QUERY', 'query')"),
+ Row("spark.apache.org", "/path", "query=1", "Ref",
+ "http", "/path?query=1", "userinfo@spark.apache.org", "userinfo", "1"))
+ }
+
test("string repeat function") {
val df = Seq(("hi", 2)).toDF("a", "b")
@@ -333,4 +362,24 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext {
df2.filter("b>0").selectExpr("format_number(a, b)"),
Row("5.0000") :: Row("4.000") :: Row("4.000") :: Row("4.000") :: Row("3.00") :: Nil)
}
+
+ test("string sentences function") {
+ val df = Seq(("Hi there! The price was $1,234.56.... But, not now.", "en", "US"))
+ .toDF("str", "language", "country")
+
+ checkAnswer(
+ df.selectExpr("sentences(str, language, country)"),
+ Row(Seq(Seq("Hi", "there"), Seq("The", "price", "was"), Seq("But", "not", "now"))))
+
+ // Type coercion
+ checkAnswer(
+ df.selectExpr("sentences(null)", "sentences(10)", "sentences(3.14)"),
+ Row(null, Seq(Seq("10")), Seq(Seq("3.14"))))
+
+ // Argument number exception
+ val m = intercept[AnalysisException] {
+ df.selectExpr("sentences()")
+ }.getMessage
+ assert(m.contains("Invalid number of arguments for function sentences"))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala
new file mode 100644
index 000000000000..532d48cc265a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/XmlFunctionsSuite.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+/**
+ * End-to-end tests for XML expressions.
+ */
+class XmlFunctionsSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
+
+ test("xpath_boolean") {
+ val df = Seq("b" -> "a/b").toDF("xml", "path")
+ checkAnswer(df.selectExpr("xpath_boolean(xml, path)"), Row(true))
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index 03d4be8ee528..3d869c77e960 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
* Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data.
@@ -123,7 +124,8 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
metricsSystem = null))
val sorter = new UnsafeKVExternalSorter(
- keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager, pageSize)
+ keySchema, valueSchema, SparkEnv.get.blockManager, SparkEnv.get.serializerManager,
+ pageSize, UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD)
// Insert the keys and values into the sorter
inputData.foreach { case (k, v) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
index 0ee8d179d79e..7d1f1d1e62fc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala
@@ -1314,6 +1314,29 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach {
}
}
+ test("create temporary view with mismatched schema") {
+ withTable("tab1") {
+ spark.range(10).write.saveAsTable("tab1")
+ withView("view1") {
+ val e = intercept[AnalysisException] {
+ sql("CREATE TEMPORARY VIEW view1 (col1, col3) AS SELECT * FROM tab1")
+ }.getMessage
+ assert(e.contains("the SELECT clause (num: `1`) does not match")
+ && e.contains("CREATE VIEW (num: `2`)"))
+ }
+ }
+ }
+
+ test("create temporary view with specified schema") {
+ withView("view1") {
+ sql("CREATE TEMPORARY VIEW view1 (col1, col2) AS SELECT 1, 2")
+ checkAnswer(
+ sql("SELECT * FROM view1"),
+ Row(1, 2) :: Nil
+ )
+ }
+ }
+
test("truncate table - external table, temporary table, view (not allowed)") {
import testImplicits._
val path = Utils.createTempDir().getAbsolutePath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
index f170065132ac..311f1fa8d2af 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala
@@ -366,6 +366,32 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils {
}
}
+ test("save csv with quoteAll enabled") {
+ withTempDir { dir =>
+ val csvDir = new File(dir, "csv").getCanonicalPath
+
+ val data = Seq(("test \"quote\"", 123, "it \"works\"!", "\"very\" well"))
+ val df = spark.createDataFrame(data)
+
+ // escapeQuotes should be true by default
+ df.coalesce(1).write
+ .format("csv")
+ .option("quote", "\"")
+ .option("escape", "\"")
+ .option("quoteAll", "true")
+ .save(csvDir)
+
+ val results = spark.read
+ .format("text")
+ .load(csvDir)
+ .collect()
+
+ val expected = "\"test \"\"quote\"\"\",\"123\",\"it \"\"works\"\"!\",\"\"\"very\"\" well\""
+
+ assert(results.toSeq.map(_.toSeq) === Seq(Seq(expected)))
+ }
+ }
+
test("save csv with quote escaping enabled") {
withTempDir { dir =>
val csvDir = new File(dir, "csv").getCanonicalPath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 9f35c02d4876..6c72019702c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -847,7 +847,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
sql(
s"""
- |CREATE TEMPORARY TABLE jsonTableSQL
+ |CREATE TEMPORARY VIEW jsonTableSQL
|USING org.apache.spark.sql.json
|OPTIONS (
| path '$path'
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 45fd6a5d80de..2a89773cf534 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -545,4 +545,18 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex
}
}
}
+
+ test("SPARK-16371 Do not push down filters when inner name and outer name are the same") {
+ withParquetDataFrame((1 to 4).map(i => Tuple1(Tuple1(i)))) { implicit df =>
+ // Here the schema becomes as below:
+ //
+ // root
+ // |-- _1: struct (nullable = true)
+ // | |-- _1: integer (nullable = true)
+ //
+ // The inner column name, `_1` and outer column name `_1` are the same.
+ // Obviously this should not push down filters because the outer column is struct.
+ assert(df.filter("_1 IS NOT NULL").count() === 4)
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
index 6ff597c16bb2..7928b8e8775c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/ForeachSinkSuite.scala
@@ -24,7 +24,7 @@ import scala.collection.mutable
import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.ForeachWriter
-import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.streaming.{OutputMode, StreamTest}
import org.apache.spark.sql.test.SharedSQLContext
class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAfter {
@@ -35,35 +35,103 @@ class ForeachSinkSuite extends StreamTest with SharedSQLContext with BeforeAndAf
sqlContext.streams.active.foreach(_.stop())
}
- test("foreach") {
+ test("foreach() with `append` output mode") {
withTempDir { checkpointDir =>
val input = MemoryStream[Int]
val query = input.toDS().repartition(2).writeStream
.option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .outputMode(OutputMode.Append)
.foreach(new TestForeachWriter())
.start()
+
+ // -- batch 0 ---------------------------------------
input.addData(1, 2, 3, 4)
query.processAllAvailable()
- val expectedEventsForPartition0 = Seq(
+ var expectedEventsForPartition0 = Seq(
ForeachSinkSuite.Open(partition = 0, version = 0),
ForeachSinkSuite.Process(value = 1),
ForeachSinkSuite.Process(value = 3),
ForeachSinkSuite.Close(None)
)
- val expectedEventsForPartition1 = Seq(
+ var expectedEventsForPartition1 = Seq(
ForeachSinkSuite.Open(partition = 1, version = 0),
ForeachSinkSuite.Process(value = 2),
ForeachSinkSuite.Process(value = 4),
ForeachSinkSuite.Close(None)
)
- val allEvents = ForeachSinkSuite.allEvents()
+ var allEvents = ForeachSinkSuite.allEvents()
+ assert(allEvents.size === 2)
+ assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))
+
+ ForeachSinkSuite.clear()
+
+ // -- batch 1 ---------------------------------------
+ input.addData(5, 6, 7, 8)
+ query.processAllAvailable()
+
+ expectedEventsForPartition0 = Seq(
+ ForeachSinkSuite.Open(partition = 0, version = 1),
+ ForeachSinkSuite.Process(value = 5),
+ ForeachSinkSuite.Process(value = 7),
+ ForeachSinkSuite.Close(None)
+ )
+ expectedEventsForPartition1 = Seq(
+ ForeachSinkSuite.Open(partition = 1, version = 1),
+ ForeachSinkSuite.Process(value = 6),
+ ForeachSinkSuite.Process(value = 8),
+ ForeachSinkSuite.Close(None)
+ )
+
+ allEvents = ForeachSinkSuite.allEvents()
assert(allEvents.size === 2)
- assert {
- allEvents === Seq(expectedEventsForPartition0, expectedEventsForPartition1) ||
- allEvents === Seq(expectedEventsForPartition1, expectedEventsForPartition0)
- }
+ assert(allEvents.toSet === Set(expectedEventsForPartition0, expectedEventsForPartition1))
+
+ query.stop()
+ }
+ }
+
+ test("foreach() with `complete` output mode") {
+ withTempDir { checkpointDir =>
+ val input = MemoryStream[Int]
+
+ val query = input.toDS()
+ .groupBy().count().as[Long].map(_.toInt)
+ .writeStream
+ .option("checkpointLocation", checkpointDir.getCanonicalPath)
+ .outputMode(OutputMode.Complete)
+ .foreach(new TestForeachWriter())
+ .start()
+
+ // -- batch 0 ---------------------------------------
+ input.addData(1, 2, 3, 4)
+ query.processAllAvailable()
+
+ var allEvents = ForeachSinkSuite.allEvents()
+ assert(allEvents.size === 1)
+ var expectedEvents = Seq(
+ ForeachSinkSuite.Open(partition = 0, version = 0),
+ ForeachSinkSuite.Process(value = 4),
+ ForeachSinkSuite.Close(None)
+ )
+ assert(allEvents === Seq(expectedEvents))
+
+ ForeachSinkSuite.clear()
+
+ // -- batch 1 ---------------------------------------
+ input.addData(5, 6, 7, 8)
+ query.processAllAvailable()
+
+ allEvents = ForeachSinkSuite.allEvents()
+ assert(allEvents.size === 1)
+ expectedEvents = Seq(
+ ForeachSinkSuite.Open(partition = 0, version = 1),
+ ForeachSinkSuite.Process(value = 8),
+ ForeachSinkSuite.Close(None)
+ )
+ assert(allEvents === Seq(expectedEvents))
+
query.stop()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index fd6671a39b6e..228e4250f3c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -24,12 +24,13 @@ import java.util.{Calendar, GregorianCalendar, Properties}
import org.h2.jdbc.JdbcSQLException
import org.scalatest.{BeforeAndAfter, PrivateMethodTester}
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.execution.DataSourceScanExec
import org.apache.spark.sql.execution.command.ExplainCommand
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD
+import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.sources._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -83,7 +84,7 @@ class JDBCSuite extends SparkFunSuite
|CREATE TEMPORARY TABLE fetchtwo
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass',
- | fetchSize '2')
+ | ${JdbcUtils.JDBC_BATCH_FETCH_SIZE} '2')
""".stripMargin.replaceAll("\n", " "))
sql(
@@ -348,38 +349,49 @@ class JDBCSuite extends SparkFunSuite
test("Basic API") {
assert(spark.read.jdbc(
- urlWithUserAndPass, "TEST.PEOPLE", new Properties).collect().length === 3)
+ urlWithUserAndPass, "TEST.PEOPLE", new Properties()).collect().length === 3)
+ }
+
+ test("Basic API with illegal FetchSize") {
+ val properties = new Properties()
+ properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, "-1")
+ val e = intercept[SparkException] {
+ spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", properties).collect()
+ }.getMessage
+ assert(e.contains("Invalid value `-1` for parameter `fetchsize`"))
}
test("Basic API with FetchSize") {
- val properties = new Properties
- properties.setProperty("fetchSize", "2")
- assert(spark.read.jdbc(
- urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3)
+ (0 to 4).foreach { size =>
+ val properties = new Properties()
+ properties.setProperty(JdbcUtils.JDBC_BATCH_FETCH_SIZE, size.toString)
+ assert(spark.read.jdbc(
+ urlWithUserAndPass, "TEST.PEOPLE", properties).collect().length === 3)
+ }
}
test("Partitioning via JDBCPartitioningInfo API") {
assert(
- spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties)
+ spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", "THEID", 0, 4, 3, new Properties())
.collect().length === 3)
}
test("Partitioning via list-of-where-clauses API") {
val parts = Array[String]("THEID < 2", "THEID >= 2")
- assert(spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties)
+ assert(spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", parts, new Properties())
.collect().length === 3)
}
test("Partitioning on column that might have null values.") {
assert(
- spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties)
+ spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "theid", 0, 4, 3, new Properties())
.collect().length === 4)
assert(
- spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties)
+ spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", "THEID", 0, 4, 3, new Properties())
.collect().length === 4)
// partitioning on a nullable quoted column
assert(
- spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties)
+ spark.read.jdbc(urlWithUserAndPass, "TEST.EMP", """"Dept"""", 0, 4, 3, new Properties())
.collect().length === 4)
}
@@ -391,7 +403,7 @@ class JDBCSuite extends SparkFunSuite
lowerBound = 0,
upperBound = 4,
numPartitions = 0,
- connectionProperties = new Properties
+ connectionProperties = new Properties()
)
assert(res.count() === 8)
}
@@ -404,7 +416,7 @@ class JDBCSuite extends SparkFunSuite
lowerBound = 1,
upperBound = 5,
numPartitions = 10,
- connectionProperties = new Properties
+ connectionProperties = new Properties()
)
assert(res.count() === 8)
}
@@ -417,7 +429,7 @@ class JDBCSuite extends SparkFunSuite
lowerBound = 5,
upperBound = 5,
numPartitions = 4,
- connectionProperties = new Properties
+ connectionProperties = new Properties()
)
assert(res.count() === 8)
}
@@ -431,7 +443,7 @@ class JDBCSuite extends SparkFunSuite
lowerBound = 5,
upperBound = 1,
numPartitions = 3,
- connectionProperties = new Properties
+ connectionProperties = new Properties()
)
}.getMessage
assert(e.contains("Operation not allowed: the lower bound of partitioning column " +
@@ -495,8 +507,8 @@ class JDBCSuite extends SparkFunSuite
test("test DATE types") {
val rows = spark.read.jdbc(
- urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
- val cachedRows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ urlWithUserAndPass, "TEST.TIMETYPES", new Properties()).collect()
+ val cachedRows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties())
.cache().collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
assert(rows(1).getAs[java.sql.Date](1) === null)
@@ -504,8 +516,8 @@ class JDBCSuite extends SparkFunSuite
}
test("test DATE types in cache") {
- val rows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties).collect()
- spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val rows = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties()).collect()
+ spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties())
.cache().createOrReplaceTempView("mycached_date")
val cachedRows = sql("select * from mycached_date").collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
@@ -514,7 +526,7 @@ class JDBCSuite extends SparkFunSuite
test("test types for null value") {
val rows = spark.read.jdbc(
- urlWithUserAndPass, "TEST.NULLTYPES", new Properties).collect()
+ urlWithUserAndPass, "TEST.NULLTYPES", new Properties()).collect()
assert((0 to 14).forall(i => rows(0).isNullAt(i)))
}
@@ -560,7 +572,7 @@ class JDBCSuite extends SparkFunSuite
test("Remap types via JdbcDialects") {
JdbcDialects.registerDialect(testH2Dialect)
- val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
+ val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties())
assert(df.schema.filter(_.dataType != org.apache.spark.sql.types.StringType).isEmpty)
val rows = df.collect()
assert(rows(0).get(0).isInstanceOf[String])
@@ -694,7 +706,7 @@ class JDBCSuite extends SparkFunSuite
// Regression test for bug SPARK-11788
val timestamp = java.sql.Timestamp.valueOf("2001-02-20 11:22:33.543543");
val date = java.sql.Date.valueOf("1995-01-01")
- val jdbcDf = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties)
+ val jdbcDf = spark.read.jdbc(urlWithUserAndPass, "TEST.TIMETYPES", new Properties())
val rows = jdbcDf.where($"B" > date && $"C" > timestamp).collect()
assert(rows(0).getAs[java.sql.Date](1) === java.sql.Date.valueOf("1996-01-01"))
assert(rows(0).getAs[java.sql.Timestamp](2)
@@ -714,7 +726,7 @@ class JDBCSuite extends SparkFunSuite
}
test("test credentials in the connection url are not in the plan output") {
- val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties)
+ val df = spark.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties())
val explain = ExplainCommand(df.queryExecution.logical, extended = true)
spark.sessionState.executePlan(explain).executedPlan.executeCollect().foreach {
r => assert(!List("testPass", "testUser").exists(r.toString.contains))
@@ -746,10 +758,16 @@ class JDBCSuite extends SparkFunSuite
urlWithUserAndPass,
"TEST.PEOPLE",
predicates = Array[String](jdbcPartitionWhereClause),
- new Properties)
+ new Properties())
df.createOrReplaceTempView("tempFrame")
assertEmptyQuery(s"SELECT * FROM tempFrame where $FALSE2")
}
}
+
+ test("SPARK-16387: Reserved SQL words are not escaped by JDBC writer") {
+ val df = spark.createDataset(Seq("a", "b", "c")).toDF("order")
+ val schema = JdbcUtils.schemaString(df, "jdbc:mysql://localhost:3306/temp")
+ assert(schema.contains("`order` TEXT"))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 48fa5f98223b..2c6449fa6870 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -22,7 +22,9 @@ import java.util.Properties
import org.scalatest.BeforeAndAfter
+import org.apache.spark.SparkException
import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
@@ -57,14 +59,14 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
sql(
s"""
- |CREATE TEMPORARY TABLE PEOPLE
+ |CREATE OR REPLACE TEMPORARY VIEW PEOPLE
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
sql(
s"""
- |CREATE TEMPORARY TABLE PEOPLE1
+ |CREATE OR REPLACE TEMPORARY VIEW PEOPLE1
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE1', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
@@ -90,10 +92,34 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
test("Basic CREATE") {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
- df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties)
- assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).count)
+ df.write.jdbc(url, "TEST.BASICCREATETEST", new Properties())
+ assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count())
assert(
- 2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties).collect()(0).length)
+ 2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).collect()(0).length)
+ }
+
+ test("Basic CREATE with illegal batchsize") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ (-1 to 0).foreach { size =>
+ val properties = new Properties()
+ properties.setProperty(JdbcUtils.JDBC_BATCH_INSERT_SIZE, size.toString)
+ val e = intercept[SparkException] {
+ df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties)
+ }.getMessage
+ assert(e.contains(s"Invalid value `$size` for parameter `batchsize`"))
+ }
+ }
+
+ test("Basic CREATE with batchsize") {
+ val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
+
+ (1 to 3).foreach { size =>
+ val properties = new Properties()
+ properties.setProperty(JdbcUtils.JDBC_BATCH_INSERT_SIZE, size.toString)
+ df.write.mode(SaveMode.Overwrite).jdbc(url, "TEST.BASICCREATETEST", properties)
+ assert(2 === spark.read.jdbc(url, "TEST.BASICCREATETEST", new Properties()).count())
+ }
}
test("CREATE with overwrite") {
@@ -101,11 +127,11 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
df.write.jdbc(url1, "TEST.DROPTEST", properties)
- assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count())
assert(3 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.DROPTEST", properties)
- assert(1 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count)
+ assert(1 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.DROPTEST", properties).collect()(0).length)
}
@@ -113,10 +139,10 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
val df2 = spark.createDataFrame(sparkContext.parallelize(arr1x2), schema2)
- df.write.jdbc(url, "TEST.APPENDTEST", new Properties)
- df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties)
- assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties).count)
- assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties).collect()(0).length)
+ df.write.jdbc(url, "TEST.APPENDTEST", new Properties())
+ df2.write.mode(SaveMode.Append).jdbc(url, "TEST.APPENDTEST", new Properties())
+ assert(3 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).count())
+ assert(2 === spark.read.jdbc(url, "TEST.APPENDTEST", new Properties()).collect()(0).length)
}
test("CREATE then INSERT to truncate") {
@@ -125,7 +151,7 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
df.write.jdbc(url1, "TEST.TRUNCATETEST", properties)
df2.write.mode(SaveMode.Overwrite).jdbc(url1, "TEST.TRUNCATETEST", properties)
- assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count)
+ assert(1 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.TRUNCATETEST", properties).collect()(0).length)
}
@@ -133,22 +159,22 @@ class JDBCWriteSuite extends SharedSQLContext with BeforeAndAfter {
val df = spark.createDataFrame(sparkContext.parallelize(arr2x2), schema2)
val df2 = spark.createDataFrame(sparkContext.parallelize(arr2x3), schema3)
- df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties)
+ df.write.jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
intercept[org.apache.spark.SparkException] {
- df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties)
+ df2.write.mode(SaveMode.Append).jdbc(url, "TEST.INCOMPATIBLETEST", new Properties())
}
}
test("INSERT to JDBC Datasource") {
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
- assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
test("INSERT to JDBC Datasource with overwrite") {
sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
- assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
+ assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).count())
assert(2 === spark.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 45e737f5ed04..be56c964a18f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -139,7 +139,7 @@ class FilteredScanSuite extends DataSourceTest with SharedSQLContext with Predic
super.beforeAll()
sql(
"""
- |CREATE TEMPORARY TABLE oneToTenFiltered
+ |CREATE TEMPORARY VIEW oneToTenFiltered
|USING org.apache.spark.sql.sources.FilteredScanSource
|OPTIONS (
| from '1',
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index 207f89d3eaea..fb6123d1cc4b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -62,7 +62,7 @@ class PrunedScanSuite extends DataSourceTest with SharedSQLContext {
super.beforeAll()
sql(
"""
- |CREATE TEMPORARY TABLE oneToTenPruned
+ |CREATE TEMPORARY VIEW oneToTenPruned
|USING org.apache.spark.sql.sources.PrunedScanSource
|OPTIONS (
| from '1',
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index 93116d84ced7..0fa0706a10b1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -137,7 +137,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext {
super.beforeAll()
sql(
"""
- |CREATE TEMPORARY TABLE oneToTen
+ |CREATE TEMPORARY VIEW oneToTen
|USING org.apache.spark.sql.sources.SimpleScanSource
|OPTIONS (
| From '1',
@@ -149,7 +149,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext {
sql(
"""
- |CREATE TEMPORARY TABLE tableWithSchema (
+ |CREATE TEMPORARY VIEW tableWithSchema (
|`string$%Field` stRIng,
|binaryField binary,
|`booleanField` boolean,
@@ -332,7 +332,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext {
test("defaultSource") {
sql(
"""
- |CREATE TEMPORARY TABLE oneToTenDef
+ |CREATE TEMPORARY VIEW oneToTenDef
|USING org.apache.spark.sql.sources
|OPTIONS (
| from '1',
@@ -351,7 +351,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext {
val schemaNotAllowed = intercept[Exception] {
sql(
"""
- |CREATE TEMPORARY TABLE relationProvierWithSchema (i int)
+ |CREATE TEMPORARY VIEW relationProvierWithSchema (i int)
|USING org.apache.spark.sql.sources.SimpleScanSource
|OPTIONS (
| From '1',
@@ -364,7 +364,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext {
val schemaNeeded = intercept[Exception] {
sql(
"""
- |CREATE TEMPORARY TABLE schemaRelationProvierWithoutSchema
+ |CREATE TEMPORARY VIEW schemaRelationProvierWithoutSchema
|USING org.apache.spark.sql.sources.AllDataTypesScanSource
|OPTIONS (
| From '1',
@@ -378,7 +378,7 @@ class TableScanSuite extends DataSourceTest with SharedSQLContext {
test("SPARK-5196 schema field with comment") {
sql(
"""
- |CREATE TEMPORARY TABLE student(name string comment "SN", age int comment "SA", grade int)
+ |CREATE TEMPORARY VIEW student(name string comment "SN", age int comment "SA", grade int)
|USING org.apache.spark.sql.sources.AllDataTypesScanSource
|OPTIONS (
| from '1',
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
index 0eade71d1ebc..29ce578bcde3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala
@@ -179,18 +179,24 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") { testError() }
}
- test("FileStreamSource schema: path doesn't exist, no schema") {
- val e = intercept[IllegalArgumentException] {
- createFileStreamSourceAndGetSchema(format = None, path = Some("/a/b/c"), schema = None)
+ test("FileStreamSource schema: path doesn't exist (without schema) should throw exception") {
+ withTempDir { dir =>
+ intercept[AnalysisException] {
+ val userSchema = new StructType().add(new StructField("value", IntegerType))
+ val schema = createFileStreamSourceAndGetSchema(
+ format = None, path = Some(new File(dir, "1").getAbsolutePath), schema = None)
+ }
}
- assert(e.getMessage.toLowerCase.contains("schema")) // reason is schema absence, not the path
}
- test("FileStreamSource schema: path doesn't exist, with schema") {
- val userSchema = new StructType().add(new StructField("value", IntegerType))
- val schema = createFileStreamSourceAndGetSchema(
- format = None, path = Some("/a/b/c"), schema = Some(userSchema))
- assert(schema === userSchema)
+ test("FileStreamSource schema: path doesn't exist (with schema) should throw exception") {
+ withTempDir { dir =>
+ intercept[AnalysisException] {
+ val userSchema = new StructType().add(new StructField("value", IntegerType))
+ val schema = createFileStreamSourceAndGetSchema(
+ format = None, path = Some(new File(dir, "1").getAbsolutePath), schema = Some(userSchema))
+ }
+ }
}
@@ -225,20 +231,6 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
// =============== Parquet file stream schema tests ================
- test("FileStreamSource schema: parquet, no existing files, no schema") {
- withTempDir { src =>
- withSQLConf(SQLConf.STREAMING_SCHEMA_INFERENCE.key -> "true") {
- val e = intercept[AnalysisException] {
- createFileStreamSourceAndGetSchema(
- format = Some("parquet"),
- path = Some(new File(src, "1").getCanonicalPath),
- schema = None)
- }
- assert("Unable to infer schema. It must be specified manually.;" === e.getMessage)
- }
- }
- }
-
test("FileStreamSource schema: parquet, existing files, no schema") {
withTempDir { src =>
Seq("a", "b", "c").toDS().as("userColumn").toDF().write
@@ -593,6 +585,82 @@ class FileStreamSourceSuite extends FileStreamSourceTest {
}
}
+ test("max files per trigger") {
+ withTempDir { case src =>
+ var lastFileModTime: Option[Long] = None
+
+ /** Create a text file with a single data item */
+ def createFile(data: Int): File = {
+ val file = stringToFile(new File(src, s"$data.txt"), data.toString)
+ if (lastFileModTime.nonEmpty) file.setLastModified(lastFileModTime.get + 1000)
+ lastFileModTime = Some(file.lastModified)
+ file
+ }
+
+ createFile(1)
+ createFile(2)
+ createFile(3)
+
+ // Set up a query to read text files 2 at a time
+ val df = spark
+ .readStream
+ .option("maxFilesPerTrigger", 2)
+ .text(src.getCanonicalPath)
+ val q = df
+ .writeStream
+ .format("memory")
+ .queryName("file_data")
+ .start()
+ .asInstanceOf[StreamExecution]
+ q.processAllAvailable()
+ val memorySink = q.sink.asInstanceOf[MemorySink]
+ val fileSource = q.logicalPlan.collect {
+ case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] =>
+ source.asInstanceOf[FileStreamSource]
+ }.head
+
+ /** Check the data read in the last batch */
+ def checkLastBatchData(data: Int*): Unit = {
+ val schema = StructType(Seq(StructField("value", StringType)))
+ val df = spark.createDataFrame(
+ spark.sparkContext.makeRDD(memorySink.latestBatchData), schema)
+ checkAnswer(df, data.map(_.toString).toDF("value"))
+ }
+
+ /** Check how many batches have executed since the last time this check was made */
+ var lastBatchId = -1L
+ def checkNumBatchesSinceLastCheck(numBatches: Int): Unit = {
+ require(lastBatchId >= 0)
+ assert(memorySink.latestBatchId.get === lastBatchId + numBatches)
+ lastBatchId = memorySink.latestBatchId.get
+ }
+
+ checkLastBatchData(3) // (1 and 2) should be in batch 1, (3) should be in batch 2 (last)
+ lastBatchId = memorySink.latestBatchId.get
+
+ fileSource.withBatchingLocked {
+ createFile(4)
+ createFile(5) // 4 and 5 should be in a batch
+ createFile(6)
+ createFile(7) // 6 and 7 should be in the last batch
+ }
+ q.processAllAvailable()
+ checkLastBatchData(6, 7)
+ checkNumBatchesSinceLastCheck(2)
+
+ fileSource.withBatchingLocked {
+ createFile(8)
+ createFile(9) // 8 and 9 should be in a batch
+ createFile(10)
+ createFile(11) // 10 and 11 should be in a batch
+ createFile(12) // 12 should be in the last batch
+ }
+ q.processAllAvailable()
+ checkLastBatchData(12)
+ checkNumBatchesSinceLastCheck(3)
+ }
+ }
+
test("explain") {
withTempDirs { case (src, tmp) =>
src.mkdirs()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
index c4a894b6816a..28170f30646a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala
@@ -120,12 +120,12 @@ class StreamSuite extends StreamTest {
}
// Running streaming plan as a batch query
- assertError("startStream" :: Nil) {
+ assertError("start" :: Nil) {
streamInput.toDS.map { i => i }.count()
}
// Running non-streaming plan with as a streaming query
- assertError("without streaming sources" :: "startStream" :: Nil) {
+ assertError("without streaming sources" :: "start" :: Nil) {
val ds = batchInput.map { i => i }
testStream(ds)()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
index ebbcc1d7ffbb..27a0a2a776c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala
@@ -82,6 +82,29 @@ class DefaultSource
}
}
+/** Dummy provider with only RelationProvider and CreatableRelationProvider. */
+class DefaultSourceWithoutUserSpecifiedSchema
+ extends RelationProvider
+ with CreatableRelationProvider {
+
+ case class FakeRelation(sqlContext: SQLContext) extends BaseRelation {
+ override def schema: StructType = StructType(Seq(StructField("a", StringType)))
+ }
+
+ override def createRelation(
+ sqlContext: SQLContext,
+ parameters: Map[String, String]): BaseRelation = {
+ FakeRelation(sqlContext)
+ }
+
+ override def createRelation(
+ sqlContext: SQLContext,
+ mode: SaveMode,
+ parameters: Map[String, String],
+ data: DataFrame): BaseRelation = {
+ FakeRelation(sqlContext)
+ }
+}
class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with BeforeAndAfter {
@@ -120,6 +143,15 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
.save()
}
+ test("resolve default source without extending SchemaRelationProvider") {
+ spark.read
+ .format("org.apache.spark.sql.test.DefaultSourceWithoutUserSpecifiedSchema")
+ .load()
+ .write
+ .format("org.apache.spark.sql.test.DefaultSourceWithoutUserSpecifiedSchema")
+ .save()
+ }
+
test("resolve full class") {
spark.read
.format("org.apache.spark.sql.test.DefaultSource")
@@ -246,8 +278,9 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be
spark.range(10).write.format("parquet").mode("overwrite").partitionBy("id").save(path)
}
intercept[AnalysisException] {
- spark.range(10).write.format("orc").mode("overwrite").partitionBy("id").save(path)
+ spark.range(10).write.format("csv").mode("overwrite").partitionBy("id").save(path)
}
+ spark.emptyDataFrame.write.format("parquet").mode("overwrite").save(path)
}
}
diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
index 7389e18aefb1..5dafec1c3021 100644
--- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
+++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala
@@ -156,6 +156,14 @@ private[hive] object SparkSQLCLIDriver extends Logging {
// Execute -i init files (always in silent mode)
cli.processInitFiles(sessionState)
+ // Respect the configurations set by --hiveconf from the command line
+ // (based on Hive's CliDriver).
+ val it = sessionState.getOverriddenConfigurations.entrySet().iterator()
+ while (it.hasNext) {
+ val kv = it.next()
+ SparkSQLEnv.sqlContext.setConf(kv.getKey, kv.getValue)
+ }
+
if (sessionState.execString != null) {
System.exit(cli.processLine(sessionState.execString))
}
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
index 75535cad1b18..d3cec11bd756 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala
@@ -91,6 +91,8 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
| --hiveconf ${ConfVars.METASTORECONNECTURLKEY}=$jdbcUrl
| --hiveconf ${ConfVars.METASTOREWAREHOUSE}=$warehousePath
| --hiveconf ${ConfVars.SCRATCHDIR}=$scratchDirPath
+ | --hiveconf conf1=conftest
+ | --hiveconf conf2=1
""".stripMargin.split("\\s+").toSeq ++ extraArgs
}
@@ -272,4 +274,13 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging {
s"LIST FILE $dataFilePath;" -> "small_kv.txt"
)
}
+
+ test("apply hiveconf from cli command") {
+ runCliWithin(2.minute)(
+ "SET conf1;" -> "conftest",
+ "SET conf2;" -> "1",
+ "SET conf3=${hiveconf:conf1};" -> "conftest",
+ "SET conf3;" -> "conftest"
+ )
+ }
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 2d5a970c1200..13d18fdec0e9 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -517,6 +517,18 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
// This test uses CREATE EXTERNAL TABLE without specifying LOCATION
"alter2",
+ // [SPARK-16248][SQL] Whitelist the list of Hive fallback functions
+ "udf_field",
+ "udf_reflect2",
+ "udf_xpath",
+ "udf_xpath_boolean",
+ "udf_xpath_double",
+ "udf_xpath_float",
+ "udf_xpath_int",
+ "udf_xpath_long",
+ "udf_xpath_short",
+ "udf_xpath_string",
+
// These tests DROP TABLE that don't exist (but do not specify IF EXISTS)
"alter_rename_partition1",
"date_1",
@@ -1004,7 +1016,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_elt",
"udf_equal",
"udf_exp",
- "udf_field",
"udf_find_in_set",
"udf_float",
"udf_floor",
@@ -1049,7 +1060,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_power",
"udf_radians",
"udf_rand",
- "udf_reflect2",
"udf_regexp",
"udf_regexp_extract",
"udf_regexp_replace",
@@ -1090,14 +1100,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
"udf_variance",
"udf_weekofyear",
"udf_when",
- "udf_xpath",
- "udf_xpath_boolean",
- "udf_xpath_double",
- "udf_xpath_float",
- "udf_xpath_int",
- "udf_xpath_long",
- "udf_xpath_short",
- "udf_xpath_string",
"union10",
"union11",
"union13",
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
index 6c3978154d4b..7ba5790c2979 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala
@@ -534,31 +534,6 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte
| rows between 2 preceding and 2 following);
""".stripMargin, reset = false)
- // collect_set() output array in an arbitrary order, hence causes different result
- // when running this test suite under Java 7 and 8.
- // We change the original sql query a little bit for making the test suite passed
- // under different JDK
- /* Disabled because:
- - Spark uses a different default stddev.
- - Tiny numerical differences in stddev results.
- createQueryTest("windowing.q -- 20. testSTATs",
- """
- |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp
- |from (
- |select p_mfgr,p_name, p_size,
- |stddev(p_retailprice) over w1 as sdev,
- |stddev_pop(p_retailprice) over w1 as sdev_pop,
- |collect_set(p_size) over w1 as uniq_size,
- |variance(p_retailprice) over w1 as var,
- |corr(p_size, p_retailprice) over w1 as cor,
- |covar_pop(p_size, p_retailprice) over w1 as covarp
- |from part
- |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name
- | rows between 2 preceding and 2 following)
- |) t lateral view explode(uniq_size) d as uniq_data
- |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp
- """.stripMargin, reset = false)
- */
createQueryTest("windowing.q -- 21. testDISTs",
"""
|select p_mfgr,p_name, p_size,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
index 2e0b5d59b578..789f94aff303 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala
@@ -147,10 +147,6 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
// it is better at here to invalidate the cache to avoid confusing waring logs from the
// cache loader (e.g. cannot find data source provider, which is only defined for
// data source table.).
- invalidateTable(tableIdent)
- }
-
- def invalidateTable(tableIdent: TableIdentifier): Unit = {
cachedDataSourceTables.invalidate(getQualifiedTableName(tableIdent))
}
@@ -191,6 +187,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
private def getCached(
tableIdentifier: QualifiedTableName,
+ pathsInMetastore: Seq[String],
metastoreRelation: MetastoreRelation,
schemaInMetastore: StructType,
expectedFileFormat: Class[_ <: FileFormat],
@@ -200,7 +197,6 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
cachedDataSourceTables.getIfPresent(tableIdentifier) match {
case null => None // Cache miss
case logical @ LogicalRelation(relation: HadoopFsRelation, _, _) =>
- val pathsInMetastore = metastoreRelation.catalogTable.storage.locationUri.toSeq
val cachedRelationFileFormatClass = relation.fileFormat.getClass
expectedFileFormat match {
@@ -265,9 +261,22 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
PartitionDirectory(values, location)
}
val partitionSpec = PartitionSpec(partitionSchema, partitions)
+ val partitionPaths = partitions.map(_.path.toString)
+
+ // By convention (for example, see MetaStorePartitionedTableFileCatalog), the definition of a
+ // partitioned table's paths depends on whether that table has any actual partitions.
+ // Partitioned tables without partitions use the location of the table's base path.
+ // Partitioned tables with partitions use the locations of those partitions' data locations,
+ // _omitting_ the table's base path.
+ val paths = if (partitionPaths.isEmpty) {
+ Seq(metastoreRelation.hiveQlTable.getDataLocation.toString)
+ } else {
+ partitionPaths
+ }
val cached = getCached(
tableIdentifier,
+ paths,
metastoreRelation,
metastoreSchema,
fileFormatClass,
@@ -312,6 +321,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log
val paths = Seq(metastoreRelation.hiveQlTable.getDataLocation.toString)
val cached = getCached(tableIdentifier,
+ paths,
metastoreRelation,
metastoreSchema,
fileFormatClass,
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 2f6a2207855e..9c7f461362d8 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -30,12 +30,13 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCatalog}
-import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo}
+import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper
import org.apache.spark.sql.hive.client.HiveClient
import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DecimalType, DoubleType}
import org.apache.spark.util.Utils
@@ -89,13 +90,10 @@ private[sql] class HiveSessionCatalog(
val CreateTables: Rule[LogicalPlan] = metastoreCatalog.CreateTables
override def refreshTable(name: TableIdentifier): Unit = {
+ super.refreshTable(name)
metastoreCatalog.refreshTable(name)
}
- override def invalidateTable(name: TableIdentifier): Unit = {
- metastoreCatalog.invalidateTable(name)
- }
-
def invalidateCache(): Unit = {
metastoreCatalog.cachedDataSourceTables.invalidateAll()
}
@@ -162,18 +160,20 @@ private[sql] class HiveSessionCatalog(
}
}
- // We have a list of Hive built-in functions that we do not support. So, we will check
- // Hive's function registry and lazily load needed functions into our own function registry.
- // Those Hive built-in functions are
- // assert_true, collect_list, collect_set, compute_stats, context_ngrams, create_union,
- // current_user ,elt, ewah_bitmap, ewah_bitmap_and, ewah_bitmap_empty, ewah_bitmap_or, field,
- // histogram_numeric, in_file, index, inline, java_method, map_keys, map_values,
- // matchpath, ngrams, noop, noopstreaming, noopwithmap, noopwithmapstreaming,
- // parse_url, parse_url_tuple, percentile, percentile_approx, posexplode, reflect, reflect2,
- // regexp, sentences, stack, std, str_to_map, windowingtablefunction, xpath, xpath_boolean,
- // xpath_double, xpath_float, xpath_int, xpath_long, xpath_number,
- // xpath_short, and xpath_string.
override def lookupFunction(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
+ try {
+ lookupFunction0(name, children)
+ } catch {
+ case NonFatal(_) =>
+ // SPARK-16228 ExternalCatalog may recognize `double`-type only.
+ val newChildren = children.map { child =>
+ if (child.dataType.isInstanceOf[DecimalType]) Cast(child, DoubleType) else child
+ }
+ lookupFunction0(name, newChildren)
+ }
+ }
+
+ private def lookupFunction0(name: FunctionIdentifier, children: Seq[Expression]): Expression = {
// TODO: Once lookupFunction accepts a FunctionIdentifier, we should refactor this method to
// if (super.functionExists(name)) {
// super.lookupFunction(name, children)
@@ -196,10 +196,12 @@ private[sql] class HiveSessionCatalog(
// built-in function.
// Hive is case insensitive.
val functionName = funcName.unquotedString.toLowerCase
- // TODO: This may not really work for current_user because current_user is not evaluated
- // with session info.
- // We do not need to use executionHive at here because we only load
- // Hive's builtin functions, which do not need current db.
+ if (!hiveFunctions.contains(functionName)) {
+ failFunctionLookup(funcName.unquotedString)
+ }
+
+ // TODO: Remove this fallback path once we implement the list of fallback functions
+ // defined below in hiveFunctions.
val functionInfo = {
try {
Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse(
@@ -221,4 +223,21 @@ private[sql] class HiveSessionCatalog(
}
}
}
+
+ /** List of functions we pass over to Hive. Note that over time this list should go to 0. */
+ // We have a list of Hive built-in functions that we do not support. So, we will check
+ // Hive's function registry and lazily load needed functions into our own function registry.
+ // List of functions we are explicitly not supporting are:
+ // compute_stats, context_ngrams, create_union,
+ // current_user, ewah_bitmap, ewah_bitmap_and, ewah_bitmap_empty, ewah_bitmap_or, field,
+ // in_file, index, java_method,
+ // matchpath, ngrams, noop, noopstreaming, noopwithmap, noopwithmapstreaming,
+ // parse_url_tuple, posexplode, reflect2,
+ // str_to_map, windowingtablefunction.
+ private val hiveFunctions = Seq(
+ "hash", "java_method", "histogram_numeric",
+ "percentile", "percentile_approx", "reflect", "str_to_map",
+ "xpath", "xpath_double", "xpath_float", "xpath_int", "xpath_long",
+ "xpath_number", "xpath_short", "xpath_string"
+ )
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
index b8099385a466..15a5d79dcb08 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateHiveTableAsSelectCommand.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql.hive.execution
+import scala.util.control.NonFatal
+
import org.apache.spark.sql.{AnalysisException, Row, SparkSession}
import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable}
import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan}
@@ -87,8 +89,15 @@ case class CreateHiveTableAsSelectCommand(
throw new AnalysisException(s"$tableIdentifier already exists.")
}
} else {
- sparkSession.sessionState.executePlan(InsertIntoTable(
- metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd
+ try {
+ sparkSession.sessionState.executePlan(InsertIntoTable(
+ metastoreRelation, Map(), query, overwrite = true, ifNotExists = false)).toRdd
+ } catch {
+ case NonFatal(e) =>
+ // drop the created table.
+ sparkSession.sessionState.catalog.dropTable(tableIdentifier, ignoreIfNotExists = true)
+ throw e
+ }
}
Seq.empty[Row]
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
index 97cd29f541ed..3d58d490a51e 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala
@@ -298,6 +298,7 @@ case class InsertIntoHiveTable(
// Invalidate the cache.
sqlContext.sharedState.cacheManager.invalidateCache(table)
+ sqlContext.sessionState.catalog.refreshTable(table.catalogTable.identifier)
// It would be nice to just return the childRdd unchanged so insert operations could be chained,
// however for now we return an empty list to simplify compatibility checks with hive, which
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index 9e25e1d40ce8..dfb12512a40f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -312,15 +312,15 @@ private class ScriptTransformationWriterThread(
}
threwException = false
} catch {
- case NonFatal(e) =>
+ case t: Throwable =>
// An error occurred while writing input, so kill the child process. According to the
// Javadoc this call will not throw an exception:
- _exception = e
+ _exception = t
proc.destroy()
- throw e
+ throw t
} finally {
try {
- outputStream.close()
+ Utils.tryLogNonFatalError(outputStream.close())
if (proc.waitFor() != 0) {
logError(stderrBuffer.toString) // log the stderr circular buffer
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
index 794fe264ead5..e65c24e6f125 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala
@@ -44,6 +44,7 @@ import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc}
import org.apache.spark.sql.types._
import org.apache.spark.util.SerializableJobConf
+import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter
/**
* Internal helper class that saves an RDD using a Hive OutputFormat.
@@ -280,7 +281,9 @@ private[spark] class SparkHiveDynamicPartitionWriterContainer(
StructType.fromAttributes(dataOutput),
SparkEnv.get.blockManager,
SparkEnv.get.serializerManager,
- TaskContext.get().taskMemoryManager().pageSizeBytes)
+ TaskContext.get().taskMemoryManager().pageSizeBytes,
+ SparkEnv.get.conf.getLong("spark.shuffle.spill.numElementsForceSpillThreshold",
+ UnsafeExternalSorter.DEFAULT_NUM_ELEMENTS_FOR_SPILL_THRESHOLD))
while (iterator.hasNext) {
val inputRow = iterator.next()
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index b45be0251d95..7f892047c707 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -73,8 +73,12 @@ class TestHiveContext(
@transient override val sparkSession: TestHiveSparkSession)
extends SQLContext(sparkSession) {
- def this(sc: SparkContext) {
- this(new TestHiveSparkSession(HiveUtils.withHiveExternalCatalog(sc)))
+ /**
+ * If loadTestTables is false, no test tables are loaded. Note that this flag can only be true
+ * when running in the JVM, i.e. it needs to be false when calling from Python.
+ */
+ def this(sc: SparkContext, loadTestTables: Boolean = true) {
+ this(new TestHiveSparkSession(HiveUtils.withHiveExternalCatalog(sc), loadTestTables))
}
override def newSession(): TestHiveContext = {
@@ -103,13 +107,24 @@ class TestHiveContext(
}
-
+/**
+ * A [[SparkSession]] used in [[TestHiveContext]].
+ *
+ * @param sc SparkContext
+ * @param warehousePath path to the Hive warehouse directory
+ * @param scratchDirPath scratch directory used by Hive's metastore client
+ * @param metastoreTemporaryConf configuration options for Hive's metastore
+ * @param existingSharedState optional [[TestHiveSharedState]]
+ * @param loadTestTables if true, load the test tables. They can only be loaded when running
+ * in the JVM, i.e when calling from Python this flag has to be false.
+ */
private[hive] class TestHiveSparkSession(
@transient private val sc: SparkContext,
val warehousePath: File,
scratchDirPath: File,
metastoreTemporaryConf: Map[String, String],
- @transient private val existingSharedState: Option[TestHiveSharedState])
+ @transient private val existingSharedState: Option[TestHiveSharedState],
+ private val loadTestTables: Boolean)
extends SparkSession(sc) with Logging { self =>
// TODO: We need to set the temp warehouse path to sc's conf.
@@ -118,13 +133,14 @@ private[hive] class TestHiveSparkSession(
// when we creating metadataHive. This flow is not easy to follow and can introduce
// confusion when a developer is debugging an issue. We need to refactor this part
// to just set the temp warehouse path in sc's conf.
- def this(sc: SparkContext) {
+ def this(sc: SparkContext, loadTestTables: Boolean) {
this(
sc,
Utils.createTempDir(namePrefix = "warehouse"),
TestHiveContext.makeScratchDir(),
HiveUtils.newTemporaryConfiguration(useInMemoryDerby = false),
- None)
+ None,
+ loadTestTables)
}
assume(sc.conf.get(CATALOG_IMPLEMENTATION) == "hive")
@@ -144,7 +160,7 @@ private[hive] class TestHiveSparkSession(
override def newSession(): TestHiveSparkSession = {
new TestHiveSparkSession(
- sc, warehousePath, scratchDirPath, metastoreTemporaryConf, Some(sharedState))
+ sc, warehousePath, scratchDirPath, metastoreTemporaryConf, Some(sharedState), loadTestTables)
}
private var cacheTables: Boolean = false
@@ -204,165 +220,173 @@ private[hive] class TestHiveSparkSession(
testTables += (testTable.name -> testTable)
}
- // The test tables that are defined in the Hive QTestUtil.
- // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java
- // https://github.com/apache/hive/blob/branch-0.13/data/scripts/q_test_init.sql
- @transient
- val hiveQTestUtilTables = Seq(
- TestTable("src",
- "CREATE TABLE src (key INT, value STRING)".cmd,
- s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd),
- TestTable("src1",
- "CREATE TABLE src1 (key INT, value STRING)".cmd,
- s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd),
- TestTable("srcpart", () => {
- sql(
- "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)")
- for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) {
+ if (loadTestTables) {
+ // The test tables that are defined in the Hive QTestUtil.
+ // /itests/util/src/main/java/org/apache/hadoop/hive/ql/QTestUtil.java
+ // https://github.com/apache/hive/blob/branch-0.13/data/scripts/q_test_init.sql
+ @transient
+ val hiveQTestUtilTables: Seq[TestTable] = Seq(
+ TestTable("src",
+ "CREATE TABLE src (key INT, value STRING)".cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' INTO TABLE src".cmd),
+ TestTable("src1",
+ "CREATE TABLE src1 (key INT, value STRING)".cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd),
+ TestTable("srcpart", () => {
sql(
- s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
- |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr')
- """.stripMargin)
- }
- }),
- TestTable("srcpart1", () => {
- sql(
- "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)")
- for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) {
+ "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)")
+ for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) {
+ sql(
+ s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
+ |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr')
+ """.stripMargin)
+ }
+ }),
+ TestTable("srcpart1", () => {
sql(
- s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
- |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr')
- """.stripMargin)
- }
- }),
- TestTable("src_thrift", () => {
- import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer
- import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat}
- import org.apache.thrift.protocol.TBinaryProtocol
+ "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)")
+ for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) {
+ sql(
+ s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}'
+ |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr')
+ """.stripMargin)
+ }
+ }),
+ TestTable("src_thrift", () => {
+ import org.apache.hadoop.hive.serde2.thrift.ThriftDeserializer
+ import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat}
+ import org.apache.thrift.protocol.TBinaryProtocol
- sql(
+ sql(
+ s"""
+ |CREATE TABLE src_thrift(fake INT)
+ |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}'
+ |WITH SERDEPROPERTIES(
+ | 'serialization.class'='org.apache.spark.sql.hive.test.Complex',
+ | 'serialization.format'='${classOf[TBinaryProtocol].getName}'
+ |)
+ |STORED AS
+ |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}'
+ |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}'
+ """.stripMargin)
+
+ sql(
+ s"""
+ |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}'
+ |INTO TABLE src_thrift
+ """.stripMargin)
+ }),
+ TestTable("serdeins",
+ s"""CREATE TABLE serdeins (key INT, value STRING)
+ |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}'
+ |WITH SERDEPROPERTIES ('field.delim'='\\t')
+ """.stripMargin.cmd,
+ "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd),
+ TestTable("episodes",
+ s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT)
+ |STORED AS avro
+ |TBLPROPERTIES (
+ | 'avro.schema.literal'='{
+ | "type": "record",
+ | "name": "episodes",
+ | "namespace": "testing.hive.avro.serde",
+ | "fields": [
+ | {
+ | "name": "title",
+ | "type": "string",
+ | "doc": "episode title"
+ | },
+ | {
+ | "name": "air_date",
+ | "type": "string",
+ | "doc": "initial date"
+ | },
+ | {
+ | "name": "doctor",
+ | "type": "int",
+ | "doc": "main actor playing the Doctor in episode"
+ | }
+ | ]
+ | }'
+ |)
+ """.stripMargin.cmd,
s"""
- |CREATE TABLE src_thrift(fake INT)
- |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}'
- |WITH SERDEPROPERTIES(
- | 'serialization.class'='org.apache.spark.sql.hive.test.Complex',
- | 'serialization.format'='${classOf[TBinaryProtocol].getName}'
- |)
- |STORED AS
- |INPUTFORMAT '${classOf[SequenceFileInputFormat[_, _]].getName}'
- |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}'
- """.stripMargin)
-
- sql(
- s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift")
- }),
- TestTable("serdeins",
- s"""CREATE TABLE serdeins (key INT, value STRING)
- |ROW FORMAT SERDE '${classOf[LazySimpleSerDe].getCanonicalName}'
- |WITH SERDEPROPERTIES ('field.delim'='\\t')
- """.stripMargin.cmd,
- "INSERT OVERWRITE TABLE serdeins SELECT * FROM src".cmd),
- TestTable("episodes",
- s"""CREATE TABLE episodes (title STRING, air_date STRING, doctor INT)
- |STORED AS avro
- |TBLPROPERTIES (
- | 'avro.schema.literal'='{
- | "type": "record",
- | "name": "episodes",
- | "namespace": "testing.hive.avro.serde",
- | "fields": [
- | {
- | "name": "title",
- | "type": "string",
- | "doc": "episode title"
- | },
- | {
- | "name": "air_date",
- | "type": "string",
- | "doc": "initial date"
- | },
- | {
- | "name": "doctor",
- | "type": "int",
- | "doc": "main actor playing the Doctor in episode"
- | }
- | ]
- | }'
- |)
- """.stripMargin.cmd,
- s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd
- ),
- // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC
- // PARTITIONING IS NOT YET SUPPORTED
- TestTable("episodes_part",
- s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT)
- |PARTITIONED BY (doctor_pt INT)
- |STORED AS avro
- |TBLPROPERTIES (
- | 'avro.schema.literal'='{
- | "type": "record",
- | "name": "episodes",
- | "namespace": "testing.hive.avro.serde",
- | "fields": [
- | {
- | "name": "title",
- | "type": "string",
- | "doc": "episode title"
- | },
- | {
- | "name": "air_date",
- | "type": "string",
- | "doc": "initial date"
- | },
- | {
- | "name": "doctor",
- | "type": "int",
- | "doc": "main actor playing the Doctor in episode"
- | }
- | ]
- | }'
- |)
- """.stripMargin.cmd,
- // WORKAROUND: Required to pass schema to SerDe for partitioned tables.
- // TODO: Pass this automatically from the table to partitions.
- s"""
- |ALTER TABLE episodes_part SET SERDEPROPERTIES (
- | 'avro.schema.literal'='{
- | "type": "record",
- | "name": "episodes",
- | "namespace": "testing.hive.avro.serde",
- | "fields": [
- | {
- | "name": "title",
- | "type": "string",
- | "doc": "episode title"
- | },
- | {
- | "name": "air_date",
- | "type": "string",
- | "doc": "initial date"
- | },
- | {
- | "name": "doctor",
- | "type": "int",
- | "doc": "main actor playing the Doctor in episode"
- | }
- | ]
- | }'
- |)
- """.stripMargin.cmd,
- s"""
- INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1)
- SELECT title, air_date, doctor FROM episodes
- """.cmd
+ |LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}'
+ |INTO TABLE episodes
+ """.stripMargin.cmd
),
- TestTable("src_json",
- s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE
- """.stripMargin.cmd,
- s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd)
- )
+ // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC
+ // PARTITIONING IS NOT YET SUPPORTED
+ TestTable("episodes_part",
+ s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT)
+ |PARTITIONED BY (doctor_pt INT)
+ |STORED AS avro
+ |TBLPROPERTIES (
+ | 'avro.schema.literal'='{
+ | "type": "record",
+ | "name": "episodes",
+ | "namespace": "testing.hive.avro.serde",
+ | "fields": [
+ | {
+ | "name": "title",
+ | "type": "string",
+ | "doc": "episode title"
+ | },
+ | {
+ | "name": "air_date",
+ | "type": "string",
+ | "doc": "initial date"
+ | },
+ | {
+ | "name": "doctor",
+ | "type": "int",
+ | "doc": "main actor playing the Doctor in episode"
+ | }
+ | ]
+ | }'
+ |)
+ """.stripMargin.cmd,
+ // WORKAROUND: Required to pass schema to SerDe for partitioned tables.
+ // TODO: Pass this automatically from the table to partitions.
+ s"""
+ |ALTER TABLE episodes_part SET SERDEPROPERTIES (
+ | 'avro.schema.literal'='{
+ | "type": "record",
+ | "name": "episodes",
+ | "namespace": "testing.hive.avro.serde",
+ | "fields": [
+ | {
+ | "name": "title",
+ | "type": "string",
+ | "doc": "episode title"
+ | },
+ | {
+ | "name": "air_date",
+ | "type": "string",
+ | "doc": "initial date"
+ | },
+ | {
+ | "name": "doctor",
+ | "type": "int",
+ | "doc": "main actor playing the Doctor in episode"
+ | }
+ | ]
+ | }'
+ |)
+ """.stripMargin.cmd,
+ s"""
+ INSERT OVERWRITE TABLE episodes_part PARTITION (doctor_pt=1)
+ SELECT title, air_date, doctor FROM episodes
+ """.cmd
+ ),
+ TestTable("src_json",
+ s"""CREATE TABLE src_json (json STRING) STORED AS TEXTFILE
+ """.stripMargin.cmd,
+ s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/json.txt")}' INTO TABLE src_json".cmd)
+ )
- hiveQTestUtilTables.foreach(registerTestTable)
+ hiveQTestUtilTables.foreach(registerTestTable)
+ }
private val loadedTables = new collection.mutable.HashSet[String]
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
new file mode 100644
index 000000000000..5714d06f0fe7
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala
@@ -0,0 +1,62 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.SparkException
+import org.apache.spark.sql.QueryTest
+import org.apache.spark.sql.hive.test.TestHiveSingleton
+import org.apache.spark.sql.test.SQLTestUtils
+
+/**
+ * Test suite to handle metadata cache related.
+ */
+class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
+
+ test("SPARK-16337 temporary view refresh") {
+ withTempTable("view_refresh") {
+ withTable("view_table") {
+ // Create a Parquet directory
+ spark.range(start = 0, end = 100, step = 1, numPartitions = 3)
+ .write.saveAsTable("view_table")
+
+ // Read the table in
+ spark.table("view_table").filter("id > -1").createOrReplaceTempView("view_refresh")
+ assert(sql("select count(*) from view_refresh").first().getLong(0) == 100)
+
+ // Delete a file using the Hadoop file system interface since the path returned by
+ // inputFiles is not recognizable by Java IO.
+ val p = new Path(spark.table("view_table").inputFiles.head)
+ assert(p.getFileSystem(hiveContext.sessionState.newHadoopConf()).delete(p, false))
+
+ // Read it again and now we should see a FileNotFoundException
+ val e = intercept[SparkException] {
+ sql("select count(*) from view_refresh").first()
+ }
+ assert(e.getMessage.contains("FileNotFoundException"))
+ assert(e.getMessage.contains("REFRESH"))
+
+ // Refresh and we should be able to read it again.
+ spark.catalog.refreshTable("view_refresh")
+ val newCount = sql("select count(*) from view_refresh").first().getLong(0)
+ assert(newCount > 0 && newCount < 100)
+ }
+ }
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index b420781e51bd..754aabb5ac93 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -26,15 +26,15 @@ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
-import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
+import org.apache.spark.sql.types.{DecimalType, StringType, StructField, StructType}
class HiveMetastoreCatalogSuite extends TestHiveSingleton {
import spark.implicits._
test("struct field should accept underscore in sub-column name") {
val hiveTypeStr = "struct"
- val dateType = CatalystSqlParser.parseDataType(hiveTypeStr)
- assert(dateType.isInstanceOf[StructType])
+ val dataType = CatalystSqlParser.parseDataType(hiveTypeStr)
+ assert(dataType.isInstanceOf[StructType])
}
test("udt to metastore type conversion") {
@@ -49,6 +49,14 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton {
logInfo(df.queryExecution.toString)
df.as('a).join(df.as('b), $"a.key" === $"b.key")
}
+
+ test("should not truncate struct type catalog string") {
+ def field(n: Int): StructField = {
+ StructField("col" + n, StringType)
+ }
+ val dataType = StructType((1 to 100).map(field))
+ assert(CatalystSqlParser.parseDataType(dataType.catalogString) == dataType)
+ }
}
class DataSourceWithHiveMetastoreCatalogSuite
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index b028d49aff58..12d250d4fb60 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -255,13 +255,13 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq)
// Discard the cached relation.
- sessionState.invalidateTable("jsonTable")
+ sessionState.refreshTable("jsonTable")
checkAnswer(
sql("SELECT * FROM jsonTable"),
sql("SELECT `c_!@(3)` FROM expectedJsonTable").collect().toSeq)
- sessionState.invalidateTable("jsonTable")
+ sessionState.refreshTable("jsonTable")
val expectedSchema = StructType(StructField("c_!@(3)", IntegerType, true) :: Nil)
assert(expectedSchema === table("jsonTable").schema)
@@ -349,7 +349,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
""".stripMargin)
// Discard the cached relation.
- sessionState.invalidateTable("ctasJsonTable")
+ sessionState.refreshTable("ctasJsonTable")
// Schema should not be changed.
assert(table("ctasJsonTable").schema === table("jsonTable").schema)
@@ -424,7 +424,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
sql("SELECT * FROM savedJsonTable tmp where tmp.a > 5"),
(6 to 10).map(i => Row(i, s"str$i")))
- sessionState.invalidateTable("savedJsonTable")
+ sessionState.refreshTable("savedJsonTable")
checkAnswer(
sql("SELECT * FROM savedJsonTable where savedJsonTable.a < 5"),
@@ -710,7 +710,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
options = Map("path" -> tempDir.getCanonicalPath),
isExternal = false)
- sessionState.invalidateTable("wide_schema")
+ sessionState.refreshTable("wide_schema")
val actualSchema = table("wide_schema").schema
assert(schema === actualSchema)
@@ -743,7 +743,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
sharedState.externalCatalog.createTable("default", hiveTable, ignoreIfExists = false)
- sessionState.invalidateTable(tableName)
+ sessionState.refreshTable(tableName)
val actualSchema = table(tableName).schema
assert(schema === actualSchema)
@@ -758,7 +758,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
withTable(tableName) {
df.write.format("parquet").partitionBy("d", "b").saveAsTable(tableName)
- sessionState.invalidateTable(tableName)
+ sessionState.refreshTable(tableName)
val metastoreTable = sharedState.externalCatalog.getTable("default", tableName)
val expectedPartitionColumns = StructType(df.schema("d") :: df.schema("b") :: Nil)
@@ -793,7 +793,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv
.bucketBy(8, "d", "b")
.sortBy("c")
.saveAsTable(tableName)
- sessionState.invalidateTable(tableName)
+ sessionState.refreshTable(tableName)
val metastoreTable = sharedState.externalCatalog.getTable("default", tableName)
val expectedBucketByColumns = StructType(df.schema("d") :: df.schema("b") :: Nil)
val expectedSortByColumns = StructType(df.schema("c") :: Nil)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
index 89f69c8e4d7f..93e50f4ee907 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala
@@ -391,6 +391,29 @@ class HiveDDLSuite
}
}
+ test("create view with mismatched schema") {
+ withTable("tab1") {
+ spark.range(10).write.saveAsTable("tab1")
+ withView("view1") {
+ val e = intercept[AnalysisException] {
+ sql("CREATE VIEW view1 (col1, col3) AS SELECT * FROM tab1")
+ }.getMessage
+ assert(e.contains("the SELECT clause (num: `1`) does not match")
+ && e.contains("CREATE VIEW (num: `2`)"))
+ }
+ }
+ }
+
+ test("create view with specified schema") {
+ withView("view1") {
+ sql("CREATE VIEW view1 (col1, col2) AS SELECT 1, 2")
+ checkAnswer(
+ sql("SELECT * FROM view1"),
+ Row(1, 2) :: Nil
+ )
+ }
+ }
+
test("desc table for Hive table") {
withTable("tab1") {
val tabName = "tab1"
@@ -554,6 +577,21 @@ class HiveDDLSuite
}
}
+ test("Create Cataloged Table As Select - Drop Table After Runtime Exception") {
+ withTable("tab") {
+ intercept[RuntimeException] {
+ sql(
+ """
+ |CREATE TABLE tab
+ |STORED AS TEXTFILE
+ |SELECT 1 AS a, (SELECT a FROM (SELECT 1 AS a UNION ALL SELECT 2 AS a) t) AS b
+ """.stripMargin)
+ }
+ // After hitting runtime exception, we should drop the created table.
+ assert(!spark.sessionState.catalog.tableExists(TableIdentifier("tab")))
+ }
+ }
+
test("desc table for data source table") {
withTable("tab1") {
val tabName = "tab1"
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 0f56b2c0d1f4..def4601cf615 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -142,6 +142,13 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq)
}
+ test("SPARK-16228 Percentile needs explicit cast to double") {
+ sql("select percentile(value, cast(0.5 as double)) from values 1,2,3 T(value)")
+ sql("select percentile_approx(value, cast(0.5 as double)) from values 1.0,2.0,3.0 T(value)")
+ sql("select percentile(value, 0.5) from values 1,2,3 T(value)")
+ sql("select percentile_approx(value, 0.5) from values 1.0,2.0,3.0 T(value)")
+ }
+
test("Generic UDAF aggregates") {
checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999D)) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
index 871b9e02eb38..0f37cd7bf365 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala
@@ -153,7 +153,7 @@ class OrcSourceSuite extends OrcSuite {
super.beforeAll()
spark.sql(
- s"""CREATE TEMPORARY TABLE normal_orc_source
+ s"""CREATE TEMPORARY VIEW normal_orc_source
|USING org.apache.spark.sql.hive.orc
|OPTIONS (
| PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}'
@@ -161,7 +161,7 @@ class OrcSourceSuite extends OrcSuite {
""".stripMargin)
spark.sql(
- s"""CREATE TEMPORARY TABLE normal_orc_as_source
+ s"""CREATE TEMPORARY VIEW normal_orc_as_source
|USING org.apache.spark.sql.hive.orc
|OPTIONS (
| PATH '${new File(orcTableAsDir.getAbsolutePath).getCanonicalPath}'
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 6af9976ea0b8..96beb2d3427b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -389,17 +389,18 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
test("SPARK-7749: non-partitioned metastore Parquet table lookup should use cached relation") {
withTable("nonPartitioned") {
sql(
- s"""CREATE TABLE nonPartitioned (
- | key INT,
- | value STRING
- |)
- |STORED AS PARQUET
- """.stripMargin)
+ """
+ |CREATE TABLE nonPartitioned (
+ | key INT,
+ | value STRING
+ |)
+ |STORED AS PARQUET
+ """.stripMargin)
// First lookup fills the cache
- val r1 = collectHadoopFsRelation (table("nonPartitioned"))
+ val r1 = collectHadoopFsRelation(table("nonPartitioned"))
// Second lookup should reuse the cache
- val r2 = collectHadoopFsRelation (table("nonPartitioned"))
+ val r2 = collectHadoopFsRelation(table("nonPartitioned"))
// They should be the same instance
assert(r1 eq r2)
}
@@ -408,18 +409,42 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
test("SPARK-7749: partitioned metastore Parquet table lookup should use cached relation") {
withTable("partitioned") {
sql(
- s"""CREATE TABLE partitioned (
- | key INT,
- | value STRING
- |)
- |PARTITIONED BY (part INT)
- |STORED AS PARQUET
- """.stripMargin)
+ """
+ |CREATE TABLE partitioned (
+ | key INT,
+ | value STRING
+ |)
+ |PARTITIONED BY (part INT)
+ |STORED AS PARQUET
+ """.stripMargin)
+
+ // First lookup fills the cache
+ val r1 = collectHadoopFsRelation(table("partitioned"))
+ // Second lookup should reuse the cache
+ val r2 = collectHadoopFsRelation(table("partitioned"))
+ // They should be the same instance
+ assert(r1 eq r2)
+ }
+ }
+
+ test("SPARK-15968: nonempty partitioned metastore Parquet table lookup should use cached " +
+ "relation") {
+ withTable("partitioned") {
+ sql(
+ """
+ |CREATE TABLE partitioned (
+ | key INT,
+ | value STRING
+ |)
+ |PARTITIONED BY (part INT)
+ |STORED AS PARQUET
+ """.stripMargin)
+ sql("INSERT INTO TABLE partitioned PARTITION(part=0) SELECT 1 as key, 'one' as value")
// First lookup fills the cache
- val r1 = collectHadoopFsRelation (table("partitioned"))
+ val r1 = collectHadoopFsRelation(table("partitioned"))
// Second lookup should reuse the cache
- val r2 = collectHadoopFsRelation (table("partitioned"))
+ val r2 = collectHadoopFsRelation(table("partitioned"))
// They should be the same instance
assert(r1 eq r2)
}
@@ -462,7 +487,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
checkCached(tableIdentifier)
// For insert into non-partitioned table, we will do the conversion,
// so the converted test_insert_parquet should be cached.
- sessionState.invalidateTable("test_insert_parquet")
+ sessionState.refreshTable("test_insert_parquet")
assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null)
sql(
"""
@@ -475,7 +500,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
sql("select * from test_insert_parquet"),
sql("select a, b from jt").collect())
// Invalidate the cache.
- sessionState.invalidateTable("test_insert_parquet")
+ sessionState.refreshTable("test_insert_parquet")
assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null)
// Create a partitioned table.
@@ -525,7 +550,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
|select b, '2015-04-02', a FROM jt
""".stripMargin).collect())
- sessionState.invalidateTable("test_parquet_partitioned_cache_test")
+ sessionState.refreshTable("test_parquet_partitioned_cache_test")
assert(sessionState.catalog.getCachedDataSourceTable(tableIdentifier) === null)
dropTables("test_insert_parquet", "test_parquet_partitioned_cache_test")
@@ -557,7 +582,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest {
Seq(("foo", 0), ("bar", 0)).toDF("a", "b"))
// Add data files to partition directory and check whether they can be read
- Seq("baz").toDF("a").write.mode(SaveMode.Overwrite).parquet(partitionDir)
+ sql("INSERT INTO TABLE test_added_partitions PARTITION (b=1) select 'baz' as a")
checkAnswer(
sql("SELECT * FROM test_added_partitions"),
Seq(("foo", 0), ("bar", 0), ("baz", 1)).toDF("a", "b"))
@@ -582,7 +607,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
"normal_parquet")
sql( s"""
- create temporary table partitioned_parquet
+ CREATE TEMPORARY VIEW partitioned_parquet
USING org.apache.spark.sql.parquet
OPTIONS (
path '${partitionedTableDir.getCanonicalPath}'
@@ -590,7 +615,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
""")
sql( s"""
- create temporary table partitioned_parquet_with_key
+ CREATE TEMPORARY VIEW partitioned_parquet_with_key
USING org.apache.spark.sql.parquet
OPTIONS (
path '${partitionedTableDirWithKey.getCanonicalPath}'
@@ -598,7 +623,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
""")
sql( s"""
- create temporary table normal_parquet
+ CREATE TEMPORARY VIEW normal_parquet
USING org.apache.spark.sql.parquet
OPTIONS (
path '${new File(partitionedTableDir, "p=1").getCanonicalPath}'
@@ -606,7 +631,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
""")
sql( s"""
- CREATE TEMPORARY TABLE partitioned_parquet_with_key_and_complextypes
+ CREATE TEMPORARY VIEW partitioned_parquet_with_key_and_complextypes
USING org.apache.spark.sql.parquet
OPTIONS (
path '${partitionedTableDirWithKeyAndComplexTypes.getCanonicalPath}'
@@ -614,7 +639,7 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
""")
sql( s"""
- CREATE TEMPORARY TABLE partitioned_parquet_with_complextypes
+ CREATE TEMPORARY VIEW partitioned_parquet_with_complextypes
USING org.apache.spark.sql.parquet
OPTIONS (
path '${partitionedTableDirWithComplexTypes.getCanonicalPath}'
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 9bb369549d94..01aa12a3c9a7 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -1053,7 +1053,14 @@ private[spark] class Client(
case YarnApplicationState.RUNNING =>
reportLauncherState(SparkAppHandle.State.RUNNING)
case YarnApplicationState.FINISHED =>
- reportLauncherState(SparkAppHandle.State.FINISHED)
+ report.getFinalApplicationStatus match {
+ case FinalApplicationStatus.FAILED =>
+ reportLauncherState(SparkAppHandle.State.FAILED)
+ case FinalApplicationStatus.KILLED =>
+ reportLauncherState(SparkAppHandle.State.KILLED)
+ case _ =>
+ reportLauncherState(SparkAppHandle.State.FINISHED)
+ }
case YarnApplicationState.FAILED =>
reportLauncherState(SparkAppHandle.State.FAILED)
case YarnApplicationState.KILLED =>
diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
index 6b20dea5908a..9085fca1d3cc 100644
--- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
+++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnClusterSuite.scala
@@ -120,6 +120,11 @@ class YarnClusterSuite extends BaseYarnClusterSuite {
finalState should be (SparkAppHandle.State.FAILED)
}
+ test("run Spark in yarn-cluster mode failure after sc initialized") {
+ val finalState = runSpark(false, mainClassName(YarnClusterDriverWithFailure.getClass))
+ finalState should be (SparkAppHandle.State.FAILED)
+ }
+
test("run Python application in yarn-client mode") {
testPySpark(true)
}
@@ -259,6 +264,16 @@ private[spark] class SaveExecutorInfo extends SparkListener {
}
}
+private object YarnClusterDriverWithFailure extends Logging with Matchers {
+ def main(args: Array[String]): Unit = {
+ val sc = new SparkContext(new SparkConf()
+ .set("spark.extraListeners", classOf[SaveExecutorInfo].getName)
+ .setAppName("yarn test with failure"))
+
+ throw new Exception("exception after sc initialized")
+ }
+}
+
private object YarnClusterDriver extends Logging with Matchers {
val WAIT_TIMEOUT_MILLIS = 10000
@@ -287,19 +302,19 @@ private object YarnClusterDriver extends Logging with Matchers {
sc.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)
data should be (Set(1, 2, 3, 4))
result = "success"
+
+ // Verify that the config archive is correctly placed in the classpath of all containers.
+ val confFile = "/" + Client.SPARK_CONF_FILE
+ assert(getClass().getResource(confFile) != null)
+ val configFromExecutors = sc.parallelize(1 to 4, 4)
+ .map { _ => Option(getClass().getResource(confFile)).map(_.toString).orNull }
+ .collect()
+ assert(configFromExecutors.find(_ == null) === None)
} finally {
Files.write(result, status, StandardCharsets.UTF_8)
sc.stop()
}
- // Verify that the config archive is correctly placed in the classpath of all containers.
- val confFile = "/" + Client.SPARK_CONF_FILE
- assert(getClass().getResource(confFile) != null)
- val configFromExecutors = sc.parallelize(1 to 4, 4)
- .map { _ => Option(getClass().getResource(confFile)).map(_.toString).orNull }
- .collect()
- assert(configFromExecutors.find(_ == null) === None)
-
// verify log urls are present
val listeners = sc.listenerBus.findListenersByClass[SaveExecutorInfo]
assert(listeners.size === 1)
@@ -330,9 +345,6 @@ private object YarnClusterDriver extends Logging with Matchers {
}
private object YarnClasspathTest extends Logging {
-
- var exitCode = 0
-
def error(m: String, ex: Throwable = null): Unit = {
logError(m, ex)
// scalastyle:off println
@@ -361,7 +373,6 @@ private object YarnClasspathTest extends Logging {
} finally {
sc.stop()
}
- System.exit(exitCode)
}
private def readResource(resultPath: String): Unit = {
@@ -374,8 +385,6 @@ private object YarnClasspathTest extends Logging {
} catch {
case t: Throwable =>
error(s"loading test.resource to $resultPath", t)
- // set the exit code if not yet set
- exitCode = 2
} finally {
Files.write(result, new File(resultPath), StandardCharsets.UTF_8)
}