diff --git a/.gitignore b/.gitignore index 07524bc429e9..8ecf536e79a5 100644 --- a/.gitignore +++ b/.gitignore @@ -60,7 +60,6 @@ dev/create-release/*final spark-*-bin-*.tgz unit-tests.log /lib/ -ec2/lib/ rat-results.txt scalastyle.txt scalastyle-output.xml diff --git a/.rat-excludes b/.rat-excludes index 7262c960ed6b..bf071eba652b 100644 --- a/.rat-excludes +++ b/.rat-excludes @@ -84,3 +84,5 @@ gen-java.* org.apache.spark.sql.sources.DataSourceRegister org.apache.spark.scheduler.SparkHistoryListenerFactory .*parquet +LZ4BlockInputStream.java +spark-deps-.* diff --git a/R/pkg/.lintr b/R/pkg/.lintr index 39c872663ad4..038236fc149e 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE), commented_code_linter = NULL) +linters: with_defaults(line_length_linter(100), camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 369714f7b99c..465bc37788e5 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,7 +1,7 @@ Package: SparkR Type: Package Title: R frontend for Spark -Version: 1.6.0 +Version: 2.0.0 Date: 2013-09-09 Author: The Apache Software Foundation Maintainer: Shivaram Venkataraman diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index ccc01fe16960..34be7f0ebd75 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -94,7 +94,8 @@ exportMethods("arrange", "withColumnRenamed", "write.df", "write.json", - "write.parquet") + "write.parquet", + "write.text") exportClasses("Column") @@ -129,6 +130,7 @@ exportMethods("%in%", "count", "countDistinct", "crc32", + "hash", "cume_dist", "date_add", "date_format", @@ -274,6 +276,7 @@ export("as.DataFrame", "parquetFile", "read.df", "read.parquet", + "read.text", "sql", "table", "tableNames", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0cfa12b997d6..3bf5bc924f7d 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -458,7 +458,10 @@ setMethod("registerTempTable", setMethod("insertInto", signature(x = "DataFrame", tableName = "character"), function(x, tableName, overwrite = FALSE) { - callJMethod(x@sdf, "insertInto", tableName, overwrite) + jmode <- convertToJSaveMode(ifelse(overwrite, "overwrite", "append")) + write <- callJMethod(x@sdf, "write") + write <- callJMethod(write, "mode", jmode) + callJMethod(write, "insertInto", tableName) }) #' Cache @@ -661,6 +664,34 @@ setMethod("saveAsParquetFile", write.parquet(x, path) }) +#' write.text +#' +#' Saves the content of the DataFrame in a text file at the specified path. +#' The DataFrame must have only one column of string type with the name "value". +#' Each row becomes a new line in the output file. +#' +#' @param x A SparkSQL DataFrame +#' @param path The directory where the file is saved +#' +#' @family DataFrame functions +#' @rdname write.text +#' @name write.text +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.txt" +#' df <- read.text(sqlContext, path) +#' write.text(df, "/tmp/sparkr-tmp/") +#'} +setMethod("write.text", + signature(x = "DataFrame", path = "character"), + function(x, path) { + write <- callJMethod(x@sdf, "write") + invisible(callJMethod(write, "text", path)) + }) + #' Distinct #' #' Return a new DataFrame containing the distinct rows in this DataFrame. @@ -1948,18 +1979,15 @@ setMethod("write.df", source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } - allModes <- c("append", "overwrite", "error", "ignore") - # nolint start - if (!(mode %in% allModes)) { - stop('mode should be one of "append", "overwrite", "error", "ignore"') - } - # nolint end - jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path } - callJMethod(df@sdf, "save", source, jmode, options) + write <- callJMethod(df@sdf, "write") + write <- callJMethod(write, "format", source) + write <- callJMethod(write, "mode", jmode) + write <- callJMethod(write, "save", path) }) #' @rdname write.df @@ -2013,15 +2041,14 @@ setMethod("saveAsTable", source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } - allModes <- c("append", "overwrite", "error", "ignore") - # nolint start - if (!(mode %in% allModes)) { - stop('mode should be one of "append", "overwrite", "error", "ignore"') - } - # nolint end - jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + jmode <- convertToJSaveMode(mode) options <- varargsToEnv(...) - callJMethod(df@sdf, "saveAsTable", tableName, source, jmode, options) + + write <- callJMethod(df@sdf, "write") + write <- callJMethod(write, "format", source) + write <- callJMethod(write, "mode", jmode) + write <- callJMethod(write, "options", options) + callJMethod(write, "saveAsTable", tableName) }) #' summary diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 00c40c38cabc..a78fbb714f2b 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -180,7 +180,7 @@ setMethod("getJRDD", signature(rdd = "PipelinedRDD"), } # Save the serialization flag after we create a RRDD rdd@env$serializedMode <- serializedMode - rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") # rddRef$asJavaRDD() + rdd@env$jrdd_val <- callJMethod(rddRef, "asJavaRDD") rdd@env$jrdd_val }) @@ -225,7 +225,7 @@ setMethod("cache", #' #' Persist this RDD with the specified storage level. For details of the #' supported storage levels, refer to -#' http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence. +#'\url{http://spark.apache.org/docs/latest/programming-guide.html#rdd-persistence}. #' #' @param x The RDD to persist #' @param newLevel The new storage level to be assigned @@ -382,11 +382,13 @@ setMethod("collectPartition", #' \code{collectAsMap} returns a named list as a map that contains all of the elements #' in a key-value pair RDD. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4)), 2L) #' collectAsMap(rdd) # list(`1` = 2, `3` = 4) #'} +# nolint end #' @rdname collect-methods #' @aliases collectAsMap,RDD-method #' @noRd @@ -442,11 +444,13 @@ setMethod("length", #' @return list of (value, count) pairs, where count is number of each unique #' value in rdd. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, c(1,2,3,2,1)) #' countByValue(rdd) # (1,2L), (2,2L), (3,1L) #'} +# nolint end #' @rdname countByValue #' @aliases countByValue,RDD-method #' @noRd @@ -597,11 +601,13 @@ setMethod("mapPartitionsWithIndex", #' @param x The RDD to be filtered. #' @param f A unary predicate function. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' unlist(collect(filterRDD(rdd, function (x) { x < 3 }))) # c(1, 2) #'} +# nolint end #' @rdname filterRDD #' @aliases filterRDD,RDD,function-method #' @noRd @@ -756,11 +762,13 @@ setMethod("foreachPartition", #' @param x The RDD to take elements from #' @param num Number of elements to take #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:10) #' take(rdd, 2L) # list(1, 2) #'} +# nolint end #' @rdname take #' @aliases take,RDD,numeric-method #' @noRd @@ -824,11 +832,13 @@ setMethod("first", #' @param x The RDD to remove duplicates from. #' @param numPartitions Number of partitions to create. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, c(1,2,2,3,3,3)) #' sort(unlist(collect(distinct(rdd)))) # c(1, 2, 3) #'} +# nolint end #' @rdname distinct #' @aliases distinct,RDD-method #' @noRd @@ -974,11 +984,13 @@ setMethod("takeSample", signature(x = "RDD", withReplacement = "logical", #' @param x The RDD. #' @param func The function to be applied. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3)) #' collect(keyBy(rdd, function(x) { x*x })) # list(list(1, 1), list(4, 2), list(9, 3)) #'} +# nolint end #' @rdname keyBy #' @aliases keyBy,RDD #' @noRd @@ -1113,11 +1125,13 @@ setMethod("saveAsTextFile", #' @param numPartitions Number of partitions to create. #' @return An RDD where all elements are sorted. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(3, 2, 1)) #' collect(sortBy(rdd, function(x) { x })) # list (1, 2, 3) #'} +# nolint end #' @rdname sortBy #' @aliases sortBy,RDD,RDD-method #' @noRd @@ -1188,11 +1202,13 @@ takeOrderedElem <- function(x, num, ascending = TRUE) { #' @param num Number of elements to return. #' @return The first N elements from the RDD in ascending order. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) #' takeOrdered(rdd, 6L) # list(1, 2, 3, 4, 5, 6) #'} +# nolint end #' @rdname takeOrdered #' @aliases takeOrdered,RDD,RDD-method #' @noRd @@ -1209,11 +1225,13 @@ setMethod("takeOrdered", #' @return The top N elements from the RDD. #' @rdname top #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(10, 1, 2, 9, 3, 4, 5, 6, 7)) #' top(rdd, 6L) # list(10, 9, 7, 6, 5, 4) #'} +# nolint end #' @aliases top,RDD,RDD-method #' @noRd setMethod("top", @@ -1261,6 +1279,7 @@ setMethod("fold", #' @rdname aggregateRDD #' @seealso reduce #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(1, 2, 3, 4)) @@ -1269,6 +1288,7 @@ setMethod("fold", #' combOp <- function(x, y) { list(x[[1]] + y[[1]], x[[2]] + y[[2]]) } #' aggregateRDD(rdd, zeroValue, seqOp, combOp) # list(10, 4) #'} +# nolint end #' @aliases aggregateRDD,RDD,RDD-method #' @noRd setMethod("aggregateRDD", @@ -1367,12 +1387,14 @@ setMethod("setName", #' @return An RDD with zipped items. #' @seealso zipWithIndex #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) #' collect(zipWithUniqueId(rdd)) #' # list(list("a", 0), list("b", 3), list("c", 1), list("d", 4), list("e", 2)) #'} +# nolint end #' @rdname zipWithUniqueId #' @aliases zipWithUniqueId,RDD #' @noRd @@ -1408,12 +1430,14 @@ setMethod("zipWithUniqueId", #' @return An RDD with zipped items. #' @seealso zipWithUniqueId #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list("a", "b", "c", "d", "e"), 3L) #' collect(zipWithIndex(rdd)) #' # list(list("a", 0), list("b", 1), list("c", 2), list("d", 3), list("e", 4)) #'} +# nolint end #' @rdname zipWithIndex #' @aliases zipWithIndex,RDD #' @noRd @@ -1454,12 +1478,14 @@ setMethod("zipWithIndex", #' @return An RDD created by coalescing all elements within #' each partition into a list. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, as.list(1:4), 2L) #' collect(glom(rdd)) #' # list(list(1, 2), list(3, 4)) #'} +# nolint end #' @rdname glom #' @aliases glom,RDD #' @noRd @@ -1519,6 +1545,7 @@ setMethod("unionRDD", #' @param other Another RDD to be zipped. #' @return An RDD zipped from the two RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, 0:4) @@ -1526,6 +1553,7 @@ setMethod("unionRDD", #' collect(zipRDD(rdd1, rdd2)) #' # list(list(0, 1000), list(1, 1001), list(2, 1002), list(3, 1003), list(4, 1004)) #'} +# nolint end #' @rdname zipRDD #' @aliases zipRDD,RDD #' @noRd @@ -1557,12 +1585,14 @@ setMethod("zipRDD", #' @param other An RDD. #' @return A new RDD which is the Cartesian product of these two RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, 1:2) #' sortByKey(cartesian(rdd, rdd)) #' # list(list(1, 1), list(1, 2), list(2, 1), list(2, 2)) #'} +# nolint end #' @rdname cartesian #' @aliases cartesian,RDD,RDD-method #' @noRd @@ -1587,6 +1617,7 @@ setMethod("cartesian", #' @param numPartitions Number of the partitions in the result RDD. #' @return An RDD with the elements from this that are not in other. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(1, 1, 2, 2, 3, 4)) @@ -1594,6 +1625,7 @@ setMethod("cartesian", #' collect(subtract(rdd1, rdd2)) #' # list(1, 1, 3) #'} +# nolint end #' @rdname subtract #' @aliases subtract,RDD #' @noRd @@ -1619,6 +1651,7 @@ setMethod("subtract", #' @param numPartitions The number of partitions in the result RDD. #' @return An RDD which is the intersection of these two RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(1, 10, 2, 3, 4, 5)) @@ -1626,6 +1659,7 @@ setMethod("subtract", #' collect(sortBy(intersection(rdd1, rdd2), function(x) { x })) #' # list(1, 2, 3) #'} +# nolint end #' @rdname intersection #' @aliases intersection,RDD #' @noRd @@ -1653,6 +1687,7 @@ setMethod("intersection", #' Assumes that all the RDDs have the *same number of partitions*, but #' does *not* require them to have the same number of elements in each partition. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, 1:2, 2L) # 1, 2 @@ -1662,6 +1697,7 @@ setMethod("intersection", #' func = function(x, y, z) { list(list(x, y, z))} )) #' # list(list(1, c(1,2), c(1,2,3)), list(2, c(3,4), c(4,5,6))) #'} +# nolint end #' @rdname zipRDD #' @aliases zipPartitions,RDD #' @noRd diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 9243d70e66f7..99679b4a774d 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -256,9 +256,12 @@ jsonFile <- function(sqlContext, path) { # TODO: support schema jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { + .Deprecated("read.json") rdd <- serializeToString(rdd) if (is.null(schema)) { - sdf <- callJMethod(sqlContext, "jsonRDD", callJMethod(getJRDD(rdd), "rdd"), samplingRatio) + read <- callJMethod(sqlContext, "read") + # samplingRatio is deprecated + sdf <- callJMethod(read, "json", callJMethod(getJRDD(rdd), "rdd")) dataFrame(sdf) } else { stop("not implemented") @@ -289,9 +292,32 @@ read.parquet <- function(sqlContext, path) { # TODO: Implement saveasParquetFile and write examples for both parquetFile <- function(sqlContext, ...) { .Deprecated("read.parquet") + read.parquet(sqlContext, unlist(list(...))) +} + +#' Create a DataFrame from a text file. +#' +#' Loads a text file and returns a DataFrame with a single string column named "value". +#' Each line in the text file is a new row in the resulting DataFrame. +#' +#' @param sqlContext SQLContext to use +#' @param path Path of file to read. A vector of multiple paths is allowed. +#' @return DataFrame +#' @rdname read.text +#' @name read.text +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' path <- "path/to/file.txt" +#' df <- read.text(sqlContext, path) +#' } +read.text <- function(sqlContext, path) { # Allow the user to have a more flexible definiton of the text file path - paths <- lapply(list(...), function(x) suppressWarnings(normalizePath(x))) - sdf <- callJMethod(sqlContext, "parquetFile", paths) + paths <- as.list(suppressWarnings(normalizePath(path))) + read <- callJMethod(sqlContext, "read") + sdf <- callJMethod(read, "text", paths) dataFrame(sdf) } diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index 7bb8ef2595b5..3ffd9a9890b2 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -209,13 +209,13 @@ setMethod("cast", setMethod("%in%", signature(x = "Column"), function(x, table) { - jc <- callJMethod(x@jc, "in", as.list(table)) + jc <- callJMethod(x@jc, "isin", as.list(table)) return(column(jc)) }) #' otherwise #' -#' If values in the specified column are null, returns the value. +#' If values in the specified column are null, returns the value. #' Can be used in conjunction with `when` to specify a default value for expressions. #' #' @rdname otherwise @@ -225,7 +225,7 @@ setMethod("%in%", setMethod("otherwise", signature(x = "Column", value = "ANY"), function(x, value) { - value <- ifelse(class(value) == "Column", value@jc, value) + value <- if (class(value) == "Column") { value@jc } else { value } jc <- callJMethod(x@jc, "otherwise", value) column(jc) }) diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index f7e56e43016e..d8a039327539 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -17,6 +17,7 @@ # Utility functions to deserialize objects from Java. +# nolint start # Type mapping from Java to R # # void -> NULL @@ -32,6 +33,8 @@ # # Array[T] -> list() # Object -> jobj +# +# nolint end readObject <- function(con) { # Read type first diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 09e4e04335a3..9bb7876b384c 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -37,7 +37,7 @@ setMethod("lit", signature("ANY"), function(x) { jc <- callJStatic("org.apache.spark.sql.functions", "lit", - ifelse(class(x) == "Column", x@jc, x)) + if (class(x) == "Column") { x@jc } else { x }) column(jc) }) @@ -340,6 +340,26 @@ setMethod("crc32", column(jc) }) +#' hash +#' +#' Calculates the hash code of given columns, and returns the result as a int column. +#' +#' @rdname hash +#' @name hash +#' @family misc_funcs +#' @export +#' @examples \dontrun{hash(df$c)} +setMethod("hash", + signature(x = "Column"), + function(x, ...) { + jcols <- lapply(list(x, ...), function (x) { + stopifnot(class(x) == "Column") + x@jc + }) + jc <- callJStatic("org.apache.spark.sql.functions", "hash", jcols) + column(jc) + }) + #' dayofmonth #' #' Extracts the day of the month as an integer from a given date/timestamp/string. @@ -2262,7 +2282,7 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), setMethod("when", signature(condition = "Column", value = "ANY"), function(condition, value) { condition <- condition@jc - value <- ifelse(class(value) == "Column", value@jc, value) + value <- if (class(value) == "Column") { value@jc } else { value } jc <- callJStatic("org.apache.spark.sql.functions", "when", condition, value) column(jc) }) @@ -2277,13 +2297,16 @@ setMethod("when", signature(condition = "Column", value = "ANY"), #' @name ifelse #' @seealso \link{when} #' @export -#' @examples \dontrun{ifelse(df$a > 1 & df$b > 2, 0, 1)} +#' @examples \dontrun{ +#' ifelse(df$a > 1 & df$b > 2, 0, 1) +#' ifelse(df$a > 1, df$a, 1) +#' } setMethod("ifelse", signature(test = "Column", yes = "ANY", no = "ANY"), function(test, yes, no) { test <- test@jc - yes <- ifelse(class(yes) == "Column", yes@jc, yes) - no <- ifelse(class(no) == "Column", no@jc, no) + yes <- if (class(yes) == "Column") { yes@jc } else { yes } + no <- if (class(no) == "Column") { no@jc } else { no } jc <- callJMethod(callJStatic("org.apache.spark.sql.functions", "when", test, yes), diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 62be2ddc8f52..5ba68e3a4f37 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -549,6 +549,10 @@ setGeneric("write.parquet", function(x, path) { standardGeneric("write.parquet") #' @export setGeneric("saveAsParquetFile", function(x, path) { standardGeneric("saveAsParquetFile") }) +#' @rdname write.text +#' @export +setGeneric("write.text", function(x, path) { standardGeneric("write.text") }) + #' @rdname schema #' @export setGeneric("schema", function(x) { standardGeneric("schema") }) @@ -732,6 +736,10 @@ setGeneric("countDistinct", function(x, ...) { standardGeneric("countDistinct") #' @export setGeneric("crc32", function(x) { standardGeneric("crc32") }) +#' @rdname hash +#' @export +setGeneric("hash", function(x, ...) { standardGeneric("hash") }) + #' @rdname cume_dist #' @export setGeneric("cume_dist", function(x) { standardGeneric("cume_dist") }) diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 334c11d2f89a..f7131140feaf 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -30,12 +30,14 @@ NULL #' @param key The key to look up for #' @return a list of values in this RDD for key key #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' pairs <- list(c(1, 1), c(2, 2), c(1, 3)) #' rdd <- parallelize(sc, pairs) #' lookup(rdd, 1) # list(1, 3) #'} +# nolint end #' @rdname lookup #' @aliases lookup,RDD-method #' @noRd @@ -58,11 +60,13 @@ setMethod("lookup", #' @param x The RDD to count keys. #' @return list of (key, count) pairs, where count is number of each key in rdd. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(c("a", 1), c("b", 1), c("a", 1))) #' countByKey(rdd) # ("a", 2L), ("b", 1L) #'} +# nolint end #' @rdname countByKey #' @aliases countByKey,RDD-method #' @noRd @@ -77,11 +81,13 @@ setMethod("countByKey", #' #' @param x The RDD from which the keys of each tuple is returned. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) #' collect(keys(rdd)) # list(1, 3) #'} +# nolint end #' @rdname keys #' @aliases keys,RDD #' @noRd @@ -98,11 +104,13 @@ setMethod("keys", #' #' @param x The RDD from which the values of each tuple is returned. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 2), list(3, 4))) #' collect(values(rdd)) # list(2, 4) #'} +# nolint end #' @rdname values #' @aliases values,RDD #' @noRd @@ -348,6 +356,7 @@ setMethod("reduceByKey", #' @return A list of elements of type list(K, V') where V' is the merged value for each key #' @seealso reduceByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) @@ -355,6 +364,7 @@ setMethod("reduceByKey", #' reduced <- reduceByKeyLocally(rdd, "+") #' reduced # list(list(1, 6), list(1.1, 3)) #'} +# nolint end #' @rdname reduceByKeyLocally #' @aliases reduceByKeyLocally,RDD,integer-method #' @noRd @@ -412,6 +422,7 @@ setMethod("reduceByKeyLocally", #' @return An RDD where each element is list(K, C) where C is the combined type #' @seealso groupByKey, reduceByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' pairs <- list(list(1, 2), list(1.1, 3), list(1, 4)) @@ -420,6 +431,7 @@ setMethod("reduceByKeyLocally", #' combined <- collect(parts) #' combined[[1]] # Should be a list(1, 6) #'} +# nolint end #' @rdname combineByKey #' @aliases combineByKey,RDD,ANY,ANY,ANY,integer-method #' @noRd @@ -473,6 +485,7 @@ setMethod("combineByKey", #' @return An RDD containing the aggregation result. #' @seealso foldByKey, combineByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) @@ -482,6 +495,7 @@ setMethod("combineByKey", #' aggregateByKey(rdd, zeroValue, seqOp, combOp, 2L) #' # list(list(1, list(3, 2)), list(2, list(7, 2))) #'} +# nolint end #' @rdname aggregateByKey #' @aliases aggregateByKey,RDD,ANY,ANY,ANY,integer-method #' @noRd @@ -509,11 +523,13 @@ setMethod("aggregateByKey", #' @return An RDD containing the aggregation result. #' @seealso aggregateByKey, combineByKey #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(1, 1), list(1, 2), list(2, 3), list(2, 4))) #' foldByKey(rdd, 0, "+", 2L) # list(list(1, 3), list(2, 7)) #'} +# nolint end #' @rdname foldByKey #' @aliases foldByKey,RDD,ANY,ANY,integer-method #' @noRd @@ -540,12 +556,14 @@ setMethod("foldByKey", #' @return a new RDD containing all pairs of elements with matching keys in #' two input RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) #' rdd2 <- parallelize(sc, list(list(1, 2), list(1, 3))) #' join(rdd1, rdd2, 2L) # list(list(1, list(1, 2)), list(1, list(1, 3)) #'} +# nolint end #' @rdname join-methods #' @aliases join,RDD,RDD-method #' @noRd @@ -578,6 +596,7 @@ setMethod("join", #' all pairs (k, (v, w)) for (k, w) in rdd2, or the pair (k, (v, NULL)) #' if no elements in rdd2 have key k. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) @@ -585,6 +604,7 @@ setMethod("join", #' leftOuterJoin(rdd1, rdd2, 2L) #' # list(list(1, list(1, 2)), list(1, list(1, 3)), list(2, list(4, NULL))) #'} +# nolint end #' @rdname join-methods #' @aliases leftOuterJoin,RDD,RDD-method #' @noRd @@ -616,6 +636,7 @@ setMethod("leftOuterJoin", #' all pairs (k, (v, w)) for (k, v) in x, or the pair (k, (NULL, w)) #' if no elements in x have key k. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3))) @@ -623,6 +644,7 @@ setMethod("leftOuterJoin", #' rightOuterJoin(rdd1, rdd2, 2L) #' # list(list(1, list(2, 1)), list(1, list(3, 1)), list(2, list(NULL, 4))) #'} +# nolint end #' @rdname join-methods #' @aliases rightOuterJoin,RDD,RDD-method #' @noRd @@ -655,6 +677,7 @@ setMethod("rightOuterJoin", #' (k, w) in y, or the pair (k, (NULL, w))/(k, (v, NULL)) if no elements #' in x/y have key k. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 2), list(1, 3), list(3, 3))) @@ -664,6 +687,7 @@ setMethod("rightOuterJoin", #' # list(2, list(NULL, 4))) #' # list(3, list(3, NULL)), #'} +# nolint end #' @rdname join-methods #' @aliases fullOuterJoin,RDD,RDD-method #' @noRd @@ -688,6 +712,7 @@ setMethod("fullOuterJoin", #' @return a new RDD containing all pairs of elements with values in a list #' in all RDDs. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list(1, 1), list(2, 4))) @@ -695,6 +720,7 @@ setMethod("fullOuterJoin", #' cogroup(rdd1, rdd2, numPartitions = 2L) #' # list(list(1, list(1, list(2, 3))), list(2, list(list(4), list())) #'} +# nolint end #' @rdname cogroup #' @aliases cogroup,RDD-method #' @noRd @@ -740,11 +766,13 @@ setMethod("cogroup", #' @param numPartitions Number of partitions to create. #' @return An RDD where all (k, v) pair elements are sorted. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd <- parallelize(sc, list(list(3, 1), list(2, 2), list(1, 3))) #' collect(sortByKey(rdd)) # list (list(1, 3), list(2, 2), list(3, 1)) #'} +# nolint end #' @rdname sortByKey #' @aliases sortByKey,RDD,RDD-method #' @noRd @@ -805,6 +833,7 @@ setMethod("sortByKey", #' @param numPartitions Number of the partitions in the result RDD. #' @return An RDD with the pairs from x whose keys are not in other. #' @examples +# nolint start #'\dontrun{ #' sc <- sparkR.init() #' rdd1 <- parallelize(sc, list(list("a", 1), list("b", 4), @@ -813,6 +842,7 @@ setMethod("sortByKey", #' collect(subtractByKey(rdd1, rdd2)) #' # list(list("b", 4), list("b", 5)) #'} +# nolint end #' @rdname subtractByKey #' @aliases subtractByKey,RDD #' @noRd diff --git a/R/pkg/R/serialize.R b/R/pkg/R/serialize.R index 17082b4e52fc..095ddb9aed2e 100644 --- a/R/pkg/R/serialize.R +++ b/R/pkg/R/serialize.R @@ -17,6 +17,7 @@ # Utility functions to serialize R objects so they can be read in Java. +# nolint start # Type mapping from R to Java # # NULL -> Void @@ -31,6 +32,7 @@ # list[T] -> Array[T], where T is one of above mentioned types # environment -> Map[String, T], where T is a native type # jobj -> Object, where jobj is an object created in the backend +# nolint end getSerdeType <- function(object) { type <- class(object)[[1]] diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 43105aaa3842..aa386e5da933 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -641,3 +641,12 @@ assignNewEnv <- function(data) { splitString <- function(input) { Filter(nzchar, unlist(strsplit(input, ",|\\s"))) } + +convertToJSaveMode <- function(mode) { + allModes <- c("append", "overwrite", "error", "ignore") + if (!(mode %in% allModes)) { + stop('mode should be one of "append", "overwrite", "error", "ignore"') # nolint + } + jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) + jmode +} diff --git a/R/pkg/inst/tests/testthat/test_rdd.R b/R/pkg/inst/tests/testthat/test_rdd.R index 7423b4f2bed1..1b3a22486e95 100644 --- a/R/pkg/inst/tests/testthat/test_rdd.R +++ b/R/pkg/inst/tests/testthat/test_rdd.R @@ -223,14 +223,14 @@ test_that("takeSample() on RDDs", { s <- takeSample(data, TRUE, 100L, seed) expect_equal(length(s), 100L) # Chance of getting all distinct elements is astronomically low, so test we - # got < 100 + # got less than 100 expect_true(length(unique(s)) < 100L) } for (seed in 4:5) { s <- takeSample(data, TRUE, 200L, seed) expect_equal(length(s), 200L) # Chance of getting all distinct elements is still quite low, so test we - # got < 100 + # got less than 100 expect_true(length(unique(s)) < 100L) } }) diff --git a/R/pkg/inst/tests/testthat/test_shuffle.R b/R/pkg/inst/tests/testthat/test_shuffle.R index adf0b91d25fe..d3d0f8a24d01 100644 --- a/R/pkg/inst/tests/testthat/test_shuffle.R +++ b/R/pkg/inst/tests/testthat/test_shuffle.R @@ -176,8 +176,8 @@ test_that("partitionBy() partitions data correctly", { resultRDD <- partitionBy(numPairsRdd, 2L, partitionByMagnitude) - expected_first <- list(list(1, 100), list(2, 200)) # key < 3 - expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key >= 3 + expected_first <- list(list(1, 100), list(2, 200)) # key less than 3 + expected_second <- list(list(4, -1), list(3, 1), list(3, 0)) # key greater than or equal 3 actual_first <- collectPartition(resultRDD, 0L) actual_second <- collectPartition(resultRDD, 1L) diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 135c7576e529..97625b94a0e2 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -62,6 +62,10 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern="sparkr-test", fileext=".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) +test_that("calling sparkRSQL.init returns existing SQL context", { + expect_equal(sparkRSQL.init(sc), sqlContext) +}) + test_that("infer types and check types", { expect_equal(infer_type(1L), "integer") expect_equal(infer_type(1.0), "double") @@ -419,12 +423,12 @@ test_that("read/write json files", { test_that("jsonRDD() on a RDD with json string", { rdd <- parallelize(sc, mockLines) expect_equal(count(rdd), 3) - df <- jsonRDD(sqlContext, rdd) + df <- suppressWarnings(jsonRDD(sqlContext, rdd)) expect_is(df, "DataFrame") expect_equal(count(df), 3) rdd2 <- flatMap(rdd, function(x) c(x, x)) - df <- jsonRDD(sqlContext, rdd2) + df <- suppressWarnings(jsonRDD(sqlContext, rdd2)) expect_is(df, "DataFrame") expect_equal(count(df), 6) }) @@ -494,9 +498,11 @@ test_that("table() returns a new DataFrame", { expect_equal(count(tabledf), 3) dropTempTable(sqlContext, "table1") + # nolint start # Test base::table is working #a <- letters[1:3] #expect_equal(class(table(a, sample(a))), "table") + # nolint end }) test_that("toRDD() returns an RRDD", { @@ -762,8 +768,10 @@ test_that("sample on a DataFrame", { sampled3 <- sample_frac(df, FALSE, 0.1, 0) # set seed for predictable result expect_true(count(sampled3) < 3) + # nolint start # Test base::sample is working #expect_equal(length(sample(1:12)), 12) + # nolint end }) test_that("select operators", { @@ -914,7 +922,7 @@ test_that("column functions", { c <- column("a") c1 <- abs(c) + acos(c) + approxCountDistinct(c) + ascii(c) + asin(c) + atan(c) c2 <- avg(c) + base64(c) + bin(c) + bitwiseNOT(c) + cbrt(c) + ceil(c) + cos(c) - c3 <- cosh(c) + count(c) + crc32(c) + exp(c) + c3 <- cosh(c) + count(c) + crc32(c) + hash(c) + exp(c) c4 <- explode(c) + expm1(c) + factorial(c) + first(c) + floor(c) + hex(c) c5 <- hour(c) + initcap(c) + last(c) + last_day(c) + length(c) c6 <- log(c) + (c) + log1p(c) + log2(c) + lower(c) + ltrim(c) + max(c) + md5(c) @@ -1048,8 +1056,8 @@ test_that("string operators", { df2 <- createDataFrame(sqlContext, l2) expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1) expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2) - expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") - expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") + expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") # nolint + expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") # nolint l3 <- list(list(a = "a.b.c.d")) df3 <- createDataFrame(sqlContext, l3) @@ -1120,6 +1128,14 @@ test_that("when(), otherwise() and ifelse() on a DataFrame", { expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, 0, 1)))[, 1], c(1, 0)) }) +test_that("when(), otherwise() and ifelse() with column on a DataFrame", { + l <- list(list(a = 1, b = 2), list(a = 3, b = 4)) + df <- createDataFrame(sqlContext, l) + expect_equal(collect(select(df, when(df$a > 1 & df$b > 2, lit(1))))[, 1], c(NA, 1)) + expect_equal(collect(select(df, otherwise(when(df$a > 1, lit(1)), lit(0))))[, 1], c(0, 1)) + expect_equal(collect(select(df, ifelse(df$a > 1 & df$b > 2, lit(0), lit(1))))[, 1], c(1, 0)) +}) + test_that("group by, agg functions", { df <- read.json(sqlContext, jsonPath) df1 <- agg(df, name = "max", age = "sum") @@ -1247,7 +1263,7 @@ test_that("filter() on a DataFrame", { expect_equal(count(filtered6), 2) # Test stats::filter is working - #expect_true(is.ts(filter(1:100, rep(1, 3)))) + #expect_true(is.ts(filter(1:100, rep(1, 3)))) # nolint }) test_that("join() and merge() on a DataFrame", { @@ -1481,6 +1497,27 @@ test_that("read/write Parquet files", { unlink(parquetPath4) }) +test_that("read/write text files", { + # Test write.df and read.df + df <- read.df(sqlContext, jsonPath, "text") + expect_is(df, "DataFrame") + expect_equal(colnames(df), c("value")) + expect_equal(count(df), 3) + textPath <- tempfile(pattern = "textPath", fileext = ".txt") + write.df(df, textPath, "text", mode="overwrite") + + # Test write.text and read.text + textPath2 <- tempfile(pattern = "textPath2", fileext = ".txt") + write.text(df, textPath2) + df2 <- read.text(sqlContext, c(textPath, textPath2)) + expect_is(df2, "DataFrame") + expect_equal(colnames(df2), c("value")) + expect_equal(count(df2), count(df) * 2) + + unlink(textPath) + unlink(textPath2) +}) + test_that("describe() and summarize() on a DataFrame", { df <- read.json(sqlContext, jsonPath) stats <- describe(df, "age") @@ -1647,7 +1684,7 @@ test_that("cov() and corr() on a DataFrame", { expect_true(abs(result - 1.0) < 1e-12) # Test stats::cov is working - #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) + #expect_true(abs(max(cov(swiss)) - 1739.295) < 1e-3) # nolint }) test_that("freqItems() on a DataFrame", { diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 12df4cf4f65b..56f14a3bce61 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -95,7 +95,9 @@ test_that("cleanClosure on R functions", { # TODO(shivaram): length(ls(env)) is 4 here for some reason and `lapply` is included in `env`. # Disabling this test till we debug this. # + # nolint start # expect_equal(length(ls(env)), 3) # Only "g", "l" and "f". No "base", "field" or "defUse". + # nolint end expect_true("g" %in% ls(env)) expect_true("l" %in% ls(env)) expect_true("f" %in% ls(env)) diff --git a/assembly/pom.xml b/assembly/pom.xml index 4b60ee00ffbe..6c79f9189787 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml @@ -44,11 +44,6 @@ spark-core_${scala.binary.version} ${project.version} - - org.apache.spark - spark-bagel_${scala.binary.version} - ${project.version} - org.apache.spark spark-mllib_${scala.binary.version} diff --git a/bagel/pom.xml b/bagel/pom.xml deleted file mode 100644 index 672e9469aec9..000000000000 --- a/bagel/pom.xml +++ /dev/null @@ -1,64 +0,0 @@ - - - - - 4.0.0 - - org.apache.spark - spark-parent_2.10 - 1.6.0-SNAPSHOT - ../pom.xml - - - org.apache.spark - spark-bagel_2.10 - - bagel - - jar - Spark Project Bagel - http://spark.apache.org/ - - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - - - org.apache.spark - spark-core_${scala.binary.version} - ${project.version} - test-jar - test - - - org.scalacheck - scalacheck_${scala.binary.version} - test - - - org.apache.spark - spark-test-tags_${scala.binary.version} - - - - target/scala-${scala.binary.version}/classes - target/scala-${scala.binary.version}/test-classes - - diff --git a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala b/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala deleted file mode 100644 index 8399033ac61e..000000000000 --- a/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala +++ /dev/null @@ -1,318 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.bagel - -import org.apache.spark._ -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -object Bagel extends Logging { - val DEFAULT_STORAGE_LEVEL = StorageLevel.MEMORY_AND_DISK - - /** - * Runs a Bagel program. - * @param sc org.apache.spark.SparkContext to use for the program. - * @param vertices vertices of the graph represented as an RDD of (Key, Vertex) pairs. Often the - * Key will be the vertex id. - * @param messages initial set of messages represented as an RDD of (Key, Message) pairs. Often - * this will be an empty array, i.e. sc.parallelize(Array[K, Message]()). - * @param combiner [[org.apache.spark.bagel.Combiner]] combines multiple individual messages to a - * given vertex into one message before sending (which often involves network - * I/O). - * @param aggregator [[org.apache.spark.bagel.Aggregator]] performs a reduce across all vertices - * after each superstep and provides the result to each vertex in the next - * superstep. - * @param partitioner org.apache.spark.Partitioner partitions values by key - * @param numPartitions number of partitions across which to split the graph. - * Default is the default parallelism of the SparkContext - * @param storageLevel org.apache.spark.storage.StorageLevel to use for caching of - * intermediate RDDs in each superstep. Defaults to caching in memory. - * @param compute function that takes a Vertex, optional set of (possibly combined) messages to - * the Vertex, optional Aggregator and the current superstep, - * and returns a set of (Vertex, outgoing Messages) pairs - * @tparam K key - * @tparam V vertex type - * @tparam M message type - * @tparam C combiner - * @tparam A aggregator - * @return an RDD of (K, V) pairs representing the graph after completion of the program - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, - C: Manifest, A: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - aggregator: Option[Aggregator[V, A]], - partitioner: Partitioner, - numPartitions: Int, - storageLevel: StorageLevel = DEFAULT_STORAGE_LEVEL - )( - compute: (V, Option[C], Option[A], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val splits = if (numPartitions != 0) numPartitions else sc.defaultParallelism - - var superstep = 0 - var verts = vertices - var msgs = messages - var noActivity = false - var lastRDD: RDD[(K, (V, Array[M]))] = null - do { - logInfo("Starting superstep " + superstep + ".") - val startTime = System.currentTimeMillis - - val aggregated = agg(verts, aggregator) - val combinedMsgs = msgs.combineByKeyWithClassTag( - combiner.createCombiner _, combiner.mergeMsg _, combiner.mergeCombiners _, partitioner) - val grouped = combinedMsgs.groupWith(verts) - val superstep_ = superstep // Create a read-only copy of superstep for capture in closure - val (processed, numMsgs, numActiveVerts) = - comp[K, V, M, C](sc, grouped, compute(_, _, aggregated, superstep_), storageLevel) - if (lastRDD != null) { - lastRDD.unpersist(false) - } - lastRDD = processed - - val timeTaken = System.currentTimeMillis - startTime - logInfo("Superstep %d took %d s".format(superstep, timeTaken / 1000)) - - verts = processed.mapValues { case (vert, msgs) => vert } - msgs = processed.flatMap { - case (id, (vert, msgs)) => msgs.map(m => (m.targetId, m)) - } - superstep += 1 - - noActivity = numMsgs == 0 && numActiveVerts == 0 - } while (!noActivity) - - verts - } - - /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the default - * storage level */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - partitioner: Partitioner, - numPartitions: Int - )( - compute: (V, Option[C], Int) => (V, Array[M])): RDD[(K, V)] = run(sc, vertices, messages, - combiner, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - partitioner: Partitioner, - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, partitioner, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, C](compute)) - } - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], default - * org.apache.spark.HashPartitioner and default storage level - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - numPartitions: Int - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, combiner, numPartitions, - DEFAULT_STORAGE_LEVEL)(compute) - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]] and the - * default org.apache.spark.HashPartitioner - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C: Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - combiner: Combiner[M, C], - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[C], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val part = new HashPartitioner(numPartitions) - run[K, V, M, C, Nothing]( - sc, vertices, messages, combiner, None, part, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, C](compute)) - } - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], - * default org.apache.spark.HashPartitioner, - * [[org.apache.spark.bagel.DefaultCombiner]] and the default storage level - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - numPartitions: Int - )( - compute: (V, Option[Array[M]], Int) => (V, Array[M]) - ): RDD[(K, V)] = run(sc, vertices, messages, numPartitions, DEFAULT_STORAGE_LEVEL)(compute) - - /** - * Runs a Bagel program with no [[org.apache.spark.bagel.Aggregator]], - * the default org.apache.spark.HashPartitioner - * and [[org.apache.spark.bagel.DefaultCombiner]] - */ - def run[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest]( - sc: SparkContext, - vertices: RDD[(K, V)], - messages: RDD[(K, M)], - numPartitions: Int, - storageLevel: StorageLevel - )( - compute: (V, Option[Array[M]], Int) => (V, Array[M]) - ): RDD[(K, V)] = { - val part = new HashPartitioner(numPartitions) - run[K, V, M, Array[M], Nothing]( - sc, vertices, messages, new DefaultCombiner(), None, part, numPartitions, storageLevel)( - addAggregatorArg[K, V, M, Array[M]](compute)) - } - - /** - * Aggregates the given vertices using the given aggregator, if it - * is specified. - */ - private def agg[K, V <: Vertex, A: Manifest]( - verts: RDD[(K, V)], - aggregator: Option[Aggregator[V, A]] - ): Option[A] = aggregator match { - case Some(a) => - Some(verts.map { - case (id, vert) => a.createAggregator(vert) - }.reduce(a.mergeAggregators(_, _))) - case None => None - } - - /** - * Processes the given vertex-message RDD using the compute - * function. Returns the processed RDD, the number of messages - * created, and the number of active vertices. - */ - private def comp[K: Manifest, V <: Vertex, M <: Message[K], C]( - sc: SparkContext, - grouped: RDD[(K, (Iterable[C], Iterable[V]))], - compute: (V, Option[C]) => (V, Array[M]), - storageLevel: StorageLevel - ): (RDD[(K, (V, Array[M]))], Int, Int) = { - var numMsgs = sc.accumulator(0) - var numActiveVerts = sc.accumulator(0) - val processed = grouped.mapValues(x => (x._1.iterator, x._2.iterator)) - .flatMapValues { - case (_, vs) if !vs.hasNext => None - case (c, vs) => { - val (newVert, newMsgs) = - compute(vs.next, - c.hasNext match { - case true => Some(c.next) - case false => None - } - ) - - numMsgs += newMsgs.size - if (newVert.active) { - numActiveVerts += 1 - } - - Some((newVert, newMsgs)) - } - }.persist(storageLevel) - - // Force evaluation of processed RDD for accurate performance measurements - processed.foreach(x => {}) - - (processed, numMsgs.value, numActiveVerts.value) - } - - /** - * Converts a compute function that doesn't take an aggregator to - * one that does, so it can be passed to Bagel.run. - */ - private def addAggregatorArg[K: Manifest, V <: Vertex : Manifest, M <: Message[K] : Manifest, C]( - compute: (V, Option[C], Int) => (V, Array[M]) - ): (V, Option[C], Option[Nothing], Int) => (V, Array[M]) = { - (vert: V, msgs: Option[C], aggregated: Option[Nothing], superstep: Int) => - compute(vert, msgs, superstep) - } -} - -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Combiner[M, C] { - def createCombiner(msg: M): C - def mergeMsg(combiner: C, msg: M): C - def mergeCombiners(a: C, b: C): C -} - -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Aggregator[V, A] { - def createAggregator(vert: V): A - def mergeAggregators(a: A, b: A): A -} - -/** Default combiner that simply appends messages together (i.e. performs no aggregation) */ -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -class DefaultCombiner[M: Manifest] extends Combiner[M, Array[M]] with Serializable { - def createCombiner(msg: M): Array[M] = - Array(msg) - def mergeMsg(combiner: Array[M], msg: M): Array[M] = - combiner :+ msg - def mergeCombiners(a: Array[M], b: Array[M]): Array[M] = - a ++ b -} - -/** - * Represents a Bagel vertex. - * - * Subclasses may store state along with each vertex and must - * inherit from java.io.Serializable or scala.Serializable. - */ -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Vertex { - def active: Boolean -} - -/** - * Represents a Bagel message to a target vertex. - * - * Subclasses may contain a payload to deliver to the target vertex - * and must inherit from java.io.Serializable or scala.Serializable. - */ -@deprecated("Uses of Bagel should migrate to GraphX", "1.6.0") -trait Message[K] { - def targetId: K -} diff --git a/bagel/src/main/scala/org/apache/spark/bagel/package-info.java b/bagel/src/main/scala/org/apache/spark/bagel/package-info.java deleted file mode 100644 index 81f26f276549..000000000000 --- a/bagel/src/main/scala/org/apache/spark/bagel/package-info.java +++ /dev/null @@ -1,21 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Bagel: An implementation of Pregel in Spark. THIS IS DEPRECATED - use Spark's GraphX library. - */ -package org.apache.spark.bagel; \ No newline at end of file diff --git a/build/mvn b/build/mvn index 7603ea03deb7..63ca9c98067d 100755 --- a/build/mvn +++ b/build/mvn @@ -81,11 +81,11 @@ install_mvn() { # Install zinc under the build/ folder install_zinc() { - local zinc_path="zinc-0.3.5.3/bin/zinc" + local zinc_path="zinc-0.3.9/bin/zinc" [ ! -f "${_DIR}/${zinc_path}" ] && ZINC_INSTALL_FLAG=1 install_app \ - "http://downloads.typesafe.com/zinc/0.3.5.3" \ - "zinc-0.3.5.3.tgz" \ + "http://downloads.typesafe.com/zinc/0.3.9" \ + "zinc-0.3.9.tgz" \ "${zinc_path}" ZINC_BIN="${_DIR}/${zinc_path}" } diff --git a/core/pom.xml b/core/pom.xml index 61744bb5c7bf..34ecb19654f1 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/core/src/main/java/org/apache/spark/api/java/Optional.java b/core/src/main/java/org/apache/spark/api/java/Optional.java new file mode 100644 index 000000000000..ca7babc3f01c --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/Optional.java @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + +import java.io.Serializable; + +import com.google.common.base.Preconditions; + +/** + *

Like {@code java.util.Optional} in Java 8, {@code scala.Option} in Scala, and + * {@code com.google.common.base.Optional} in Google Guava, this class represents a + * value of a given type that may or may not exist. It is used in methods that wish + * to optionally return a value, in preference to returning {@code null}.

+ * + *

In fact, the class here is a reimplementation of the essential API of both + * {@code java.util.Optional} and {@code com.google.common.base.Optional}. From + * {@code java.util.Optional}, it implements:

+ * + * + * + *

From {@code com.google.common.base.Optional} it implements:

+ * + * + * + *

{@code java.util.Optional} itself is not used at this time because the + * project does not require Java 8. Using {@code com.google.common.base.Optional} + * has in the past caused serious library version conflicts with Guava that can't + * be resolved by shading. Hence this work-alike clone.

+ * + * @param type of value held inside + */ +public final class Optional implements Serializable { + + private static final Optional EMPTY = new Optional<>(); + + private final T value; + + private Optional() { + this.value = null; + } + + private Optional(T value) { + Preconditions.checkNotNull(value); + this.value = value; + } + + // java.util.Optional API (subset) + + /** + * @return an empty {@code Optional} + */ + public static Optional empty() { + @SuppressWarnings("unchecked") + Optional t = (Optional) EMPTY; + return t; + } + + /** + * @param value non-null value to wrap + * @return {@code Optional} wrapping this value + * @throws NullPointerException if value is null + */ + public static Optional of(T value) { + return new Optional<>(value); + } + + /** + * @param value value to wrap, which may be null + * @return {@code Optional} wrapping this value, which may be empty + */ + public static Optional ofNullable(T value) { + if (value == null) { + return empty(); + } else { + return of(value); + } + } + + /** + * @return the value wrapped by this {@code Optional} + * @throws NullPointerException if this is empty (contains no value) + */ + public T get() { + Preconditions.checkNotNull(value); + return value; + } + + /** + * @param other value to return if this is empty + * @return this {@code Optional}'s value if present, or else the given value + */ + public T orElse(T other) { + return value != null ? value : other; + } + + /** + * @return true iff this {@code Optional} contains a value (non-empty) + */ + public boolean isPresent() { + return value != null; + } + + // Guava API (subset) + // of(), get() and isPresent() are identically present in the Guava API + + /** + * @return an empty {@code Optional} + */ + public static Optional absent() { + return empty(); + } + + /** + * @param value value to wrap, which may be null + * @return {@code Optional} wrapping this value, which may be empty + */ + public static Optional fromNullable(T value) { + return ofNullable(value); + } + + /** + * @param other value to return if this is empty + * @return this {@code Optional}'s value if present, or else the given value + */ + public T or(T other) { + return value != null ? value : other; + } + + /** + * @return this {@code Optional}'s value if present, or else null + */ + public T orNull() { + return value; + } + + // Common methods + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Optional)) { + return false; + } + Optional other = (Optional) obj; + return value == null ? other.value == null : value.equals(other.value); + } + + @Override + public int hashCode() { + return value == null ? 0 : value.hashCode(); + } + + @Override + public String toString() { + return value == null ? "Optional.empty" : String.format("Optional[%s]", value); + } + +} diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 79d74b23ceae..68dc0c6d415f 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -45,7 +45,9 @@ public final class UnsafeExternalSorter extends MemoryConsumer { private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class); + @Nullable private final PrefixComparator prefixComparator; + @Nullable private final RecordComparator recordComparator; private final TaskMemoryManager taskMemoryManager; private final BlockManager blockManager; @@ -400,6 +402,7 @@ public void merge(UnsafeExternalSorter other) throws IOException { * after consuming this iterator. */ public UnsafeSorterIterator getSortedIterator() throws IOException { + assert(recordComparator != null); if (spillWriters.isEmpty()) { assert(inMemSorter != null); readingIterator = new SpillableIterator(inMemSorter.getSortedIterator()); @@ -430,7 +433,11 @@ class SpillableIterator extends UnsafeSorterIterator { public SpillableIterator(UnsafeInMemorySorter.SortedIterator inMemIterator) { this.upstream = inMemIterator; - this.numRecords = inMemIterator.numRecordsLeft(); + this.numRecords = inMemIterator.getNumRecords(); + } + + public int getNumRecords() { + return numRecords; } public long spill() throws IOException { @@ -531,18 +538,20 @@ public long getKeyPrefix() { * * It is the caller's responsibility to call `cleanupResources()` * after consuming this iterator. + * + * TODO: support forced spilling */ public UnsafeSorterIterator getIterator() throws IOException { if (spillWriters.isEmpty()) { assert(inMemSorter != null); - return inMemSorter.getIterator(); + return inMemSorter.getSortedIterator(); } else { LinkedList queue = new LinkedList<>(); for (UnsafeSorterSpillWriter spillWriter : spillWriters) { queue.add(spillWriter.getReader(blockManager)); } if (inMemSorter != null) { - queue.add(inMemSorter.getIterator()); + queue.add(inMemSorter.getSortedIterator()); } return new ChainedIterator(queue); } @@ -555,13 +564,23 @@ class ChainedIterator extends UnsafeSorterIterator { private final Queue iterators; private UnsafeSorterIterator current; + private int numRecords; public ChainedIterator(Queue iterators) { assert iterators.size() > 0; + this.numRecords = 0; + for (UnsafeSorterIterator iter: iterators) { + this.numRecords += iter.getNumRecords(); + } this.iterators = iterators; this.current = iterators.remove(); } + @Override + public int getNumRecords() { + return numRecords; + } + @Override public boolean hasNext() { while (!current.hasNext() && !iterators.isEmpty()) { @@ -572,6 +591,9 @@ public boolean hasNext() { @Override public void loadNext() throws IOException { + while (!current.hasNext() && !iterators.isEmpty()) { + current = iterators.remove(); + } current.loadNext(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c16cbce9a0f6..f71b8d154cc2 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -19,6 +19,8 @@ import java.util.Comparator; +import org.apache.avro.reflect.Nullable; + import org.apache.spark.memory.MemoryConsumer; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.unsafe.Platform; @@ -66,7 +68,9 @@ public int compare(RecordPointerAndKeyPrefix r1, RecordPointerAndKeyPrefix r2) { private final MemoryConsumer consumer; private final TaskMemoryManager memoryManager; + @Nullable private final Sorter sorter; + @Nullable private final Comparator sortComparator; /** @@ -98,8 +102,13 @@ public UnsafeInMemorySorter( LongArray array) { this.consumer = consumer; this.memoryManager = memoryManager; - this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); - this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + if (recordComparator != null) { + this.sorter = new Sorter<>(UnsafeSortDataFormat.INSTANCE); + this.sortComparator = new SortComparator(recordComparator, prefixComparator, memoryManager); + } else { + this.sorter = null; + this.sortComparator = null; + } this.array = array; } @@ -186,12 +195,13 @@ public SortedIterator clone() { } @Override - public boolean hasNext() { - return position / 2 < numRecords; + public int getNumRecords() { + return numRecords; } - public int numRecordsLeft() { - return numRecords - position / 2; + @Override + public boolean hasNext() { + return position / 2 < numRecords; } @Override @@ -223,14 +233,9 @@ public void loadNext() { * {@code next()} will return the same mutable object. */ public SortedIterator getSortedIterator() { - sorter.sort(array, 0, pos / 2, sortComparator); - return new SortedIterator(pos / 2); - } - - /** - * Returns an iterator over record pointers in original order (inserted). - */ - public SortedIterator getIterator() { + if (sorter != null) { + sorter.sort(array, 0, pos / 2, sortComparator); + } return new SortedIterator(pos / 2); } } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java index 16ac2e8d821b..1b3167fcc250 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterIterator.java @@ -32,4 +32,6 @@ public abstract class UnsafeSorterIterator { public abstract int getRecordLength(); public abstract long getKeyPrefix(); + + public abstract int getNumRecords(); } diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index 3874a9f9cbdb..ceb59352af64 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -23,6 +23,7 @@ final class UnsafeSorterSpillMerger { + private int numRecords = 0; private final PriorityQueue priorityQueue; public UnsafeSorterSpillMerger( @@ -59,6 +60,7 @@ public void addSpillIfNotEmpty(UnsafeSorterIterator spillReader) throws IOExcept // priorityQueue, we will have n extra empty records in the result of the UnsafeSorterIterator. spillReader.loadNext(); priorityQueue.add(spillReader); + numRecords += spillReader.getNumRecords(); } } @@ -67,6 +69,11 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { private UnsafeSorterIterator spillReader; + @Override + public int getNumRecords() { + return numRecords; + } + @Override public boolean hasNext() { return !priorityQueue.isEmpty() || (spillReader != null && spillReader.hasNext()); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java index dcb13e6581e5..20ee1c8eb0c7 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillReader.java @@ -38,6 +38,7 @@ public final class UnsafeSorterSpillReader extends UnsafeSorterIterator implemen // Variables that change with every record read: private int recordLength; private long keyPrefix; + private int numRecords; private int numRecordsRemaining; private byte[] arr = new byte[1024 * 1024]; @@ -53,13 +54,18 @@ public UnsafeSorterSpillReader( try { this.in = blockManager.wrapForCompression(blockId, bs); this.din = new DataInputStream(this.in); - numRecordsRemaining = din.readInt(); + numRecords = numRecordsRemaining = din.readInt(); } catch (IOException e) { Closeables.close(bs, /* swallowIOException = */ true); throw e; } } + @Override + public int getNumRecords() { + return numRecords; + } + @Override public boolean hasNext() { return (numRecordsRemaining > 0); diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index b54e33a96fa2..48f86d1536c9 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -225,4 +225,13 @@ a.expandbutton { background-color: #49535a !important; color: white; cursor:pointer; -} \ No newline at end of file +} + +.table-head-clickable th a, .table-head-clickable th a:hover { + /* Make the entire header clickable, not just the text label */ + display: block; + width: 100%; + /* Suppress the default link styling */ + color: #333; + text-decoration: none; +} diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 7196e57d5d2e..62629000cfc2 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -34,10 +34,6 @@ case class Aggregator[K, V, C] ( mergeValue: (C, V) => C, mergeCombiners: (C, C) => C) { - @deprecated("use combineValuesByKey with TaskContext argument", "0.9.0") - def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] = - combineValuesByKey(iter, null) - def combineValuesByKey( iter: Iterator[_ <: Product2[K, V]], context: TaskContext): Iterator[(K, C)] = { @@ -47,10 +43,6 @@ case class Aggregator[K, V, C] ( combiners.iterator } - @deprecated("use combineCombinersByKey with TaskContext argument", "0.9.0") - def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]]) : Iterator[(K, C)] = - combineCombinersByKey(iter, null) - def combineCombinersByKey( iter: Iterator[_ <: Product2[K, C]], context: TaskContext): Iterator[(K, C)] = { diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index bc732535fed8..4628093b91cb 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -18,7 +18,7 @@ package org.apache.spark import java.lang.ref.{ReferenceQueue, WeakReference} -import java.util.concurrent.{TimeUnit, ScheduledExecutorService} +import java.util.concurrent.{ScheduledExecutorService, TimeUnit} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala index 6176e258989d..3431fc13dcb4 100644 --- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala +++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala @@ -24,9 +24,9 @@ import scala.util.control.ControlThrowable import com.codahale.metrics.{Gauge, MetricRegistry} -import org.apache.spark.scheduler._ import org.apache.spark.metrics.source.Source -import org.apache.spark.util.{ThreadUtils, Clock, SystemClock, Utils} +import org.apache.spark.scheduler._ +import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** * An agent that dynamically allocates and removes executors based on the workload. @@ -423,7 +423,8 @@ private[spark] class ExecutorAllocationManager( executorsPendingToRemove.add(executorId) true } else { - logWarning(s"Unable to reach the cluster manager to kill executor $executorId!") + logWarning(s"Unable to reach the cluster manager to kill executor $executorId," + + s"or no executor eligible to kill!") false } } @@ -524,7 +525,6 @@ private[spark] class ExecutorAllocationManager( private def onExecutorBusy(executorId: String): Unit = synchronized { logDebug(s"Clearing idle timer for $executorId because it is now running a task") removeTimes.remove(executorId) - executorsPendingToRemove.remove(executorId) } /** diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 1f1f0b75de5f..e03977828b86 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -23,9 +23,9 @@ import scala.collection.mutable import scala.concurrent.Future import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rpc.{ThreadSafeRpcEndpoint, RpcEnv, RpcCallContext} -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} /** diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala index 77d8ec9bb160..46f9f9e9af7d 100644 --- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala @@ -63,12 +63,12 @@ private[spark] class HttpFileServer( def addFile(file: File) : String = { addFileToDir(file, fileDir) - serverUri + "/files/" + file.getName + serverUri + "/files/" + Utils.encodeFileNameToURIRawPath(file.getName) } def addJar(file: File) : String = { addFileToDir(file, jarDir) - serverUri + "/jars/" + file.getName + serverUri + "/jars/" + Utils.encodeFileNameToURIRawPath(file.getName) } def addDirectory(path: String, resourceBase: String): String = { @@ -85,7 +85,7 @@ private[spark] class HttpFileServer( throw new IllegalArgumentException(s"$file cannot be a directory.") } Files.copy(file, new File(dir, file.getName)) - dir + "/" + file.getName + dir + "/" + Utils.encodeFileNameToURIRawPath(file.getName) } } diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala index faa3ef3d7561..3c808420c8b2 100644 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ b/core/src/main/scala/org/apache/spark/HttpServer.scala @@ -19,18 +19,17 @@ package org.apache.spark import java.io.File -import org.eclipse.jetty.server.ssl.SslSocketConnector -import org.eclipse.jetty.util.security.{Constraint, Password} -import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} +import org.eclipse.jetty.security.authentication.DigestAuthenticator import org.eclipse.jetty.server.Server import org.eclipse.jetty.server.bio.SocketConnector +import org.eclipse.jetty.server.ssl.SslSocketConnector import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder} +import org.eclipse.jetty.util.security.{Constraint, Password} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.apache.spark.util.Utils - /** * Exception type thrown by HttpServer when it is in the wrong state for an operation. */ diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 72355cdfa68b..1b59beb8d6ef 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -18,7 +18,6 @@ package org.apache.spark import java.io._ -import java.util.Arrays import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} @@ -26,7 +25,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.reflect.ClassTag -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, RpcEndpoint} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId} @@ -267,8 +266,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging } /** - * MapOutputTracker for the driver. This uses TimeStampedHashMap to keep track of map - * output information, which allows old output information based on a TTL. + * MapOutputTracker for the driver. */ private[spark] class MapOutputTrackerMaster(conf: SparkConf) extends MapOutputTracker(conf) { @@ -291,17 +289,10 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) // can be read locally, but may lead to more delay in scheduling if those locations are busy. private val REDUCER_PREF_LOCS_FRACTION = 0.2 - /** - * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the driver, - * so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set). - * Other than these two scenarios, nothing should be dropped from this HashMap. - */ - protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]() - private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]() - - // For cleaning up TimeStampedHashMaps - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) + // HashMaps for storing mapStatuses and cached serialized statuses in the driver. + // Statuses are dropped only by explicit de-registering. + protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala + private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala def registerShuffle(shuffleId: Int, numMaps: Int) { if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) { @@ -462,14 +453,8 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf) sendTracker(StopMapOutputTracker) mapStatuses.clear() trackerEndpoint = null - metadataCleaner.cancel() cachedSerializedStatuses.clear() } - - private def cleanup(cleanupTime: Long) { - mapStatuses.clearOldValues(cleanupTime) - cachedSerializedStatuses.clearOldValues(cleanupTime) - } } /** diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index ef9a2dab1c10..a7c2790c8360 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -21,13 +21,13 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} import scala.util.hashing.byteswap32 import org.apache.spark.rdd.{PartitionPruningRDD, RDD} import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.{CollectionsUtils, Utils} -import org.apache.spark.util.random.{XORShiftRandom, SamplingUtils} +import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index d3384fb29773..340e1f7824d1 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ import scala.collection.mutable.LinkedHashSet -import org.apache.avro.{SchemaNormalization, Schema} +import org.apache.avro.{Schema, SchemaNormalization} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.Utils @@ -544,7 +544,8 @@ private[spark] object SparkConf extends Logging { DeprecatedConfig("spark.kryoserializer.buffer.mb", "1.4", "Please use spark.kryoserializer.buffer instead. The default value for " + "spark.kryoserializer.buffer.mb was previously specified as '0.064'. Fractional values " + - "are no longer accepted. To specify the equivalent now, one may use '64k'.") + "are no longer accepted. To specify the equivalent now, one may use '64k'."), + DeprecatedConfig("spark.rpc", "2.0", "Not used any more.") ) Map(configs.map { cfg => (cfg.key -> cfg) } : _*) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 194ecc0a0434..98075cef112d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -17,22 +17,23 @@ package org.apache.spark -import scala.language.implicitConversions - import java.io._ import java.lang.reflect.Constructor import java.net.URI import java.util.{Arrays, Properties, UUID} -import java.util.concurrent.atomic.{AtomicReference, AtomicBoolean, AtomicInteger} +import java.util.concurrent.ConcurrentMap +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} import java.util.UUID.randomUUID import scala.collection.JavaConverters._ -import scala.collection.{Map, Set} +import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.HashMap -import scala.reflect.{ClassTag, classTag} +import scala.language.implicitConversions +import scala.reflect.{classTag, ClassTag} import scala.util.control.NonFatal +import com.google.common.collect.MapMaker import org.apache.commons.lang.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -42,27 +43,26 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, Sequence TextInputFormat} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} - import org.apache.mesos.MesosNativeLibrary import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil} -import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, - FixedLengthBinaryInputFormat} +import org.apache.spark.input.{FixedLengthBinaryInputFormat, PortableDataStream, StreamInputFormat, + WholeTextFileInputFormat} import org.apache.spark.io.CompressionCodec import org.apache.spark.metrics.MetricsSystem import org.apache.spark.partial.{ApproximateEvaluator, PartialResult} import org.apache.spark.rdd._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, - SparkDeploySchedulerBackend, SimrSchedulerBackend} +import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SimrSchedulerBackend, + SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump -import org.apache.spark.ui.{SparkUI, ConsoleProgressBar} +import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} import org.apache.spark.ui.jobs.JobProgressListener import org.apache.spark.util._ @@ -122,20 +122,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli */ def this() = this(new SparkConf()) - /** - * :: DeveloperApi :: - * Alternative constructor for setting preferred locations where Spark will create executors. - * - * @param config a [[org.apache.spark.SparkConf]] object specifying other Spark parameters - * @param preferredNodeLocationData not used. Left for backward compatibility. - */ - @deprecated("Passing in preferred locations has no effect at all, see SPARK-8949", "1.5.0") - @DeveloperApi - def this(config: SparkConf, preferredNodeLocationData: Map[String, Set[SplitInfo]]) = { - this(config) - logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") - } - /** * Alternative constructor that allows setting common Spark properties directly * @@ -155,21 +141,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli * @param jars Collection of JARs to send to the cluster. These can be paths on the local file * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes. - * @param preferredNodeLocationData not used. Left for backward compatibility. */ - @deprecated("Passing in preferred locations has no effect at all, see SPARK-10921", "1.6.0") def this( master: String, appName: String, sparkHome: String = null, jars: Seq[String] = Nil, - environment: Map[String, String] = Map(), - preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) = + environment: Map[String, String] = Map()) = { this(SparkContext.updatedConf(new SparkConf(), master, appName, sparkHome, jars, environment)) - if (preferredNodeLocationData.nonEmpty) { - logWarning("Passing in preferred locations has no effect at all, see SPARK-8949") - } } // NOTE: The below constructors could be consolidated using default arguments. Due to @@ -221,7 +201,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private var _eventLogDir: Option[URI] = None private var _eventLogCodec: Option[String] = None private var _env: SparkEnv = _ - private var _metadataCleaner: MetadataCleaner = _ private var _jobProgressListener: JobProgressListener = _ private var _statusTracker: SparkStatusTracker = _ private var _progressBar: Option[ConsoleProgressBar] = None @@ -267,8 +246,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Generate the random name for a temp folder in external block store. // Add a timestamp as the suffix here to make it more safe val externalBlockStoreFolderName = "spark-" + randomUUID.toString() - @deprecated("Use externalBlockStoreFolderName instead.", "1.4.0") - val tachyonFolderName = externalBlockStoreFolderName def isLocal: Boolean = (master == "local" || master.startsWith("local[")) @@ -295,8 +272,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli private[spark] val addedJars = HashMap[String, Long]() // Keeps track of all persisted RDDs - private[spark] val persistentRdds = new TimeStampedWeakValueHashMap[Int, RDD[_]] - private[spark] def metadataCleaner: MetadataCleaner = _metadataCleaner + private[spark] val persistentRdds = { + val map: ConcurrentMap[Int, RDD[_]] = new MapMaker().weakValues().makeMap[Int, RDD[_]]() + map.asScala + } private[spark] def jobProgressListener: JobProgressListener = _jobProgressListener def statusTracker: SparkStatusTracker = _statusTracker @@ -463,8 +442,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli _conf.set("spark.repl.class.uri", replUri) } - _metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, _conf) - _statusTracker = new SparkStatusTracker(this) _progressBar = @@ -641,11 +618,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli localProperties.set(props) } - @deprecated("Properties no longer need to be explicitly initialized.", "1.0.0") - def initLocalProperties() { - localProperties.set(new Properties()) - } - /** * Set a local property that affects jobs submitted from this thread, such as the * Spark fair scheduler pool. @@ -759,7 +731,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val numElements: BigInt = { val safeStart = BigInt(start) val safeEnd = BigInt(end) - if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) { + if ((safeEnd - safeStart) % step == 0 || (safeEnd > safeStart) != (step > 0)) { (safeEnd - safeStart) / step } else { // the remainder has the same sign with range, could add 1 more @@ -836,7 +808,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli minPartitions: Int = defaultMinPartitions): RDD[String] = withScope { assertNotStopped() hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text], - minPartitions).map(pair => pair._2.toString) + minPartitions).map(pair => pair._2.toString).setName(path) } /** @@ -874,18 +846,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, String)] = withScope { assertNotStopped() - val job = new NewHadoopJob(hadoopConfiguration) + val job = NewHadoopJob.getInstance(hadoopConfiguration) // Use setInputPaths so that wholeTextFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val updateConf = job.getConfiguration new WholeTextFileRDD( this, classOf[WholeTextFileInputFormat], classOf[Text], classOf[Text], updateConf, - minPartitions).setName(path).map(record => (record._1.toString, record._2.toString)) + minPartitions).map(record => (record._1.toString, record._2.toString)).setName(path) } /** @@ -923,11 +895,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli path: String, minPartitions: Int = defaultMinPartitions): RDD[(String, PortableDataStream)] = withScope { assertNotStopped() - val job = new NewHadoopJob(hadoopConfiguration) + val job = NewHadoopJob.getInstance(hadoopConfiguration) // Use setInputPaths so that binaryFiles aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updateConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val updateConf = job.getConfiguration new BinaryFileRDD( this, classOf[StreamInputFormat], @@ -1100,13 +1072,13 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli vClass: Class[V], conf: Configuration = hadoopConfiguration): RDD[(K, V)] = withScope { assertNotStopped() - // The call to new NewHadoopJob automatically adds security credentials to conf, + // The call to NewHadoopJob automatically adds security credentials to conf, // so we don't need to explicitly add them ourselves - val job = new NewHadoopJob(conf) + val job = NewHadoopJob.getInstance(conf) // Use setInputPaths so that newAPIHadoopFile aligns with hadoopFile/textFile in taking // comma separated files as input. (see SPARK-7155) NewFileInputFormat.setInputPaths(job, path) - val updatedConf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val updatedConf = job.getConfiguration new NewHadoopRDD(this, fClass, kClass, vClass, updatedConf).setName(path) } @@ -1248,7 +1220,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } /** Get an RDD that has no partitions or elements. */ - def emptyRDD[T: ClassTag]: EmptyRDD[T] = new EmptyRDD[T](this) + def emptyRDD[T: ClassTag]: RDD[T] = new EmptyRDD[T](this) // Methods for creating shared variables @@ -1369,7 +1341,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (!fs.exists(hadoopPath)) { throw new FileNotFoundException(s"Added file $hadoopPath does not exist.") } - val isDir = fs.getFileStatus(hadoopPath).isDir + val isDir = fs.getFileStatus(hadoopPath).isDirectory if (!isLocal && scheme == "file" && isDir) { throw new SparkException(s"addFile does not support local directories when not running " + "local mode.") @@ -1585,15 +1557,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli taskScheduler.schedulingMode } - /** - * Clear the job's list of files added by `addFile` so that they do not get downloaded to - * any new nodes. - */ - @deprecated("adding files no longer creates local copies that need to be deleted", "1.0.0") - def clearFiles() { - addedFiles.clear() - } - /** * Gets the locality information associated with the partition in a particular rdd * @param rdd of interest @@ -1685,15 +1648,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli postEnvironmentUpdate() } - /** - * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to - * any new nodes. - */ - @deprecated("adding jars no longer creates local copies that need to be deleted", "1.0.0") - def clearJars() { - addedJars.clear() - } - // Shut down the SparkContext. def stop() { if (AsynchronousListenerBus.withinListenerThread.value) { @@ -1721,11 +1675,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli env.metricsSystem.report() } } - if (metadataCleaner != null) { - Utils.tryLogNonFatalError { - metadataCleaner.cancel() - } - } Utils.tryLogNonFatalError { _cleaner.foreach(_.stop()) } @@ -1864,63 +1813,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli runJob(rdd, (ctx: TaskContext, it: Iterator[T]) => cleanedFunc(it), partitions) } - - /** - * Run a function on a given set of partitions in an RDD and pass the results to the given - * handler function. This is the main entry point for all actions in Spark. - * - * The allowLocal flag is deprecated as of Spark 1.5.0+. - */ - @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") - def runJob[T, U: ClassTag]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean, - resultHandler: (Int, U) => Unit): Unit = { - if (allowLocal) { - logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") - } - runJob(rdd, func, partitions, resultHandler) - } - - /** - * Run a function on a given set of partitions in an RDD and return the results as an array. - * - * The allowLocal flag is deprecated as of Spark 1.5.0+. - */ - @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") - def runJob[T, U: ClassTag]( - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitions: Seq[Int], - allowLocal: Boolean - ): Array[U] = { - if (allowLocal) { - logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") - } - runJob(rdd, func, partitions) - } - - /** - * Run a job on a given set of partitions of an RDD, but take a function of type - * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. - * - * The allowLocal argument is deprecated as of Spark 1.5.0+. - */ - @deprecated("use the version of runJob without the allowLocal parameter", "1.5.0") - def runJob[T, U: ClassTag]( - rdd: RDD[T], - func: Iterator[T] => U, - partitions: Seq[Int], - allowLocal: Boolean - ): Array[U] = { - if (allowLocal) { - logWarning("sc.runJob with allowLocal=true is deprecated in Spark 1.5.0+") - } - runJob(rdd, func, partitions) - } - /** * Run a job on all partitions in an RDD and return the results in an array. */ @@ -2073,8 +1965,9 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // its own local file system, which is incorrect because the checkpoint files // are actually on the executor machines. if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) { - logWarning("Checkpoint directory must be non-local " + - "if Spark is running on a cluster: " + directory) + logWarning("Spark is not running in local mode, therefore the checkpoint directory " + + s"must not be on the local filesystem. Directory '$directory' " + + "appears to be on the local filesystem.") } checkpointDir = Option(directory).map { dir => @@ -2093,10 +1986,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli taskScheduler.defaultParallelism } - /** Default min number of partitions for Hadoop RDDs when not given by user */ - @deprecated("use defaultMinPartitions", "1.0.0") - def defaultMinSplits: Int = math.min(defaultParallelism, 2) - /** * Default min number of partitions for Hadoop RDDs when not given by user * Notice that we use math.min so the "defaultMinPartitions" cannot be higher than 2. @@ -2192,11 +2081,6 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli } } - /** Called by MetadataCleaner to clean up the persistentRdds map periodically */ - private[spark] def cleanup(cleanupTime: Long) { - persistentRdds.clearOldValues(cleanupTime) - } - // In order to prevent multiple SparkContexts from being active at the same time, mark this // context as having finished construction. // NOTE: this must be placed at the end of the SparkContext constructor. @@ -2363,113 +2247,6 @@ object SparkContext extends Logging { */ private[spark] val LEGACY_DRIVER_IDENTIFIER = "" - // The following deprecated objects have already been copied to `object AccumulatorParam` to - // make the compiler find them automatically. They are duplicate codes only for backward - // compatibility, please update `object AccumulatorParam` accordingly if you plan to modify the - // following ones. - - @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + - "backward compatibility.", "1.3.0") - object DoubleAccumulatorParam extends AccumulatorParam[Double] { - def addInPlace(t1: Double, t2: Double): Double = t1 + t2 - def zero(initialValue: Double): Double = 0.0 - } - - @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + - "backward compatibility.", "1.3.0") - object IntAccumulatorParam extends AccumulatorParam[Int] { - def addInPlace(t1: Int, t2: Int): Int = t1 + t2 - def zero(initialValue: Int): Int = 0 - } - - @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + - "backward compatibility.", "1.3.0") - object LongAccumulatorParam extends AccumulatorParam[Long] { - def addInPlace(t1: Long, t2: Long): Long = t1 + t2 - def zero(initialValue: Long): Long = 0L - } - - @deprecated("Replaced by implicit objects in AccumulatorParam. This is kept here only for " + - "backward compatibility.", "1.3.0") - object FloatAccumulatorParam extends AccumulatorParam[Float] { - def addInPlace(t1: Float, t2: Float): Float = t1 + t2 - def zero(initialValue: Float): Float = 0f - } - - // The following deprecated functions have already been moved to `object RDD` to - // make the compiler find them automatically. They are still kept here for backward compatibility - // and just call the corresponding functions in `object RDD`. - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def rddToPairRDDFunctions[K, V](rdd: RDD[(K, V)]) - (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null): PairRDDFunctions[K, V] = - RDD.rddToPairRDDFunctions(rdd) - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]): AsyncRDDActions[T] = - RDD.rddToAsyncRDDActions(rdd) - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( - rdd: RDD[(K, V)]): SequenceFileRDDFunctions[K, V] = { - val kf = implicitly[K => Writable] - val vf = implicitly[V => Writable] - // Set the Writable class to null and `SequenceFileRDDFunctions` will use Reflection to get it - implicit val keyWritableFactory = new WritableFactory[K](_ => null, kf) - implicit val valueWritableFactory = new WritableFactory[V](_ => null, vf) - RDD.rddToSequenceFileRDDFunctions(rdd) - } - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def rddToOrderedRDDFunctions[K : Ordering : ClassTag, V: ClassTag]( - rdd: RDD[(K, V)]): OrderedRDDFunctions[K, V, (K, V)] = - RDD.rddToOrderedRDDFunctions(rdd) - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def doubleRDDToDoubleRDDFunctions(rdd: RDD[Double]): DoubleRDDFunctions = - RDD.doubleRDDToDoubleRDDFunctions(rdd) - - @deprecated("Replaced by implicit functions in the RDD companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - def numericRDDToDoubleRDDFunctions[T](rdd: RDD[T])(implicit num: Numeric[T]): DoubleRDDFunctions = - RDD.numericRDDToDoubleRDDFunctions(rdd) - - // The following deprecated functions have already been moved to `object WritableFactory` to - // make the compiler find them automatically. They are still kept here for backward compatibility. - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def intToIntWritable(i: Int): IntWritable = new IntWritable(i) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def longToLongWritable(l: Long): LongWritable = new LongWritable(l) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def floatToFloatWritable(f: Float): FloatWritable = new FloatWritable(f) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def doubleToDoubleWritable(d: Double): DoubleWritable = new DoubleWritable(d) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def boolToBoolWritable (b: Boolean): BooleanWritable = new BooleanWritable(b) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def bytesToBytesWritable (aob: Array[Byte]): BytesWritable = new BytesWritable(aob) - - @deprecated("Replaced by implicit functions in the WritableFactory companion object. This is " + - "kept here only for backward compatibility.", "1.3.0") - implicit def stringToText(s: String): Text = new Text(s) - private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]) : ArrayWritable = { def anyToWritable[U <% Writable](u: U): Writable = u @@ -2478,50 +2255,6 @@ object SparkContext extends Logging { arr.map(x => anyToWritable(x)).toArray) } - // The following deprecated functions have already been moved to `object WritableConverter` to - // make the compiler find them automatically. They are still kept here for backward compatibility - // and just call the corresponding functions in `object WritableConverter`. - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def intWritableConverter(): WritableConverter[Int] = - WritableConverter.intWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def longWritableConverter(): WritableConverter[Long] = - WritableConverter.longWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def doubleWritableConverter(): WritableConverter[Double] = - WritableConverter.doubleWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def floatWritableConverter(): WritableConverter[Float] = - WritableConverter.floatWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def booleanWritableConverter(): WritableConverter[Boolean] = - WritableConverter.booleanWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def bytesWritableConverter(): WritableConverter[Array[Byte]] = - WritableConverter.bytesWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def stringWritableConverter(): WritableConverter[String] = - WritableConverter.stringWritableConverter() - - @deprecated("Replaced by implicit functions in WritableConverter. This is kept here only for " + - "backward compatibility.", "1.3.0") - def writableWritableConverter[T <: Writable](): WritableConverter[T] = - WritableConverter.writableWritableConverter() - /** * Find the JAR from which a given class was loaded, to make it easy for users to pass * their JARs to SparkContext. diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 52acde1b414e..ec43be0e2f3a 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -29,13 +29,12 @@ import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.python.PythonWorkerFactory import org.apache.spark.broadcast.BroadcastManager -import org.apache.spark.metrics.MetricsSystem import org.apache.spark.memory.{MemoryManager, StaticMemoryManager, UnifiedMemoryManager} +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} -import org.apache.spark.rpc.akka.AkkaRpcEnv -import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} +import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager @@ -97,9 +96,7 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() - if (!rpcEnv.isInstanceOf[AkkaRpcEnv]) { - actorSystem.shutdown() - } + actorSystem.shutdown() rpcEnv.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut @@ -248,14 +245,11 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) - // Create the ActorSystem for Akka and get the port it binds to. val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + // Create the ActorSystem for Akka and get the port it binds to. val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager, clientMode = !isDriver) - val actorSystem: ActorSystem = - if (rpcEnv.isInstanceOf[AkkaRpcEnv]) { - rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem - } else { + val actorSystem: ActorSystem = { val actorSystemPort = if (port == 0 || rpcEnv.address == null) { port diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala index ac6eaab20d8d..58647860623e 100644 --- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala +++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala @@ -22,9 +22,10 @@ import java.text.NumberFormat import java.text.SimpleDateFormat import java.util.Date -import org.apache.hadoop.mapred._ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.TaskType import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.HadoopRDD @@ -37,10 +38,7 @@ import org.apache.spark.util.SerializableJobConf * a filename to write to, etc, exactly like in a Hadoop MapReduce job. */ private[spark] -class SparkHadoopWriter(jobConf: JobConf) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { +class SparkHadoopWriter(jobConf: JobConf) extends Logging with Serializable { private val now = new Date() private val conf = new SerializableJobConf(jobConf) @@ -131,7 +129,7 @@ class SparkHadoopWriter(jobConf: JobConf) private def getJobContext(): JobContext = { if (jobContext == null) { - jobContext = newJobContext(conf.value, jID.value) + jobContext = new JobContextImpl(conf.value, jID.value) } jobContext } @@ -143,6 +141,12 @@ class SparkHadoopWriter(jobConf: JobConf) taskContext } + protected def newTaskAttemptContext( + conf: JobConf, + attemptId: TaskAttemptID): TaskAttemptContext = { + new TaskAttemptContextImpl(conf, attemptId) + } + private def setIDs(jobid: Int, splitid: Int, attemptid: Int) { jobID = jobid splitID = splitid @@ -150,7 +154,7 @@ class SparkHadoopWriter(jobConf: JobConf) jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobid)) taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) } } @@ -168,9 +172,9 @@ object SparkHadoopWriter { } val outputPath = new Path(path) val fs = outputPath.getFileSystem(conf) - if (outputPath == null || fs == null) { + if (fs == null) { throw new IllegalArgumentException("Incorrectly formatted output path") } - outputPath.makeQualified(fs) + outputPath.makeQualified(fs.getUri, fs.getWorkingDirectory) } } diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index af558d6e5b47..e25ed0fdd7fd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -95,9 +95,6 @@ abstract class TaskContext extends Serializable { */ def isInterrupted(): Boolean - @deprecated("use isRunningLocally", "1.2.0") - def runningLocally(): Boolean - /** * Returns true if the task is running locally in the driver program. * @return @@ -118,16 +115,6 @@ abstract class TaskContext extends Serializable { */ def addTaskCompletionListener(f: (TaskContext) => Unit): TaskContext - /** - * Adds a callback function to be executed on task completion. An example use - * is for HadoopRDD to register a callback to close the input stream. - * Will be called in any situation - success, failure, or cancellation. - * - * @param f Callback function. - */ - @deprecated("use addTaskCompletionListener", "1.2.0") - def addOnCompleteCallback(f: () => Unit) - /** * The ID of the stage that this task belong to. */ @@ -144,9 +131,6 @@ abstract class TaskContext extends Serializable { */ def attemptNumber(): Int - @deprecated("use attemptNumber", "1.3.0") - def attemptId(): Long - /** * An ID that is unique to this task attempt (within the same SparkContext, no two task attempts * will share the same attempt ID). This is roughly equivalent to Hadoop's TaskAttemptID. diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index f0ae83a9341b..6c493630997e 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -38,9 +38,6 @@ private[spark] class TaskContextImpl( extends TaskContext with Logging { - // For backwards-compatibility; this method is now deprecated as of 1.3.0. - override def attemptId(): Long = taskAttemptId - // List of callback functions to execute when the task completes. @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] @@ -62,13 +59,6 @@ private[spark] class TaskContextImpl( this } - @deprecated("use addTaskCompletionListener", "1.1.0") - override def addOnCompleteCallback(f: () => Unit) { - onCompleteCallbacks += new TaskCompletionListener { - override def onTaskCompletion(context: TaskContext): Unit = f() - } - } - /** Marks the task as completed and triggers the listeners. */ private[spark] def markTaskCompleted(): Unit = { completed = true diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index c32aefac465b..37ae007f880c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -23,7 +23,6 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark.Partitioner -import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.rdd.RDD diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 87deaf20e2b2..fb04472ee73f 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -17,14 +17,14 @@ package org.apache.spark.api.java -import java.util.{Comparator, List => JList, Map => JMap} +import java.{lang => jl} import java.lang.{Iterable => JIterable} +import java.util.{Comparator, List => JList} import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag -import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{JobConf, OutputFormat} @@ -139,7 +139,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * math.ceil(numItems * samplingRate) over all key values. */ def sampleByKey(withReplacement: Boolean, - fractions: JMap[K, Double], + fractions: java.util.Map[K, Double], seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sampleByKey(withReplacement, fractions.asScala, seed)) @@ -154,7 +154,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * Use Utils.random.nextLong as the default seed for the random number generator. */ def sampleByKey(withReplacement: Boolean, - fractions: JMap[K, Double]): JavaPairRDD[K, V] = + fractions: java.util.Map[K, Double]): JavaPairRDD[K, V] = sampleByKey(withReplacement, fractions, Utils.random.nextLong) /** @@ -168,7 +168,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * two additional passes. */ def sampleByKeyExact(withReplacement: Boolean, - fractions: JMap[K, Double], + fractions: java.util.Map[K, Double], seed: Long): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.sampleByKeyExact(withReplacement, fractions.asScala, seed)) @@ -184,7 +184,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * * Use Utils.random.nextLong as the default seed for the random number generator. */ - def sampleByKeyExact(withReplacement: Boolean, fractions: JMap[K, Double]): JavaPairRDD[K, V] = + def sampleByKeyExact( + withReplacement: Boolean, + fractions: java.util.Map[K, Double]): JavaPairRDD[K, V] = sampleByKeyExact(withReplacement, fractions, Utils.random.nextLong) /** @@ -292,7 +294,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func)) /** Count the number of elements for each key, and return the result to the master as a Map. */ - def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey()) + def countByKey(): java.util.Map[K, jl.Long] = + mapAsSerializableJavaMap(rdd.countByKey()).asInstanceOf[java.util.Map[K, jl.Long]] /** * Approximate version of countByKey that can return a partial result if it does @@ -651,7 +654,6 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * keys; this also retains the original RDD's partitioning. */ def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { - import scala.collection.JavaConverters._ def fn: (V) => Iterable[U] = (x: V) => f.call(x).asScala implicit val ctag: ClassTag[U] = fakeClassTag fromRDD(rdd.flatMapValues(fn)) @@ -934,9 +936,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * It must be greater than 0.000017. * @param partitioner partitioner of the resulting RDD. */ - def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner): JavaPairRDD[K, Long] = - { - fromRDD(rdd.countApproxDistinctByKey(relativeSD, partitioner)) + def countApproxDistinctByKey(relativeSD: Double, partitioner: Partitioner) + : JavaPairRDD[K, jl.Long] = { + fromRDD(rdd.countApproxDistinctByKey(relativeSD, partitioner)). + asInstanceOf[JavaPairRDD[K, jl.Long]] } /** @@ -950,8 +953,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * It must be greater than 0.000017. * @param numPartitions number of partitions of the resulting RDD. */ - def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): JavaPairRDD[K, Long] = { - fromRDD(rdd.countApproxDistinctByKey(relativeSD, numPartitions)) + def countApproxDistinctByKey(relativeSD: Double, numPartitions: Int): JavaPairRDD[K, jl.Long] = { + fromRDD(rdd.countApproxDistinctByKey(relativeSD, numPartitions)). + asInstanceOf[JavaPairRDD[K, jl.Long]] } /** @@ -964,8 +968,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) * @param relativeSD Relative accuracy. Smaller values create counters that require more space. * It must be greater than 0.000017. */ - def countApproxDistinctByKey(relativeSD: Double): JavaPairRDD[K, Long] = { - fromRDD(rdd.countApproxDistinctByKey(relativeSD)) + def countApproxDistinctByKey(relativeSD: Double): JavaPairRDD[K, jl.Long] = { + fromRDD(rdd.countApproxDistinctByKey(relativeSD)).asInstanceOf[JavaPairRDD[K, jl.Long]] } /** Assign a name to this RDD */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 0e4d7dce0f2f..0f8d13cf5cc2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -18,13 +18,12 @@ package org.apache.spark.api.java import java.{lang => jl} -import java.lang.{Iterable => JIterable, Long => JLong} -import java.util.{Comparator, List => JList, Iterator => JIterator} +import java.lang.{Iterable => JIterable} +import java.util.{Comparator, Iterator => JIterator, List => JList} import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark._ @@ -57,9 +56,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def rdd: RDD[T] - @deprecated("Use partitions() instead.", "1.1.0") - def splits: JList[Partition] = rdd.partitions.toSeq.asJava - /** Set of partitions in this RDD. */ def partitions: JList[Partition] = rdd.partitions.toSeq.asJava @@ -125,7 +121,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMap[U](f: FlatMapFunction[T, U]): JavaRDD[U] = { - import scala.collection.JavaConverters._ def fn: (T) => Iterable[U] = (x: T) => f.call(x).asScala JavaRDD.fromRDD(rdd.flatMap(fn)(fakeClassTag[U]))(fakeClassTag[U]) } @@ -135,7 +130,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { - import scala.collection.JavaConverters._ def fn: (T) => Iterable[jl.Double] = (x: T) => f.call(x).asScala new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) } @@ -145,7 +139,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * RDD, and then flattening the results. */ def flatMapToPair[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - import scala.collection.JavaConverters._ def fn: (T) => Iterable[(K2, V2)] = (x: T) => f.call(x).asScala def cm: ClassTag[(K2, V2)] = implicitly[ClassTag[(K2, V2)]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(fakeClassTag[K2], fakeClassTag[V2]) @@ -308,8 +301,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * 2*n+k, ..., where n is the number of partitions. So there may exist gaps, but this method * won't trigger a spark job, which is different from [[org.apache.spark.rdd.RDD#zipWithIndex]]. */ - def zipWithUniqueId(): JavaPairRDD[T, JLong] = { - JavaPairRDD.fromRDD(rdd.zipWithUniqueId()).asInstanceOf[JavaPairRDD[T, JLong]] + def zipWithUniqueId(): JavaPairRDD[T, jl.Long] = { + JavaPairRDD.fromRDD(rdd.zipWithUniqueId()).asInstanceOf[JavaPairRDD[T, jl.Long]] } /** @@ -319,8 +312,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * This is similar to Scala's zipWithIndex but it uses Long instead of Int as the index type. * This method needs to trigger a spark job when this RDD contains more than one partitions. */ - def zipWithIndex(): JavaPairRDD[T, JLong] = { - JavaPairRDD.fromRDD(rdd.zipWithIndex()).asInstanceOf[JavaPairRDD[T, JLong]] + def zipWithIndex(): JavaPairRDD[T, jl.Long] = { + JavaPairRDD.fromRDD(rdd.zipWithIndex()).asInstanceOf[JavaPairRDD[T, jl.Long]] } // Actions (launch a job to return a value to the user program) @@ -346,13 +339,6 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def toLocalIterator(): JIterator[T] = asJavaIteratorConverter(rdd.toLocalIterator).asJava - /** - * Return an array that contains all of the elements in this RDD. - * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead - */ - @deprecated("use collect()", "1.0.0") - def toArray(): JList[T] = collect() - /** * Return an array that contains all of the elements in a specific partition of this RDD. */ @@ -458,7 +444,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): java.util.Map[T, jl.Long] = - mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new jl.Long(x._2))))) + mapAsSerializableJavaMap(rdd.countByValue()).asInstanceOf[java.util.Map[T, jl.Long]] /** * (Experimental) Approximate version of countByValue(). @@ -641,8 +627,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * The asynchronous version of `count`, which returns a * future for counting the number of elements in this RDD. */ - def countAsync(): JavaFutureAction[JLong] = { - new JavaFutureActionWrapper[Long, JLong](rdd.countAsync(), JLong.valueOf) + def countAsync(): JavaFutureAction[jl.Long] = { + new JavaFutureActionWrapper[Long, jl.Long](rdd.countAsync(), jl.Long.valueOf) } /** diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 4f54cd69e217..01433ca2efc1 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -25,9 +25,7 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag -import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration -import org.apache.spark.input.PortableDataStream import org.apache.hadoop.mapred.{InputFormat, JobConf} import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} @@ -35,6 +33,7 @@ import org.apache.spark._ import org.apache.spark.AccumulatorParam._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.broadcast.Broadcast +import org.apache.spark.input.PortableDataStream import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD} /** @@ -102,7 +101,7 @@ class JavaSparkContext(val sc: SparkContext) */ def this(master: String, appName: String, sparkHome: String, jars: Array[String], environment: JMap[String, String]) = - this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment.asScala, Map())) + this(new SparkContext(master, appName, sparkHome, jars.toSeq, environment.asScala)) private[spark] val env = sc.env @@ -126,14 +125,6 @@ class JavaSparkContext(val sc: SparkContext) /** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */ def defaultParallelism: java.lang.Integer = sc.defaultParallelism - /** - * Default min number of partitions for Hadoop RDDs when not given by user. - * @deprecated As of Spark 1.0.0, defaultMinSplits is deprecated, use - * {@link #defaultMinPartitions()} instead - */ - @deprecated("use defaultMinPartitions", "1.0.0") - def defaultMinSplits: java.lang.Integer = sc.defaultMinSplits - /** Default min number of partitions for Hadoop RDDs when not given by user */ def defaultMinPartitions: java.lang.Integer = sc.defaultMinPartitions @@ -671,24 +662,6 @@ class JavaSparkContext(val sc: SparkContext) sc.addJar(path) } - /** - * Clear the job's list of JARs added by `addJar` so that they do not get downloaded to - * any new nodes. - */ - @deprecated("adding jars no longer creates local copies that need to be deleted", "1.0.0") - def clearJars() { - sc.clearJars() - } - - /** - * Clear the job's list of files added by `addFile` so that they do not get downloaded to - * any new nodes. - */ - @deprecated("adding files no longer creates local copies that need to be deleted", "1.0.0") - def clearFiles() { - sc.clearFiles() - } - /** * Returns the Hadoop configuration used for the Hadoop code (e.g. file systems) we reuse. * diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala index 3300cad9efba..99ca3c77cced 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.java -import org.apache.spark.{SparkStageInfo, SparkJobInfo, SparkContext} +import org.apache.spark.{SparkContext, SparkJobInfo, SparkStageInfo} /** * Low-level status reporting APIs for monitoring job and stage progress. diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala index 8f9647eea9e2..f820401da2fc 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala @@ -17,18 +17,17 @@ package org.apache.spark.api.java +import java.{util => ju} import java.util.Map.Entry -import com.google.common.base.Optional - -import java.{util => ju} import scala.collection.mutable private[spark] object JavaUtils { def optionToOptional[T](option: Option[T]): Optional[T] = - option match { - case Some(value) => Optional.of(value) - case None => Optional.absent() + if (option.isDefined) { + Optional.of(option.get) + } else { + Optional.empty[T] } // Workaround for SPARK-3926 / SI-8911 diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 8464b578ed09..f12e2dfafa19 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -19,7 +19,7 @@ package org.apache.spark.api.python import java.io._ import java.net._ -import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap} +import java.util.{ArrayList => JArrayList, Collections, List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala index 292ac4cfc35b..2d97cd9a9a20 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonUtils.scala @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.spark.SparkContext -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} private[spark] object PythonUtils { /** Get the PYTHONPATH for PySpark, either from SPARK_HOME, if it is set, or from our JAR */ diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 7039b734d2e4..a2a2f89f1e87 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -17,7 +17,7 @@ package org.apache.spark.api.python -import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter} +import java.io.{DataInputStream, DataOutputStream, InputStream, OutputStreamWriter} import java.net.{InetAddress, ServerSocket, Socket, SocketException} import java.util.Arrays diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala index fd27276e70bf..b0d858486bfb 100644 --- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala @@ -20,16 +20,15 @@ package org.apache.spark.api.python import java.nio.ByteOrder import java.util.{ArrayList => JArrayList} -import org.apache.spark.api.java.JavaRDD - import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Failure import scala.util.Try -import net.razorvine.pickle.{Unpickler, Pickler} +import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.{Logging, SparkException} +import org.apache.spark.api.java.JavaRDD import org.apache.spark.rdd.RDD /** Utilities for serialization / deserialization between Python and Java, using Pickle. */ diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala index ee1fb056f0d9..9549784aeabf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala +++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala @@ -17,13 +17,12 @@ package org.apache.spark.api.python -import java.io.{DataOutput, DataInput} import java.{util => ju} +import java.io.{DataInput, DataOutput} import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 - import org.apache.hadoop.io._ import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat diff --git a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala index 0095548c463c..9bddd7248c7e 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala @@ -22,8 +22,8 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.mutable.HashMap import scala.language.existentials -import io.netty.channel.ChannelHandler.Sharable import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler} +import io.netty.channel.ChannelHandler.Sharable import org.apache.spark.Logging import org.apache.spark.api.r.SerDe._ diff --git a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala index 7509b3d3f44b..401f362fee82 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRDD.scala @@ -19,8 +19,7 @@ package org.apache.spark.api.r import java.io._ import java.net.{InetAddress, ServerSocket} -import java.util.Arrays -import java.util.{Map => JMap} +import java.util.{Arrays, Map => JMap} import scala.collection.JavaConverters._ import scala.io.Source diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index da126bac7ad1..af815f885e8a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -18,7 +18,7 @@ package org.apache.spark.api.r import java.io.{DataInputStream, DataOutputStream} -import java.sql.{Timestamp, Date, Time} +import java.sql.{Date, Time, Timestamp} import scala.collection.JavaConverters._ import scala.collection.mutable.WrappedArray diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 12d79f6ed311..0d68872dcb6e 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -19,12 +19,12 @@ package org.apache.spark.broadcast import java.io.Serializable -import org.apache.spark.SparkException +import scala.reflect.ClassTag + import org.apache.spark.Logging +import org.apache.spark.SparkException import org.apache.spark.util.Utils -import scala.reflect.ClassTag - /** * A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable * cached on each machine rather than shipping a copy of it with tasks. They can be used, for diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala index 6a187b40628a..7f35ac47479b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala @@ -24,14 +24,12 @@ import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi /** - * :: DeveloperApi :: * An interface for all the broadcast implementations in Spark (to allow * multiple broadcast implementations). SparkContext uses a user-specified * BroadcastFactory implementation to instantiate a particular broadcast for the * entire Spark job. */ -@DeveloperApi -trait BroadcastFactory { +private[spark] trait BroadcastFactory { def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala index fac6666bb341..be416c4f74cb 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala @@ -21,8 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.reflect.ClassTag -import org.apache.spark._ -import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SecurityManager, SparkConf} private[spark] class BroadcastManager( val isDriver: Boolean, @@ -39,15 +38,8 @@ private[spark] class BroadcastManager( private def initialize() { synchronized { if (!initialized) { - val broadcastFactoryClass = - conf.get("spark.broadcast.factory", "org.apache.spark.broadcast.TorrentBroadcastFactory") - - broadcastFactory = - Utils.classForName(broadcastFactoryClass).newInstance.asInstanceOf[BroadcastFactory] - - // Initialize appropriate BroadcastFactory and BroadcastObject + broadcastFactory = new TorrentBroadcastFactory broadcastFactory.initialize(isDriver, conf, securityManager) - initialized = true } } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala deleted file mode 100644 index b69af639f786..000000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ /dev/null @@ -1,269 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import java.io.{File, FileOutputStream, ObjectInputStream, ObjectOutputStream, OutputStream} -import java.io.{BufferedInputStream, BufferedOutputStream} -import java.net.{URL, URLConnection, URI} -import java.util.concurrent.TimeUnit - -import scala.reflect.ClassTag - -import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv} -import org.apache.spark.io.CompressionCodec -import org.apache.spark.storage.{BroadcastBlockId, StorageLevel} -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashSet, Utils} - -/** - * A [[org.apache.spark.broadcast.Broadcast]] implementation that uses HTTP server - * as a broadcast mechanism. The first time a HTTP broadcast variable (sent as part of a - * task) is deserialized in the executor, the broadcasted data is fetched from the driver - * (through a HTTP server running at the driver) and stored in the BlockManager of the - * executor to speed up future accesses. - */ -private[spark] class HttpBroadcast[T: ClassTag]( - @transient var value_ : T, isLocal: Boolean, id: Long) - extends Broadcast[T](id) with Logging with Serializable { - - override protected def getValue() = value_ - - private val blockId = BroadcastBlockId(id) - - /* - * Broadcasted data is also stored in the BlockManager of the driver. The BlockManagerMaster - * does not need to be told about this block as not only need to know about this data block. - */ - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - } - - if (!isLocal) { - HttpBroadcast.write(id, value_) - } - - /** - * Remove all persisted state associated with this HTTP broadcast on the executors. - */ - override protected def doUnpersist(blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver = false, blocking) - } - - /** - * Remove all persisted state associated with this HTTP broadcast on the executors and driver. - */ - override protected def doDestroy(blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver = true, blocking) - } - - /** Used by the JVM when serializing this object. */ - private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException { - assertValid() - out.defaultWriteObject() - } - - /** Used by the JVM when deserializing this object. */ - private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException { - in.defaultReadObject() - HttpBroadcast.synchronized { - SparkEnv.get.blockManager.getSingle(blockId) match { - case Some(x) => value_ = x.asInstanceOf[T] - case None => { - logInfo("Started reading broadcast variable " + id) - val start = System.nanoTime - value_ = HttpBroadcast.read[T](id) - /* - * We cache broadcast data in the BlockManager so that subsequent tasks using it - * do not need to re-fetch. This data is only used locally and no other node - * needs to fetch this block, so we don't notify the master. - */ - SparkEnv.get.blockManager.putSingle( - blockId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false) - val time = (System.nanoTime - start) / 1e9 - logInfo("Reading broadcast variable " + id + " took " + time + " s") - } - } - } - } -} - -private[broadcast] object HttpBroadcast extends Logging { - private var initialized = false - private var broadcastDir: File = null - private var compress: Boolean = false - private var bufferSize: Int = 65536 - private var serverUri: String = null - private var server: HttpServer = null - private var securityManager: SecurityManager = null - - // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist - private val files = new TimeStampedHashSet[File] - private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt - private var compressionCodec: CompressionCodec = null - private var cleaner: MetadataCleaner = null - - def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - synchronized { - if (!initialized) { - bufferSize = conf.getInt("spark.buffer.size", 65536) - compress = conf.getBoolean("spark.broadcast.compress", true) - securityManager = securityMgr - if (isDriver) { - createServer(conf) - conf.set("spark.httpBroadcast.uri", serverUri) - } - serverUri = conf.get("spark.httpBroadcast.uri") - cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) - compressionCodec = CompressionCodec.createCodec(conf) - initialized = true - } - } - } - - def stop() { - synchronized { - if (server != null) { - server.stop() - server = null - } - if (cleaner != null) { - cleaner.cancel() - cleaner = null - } - compressionCodec = null - initialized = false - } - } - - private def createServer(conf: SparkConf) { - broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast") - val broadcastPort = conf.getInt("spark.broadcast.port", 0) - server = - new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server") - server.start() - serverUri = server.uri - logInfo("Broadcast server started at " + serverUri) - } - - def getFile(id: Long): File = new File(broadcastDir, BroadcastBlockId(id).name) - - private def write(id: Long, value: Any) { - val file = getFile(id) - val fileOutputStream = new FileOutputStream(file) - Utils.tryWithSafeFinally { - val out: OutputStream = { - if (compress) { - compressionCodec.compressedOutputStream(fileOutputStream) - } else { - new BufferedOutputStream(fileOutputStream, bufferSize) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val serOut = ser.serializeStream(out) - Utils.tryWithSafeFinally { - serOut.writeObject(value) - } { - serOut.close() - } - files += file - } { - fileOutputStream.close() - } - } - - private def read[T: ClassTag](id: Long): T = { - logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id) - val url = serverUri + "/" + BroadcastBlockId(id).name - - var uc: URLConnection = null - if (securityManager.isAuthenticationEnabled()) { - logDebug("broadcast security enabled") - val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager) - uc = newuri.toURL.openConnection() - uc.setConnectTimeout(httpReadTimeout) - uc.setAllowUserInteraction(false) - } else { - logDebug("broadcast not using security") - uc = new URL(url).openConnection() - uc.setConnectTimeout(httpReadTimeout) - } - Utils.setupSecureURLConnection(uc, securityManager) - - val in = { - uc.setReadTimeout(httpReadTimeout) - val inputStream = uc.getInputStream - if (compress) { - compressionCodec.compressedInputStream(inputStream) - } else { - new BufferedInputStream(inputStream, bufferSize) - } - } - val ser = SparkEnv.get.serializer.newInstance() - val serIn = ser.deserializeStream(in) - Utils.tryWithSafeFinally { - serIn.readObject[T]() - } { - serIn.close() - } - } - - /** - * Remove all persisted blocks associated with this HTTP broadcast on the executors. - * If removeFromDriver is true, also remove these persisted blocks on the driver - * and delete the associated broadcast file. - */ - def unpersist(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit = synchronized { - SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking) - if (removeFromDriver) { - val file = getFile(id) - files.remove(file) - deleteBroadcastFile(file) - } - } - - /** - * Periodically clean up old broadcasts by removing the associated map entries and - * deleting the associated files. - */ - private def cleanup(cleanupTime: Long) { - val iterator = files.internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - val (file, time) = (entry.getKey, entry.getValue) - if (time < cleanupTime) { - iterator.remove() - deleteBroadcastFile(file) - } - } - } - - private def deleteBroadcastFile(file: File) { - try { - if (file.exists) { - if (file.delete()) { - logInfo("Deleted broadcast file: %s".format(file)) - } else { - logWarning("Could not delete broadcast file: %s".format(file)) - } - } - } catch { - case e: Exception => - logError("Exception while deleting broadcast file: %s".format(file), e) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala deleted file mode 100644 index cf3ae36f2794..000000000000 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.broadcast - -import scala.reflect.ClassTag - -import org.apache.spark.{SecurityManager, SparkConf} - -/** - * A [[org.apache.spark.broadcast.BroadcastFactory]] implementation that uses a - * HTTP server as the broadcast mechanism. Refer to - * [[org.apache.spark.broadcast.HttpBroadcast]] for more details about this mechanism. - */ -class HttpBroadcastFactory extends BroadcastFactory { - override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { - HttpBroadcast.initialize(isDriver, conf, securityMgr) - } - - override def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long): Broadcast[T] = - new HttpBroadcast[T](value_, isLocal, id) - - override def stop() { HttpBroadcast.stop() } - - /** - * Remove all persisted state associated with the HTTP broadcast with the given ID. - * @param removeFromDriver Whether to remove state from the driver - * @param blocking Whether to block until unbroadcasted - */ - override def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean) { - HttpBroadcast.unpersist(id, removeFromDriver, blocking) - } -} diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 7e3764d802fe..9bd69727f608 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.io.ByteArrayChunkOutputStream * BlockManager, ready for other executors to fetch from. * * This prevents the driver from being the bottleneck in sending out multiple copies of the - * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]]. + * broadcast data (one per executor). * * When initialized, TorrentBroadcast objects read SparkEnv.get.conf. * diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala index 96d8dd79908c..b11f9ba171b8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SecurityManager, SparkConf} * protocol to do a distributed transfer of the broadcasted data to the executors. Refer to * [[org.apache.spark.broadcast.TorrentBroadcast]] for more details. */ -class TorrentBroadcastFactory extends BroadcastFactory { +private[spark] class TorrentBroadcastFactory extends BroadcastFactory { override def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager) { } diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index f03875a3e8c8..63a20ab41a0f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -24,11 +24,11 @@ import scala.util.{Failure, Success} import org.apache.log4j.{Level, Logger} -import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ThreadUtils, SparkExitCode, Utils} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.util.{SparkExitCode, ThreadUtils, Utils} /** * Proxy that relays messages to the driver. @@ -230,7 +230,7 @@ object Client { RpcEnv.create("driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) val masterEndpoints = driverArgs.masters.map(RpcAddress.fromSparkURL). - map(rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, _, Master.ENDPOINT_NAME)) + map(rpcEnv.setupEndpointRef(_, Master.ENDPOINT_NAME)) rpcEnv.setupEndpoint("client", new ClientEndpoint(rpcEnv, driverArgs, masterEndpoints, conf)) rpcEnv.awaitTermination() diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala index 72cc330a398d..255420182b49 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala @@ -22,6 +22,7 @@ import java.net.{URI, URISyntaxException} import scala.collection.mutable.ListBuffer import org.apache.log4j.Level + import org.apache.spark.util.{IntParam, MemoryParam, Utils} /** diff --git a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala index 7fc96e4f764b..c514a1a86bab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExternalShuffleService.scala @@ -21,11 +21,11 @@ import java.util.concurrent.CountDownLatch import scala.collection.JavaConverters._ -import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.network.TransportContext import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.sasl.SaslServerBootstrap -import org.apache.spark.network.server.{TransportServerBootstrap, TransportServer} +import org.apache.spark.network.server.{TransportServer, TransportServerBootstrap} import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler import org.apache.spark.network.util.TransportConf import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -108,6 +108,7 @@ object ExternalShuffleService extends Logging { private[spark] def main( args: Array[String], newShuffleService: (SparkConf, SecurityManager) => ExternalShuffleService): Unit = { + Utils.initDaemon(log) val sparkConf = new SparkConf Utils.loadDefaultSparkProperties(sparkConf) val securityManager = new SecurityManager(sparkConf) diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index b4edb6109e83..c0ede4b7c873 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -22,7 +22,7 @@ import java.net.URL import java.util.concurrent.TimeoutException import scala.collection.mutable.ListBuffer -import scala.concurrent.{Await, future, promise} +import scala.concurrent.{future, promise, Await} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.language.postfixOps diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 5bb62d37d637..2dfb813d5fb4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,10 +19,10 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import org.apache.spark.rpc.RpcEnv import org.apache.spark.{Logging, SparkConf} -import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master +import org.apache.spark.deploy.worker.Worker +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index d85327603f64..c0a9e3f280ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy -import java.net.URI import java.io.File +import java.net.URI import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ diff --git a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala index d46dc87a92c9..4911c3be3a02 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RPackageUtils.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import com.google.common.io.{ByteStreams, Files} -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.api.r.RUtils import org.apache.spark.util.{RedirectThread, Utils} diff --git a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala index 661f7317c674..d0466830b217 100644 --- a/core/src/main/scala/org/apache/spark/deploy/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/RRunner.scala @@ -24,8 +24,8 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path -import org.apache.spark.api.r.{RBackend, RUtils} import org.apache.spark.{SparkException, SparkUserAppException} +import org.apache.spark.api.r.{RBackend, RUtils} import org.apache.spark.util.RedirectThread /** diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 59e90564b351..8ba3f5e24189 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -29,18 +29,15 @@ import scala.util.control.NonFatal import com.google.common.primitives.Longs import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} +import org.apache.hadoop.fs.FileSystem.Statistics import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.mapred.JobConf -import org.apache.hadoop.mapreduce.JobContext -import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} -import org.apache.hadoop.mapreduce.{TaskAttemptID => MapReduceTaskAttemptID} import org.apache.hadoop.security.{Credentials, UserGroupInformation} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf, SparkException} /** * :: DeveloperApi :: @@ -76,9 +73,6 @@ class SparkHadoopUtil extends Logging { } } - @deprecated("use newConfiguration with SparkConf argument", "1.2.0") - def newConfiguration(): Configuration = newConfiguration(null) - /** * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop * subsystems. @@ -190,33 +184,6 @@ class SparkHadoopUtil extends Logging { statisticsDataClass.getDeclaredMethod(methodName) } - /** - * Using reflection to get the Configuration from JobContext/TaskAttemptContext. If we directly - * call `JobContext/TaskAttemptContext.getConfiguration`, it will generate different byte codes - * for Hadoop 1.+ and Hadoop 2.+ because JobContext/TaskAttemptContext is class in Hadoop 1.+ - * while it's interface in Hadoop 2.+. - */ - def getConfigurationFromJobContext(context: JobContext): Configuration = { - // scalastyle:off jobconfig - val method = context.getClass.getMethod("getConfiguration") - // scalastyle:on jobconfig - method.invoke(context).asInstanceOf[Configuration] - } - - /** - * Using reflection to call `getTaskAttemptID` from TaskAttemptContext. If we directly - * call `TaskAttemptContext.getTaskAttemptID`, it will generate different byte codes - * for Hadoop 1.+ and Hadoop 2.+ because TaskAttemptContext is class in Hadoop 1.+ - * while it's interface in Hadoop 2.+. - */ - def getTaskAttemptIDFromTaskAttemptContext( - context: MapReduceTaskAttemptContext): MapReduceTaskAttemptID = { - // scalastyle:off jobconfig - val method = context.getClass.getMethod("getTaskAttemptID") - // scalastyle:on jobconfig - method.invoke(context).asInstanceOf[MapReduceTaskAttemptID] - } - /** * Get [[FileStatus]] objects for all leaf children (files) under the given base path. If the * given path points to a file, return a single-element collection containing [[FileStatus]] of @@ -233,11 +200,11 @@ class SparkHadoopUtil extends Logging { */ def listLeafStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { def recurse(status: FileStatus): Seq[FileStatus] = { - val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDir) + val (directories, leaves) = fs.listStatus(status.getPath).partition(_.isDirectory) leaves ++ directories.flatMap(f => listLeafStatuses(fs, f)) } - if (baseStatus.isDir) recurse(baseStatus) else Seq(baseStatus) + if (baseStatus.isDirectory) recurse(baseStatus) else Seq(baseStatus) } def listLeafDirStatuses(fs: FileSystem, basePath: Path): Seq[FileStatus] = { @@ -246,12 +213,12 @@ class SparkHadoopUtil extends Logging { def listLeafDirStatuses(fs: FileSystem, baseStatus: FileStatus): Seq[FileStatus] = { def recurse(status: FileStatus): Seq[FileStatus] = { - val (directories, files) = fs.listStatus(status.getPath).partition(_.isDir) + val (directories, files) = fs.listStatus(status.getPath).partition(_.isDirectory) val leaves = if (directories.isEmpty) Seq(status) else Seq.empty[FileStatus] leaves ++ directories.flatMap(dir => listLeafDirStatuses(fs, dir)) } - assert(baseStatus.isDir) + assert(baseStatus.isDirectory) recurse(baseStatus) } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 52d3ab34c178..a1e8da1ec8f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -37,9 +37,9 @@ import org.apache.ivy.core.retrieve.RetrieveOptions import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.matcher.GlobPatternMatcher import org.apache.ivy.plugins.repository.file.FileRepository -import org.apache.ivy.plugins.resolver.{FileSystemResolver, ChainResolver, IBiblioResolver} +import org.apache.ivy.plugins.resolver.{ChainResolver, FileSystemResolver, IBiblioResolver} -import org.apache.spark.{SparkException, SparkUserAppException, SPARK_VERSION} +import org.apache.spark.{SPARK_VERSION, SparkException, SparkUserAppException} import org.apache.spark.api.r.RUtils import org.apache.spark.deploy.rest._ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -965,7 +965,7 @@ private[spark] object SparkSubmitUtils { // We need to specify each component explicitly, otherwise we miss spark-streaming-kafka and // other spark-streaming utility components. Underscore is there to differentiate between // spark-streaming_2.1x and spark-streaming-kafka-assembly_2.1x - val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", + val components = Seq("catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") components.foreach { comp => diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 1e2f469214b8..a7a0a78f1456 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -18,8 +18,8 @@ package org.apache.spark.deploy.client import java.util.concurrent._ -import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} import scala.util.control.NonFatal @@ -104,8 +104,7 @@ private[spark] class AppClient( return } logInfo("Connecting to master " + masterAddress.toSparkURL + "...") - val masterRef = - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + val masterRef = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) masterRef.send(RegisterApplication(appDescription, self)) } catch { case ie: InterruptedException => // Cancelled @@ -124,7 +123,7 @@ private[spark] class AppClient( */ private def registerWithMaster(nthRetry: Int) { registerMasterFutures.set(tryRegisterAllMasters()) - registrationRetryTimer.set(registrationRetryThread.scheduleAtFixedRate(new Runnable { + registrationRetryTimer.set(registrationRetryThread.schedule(new Runnable { override def run(): Unit = { Utils.tryOrExit { if (registered.get) { @@ -138,7 +137,7 @@ private[spark] class AppClient( } } } - }, REGISTRATION_TIMEOUT_SECONDS, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)) + }, REGISTRATION_TIMEOUT_SECONDS, TimeUnit.SECONDS)) } /** diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index adb3f0225802..f8d3da24b956 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,9 +17,9 @@ package org.apache.spark.deploy.client -import org.apache.spark.rpc.RpcEnv -import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, Command} +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils private[spark] object TestClient { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 718efc4f3bd5..22e4155cc545 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.history import java.io.{BufferedInputStream, FileNotFoundException, InputStream, IOException, OutputStream} import java.util.UUID -import java.util.concurrent.{ExecutorService, Executors, TimeUnit} +import java.util.concurrent.{Executors, ExecutorService, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.mutable @@ -28,6 +28,7 @@ import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.hdfs.DistributedFileSystem +import org.apache.hadoop.hdfs.protocol.HdfsConstants import org.apache.hadoop.security.AccessControlException import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} @@ -167,7 +168,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } throw new IllegalArgumentException(msg) } - if (!fs.getFileStatus(path).isDir) { + if (!fs.getFileStatus(path).isDirectory) { throw new IllegalArgumentException( "Logging directory specified is not a directory: %s".format(logDir)) } @@ -304,7 +305,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) logError("Exception encountered when attempting to update last scan time", e) lastScanTime } finally { - if (!fs.delete(path)) { + if (!fs.delete(path, true)) { logWarning(s"Error deleting ${path}") } } @@ -603,7 +604,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * As of Spark 1.3, these files are consolidated into a single one that replaces the directory. * See SPARK-2261 for more detail. */ - private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDir() + private def isLegacyLogDirectory(entry: FileStatus): Boolean = entry.isDirectory /** * Returns the modification time of the given event log. If the status points at an empty @@ -648,8 +649,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } /** - * Checks whether HDFS is in safe mode. The API is slightly different between hadoop 1 and 2, - * so we have to resort to ugly reflection (as usual...). + * Checks whether HDFS is in safe mode. * * Note that DistributedFileSystem is a `@LimitedPrivate` class, which for all practical reasons * makes it more public than not. @@ -663,19 +663,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // For testing. private[history] def isFsInSafeMode(dfs: DistributedFileSystem): Boolean = { - val hadoop1Class = "org.apache.hadoop.hdfs.protocol.FSConstants$SafeModeAction" - val hadoop2Class = "org.apache.hadoop.hdfs.protocol.HdfsConstants$SafeModeAction" - val actionClass: Class[_] = - try { - getClass().getClassLoader().loadClass(hadoop2Class) - } catch { - case _: ClassNotFoundException => - getClass().getClassLoader().loadClass(hadoop1Class) - } - - val action = actionClass.getField("SAFEMODE_GET").get(null) - val method = dfs.getClass().getMethod("setSafeMode", action.getClass()) - method.invoke(dfs, action).asInstanceOf[Boolean] + dfs.setSafeMode(HdfsConstants.SafeModeAction.SAFEMODE_GET) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index 642d71b18c9e..04bad79dccab 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -21,7 +21,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index f31fef0eccc3..96007a06e3c5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -21,16 +21,17 @@ import java.util.NoSuchElementException import java.util.zip.ZipOutputStream import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} +import scala.util.control.NonFatal + import com.google.common.cache._ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, - UIRoot} +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, UIRoot} import org.apache.spark.ui.{SparkUI, UIUtils, WebUI} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.util.{ShutdownHookManager, SignalLogger, Utils} +import org.apache.spark.util.{ShutdownHookManager, Utils} /** * A web server that renders SparkUIs of completed applications. @@ -115,7 +116,17 @@ class HistoryServer( } def getSparkUI(appKey: String): Option[SparkUI] = { - Option(appCache.get(appKey)) + try { + val ui = appCache.get(appKey) + Some(ui) + } catch { + case NonFatal(e) => e.getCause() match { + case nsee: NoSuchElementException => + None + + case cause: Exception => throw cause + } + } } initialize() @@ -195,7 +206,7 @@ class HistoryServer( appCache.get(appId + attemptId.map { id => s"/$id" }.getOrElse("")) true } catch { - case e: Exception => e.getCause() match { + case NonFatal(e) => e.getCause() match { case nsee: NoSuchElementException => false @@ -223,7 +234,7 @@ object HistoryServer extends Logging { val UI_PATH_PREFIX = "/history" def main(argStrings: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) new HistoryServerArguments(conf, argStrings) initSecurity() val securityManager = new SecurityManager(conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index fc42bf06e40a..0deab8ddd527 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -24,14 +24,13 @@ import java.util.Date import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.Duration import scala.language.postfixOps import scala.util.Random import org.apache.hadoop.fs.Path -import org.apache.spark.rpc._ import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.{ApplicationDescription, DriverDescription, ExecutorState, SparkHadoopUtil} @@ -42,10 +41,11 @@ import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.deploy.rest.StandaloneRestServer import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc._ import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, Utils} private[deploy] class Master( override val rpcEnv: RpcEnv, @@ -979,7 +979,11 @@ private[deploy] class Master( futureUI.onSuccess { case Some(ui) => appIdToUI.put(app.id, ui) - self.send(AttachCompletedRebuildUI(app.id)) + // `self` can be null if we are already in the process of shutting down + // This happens frequently in tests where `local-cluster` is used + if (self != null) { + self.send(AttachCompletedRebuildUI(app.id)) + } // Application UI is successfully rebuilt, so link the Master UI to it // NOTE - app.appUIUrlAtHistoryServer is volatile app.appUIUrlAtHistoryServer = Some(ui.basePath) @@ -1083,7 +1087,7 @@ private[deploy] object Master extends Logging { val ENDPOINT_NAME = "Master" def main(argStrings: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) val (rpcEnv, _, _) = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index 58a00bceee6a..dddf2be57ee4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -17,11 +17,11 @@ package org.apache.spark.deploy.master +import scala.reflect.ClassTag + import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rpc.RpcEnv -import scala.reflect.ClassTag - /** * Allows Master to persist any state that is necessary in order to recover from a failure. * The following semantics are required: diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index d317206a614f..336cb24c19b1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.master -import org.apache.spark.{Logging, SparkConf} import org.apache.curator.framework.CuratorFramework -import org.apache.curator.framework.recipes.leader.{LeaderLatchListener, LeaderLatch} +import org.apache.curator.framework.recipes.leader.{LeaderLatch, LeaderLatchListener} + +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.SparkCuratorUtil private[master] class ZooKeeperLeaderElectionAgent(val masterInstance: LeaderElectable, diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index f405aa2bdc8b..1b18cf0ded69 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -21,8 +21,8 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, RequestMasterState} +import org.apache.spark.deploy.ExecutorState import org.apache.spark.deploy.master.ExecutorDesc import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index ee539dd1f511..f9b0279c3d1e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -23,10 +23,10 @@ import scala.xml.Node import org.json4s.JValue +import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, MasterStateResponse, RequestKillDriver, RequestMasterState} import org.apache.spark.deploy.JsonProtocol -import org.apache.spark.deploy.DeployMessages.{KillDriverResponse, RequestKillDriver, MasterStateResponse, RequestMasterState} import org.apache.spark.deploy.master._ -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index e41554a5a6d2..750ef0a96255 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -19,7 +19,7 @@ package org.apache.spark.deploy.master.ui import org.apache.spark.Logging import org.apache.spark.deploy.master.Master -import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationsListResource, ApplicationInfo, +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, ApplicationsListResource, UIRoot} import org.apache.spark.ui.{SparkUI, WebUI} import org.apache.spark.ui.JettyUtils._ diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala index 389eff5e0645..66e1e645007a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcher.scala @@ -19,11 +19,11 @@ package org.apache.spark.deploy.mesos import java.util.concurrent.CountDownLatch +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.mesos.ui.MesosClusterUI import org.apache.spark.deploy.rest.mesos.MesosRestServer import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.util.{ShutdownHookManager, SignalLogger} -import org.apache.spark.{Logging, SecurityManager, SparkConf} +import org.apache.spark.util.{ShutdownHookManager, Utils} /* * A dispatcher that is responsible for managing and launching drivers, and is intended to be @@ -92,7 +92,7 @@ private[mesos] class MesosClusterDispatcher( private[mesos] object MesosClusterDispatcher extends Logging { def main(args: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) val conf = new SparkConf val dispatcherArgs = new MesosClusterDispatcherArguments(args, conf) conf.setMaster(dispatcherArgs.masterUrl) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala index 8ffcfc0878a4..4172d924c802 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosExternalShuffleService.scala @@ -65,7 +65,7 @@ private[mesos] class MesosExternalShuffleBlockHandler(transportConf: TransportCo /** * On connection termination, clean up shuffle files written by the associated application. */ - override def connectionTerminated(client: TransportClient): Unit = { + override def channelInactive(client: TransportClient): Unit = { val address = client.getSocketAddress if (connectedApps.contains(address)) { val appId = connectedApps(address) diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala index bc67fd460d9a..807835105ec3 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/DriverPage.scala @@ -23,10 +23,9 @@ import scala.xml.Node import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription -import org.apache.spark.scheduler.cluster.mesos.{MesosClusterSubmissionState, MesosClusterRetryState} +import org.apache.spark.scheduler.cluster.mesos.{MesosClusterRetryState, MesosClusterSubmissionState} import org.apache.spark.ui.{UIUtils, WebUIPage} - private[ui] class DriverPage(parent: MesosClusterUI) extends WebUIPage("driver") { override def render(request: HttpServletRequest): Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala index 7419fa969964..166f666fbcfd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterPage.scala @@ -22,6 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.mesos.Protos.TaskStatus + import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.scheduler.cluster.mesos.MesosClusterSubmissionState import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala index 3f693545a034..da9740bb41f5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/ui/MesosClusterUI.scala @@ -17,10 +17,10 @@ package org.apache.spark.deploy.mesos.ui -import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.ui.JettyUtils._ +import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler import org.apache.spark.ui.{SparkUI, WebUI} +import org.apache.spark.ui.JettyUtils._ /** * UI that displays driver results from the [[org.apache.spark.deploy.mesos.MesosClusterDispatcher]] diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index f0dd667ea1b2..4ec6bfe2f9eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -23,15 +23,15 @@ import java.util.concurrent.TimeoutException import javax.servlet.http.HttpServletResponse import scala.collection.mutable -import scala.concurrent.duration._ import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ import scala.io.Source import com.fasterxml.jackson.core.JsonProcessingException import com.google.common.base.Charsets -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion, SparkConf} +import org.apache.spark.util.Utils /** * A client that submits applications to a [[RestSubmissionServer]]. @@ -428,8 +428,10 @@ private[spark] object RestSubmissionClient { * Filter non-spark environment variables from any environment. */ private[rest] def filterSystemEnvironment(env: Map[String, String]): Map[String, String] = { - env.filter { case (k, _) => - (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED") || k.startsWith("MESOS_") + env.filterKeys { k => + // SPARK_HOME is filtered out because it is usually wrong on the remote machine (SPARK-12345) + (k.startsWith("SPARK_") && k != "SPARK_ENV_LOADED" && k != "SPARK_HOME") || + k.startsWith("MESOS_") } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala index 2e78d03e5c0c..8e0862df4c29 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionServer.scala @@ -21,14 +21,15 @@ import java.net.InetSocketAddress import javax.servlet.http.{HttpServlet, HttpServletRequest, HttpServletResponse} import scala.io.Source + import com.fasterxml.jackson.core.JsonProcessingException import org.eclipse.jetty.server.Server -import org.eclipse.jetty.servlet.{ServletHolder, ServletContextHandler} +import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.eclipse.jetty.util.thread.QueuedThreadPool import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION => sparkVersion} +import org.apache.spark.{Logging, SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index d5b9bcab1423..c19296c7b3e0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -20,11 +20,11 @@ package org.apache.spark.deploy.rest import java.io.File import javax.servlet.http.HttpServletResponse -import org.apache.spark.deploy.ClientArguments._ +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.deploy.{Command, DeployMessages, DriverDescription} +import org.apache.spark.deploy.ClientArguments._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils -import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} /** * A server that responds to requests submitted by the [[RestSubmissionClient]]. diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala index 24510db2bd0b..a8b2f788893d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/mesos/MesosRestServer.scala @@ -23,13 +23,12 @@ import java.util.Date import java.util.concurrent.atomic.AtomicLong import javax.servlet.http.HttpServletResponse +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest._ import org.apache.spark.scheduler.cluster.mesos.MesosClusterScheduler import org.apache.spark.util.Utils -import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} - /** * A server that responds to requests submitted by the [[RestSubmissionClient]]. @@ -94,12 +93,7 @@ private[mesos] class MesosSubmitRequestServlet( val driverMemory = sparkProperties.get("spark.driver.memory") val driverCores = sparkProperties.get("spark.driver.cores") val appArgs = request.appArgs - // We don't want to pass down SPARK_HOME when launching Spark apps - // with Mesos cluster mode since it's populated by default on the client and it will - // cause spark-submit script to look for files in SPARK_HOME instead. - // We only need the ability to specify where to find spark-submit script - // which user can user spark.executor.home or spark.home configurations. - val environmentVariables = request.environmentVariables.filter(!_.equals("SPARK_HOME")) + val environmentVariables = request.environmentVariables val name = request.sparkProperties.get("spark.app.name").getOrElse(mainClass) // Construct driver description diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 89159ff5e2b3..6049db6d989a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -25,13 +25,13 @@ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.fs.Path -import org.apache.spark.{Logging, SparkConf, SecurityManager} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{DriverDescription, SparkHadoopUtil} import org.apache.spark.deploy.DeployMessages.DriverStateChanged import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.util.{Utils, Clock, SystemClock} +import org.apache.spark.util.{Clock, SystemClock, Utils} /** * Manages the execution of one driver, including automatically restarting the driver on failure. diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index 9a42487bb37a..c6687a4c63a6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -24,10 +24,10 @@ import scala.collection.JavaConverters._ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files -import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.{SecurityManager, SparkConf, Logging} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.util.logging.FileAppender @@ -60,6 +60,9 @@ private[deploy] class ExecutorRunner( private var stdoutAppender: FileAppender = null private var stderrAppender: FileAppender = null + // Timeout to wait for when trying to terminate an executor. + private val EXECUTOR_TERMINATE_TIMEOUT_MS = 10 * 1000 + // NOTE: This is now redundant with the automated shut-down enforced by the Executor. It might // make sense to remove this in the future. private var shutdownHook: AnyRef = null @@ -94,8 +97,11 @@ private[deploy] class ExecutorRunner( if (stderrAppender != null) { stderrAppender.stop() } - process.destroy() - exitCode = Some(process.waitFor()) + exitCode = Utils.terminateProcess(process, EXECUTOR_TERMINATE_TIMEOUT_MS) + if (exitCode.isEmpty) { + logWarning("Failed to terminate process: " + process + + ". This process will likely be orphaned.") + } } try { worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index f41efb097b4b..98e17da48974 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat -import java.util.{UUID, Date} +import java.util.{Date, UUID} import java.util.concurrent._ import java.util.concurrent.{Future => JFuture, ScheduledFuture => JScheduledFuture} @@ -37,7 +37,7 @@ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.rpc._ -import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, ThreadUtils, Utils} private[deploy] class Worker( override val rpcEnv: RpcEnv, @@ -45,7 +45,6 @@ private[deploy] class Worker( cores: Int, memory: Int, masterRpcAddresses: Array[RpcAddress], - systemName: String, endpointName: String, workDirPath: String = null, val conf: SparkConf, @@ -101,7 +100,7 @@ private[deploy] class Worker( private var master: Option[RpcEndpointRef] = None private var activeMasterUrl: String = "" private[worker] var activeMasterWebUiUrl : String = "" - private val workerUri = rpcEnv.uriOf(systemName, rpcEnv.address, endpointName) + private val workerUri = RpcEndpointAddress(rpcEnv.address, endpointName).toString private var registered = false private var connected = false private val workerId = generateWorkerId() @@ -209,8 +208,7 @@ private[deploy] class Worker( override def run(): Unit = { try { logInfo("Connecting to master " + masterAddress + "...") - val masterEndpoint = - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled @@ -266,8 +264,7 @@ private[deploy] class Worker( override def run(): Unit = { try { logInfo("Connecting to master " + masterAddress + "...") - val masterEndpoint = - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, masterAddress, Master.ENDPOINT_NAME) + val masterEndpoint = rpcEnv.setupEndpointRef(masterAddress, Master.ENDPOINT_NAME) registerWithMaster(masterEndpoint) } catch { case ie: InterruptedException => // Cancelled @@ -686,7 +683,7 @@ private[deploy] object Worker extends Logging { val ENDPOINT_NAME = "Worker" def main(argStrings: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) val rpcEnv = startRpcEnvAndEndpoint(args.host, args.port, args.webUiPort, args.cores, @@ -711,7 +708,7 @@ private[deploy] object Worker extends Logging { val rpcEnv = RpcEnv.create(systemName, host, port, conf, securityMgr) val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL(_)) rpcEnv.setupEndpoint(ENDPOINT_NAME, new Worker(rpcEnv, webUiPort, cores, memory, - masterAddresses, systemName, ENDPOINT_NAME, workDir, conf, securityMgr)) + masterAddresses, ENDPOINT_NAME, workDir, conf, securityMgr)) rpcEnv } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 5181142c5f80..de3c7cd265d2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -175,7 +175,7 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { def checkWorkerMemory(): Unit = { if (memory <= 0) { - val message = "Memory can't be 0, missing a M or G on the end of the memory specification?" + val message = "Memory is below 1MB, or missing a M/G at the end of the memory specification?" throw new IllegalStateException(message) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index 5a1d06eb87db..49803a27a5b0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -23,9 +23,9 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import org.apache.spark.ui.{WebUIPage, UIUtils} -import org.apache.spark.util.Utils import org.apache.spark.Logging +import org.apache.spark.ui.{UIUtils, WebUIPage} +import org.apache.spark.util.Utils import org.apache.spark.util.logging.RollingFileAppender private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with Logging { diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index fd905feb97e9..8ebcbcb6a173 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,16 +17,17 @@ package org.apache.spark.deploy.worker.ui +import javax.servlet.http.HttpServletRequest + import scala.xml.Node -import javax.servlet.http.HttpServletRequest import org.json4s.JValue -import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.DeployMessages.{RequestWorkerState, WorkerStateResponse} +import org.apache.spark.deploy.JsonProtocol import org.apache.spark.deploy.master.DriverState import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} import org.apache.spark.util.Utils private[ui] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index c2ebf3059621..58bd9ca3d12c 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -20,20 +20,18 @@ package org.apache.spark.executor import java.net.URL import java.nio.ByteBuffer -import org.apache.hadoop.conf.Configuration - import scala.collection.mutable import scala.util.{Failure, Success} -import org.apache.spark.rpc._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher +import org.apache.spark.rpc._ import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.serializer.SerializerInstance -import org.apache.spark.util.{ThreadUtils, SignalLogger, Utils} +import org.apache.spark.util.{ThreadUtils, Utils} private[spark] class CoarseGrainedExecutorBackend( override val rpcEnv: RpcEnv, @@ -146,7 +144,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { workerUrl: Option[String], userClassPath: Seq[URL]) { - SignalLogger.register(log) + Utils.initDaemon(log) SparkHadoopUtil.get.runAsSparkUser { () => // Debug code @@ -257,7 +255,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // scalastyle:off println System.err.println( """ - |"Usage: CoarseGrainedExecutorBackend [options] + |Usage: CoarseGrainedExecutorBackend [options] | | Options are: | --driver-url diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 552b644d13aa..9b1418436424 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -30,6 +30,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.rpc.RpcTimeout import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{StorageLevel, TaskResultBlockId} @@ -445,7 +446,8 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse](message) + val response = heartbeatReceiverRef.askWithRetry[HeartbeatResponse]( + message, RpcTimeout(conf, "spark.executor.heartbeatInterval", "10s")) if (response.reregisterBlockManager) { logInfo("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala index c9f18ebc7f0e..cfd9bcd65c56 100644 --- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala @@ -21,15 +21,15 @@ import java.nio.ByteBuffer import scala.collection.JavaConverters._ -import org.apache.mesos.protobuf.ByteString import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver} import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _} +import org.apache.mesos.protobuf.ByteString -import org.apache.spark.{Logging, TaskState, SparkConf, SparkEnv} +import org.apache.spark.{Logging, SparkConf, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler.cluster.mesos.MesosTaskLaunchData -import org.apache.spark.util.{SignalLogger, Utils} +import org.apache.spark.util.Utils private[spark] class MesosExecutorBackend extends MesosExecutor @@ -121,7 +121,7 @@ private[spark] class MesosExecutorBackend */ private[spark] object MesosExecutorBackend extends Logging { def main(args: Array[String]) { - SignalLogger.register(log) + Utils.initDaemon(log) // Create a new Executor and start it running val runner = new MesosExecutorBackend() new MesosExecutorDriver(runner).run() diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala index 532850dd5771..bc98273add3a 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala @@ -19,11 +19,10 @@ package org.apache.spark.input import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.Logging -import org.apache.spark.deploy.SparkHadoopUtil /** * Custom Input Format for reading and splitting flat binary files that contain records, @@ -36,7 +35,7 @@ private[spark] object FixedLengthBinaryInputFormat { /** Retrieves the record length property from a Hadoop configuration */ def getRecordLength(context: JobContext): Int = { - SparkHadoopUtil.get.getConfigurationFromJobContext(context).get(RECORD_LENGTH_PROPERTY).toInt + context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt } } diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala index 67a96925da01..549395314ba6 100644 --- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala @@ -20,11 +20,10 @@ package org.apache.spark.input import java.io.IOException import org.apache.hadoop.fs.FSDataInputStream -import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.io.{BytesWritable, LongWritable} +import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileSplit -import org.apache.spark.deploy.SparkHadoopUtil /** * FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat. @@ -83,16 +82,16 @@ private[spark] class FixedLengthBinaryRecordReader // the actual file we will be reading from val file = fileSplit.getPath // job configuration - val job = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val conf = context.getConfiguration // check compression - val codec = new CompressionCodecFactory(job).getCodec(file) + val codec = new CompressionCodecFactory(conf).getCodec(file) if (codec != null) { throw new IOException("FixedLengthRecordReader does not support reading compressed files") } // get the record length recordLength = FixedLengthBinaryInputFormat.getRecordLength(context) // get the filesystem - val fs = file.getFileSystem(job) + val fs = file.getFileSystem(conf) // open the File fileInputStream = fs.open(file) // seek to the splitStart position diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala index 280e7a5fe893..18cb7631b3d4 100644 --- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala +++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala @@ -21,14 +21,12 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, Da import scala.collection.JavaConverters._ -import com.google.common.io.{Closeables, ByteStreams} +import com.google.common.io.{ByteStreams, Closeables} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit} -import org.apache.spark.deploy.SparkHadoopUtil - /** * A general format for reading whole files in as streams, byte arrays, * or other functions to be added @@ -43,9 +41,8 @@ private[spark] abstract class StreamFileInputFormat[T] * which is set through setMaxSplitSize */ def setMinPartitions(context: JobContext, minPartitions: Int) { - val files = listStatus(context).asScala - val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum - val maxSplitSize = Math.ceil(totalLen * 1.0 / files.size).toLong + val totalLen = listStatus(context).asScala.filterNot(_.isDirectory).map(_.getLen).sum + val maxSplitSize = math.ceil(totalLen / math.max(minPartitions, 1.0)).toLong super.setMaxSplitSize(maxSplitSize) } @@ -135,8 +132,7 @@ class PortableDataStream( private val confBytes = { val baos = new ByteArrayOutputStream() - SparkHadoopUtil.get.getConfigurationFromJobContext(context). - write(new DataOutputStream(baos)) + context.getConfiguration.write(new DataOutputStream(baos)) baos.toByteArray } diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala index 413408723b54..fa34f1e886c7 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala @@ -53,7 +53,7 @@ private[spark] class WholeTextFileInputFormat */ def setMinPartitions(context: JobContext, minPartitions: Int) { val files = listStatus(context).asScala - val totalLen = files.map(file => if (file.isDir) 0L else file.getLen).sum + val totalLen = files.map(file => if (file.isDirectory) 0L else file.getLen).sum val maxSplitSize = Math.ceil(totalLen * 1.0 / (if (minPartitions == 0) 1 else minPartitions)).toLong super.setMaxSplitSize(maxSplitSize) diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala index b56b2aa88a41..6b7f086678e9 100644 --- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala +++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala @@ -17,17 +17,14 @@ package org.apache.spark.input -import org.apache.hadoop.conf.{Configuration, Configurable => HConfigurable} import com.google.common.io.{ByteStreams, Closeables} - +import org.apache.hadoop.conf.{Configurable => HConfigurable, Configuration} import org.apache.hadoop.io.Text import org.apache.hadoop.io.compress.CompressionCodecFactory import org.apache.hadoop.mapreduce.InputSplit -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader} import org.apache.hadoop.mapreduce.RecordReader import org.apache.hadoop.mapreduce.TaskAttemptContext -import org.apache.spark.deploy.SparkHadoopUtil - +import org.apache.hadoop.mapreduce.lib.input.{CombineFileRecordReader, CombineFileSplit} /** * A trait to implement [[org.apache.hadoop.conf.Configurable Configurable]] interface. @@ -52,8 +49,7 @@ private[spark] class WholeTextFileRecordReader( extends RecordReader[Text, Text] with Configurable { private[this] val path = split.getPath(index) - private[this] val fs = path.getFileSystem( - SparkHadoopUtil.get.getConfigurationFromJobContext(context)) + private[this] val fs = path.getFileSystem(context.getConfiguration) // True means the current file has been processed, then skip it. private[this] var processed = false diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index ca74eedf89be..717804626f85 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -17,10 +17,10 @@ package org.apache.spark.io -import java.io.{IOException, InputStream, OutputStream} +import java.io._ import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} -import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} +import net.jpountz.lz4.LZ4BlockOutputStream import org.xerial.snappy.{Snappy, SnappyInputStream, SnappyOutputStream} import org.apache.spark.SparkConf @@ -49,7 +49,8 @@ private[spark] object CompressionCodec { private val configKey = "spark.io.compression.codec" private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = { - codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] + (codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec] + || codec.isInstanceOf[LZ4CompressionCodec]) } private val shortCompressionCodecNames = Map( @@ -92,12 +93,11 @@ private[spark] object CompressionCodec { } } - val FALLBACK_COMPRESSION_CODEC = "lzf" - val DEFAULT_COMPRESSION_CODEC = "snappy" + val FALLBACK_COMPRESSION_CODEC = "snappy" + val DEFAULT_COMPRESSION_CODEC = "lz4" val ALL_COMPRESSION_CODECS = shortCompressionCodecNames.values.toSeq } - /** * :: DeveloperApi :: * LZ4 implementation of [[org.apache.spark.io.CompressionCodec]]. diff --git a/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java b/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java new file mode 100644 index 000000000000..27b6f0d4a388 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/io/LZ4BlockInputStream.java @@ -0,0 +1,263 @@ +package org.apache.spark.io; + +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.io.EOFException; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.zip.Checksum; + +import net.jpountz.lz4.LZ4BlockOutputStream; +import net.jpountz.lz4.LZ4Exception; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.lz4.LZ4FastDecompressor; +import net.jpountz.util.SafeUtils; +import net.jpountz.xxhash.StreamingXXHash32; +import net.jpountz.xxhash.XXHash32; +import net.jpountz.xxhash.XXHashFactory; + +/** + * {@link InputStream} implementation to decode data written with + * {@link LZ4BlockOutputStream}. This class is not thread-safe and does not + * support {@link #mark(int)}/{@link #reset()}. + * @see LZ4BlockOutputStream + * + * This is based on net.jpountz.lz4.LZ4BlockInputStream + * + * changes: https://github.com/davies/lz4-java/commit/cc1fa940ac57cc66a0b937300f805d37e2bf8411 + * + * TODO: merge this into upstream + */ +public final class LZ4BlockInputStream extends FilterInputStream { + + // Copied from net.jpountz.lz4.LZ4BlockOutputStream + static final byte[] MAGIC = new byte[] { 'L', 'Z', '4', 'B', 'l', 'o', 'c', 'k' }; + static final int MAGIC_LENGTH = MAGIC.length; + + static final int HEADER_LENGTH = + MAGIC_LENGTH // magic bytes + + 1 // token + + 4 // compressed length + + 4 // decompressed length + + 4; // checksum + + static final int COMPRESSION_LEVEL_BASE = 10; + + static final int COMPRESSION_METHOD_RAW = 0x10; + static final int COMPRESSION_METHOD_LZ4 = 0x20; + + static final int DEFAULT_SEED = 0x9747b28c; + + private final LZ4FastDecompressor decompressor; + private final Checksum checksum; + private byte[] buffer; + private byte[] compressedBuffer; + private int originalLen; + private int o; + private boolean finished; + + /** + * Create a new {@link InputStream}. + * + * @param in the {@link InputStream} to poll + * @param decompressor the {@link LZ4FastDecompressor decompressor} instance to + * use + * @param checksum the {@link Checksum} instance to use, must be + * equivalent to the instance which has been used to + * write the stream + */ + public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor, Checksum checksum) { + super(in); + this.decompressor = decompressor; + this.checksum = checksum; + this.buffer = new byte[0]; + this.compressedBuffer = new byte[HEADER_LENGTH]; + o = originalLen = 0; + finished = false; + } + + /** + * Create a new instance using {@link XXHash32} for checksuming. + * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor, Checksum) + * @see StreamingXXHash32#asChecksum() + */ + public LZ4BlockInputStream(InputStream in, LZ4FastDecompressor decompressor) { + this(in, decompressor, XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum()); + } + + /** + * Create a new instance which uses the fastest {@link LZ4FastDecompressor} available. + * @see LZ4Factory#fastestInstance() + * @see #LZ4BlockInputStream(InputStream, LZ4FastDecompressor) + */ + public LZ4BlockInputStream(InputStream in) { + this(in, LZ4Factory.fastestInstance().fastDecompressor()); + } + + @Override + public int available() throws IOException { + refill(); + return originalLen - o; + } + + @Override + public int read() throws IOException { + refill(); + if (finished) { + return -1; + } + return buffer[o++] & 0xFF; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + SafeUtils.checkRange(b, off, len); + refill(); + if (finished) { + return -1; + } + len = Math.min(len, originalLen - o); + System.arraycopy(buffer, o, b, off, len); + o += len; + return len; + } + + @Override + public int read(byte[] b) throws IOException { + return read(b, 0, b.length); + } + + @Override + public long skip(long n) throws IOException { + refill(); + if (finished) { + return -1; + } + final int skipped = (int) Math.min(n, originalLen - o); + o += skipped; + return skipped; + } + + private void refill() throws IOException { + if (finished || o < originalLen) { + return; + } + try { + readFully(compressedBuffer, HEADER_LENGTH); + } catch (EOFException e) { + finished = true; + return; + } + for (int i = 0; i < MAGIC_LENGTH; ++i) { + if (compressedBuffer[i] != MAGIC[i]) { + throw new IOException("Stream is corrupted"); + } + } + final int token = compressedBuffer[MAGIC_LENGTH] & 0xFF; + final int compressionMethod = token & 0xF0; + final int compressionLevel = COMPRESSION_LEVEL_BASE + (token & 0x0F); + if (compressionMethod != COMPRESSION_METHOD_RAW && compressionMethod != COMPRESSION_METHOD_LZ4) + { + throw new IOException("Stream is corrupted"); + } + final int compressedLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 1); + originalLen = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 5); + final int check = SafeUtils.readIntLE(compressedBuffer, MAGIC_LENGTH + 9); + assert HEADER_LENGTH == MAGIC_LENGTH + 13; + if (originalLen > 1 << compressionLevel + || originalLen < 0 + || compressedLen < 0 + || (originalLen == 0 && compressedLen != 0) + || (originalLen != 0 && compressedLen == 0) + || (compressionMethod == COMPRESSION_METHOD_RAW && originalLen != compressedLen)) { + throw new IOException("Stream is corrupted"); + } + if (originalLen == 0 && compressedLen == 0) { + if (check != 0) { + throw new IOException("Stream is corrupted"); + } + refill(); + return; + } + if (buffer.length < originalLen) { + buffer = new byte[Math.max(originalLen, buffer.length * 3 / 2)]; + } + switch (compressionMethod) { + case COMPRESSION_METHOD_RAW: + readFully(buffer, originalLen); + break; + case COMPRESSION_METHOD_LZ4: + if (compressedBuffer.length < originalLen) { + compressedBuffer = new byte[Math.max(compressedLen, compressedBuffer.length * 3 / 2)]; + } + readFully(compressedBuffer, compressedLen); + try { + final int compressedLen2 = + decompressor.decompress(compressedBuffer, 0, buffer, 0, originalLen); + if (compressedLen != compressedLen2) { + throw new IOException("Stream is corrupted"); + } + } catch (LZ4Exception e) { + throw new IOException("Stream is corrupted", e); + } + break; + default: + throw new AssertionError(); + } + checksum.reset(); + checksum.update(buffer, 0, originalLen); + if ((int) checksum.getValue() != check) { + throw new IOException("Stream is corrupted"); + } + o = 0; + } + + private void readFully(byte[] b, int len) throws IOException { + int read = 0; + while (read < len) { + final int r = in.read(b, read, len - read); + if (r < 0) { + throw new EOFException("Stream ended prematurely"); + } + read += r; + } + assert len == read; + } + + @Override + public boolean markSupported() { + return false; + } + + @SuppressWarnings("sync-override") + @Override + public void mark(int readlimit) { + // unsupported + } + + @SuppressWarnings("sync-override") + @Override + public void reset() throws IOException { + throw new IOException("mark/reset not supported"); + } + + @Override + public String toString() { + return getClass().getSimpleName() + "(in=" + in + + ", decompressor=" + decompressor + ", checksum=" + checksum + ")"; + } + +} diff --git a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala index f7298e8d5c62..6841485f4b93 100644 --- a/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala +++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala @@ -18,61 +18,12 @@ package org.apache.spark.mapred import java.io.IOException -import java.lang.reflect.Modifier -import org.apache.hadoop.mapred._ import org.apache.hadoop.mapreduce.{TaskAttemptContext => MapReduceTaskAttemptContext} import org.apache.hadoop.mapreduce.{OutputCommitter => MapReduceOutputCommitter} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.CommitDeniedException import org.apache.spark.{Logging, SparkEnv, TaskContext} -import org.apache.spark.util.{Utils => SparkUtils} - -private[spark] -trait SparkHadoopMapRedUtil { - def newJobContext(conf: JobConf, jobId: JobID): JobContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl", - "org.apache.hadoop.mapred.JobContext") - val ctor = klass.getDeclaredConstructor(classOf[JobConf], - classOf[org.apache.hadoop.mapreduce.JobID]) - // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private. - // Make it accessible if it's not in order to access it. - if (!Modifier.isPublic(ctor.getModifiers)) { - ctor.setAccessible(true) - } - ctor.newInstance(conf, jobId).asInstanceOf[JobContext] - } - - def newTaskAttemptContext(conf: JobConf, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl", - "org.apache.hadoop.mapred.TaskAttemptContext") - val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID]) - // See above - if (!Modifier.isPublic(ctor.getModifiers)) { - ctor.setAccessible(true) - } - ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] - } - - def newTaskAttemptID( - jtIdentifier: String, - jobId: Int, - isMap: Boolean, - taskId: Int, - attemptId: Int): TaskAttemptID = { - new TaskAttemptID(jtIdentifier, jobId, isMap, taskId, attemptId) - } - - private def firstAvailableClass(first: String, second: String): Class[_] = { - try { - SparkUtils.classForName(first) - } catch { - case e: ClassNotFoundException => - SparkUtils.classForName(second) - } - } -} +import org.apache.spark.executor.CommitDeniedException object SparkHadoopMapRedUtil extends Logging { /** @@ -93,7 +44,7 @@ object SparkHadoopMapRedUtil extends Logging { jobId: Int, splitId: Int): Unit = { - val mrTaskAttemptID = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(mrTaskContext) + val mrTaskAttemptID = mrTaskContext.getTaskAttemptID // Called after we have decided to commit def performCommit(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala deleted file mode 100644 index 943ebcb7bd0a..000000000000 --- a/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mapreduce - -import java.lang.{Boolean => JBoolean, Integer => JInteger} - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID} -import org.apache.spark.util.Utils - -private[spark] -trait SparkHadoopMapReduceUtil { - def newJobContext(conf: Configuration, jobId: JobID): JobContext = { - val klass = firstAvailableClass( - "org.apache.hadoop.mapreduce.task.JobContextImpl", // hadoop2, hadoop2-yarn - "org.apache.hadoop.mapreduce.JobContext") // hadoop1 - val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[JobID]) - ctor.newInstance(conf, jobId).asInstanceOf[JobContext] - } - - def newTaskAttemptContext(conf: Configuration, attemptId: TaskAttemptID): TaskAttemptContext = { - val klass = firstAvailableClass( - "org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl", // hadoop2, hadoop2-yarn - "org.apache.hadoop.mapreduce.TaskAttemptContext") // hadoop1 - val ctor = klass.getDeclaredConstructor(classOf[Configuration], classOf[TaskAttemptID]) - ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext] - } - - def newTaskAttemptID( - jtIdentifier: String, - jobId: Int, - isMap: Boolean, - taskId: Int, - attemptId: Int): TaskAttemptID = { - val klass = Utils.classForName("org.apache.hadoop.mapreduce.TaskAttemptID") - try { - // First, attempt to use the old-style constructor that takes a boolean isMap - // (not available in YARN) - val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], classOf[Boolean], - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), new JBoolean(isMap), new JInteger(taskId), - new JInteger(attemptId)).asInstanceOf[TaskAttemptID] - } catch { - case exc: NoSuchMethodException => { - // If that failed, look for the new constructor that takes a TaskType (not available in 1.x) - val taskTypeClass = Utils.classForName("org.apache.hadoop.mapreduce.TaskType") - .asInstanceOf[Class[Enum[_]]] - val taskType = taskTypeClass.getMethod("valueOf", classOf[String]).invoke( - taskTypeClass, if (isMap) "MAP" else "REDUCE") - val ctor = klass.getDeclaredConstructor(classOf[String], classOf[Int], taskTypeClass, - classOf[Int], classOf[Int]) - ctor.newInstance(jtIdentifier, new JInteger(jobId), taskType, new JInteger(taskId), - new JInteger(attemptId)).asInstanceOf[TaskAttemptID] - } - } - } - - private def firstAvailableClass(first: String, second: String): Class[_] = { - try { - Utils.classForName(first) - } catch { - case e: ClassNotFoundException => - Utils.classForName(second) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala index e707e27d96b5..33f8b9f16c11 100644 --- a/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/MemoryManager.scala @@ -21,7 +21,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.memory.MemoryAllocator diff --git a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala index 70af83b5ee09..4036484aada2 100644 --- a/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala +++ b/core/src/main/scala/org/apache/spark/memory/StorageMemoryPool.scala @@ -22,8 +22,8 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{TaskContext, Logging} -import org.apache.spark.storage.{MemoryStore, BlockStatus, BlockId} +import org.apache.spark.{Logging, TaskContext} +import org.apache.spark.storage.{BlockId, BlockStatus, MemoryStore} /** * Performs bookkeeping for managing an adjustable-size pool of memory that is used for storage diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index 829f054dba0e..57a24ac14028 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.memory import scala.collection.mutable import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockStatus, BlockId} +import org.apache.spark.storage.{BlockId, BlockStatus} /** * A [[MemoryManager]] that enforces a soft boundary between execution and storage such that diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala index dd2d325d8703..8540984bfe82 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsConfig.scala @@ -24,8 +24,8 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.matching.Regex -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils private[spark] class MetricsConfig(conf: SparkConf) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index fdf76d312db3..e34cfc698dce 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -20,8 +20,6 @@ package org.apache.spark.metrics import java.util.Properties import java.util.concurrent.TimeUnit -import org.apache.spark.util.Utils - import scala.collection.mutable import com.codahale.metrics.{Metric, MetricFilter, MetricRegistry} @@ -30,6 +28,7 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.metrics.sink.{MetricsServlet, Sink} import org.apache.spark.metrics.source.Source +import org.apache.spark.util.Utils /** * Spark Metrics System, created by specific "instance", combined by source, diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index 2d25ebd66159..22454e50b14b 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -22,7 +22,7 @@ import java.util.Properties import java.util.concurrent.TimeUnit import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.graphite.{GraphiteUDP, Graphite, GraphiteReporter} +import com.codahale.metrics.graphite.{Graphite, GraphiteReporter, GraphiteUDP} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala index 2588fe2c9edb..1992b42ac7f6 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/JmxSink.scala @@ -20,6 +20,7 @@ package org.apache.spark.metrics.sink import java.util.Properties import com.codahale.metrics.{JmxReporter, MetricRegistry} + import org.apache.spark.SecurityManager private[spark] class JmxSink(val property: Properties, val registry: MetricRegistry, diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala index 4193e1d21d3c..68b58b849064 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/MetricsServlet.scala @@ -19,7 +19,6 @@ package org.apache.spark.metrics.sink import java.util.Properties import java.util.concurrent.TimeUnit - import javax.servlet.http.HttpServletRequest import com.codahale.metrics.MetricRegistry @@ -27,7 +26,7 @@ import com.codahale.metrics.json.MetricsModule import com.fasterxml.jackson.databind.ObjectMapper import org.eclipse.jetty.servlet.ServletContextHandler -import org.apache.spark.{SparkConf, SecurityManager} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.ui.JettyUtils._ private[spark] class MetricsServlet( diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala index 11dfcfe2f04e..773e074336cb 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/Slf4jSink.scala @@ -20,7 +20,7 @@ package org.apache.spark.metrics.sink import java.util.Properties import java.util.concurrent.TimeUnit -import com.codahale.metrics.{Slf4jReporter, MetricRegistry} +import com.codahale.metrics.{MetricRegistry, Slf4jReporter} import org.apache.spark.SecurityManager import org.apache.spark.metrics.MetricsSystem diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index dcbda5a8515d..15d3540f3427 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -20,13 +20,13 @@ package org.apache.spark.network import java.io.Closeable import java.nio.ByteBuffer -import scala.concurrent.{Promise, Await, Future} +import scala.concurrent.{Await, Future, Promise} import scala.concurrent.duration.Duration import org.apache.spark.Logging -import org.apache.spark.network.buffer.{NioManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{ShuffleClient, BlockFetchingListener} -import org.apache.spark.storage.{BlockManagerId, BlockId, StorageLevel} +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} +import org.apache.spark.storage.{BlockId, BlockManagerId, StorageLevel} private[spark] abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index 40604a4da18d..f588a28eed28 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -25,10 +25,10 @@ import scala.concurrent.{Future, Promise} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer -import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory} +import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.sasl.{SaslClientBootstrap, SaslServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala index 84833f59d7af..86874e2067dd 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala @@ -18,7 +18,7 @@ package org.apache.spark.network.netty import org.apache.spark.SparkConf -import org.apache.spark.network.util.{TransportConf, ConfigProvider} +import org.apache.spark.network.util.{ConfigProvider, TransportConf} /** * Provides a utility for transforming from a SparkConf inside a Spark JVM (e.g., Executor, diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala index 7515aad09db7..cc5e7ef3ae00 100644 --- a/core/src/main/scala/org/apache/spark/package.scala +++ b/core/src/main/scala/org/apache/spark/package.scala @@ -43,5 +43,5 @@ package org.apache package object spark { // For package docs only - val SPARK_VERSION = "1.6.0-SNAPSHOT" + val SPARK_VERSION = "2.0.0-SNAPSHOT" } diff --git a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala index 828bf96c2c0b..55acb9ca64d3 100644 --- a/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala +++ b/core/src/main/scala/org/apache/spark/partial/StudentTCacher.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution} +import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} /** * A utility class for caching Student's T distribution values for a given confidence level diff --git a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala index 1753c2561b67..44295e5a1aff 100644 --- a/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala +++ b/core/src/main/scala/org/apache/spark/partial/SumEvaluator.scala @@ -17,7 +17,7 @@ package org.apache.spark.partial -import org.apache.commons.math3.distribution.{TDistribution, NormalDistribution} +import org.apache.commons.math3.distribution.{NormalDistribution, TDistribution} import org.apache.spark.util.StatCounter diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index 14f541f937b4..7de9df1e489f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -20,10 +20,10 @@ package org.apache.spark.rdd import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Future, ExecutionContext} +import scala.concurrent.{ExecutionContext, Future} import scala.reflect.ClassTag -import org.apache.spark.{JobSubmitter, ComplexFutureAction, FutureAction, Logging} +import org.apache.spark.{ComplexFutureAction, FutureAction, JobSubmitter, Logging} import org.apache.spark.util.ThreadUtils /** @@ -68,7 +68,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi val localProperties = self.context.getLocalProperties // Cached thread pool to handle aggregation of subtasks. implicit val executionContext = AsyncRDDActions.futureExecutionContext - val results = new ArrayBuffer[T](num) + val results = new ArrayBuffer[T] val totalParts = self.partitions.length /* @@ -77,13 +77,13 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi This implementation is non-blocking, asynchronously handling the results of each job and triggering the next job using callbacks on futures. */ - def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter) : Future[Seq[T]] = + def continue(partsScanned: Int)(implicit jobSubmitter: JobSubmitter): Future[Seq[T]] = if (results.size >= num || partsScanned >= totalParts) { Future.successful(results.toSeq) } else { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -99,7 +99,7 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi } val left = num - results.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val buf = new Array[Array[T]](p.size) self.context.setCallSite(callSite) @@ -109,9 +109,9 @@ class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Loggi p, (index: Int, data: Array[T]) => buf(index) = data, Unit) - job.flatMap {_ => + job.flatMap { _ => buf.foreach(results ++= _.take(num - results.size)) - continue(partsScanned + numPartsToTry) + continue(partsScanned + p.size) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala index aedced7408cd..be0cb175f534 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BinaryFileRDD.scala @@ -20,8 +20,10 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{ Configurable, Configuration } import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ +import org.apache.hadoop.mapreduce.task.JobContextImpl + +import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.StreamFileInputFormat -import org.apache.spark.{ Partition, SparkContext } private[spark] class BinaryFileRDD[T]( sc: SparkContext, @@ -40,7 +42,7 @@ private[spark] class BinaryFileRDD[T]( configurable.setConf(conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = new JobContextImpl(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index fc1710fbad0a..8358244987a6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.storage.{BlockId, BlockManager} -import scala.Some private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends Partition { val index = idx diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 3a0ca1d81329..3587e7eb1afa 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -17,18 +17,17 @@ package org.apache.spark.rdd -import scala.language.existentials - import java.io.{IOException, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer +import scala.language.existentials import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.serializer.Serializer import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils -import org.apache.spark.serializer.Serializer /** The references to rdd and splitIndex are transient because redundant information is stored * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 7fbaadcea3a3..c07f346bbafd 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -17,8 +17,8 @@ package org.apache.spark.rdd +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.annotation.Experimental -import org.apache.spark.{TaskContext, Logging} import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.MeanEvaluator import org.apache.spark.partial.PartialResult diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index f37c95bedc0a..a7a6e0b8a94f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -17,25 +17,26 @@ package org.apache.spark.rdd +import java.io.EOFException import java.text.SimpleDateFormat import java.util.Date -import java.io.EOFException import scala.collection.immutable.Map -import scala.reflect.ClassTag import scala.collection.mutable.ListBuffer +import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.mapred.FileSplit import org.apache.hadoop.mapred.InputFormat import org.apache.hadoop.mapred.InputSplit import org.apache.hadoop.mapred.JobConf +import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.RecordReader import org.apache.hadoop.mapred.Reporter -import org.apache.hadoop.mapred.JobID import org.apache.hadoop.mapred.TaskAttemptID import org.apache.hadoop.mapred.TaskID import org.apache.hadoop.mapred.lib.CombineFileSplit +import org.apache.hadoop.mapreduce.TaskType import org.apache.hadoop.util.ReflectionUtils import org.apache.spark._ @@ -44,9 +45,9 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager, NextIterator, Utils} -import org.apache.spark.scheduler.{HostTaskLocation, HDFSCacheTaskLocation} +import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{NextIterator, SerializableConfiguration, ShutdownHookManager, Utils} /** * A Spark split class that wraps around a Hadoop InputSplit. @@ -357,7 +358,7 @@ private[spark] object HadoopRDD extends Logging { def addLocalConfiguration(jobTrackerId: String, jobId: Int, splitId: Int, attemptId: Int, conf: JobConf) { val jobID = new JobID(jobTrackerId, jobId) - val taId = new TaskAttemptID(new TaskID(jobID, true, splitId), attemptId) + val taId = new TaskAttemptID(new TaskID(jobID, TaskType.MAP, splitId), attemptId) conf.set("mapred.tip.id", taId.getTaskID.toString) conf.set("mapred.task.id", taId.toString) diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 0c28f045e46e..469962db6763 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -17,15 +17,15 @@ package org.apache.spark.rdd -import java.sql.{PreparedStatement, Connection, ResultSet} +import java.sql.{Connection, PreparedStatement, ResultSet} import scala.reflect.ClassTag +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.util.NextIterator -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { override def index: Int = idx diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 4312d3a41775..e4587c96eae1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -25,7 +25,7 @@ import org.apache.spark.{Partition, TaskContext} * An RDD that applies the provided function to every partition of the parent RDD. */ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( - prev: RDD[T], + var prev: RDD[T], f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) preservesPartitioning: Boolean = false) extends RDD[U](prev) { @@ -36,4 +36,9 @@ private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[U] = f(context, split.index, firstParent[T].iterator(split, context)) + + override def clearDependencies() { + super.clearDependencies() + prev = null + } } diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 86f38ae836b2..7a1197830443 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -24,17 +24,18 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable +import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} -import org.apache.spark.annotation.DeveloperApi import org.apache.spark._ +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} private[spark] class NewHadoopPartition( rddId: Int, @@ -66,9 +67,7 @@ class NewHadoopRDD[K, V]( keyClass: Class[K], valueClass: Class[V], @transient private val _conf: Configuration) - extends RDD[(K, V)](sc, Nil) - with SparkHadoopMapReduceUtil - with Logging { + extends RDD[(K, V)](sc, Nil) with Logging { // A Hadoop Configuration can be about 10 KB, which is pretty big, so broadcast it private val confBroadcast = sc.broadcast(new SerializableConfiguration(_conf)) @@ -95,7 +94,13 @@ class NewHadoopRDD[K, V]( // issues, this cloning is disabled by default. NewHadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { logDebug("Cloning Hadoop Configuration") - new Configuration(conf) + // The Configuration passed in is actually a JobConf and possibly contains credentials. + // To keep those credentials properly we have to create a new JobConf not a Configuration. + if (conf.isInstanceOf[JobConf]) { + new JobConf(conf) + } else { + new Configuration(conf) + } } } else { conf @@ -109,7 +114,7 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val jobContext = newJobContext(_conf, jobId) + val jobContext = new JobContextImpl(_conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -144,8 +149,8 @@ class NewHadoopRDD[K, V]( configurable.setConf(conf) case _ => } - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) private var reader = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 44d195587a08..16a856f594e9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import java.text.SimpleDateFormat import java.util.{Date, HashMap => JHashMap} -import scala.collection.{Map, mutable} +import scala.collection.{mutable, Map} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag @@ -33,15 +33,14 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.io.SequenceFile.CompressionType import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat} -import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, - RecordWriter => NewRecordWriter} +import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptID, TaskType} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl import org.apache.spark._ import org.apache.spark.Partitioner.defaultPartitioner import org.apache.spark.annotation.Experimental import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.{DataWriteMethod, OutputMetrics} -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.serializer.Serializer import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -53,10 +52,7 @@ import org.apache.spark.util.random.StratifiedSamplingUtils */ class PairRDDFunctions[K, V](self: RDD[(K, V)]) (implicit kt: ClassTag[K], vt: ClassTag[V], ord: Ordering[K] = null) - extends Logging - with SparkHadoopMapReduceUtil - with Serializable -{ + extends Logging with Serializable { /** * :: Experimental :: @@ -363,12 +359,6 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) self.mapPartitions(reducePartition).reduce(mergeMaps).asScala } - /** Alias for reduceByKeyLocally */ - @deprecated("Use reduceByKeyLocally", "1.0.0") - def reduceByKeyToDriver(func: (V, V) => V): Map[K, V] = self.withScope { - reduceByKeyLocally(func) - } - /** * Count the number of elements for each key, collecting the results to a local Map. * @@ -985,11 +975,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) conf: Configuration = self.context.hadoopConfiguration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val job = new NewAPIHadoopJob(hadoopConf) + val job = NewAPIHadoopJob.getInstance(hadoopConf) job.setOutputKeyClass(keyClass) job.setOutputValueClass(valueClass) job.setOutputFormatClass(outputFormatClass) - val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfiguration = job.getConfiguration jobConfiguration.set("mapred.output.dir", path) saveAsNewAPIHadoopDataset(jobConfiguration) } @@ -1074,11 +1064,11 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) def saveAsNewAPIHadoopDataset(conf: Configuration): Unit = self.withScope { // Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038). val hadoopConf = conf - val job = new NewAPIHadoopJob(hadoopConf) + val job = NewAPIHadoopJob.getInstance(hadoopConf) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) val stageId = self.id - val jobConfiguration = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfiguration = job.getConfiguration val wrappedConf = new SerializableConfiguration(jobConfiguration) val outfmt = job.getOutputFormatClass val jobFormat = outfmt.newInstance @@ -1091,9 +1081,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) val writeShard = (context: TaskContext, iter: Iterator[(K, V)]) => { val config = wrappedConf.value /* "reduce task" */ - val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId, + val attemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.REDUCE, context.partitionId, context.attemptNumber) - val hadoopContext = newTaskAttemptContext(config, attemptId) + val hadoopContext = new TaskAttemptContextImpl(config, attemptId) val format = outfmt.newInstance format match { case c: Configurable => c.setConf(config) @@ -1125,8 +1115,8 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) 1 } : Int - val jobAttemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = true, 0, 0) - val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) + val jobAttemptId = new TaskAttemptID(jobtrackerID, stageId, TaskType.MAP, 0, 0) + val jobTaskContext = new TaskAttemptContextImpl(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) // When speculation is on and output committer class name contains "Direct", we should warn diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 9fe9d83a705b..de7102f5b624 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapred.TextOutputFormat import org.apache.spark._ import org.apache.spark.Partitioner._ -import org.apache.spark.annotation.{Since, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator @@ -40,7 +40,7 @@ import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{BoundedPriorityQueue, Utils} import org.apache.spark.util.collection.OpenHashMap -import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, BernoulliCellSampler, +import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler, SamplingUtils} /** @@ -746,99 +746,6 @@ abstract class RDD[T: ClassTag]( preservesPartitioning) } - /** - * :: DeveloperApi :: - * Return a new RDD by applying a function to each partition of this RDD. This is a variant of - * mapPartitions that also passes the TaskContext into the closure. - * - * `preservesPartitioning` indicates whether the input function preserves the partitioner, which - * should be `false` unless this is a pair RDD and the input function doesn't modify the keys. - */ - @DeveloperApi - @deprecated("use TaskContext.get", "1.2.0") - def mapPartitionsWithContext[U: ClassTag]( - f: (TaskContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = withScope { - val cleanF = sc.clean(f) - val func = (context: TaskContext, index: Int, iter: Iterator[T]) => cleanF(context, iter) - new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) - } - - /** - * Return a new RDD by applying a function to each partition of this RDD, while tracking the index - * of the original partition. - */ - @deprecated("use mapPartitionsWithIndex", "0.7.0") - def mapPartitionsWithSplit[U: ClassTag]( - f: (Int, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false): RDD[U] = withScope { - mapPartitionsWithIndex(f, preservesPartitioning) - } - - /** - * Maps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex", "1.0.0") - def mapWith[A, U: ClassTag] - (constructA: Int => A, preservesPartitioning: Boolean = false) - (f: (T, A) => U): RDD[U] = withScope { - val cleanF = sc.clean(f) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex((index, iter) => { - val a = cleanA(index) - iter.map(t => cleanF(t, a)) - }, preservesPartitioning) - } - - /** - * FlatMaps f over this RDD, where f takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex and flatMap", "1.0.0") - def flatMapWith[A, U: ClassTag] - (constructA: Int => A, preservesPartitioning: Boolean = false) - (f: (T, A) => Seq[U]): RDD[U] = withScope { - val cleanF = sc.clean(f) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex((index, iter) => { - val a = cleanA(index) - iter.flatMap(t => cleanF(t, a)) - }, preservesPartitioning) - } - - /** - * Applies f to each element of this RDD, where f takes an additional parameter of type A. - * This additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex and foreach", "1.0.0") - def foreachWith[A](constructA: Int => A)(f: (T, A) => Unit): Unit = withScope { - val cleanF = sc.clean(f) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex { (index, iter) => - val a = cleanA(index) - iter.map(t => {cleanF(t, a); t}) - } - } - - /** - * Filters this RDD with p, where p takes an additional parameter of type A. This - * additional parameter is produced by constructA, which is called in each - * partition with the index of that partition. - */ - @deprecated("use mapPartitionsWithIndex and filter", "1.0.0") - def filterWith[A](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = withScope { - val cleanP = sc.clean(p) - val cleanA = sc.clean(constructA) - mapPartitionsWithIndex((index, iter) => { - val a = cleanA(index) - iter.filter(t => cleanP(t, a)) - }, preservesPartitioning = true) - } - /** * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, * second element in each RDD, etc. Assumes that the two RDDs have the *same number of @@ -944,14 +851,6 @@ abstract class RDD[T: ClassTag]( (0 until partitions.length).iterator.flatMap(i => collectPartition(i)) } - /** - * Return an array that contains all of the elements in this RDD. - */ - @deprecated("use collect", "1.0.0") - def toArray(): Array[T] = withScope { - collect() - } - /** * Return an RDD that contains all matching values by applying `f`. */ @@ -1295,7 +1194,7 @@ abstract class RDD[T: ClassTag]( while (buf.size < num && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate @@ -1310,11 +1209,11 @@ abstract class RDD[T: ClassTag]( } val left = num - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(num - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray diff --git a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala index fa71b8c26233..a9b3d52bbee0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ReliableCheckpointRDD.scala @@ -174,7 +174,8 @@ private[spark] object ReliableCheckpointRDD extends Logging { fs.create(tempOutputPath, false, bufferSize) } else { // This is mainly for testing purpose - fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + fs.create(tempOutputPath, false, bufferSize, + fs.getDefaultReplication(fs.getWorkingDirectory), blockSize) } val serializer = env.serializer.newInstance() val serializeStream = serializer.serializeStream(fileOutputStream) diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala deleted file mode 100644 index 9e8cee5331cf..000000000000 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import java.util.Random - -import scala.reflect.ClassTag - -import org.apache.commons.math3.distribution.PoissonDistribution - -import org.apache.spark.{Partition, TaskContext} - -@deprecated("Replaced by PartitionwiseSampledRDDPartition", "1.0.0") -private[spark] -class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition with Serializable { - override val index: Int = prev.index -} - -@deprecated("Replaced by PartitionwiseSampledRDD", "1.0.0") -private[spark] class SampledRDD[T: ClassTag]( - prev: RDD[T], - withReplacement: Boolean, - frac: Double, - seed: Int) - extends RDD[T](prev) { - - override def getPartitions: Array[Partition] = { - val rg = new Random(seed) - firstParent[T].partitions.map(x => new SampledRDDPartition(x, rg.nextInt)) - } - - override def getPreferredLocations(split: Partition): Seq[String] = - firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDPartition].prev) - - override def compute(splitIn: Partition, context: TaskContext): Iterator[T] = { - val split = splitIn.asInstanceOf[SampledRDDPartition] - if (withReplacement) { - // For large datasets, the expected number of occurrences of each element in a sample with - // replacement is Poisson(frac). We use that to get a count for each element. - val poisson = new PoissonDistribution(frac) - poisson.reseedRandomGenerator(split.seed) - - firstParent[T].iterator(split.prev, context).flatMap { element => - val count = poisson.sample() - if (count == 0) { - Iterator.empty // Avoid object allocation when we return 0 items, which is quite often - } else { - Iterator.fill(count)(element) - } - } - } else { // Sampling without replacement - val rand = new Random(split.seed) - firstParent[T].iterator(split.prev, context).filter(x => (rand.nextDouble <= frac)) - } - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 4b5f15dd06b8..92d9e3581ee5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.rdd -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} import org.apache.hadoop.io.Writable import org.apache.hadoop.io.compress.CompressionCodec @@ -38,11 +38,6 @@ class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag extends Logging with Serializable { - @deprecated("It's used to provide backward compatibility for pre 1.3.0.", "1.3.0") - def this(self: RDD[(K, V)]) { - this(self, null, null) - } - private val keyWritableClass = if (_keyWritableClass == null) { // pre 1.3.0, we need to use Reflection to get the Writable class diff --git a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala index e3f14fe7ef0f..8e1baae796fc 100644 --- a/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/WholeTextFileRDD.scala @@ -20,6 +20,7 @@ package org.apache.spark.rdd import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.{Text, Writable} import org.apache.hadoop.mapreduce.InputSplit +import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.spark.{Partition, SparkContext} import org.apache.spark.input.WholeTextFileInputFormat @@ -44,7 +45,7 @@ private[spark] class WholeTextFileRDD( configurable.setConf(conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = new JobContextImpl(conf, jobId) inputFormat.setMinPartitions(jobContext, minPartitions) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[Partition](rawSplits.size) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala similarity index 87% rename from core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala rename to core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala index d2e94f943aba..b9db60a7797d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/RpcEndpointAddress.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointAddress.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.rpc.netty +package org.apache.spark.rpc import org.apache.spark.SparkException -import org.apache.spark.rpc.RpcAddress /** * An address identifier for an RPC endpoint. @@ -26,10 +25,10 @@ import org.apache.spark.rpc.RpcAddress * The `rpcAddress` may be null, in which case the endpoint is registered via a client-only * connection and can only be reached via the client that sent the endpoint reference. * - * @param rpcAddress The socket address of the endpint. + * @param rpcAddress The socket address of the endpoint. * @param name Name of the endpoint. */ -private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { +private[spark] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val name: String) { require(name != null, "RpcEndpoint name must be provided.") @@ -44,7 +43,11 @@ private[netty] case class RpcEndpointAddress(val rpcAddress: RpcAddress, val nam } } -private[netty] object RpcEndpointAddress { +private[spark] object RpcEndpointAddress { + + def apply(host: String, port: Int, name: String): RpcEndpointAddress = { + new RpcEndpointAddress(host, port, name) + } def apply(sparkUrl: String): RpcEndpointAddress = { try { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala index 623da3e9c11b..154398b57280 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpointRef.scala @@ -20,8 +20,8 @@ package org.apache.spark.rpc import scala.concurrent.Future import scala.reflect.ClassTag +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.util.RpcUtils -import org.apache.spark.{SparkException, Logging, SparkConf} /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 64a4a8bf7c5e..56683771335a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -23,7 +23,8 @@ import java.nio.channels.ReadableByteChannel import scala.concurrent.Future import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{RpcUtils, Utils} +import org.apache.spark.rpc.netty.NettyRpcEnvFactory +import org.apache.spark.util.RpcUtils /** @@ -32,15 +33,6 @@ import org.apache.spark.util.{RpcUtils, Utils} */ private[spark] object RpcEnv { - private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { - val rpcEnvNames = Map( - "akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory", - "netty" -> "org.apache.spark.rpc.netty.NettyRpcEnvFactory") - val rpcEnvName = conf.get("spark.rpc", "netty") - val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - Utils.classForName(rpcEnvFactoryClassName).newInstance().asInstanceOf[RpcEnvFactory] - } - def create( name: String, host: String, @@ -48,9 +40,8 @@ private[spark] object RpcEnv { conf: SparkConf, securityManager: SecurityManager, clientMode: Boolean = false): RpcEnv = { - // Using Reflection to create the RpcEnv to avoid to depend on Akka directly val config = RpcEnvConfig(conf, name, host, port, securityManager, clientMode) - getRpcEnvFactory(conf).create(config) + new NettyRpcEnvFactory().create(config) } } @@ -98,12 +89,11 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { } /** - * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. + * Retrieve the [[RpcEndpointRef]] represented by `address` and `endpointName`. * This is a blocking action. */ - def setupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { - setupEndpointRefByURI(uriOf(systemName, address, endpointName)) + def setupEndpointRef(address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByURI(RpcEndpointAddress(address, endpointName).toString) } /** @@ -124,12 +114,6 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def awaitTermination(): Unit - /** - * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of - * creating it manually because different [[RpcEnv]] may have different formats. - */ - def uriOf(systemName: String, address: RpcAddress, endpointName: String): String - /** * [[RpcEndpointRef]] cannot be deserialized without [[RpcEnv]]. So when deserializing any object * that contains [[RpcEndpointRef]]s, the deserialization codes should be wrapped by this method. diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 285786ebf9f1..8b4ebf34ba83 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -19,13 +19,12 @@ package org.apache.spark.rpc import java.util.concurrent.TimeoutException -import scala.concurrent.{Awaitable, Await} +import scala.concurrent.{Await, Awaitable} import scala.concurrent.duration._ import org.apache.spark.SparkConf import org.apache.spark.util.Utils - /** * An exception thrown if RpcTimeout modifies a [[TimeoutException]]. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala deleted file mode 100644 index 9d098154f719..000000000000 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ /dev/null @@ -1,404 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.akka - -import java.io.File -import java.nio.channels.ReadableByteChannel -import java.util.concurrent.ConcurrentHashMap - -import scala.concurrent.Future -import scala.language.postfixOps -import scala.reflect.ClassTag -import scala.util.control.NonFatal - -import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} -import akka.event.Logging.Error -import akka.pattern.{ask => akkaAsk} -import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} -import akka.serialization.JavaSerializer - -import org.apache.spark.{HttpFileServer, Logging, SecurityManager, SparkConf, SparkException} -import org.apache.spark.rpc._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, ThreadUtils} - -/** - * A RpcEnv implementation based on Akka. - * - * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and - * remove Akka from the dependencies. - */ -private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, - val securityManager: SecurityManager, - conf: SparkConf, - boundPort: Int) - extends RpcEnv(conf) with Logging { - - private val defaultAddress: RpcAddress = { - val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress - // In some test case, ActorSystem doesn't bind to any address. - // So just use some default value since they are only some unit tests - RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) - } - - override val address: RpcAddress = defaultAddress - - /** - * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make - * [[RpcEndpoint.self]] work. - */ - private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() - - /** - * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` - */ - private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() - - private val _fileServer = new AkkaFileServer(conf, securityManager) - - private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { - endpointToRef.put(endpoint, endpointRef) - refToEndpoint.put(endpointRef, endpoint) - } - - private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { - val endpoint = refToEndpoint.remove(endpointRef) - if (endpoint != null) { - endpointToRef.remove(endpoint) - } - } - - /** - * Retrieve the [[RpcEndpointRef]] of `endpoint`. - */ - override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = endpointToRef.get(endpoint) - - override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - @volatile var endpointRef: AkkaRpcEndpointRef = null - // Use defered function because the Actor needs to use `endpointRef`. - // So `actorRef` should be created after assigning `endpointRef`. - val actorRef = () => actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { - - assert(endpointRef != null) - - override def preStart(): Unit = { - // Listen for remote client network events - context.system.eventStream.subscribe(self, classOf[AssociationEvent]) - safelyCall(endpoint) { - endpoint.onStart() - } - } - - override def receiveWithLogging: Receive = { - case AssociatedEvent(_, remoteAddress, _) => - safelyCall(endpoint) { - endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) - } - - case DisassociatedEvent(_, remoteAddress, _) => - safelyCall(endpoint) { - endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) - } - - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => - safelyCall(endpoint) { - endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) - } - - case e: AssociationEvent => - // TODO ignore? - - case m: AkkaMessage => - logDebug(s"Received RPC message: $m") - safelyCall(endpoint) { - processMessage(endpoint, m, sender) - } - - case AkkaFailure(e) => - safelyCall(endpoint) { - throw e - } - - case message: Any => { - logWarning(s"Unknown message: $message") - } - - } - - override def postStop(): Unit = { - unregisterEndpoint(endpoint.self) - safelyCall(endpoint) { - endpoint.onStop() - } - } - - }), name = name) - endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) - registerEndpoint(endpoint, endpointRef) - // Now actorRef can be created safely - endpointRef.init() - endpointRef - } - - private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { - val message = m.message - val needReply = m.needReply - val pf: PartialFunction[Any, Unit] = - if (needReply) { - endpoint.receiveAndReply(new RpcCallContext { - override def sendFailure(e: Throwable): Unit = { - _sender ! AkkaFailure(e) - } - - override def reply(response: Any): Unit = { - _sender ! AkkaMessage(response, false) - } - - // Use "lazy" because most of RpcEndpoints don't need "senderAddress" - override lazy val senderAddress: RpcAddress = - new AkkaRpcEndpointRef(defaultAddress, _sender, conf).address - }) - } else { - endpoint.receive - } - try { - pf.applyOrElse[Any, Unit](message, { message => - throw new SparkException(s"Unmatched message $message from ${_sender}") - }) - } catch { - case NonFatal(e) => - _sender ! AkkaFailure(e) - if (!needReply) { - // If the sender does not require a reply, it may not handle the exception. So we rethrow - // "e" to make sure it will be processed. - throw e - } - } - } - - /** - * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will - * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. - */ - private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { - try { - action - } catch { - case NonFatal(e) => { - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) - } - } - } - } - - private def akkaAddressToRpcAddress(address: Address): RpcAddress = { - RpcAddress(address.host.getOrElse(defaultAddress.host), - address.port.getOrElse(defaultAddress.port)) - } - - override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { - import actorSystem.dispatcher - actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout.duration). - map(new AkkaRpcEndpointRef(defaultAddress, _, conf)). - // this is just in case there is a timeout from creating the future in resolveOne, we want the - // exception to indicate the conf that determines the timeout - recover(defaultLookupTimeout.addMessageIfTimeout) - } - - override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { - AkkaUtils.address( - AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) - } - - override def shutdown(): Unit = { - actorSystem.shutdown() - _fileServer.shutdown() - } - - override def stop(endpoint: RpcEndpointRef): Unit = { - require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) - actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) - } - - override def awaitTermination(): Unit = { - actorSystem.awaitTermination() - } - - override def toString: String = s"${getClass.getSimpleName}($actorSystem)" - - override def deserialize[T](deserializationAction: () => T): T = { - JavaSerializer.currentSystem.withValue(actorSystem.asInstanceOf[ExtendedActorSystem]) { - deserializationAction() - } - } - - override def openChannel(uri: String): ReadableByteChannel = { - throw new UnsupportedOperationException( - "AkkaRpcEnv's files should be retrieved using an HTTP client.") - } - - override def fileServer: RpcEnvFileServer = _fileServer - -} - -private[akka] class AkkaFileServer( - conf: SparkConf, - securityManager: SecurityManager) extends RpcEnvFileServer { - - @volatile private var httpFileServer: HttpFileServer = _ - - override def addFile(file: File): String = { - getFileServer().addFile(file) - } - - override def addJar(file: File): String = { - getFileServer().addJar(file) - } - - override def addDirectory(baseUri: String, path: File): String = { - val fixedBaseUri = validateDirectoryUri(baseUri) - getFileServer().addDirectory(fixedBaseUri, path.getAbsolutePath()) - } - - def shutdown(): Unit = { - if (httpFileServer != null) { - httpFileServer.stop() - } - } - - private def getFileServer(): HttpFileServer = { - if (httpFileServer == null) synchronized { - if (httpFileServer == null) { - httpFileServer = startFileServer() - } - } - httpFileServer - } - - private def startFileServer(): HttpFileServer = { - val fileServerPort = conf.getInt("spark.fileserver.port", 0) - val server = new HttpFileServer(conf, securityManager, fileServerPort) - server.initialize() - server - } - -} - -private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { - - def create(config: RpcEnvConfig): RpcEnv = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - config.name, config.host, config.port, config.conf, config.securityManager) - actorSystem.actorOf(Props(classOf[ErrorMonitor]), "ErrorMonitor") - new AkkaRpcEnv(actorSystem, config.securityManager, config.conf, boundPort) - } -} - -/** - * Monitor errors reported by Akka and log them. - */ -private[akka] class ErrorMonitor extends Actor with ActorLogReceive with Logging { - - override def preStart(): Unit = { - context.system.eventStream.subscribe(self, classOf[Error]) - } - - override def receiveWithLogging: Actor.Receive = { - case Error(cause: Throwable, _, _, message: String) => logDebug(message, cause) - } -} - -private[akka] class AkkaRpcEndpointRef( - @transient private val defaultAddress: RpcAddress, - @transient private val _actorRef: () => ActorRef, - conf: SparkConf, - initInConstructor: Boolean) - extends RpcEndpointRef(conf) with Logging { - - def this( - defaultAddress: RpcAddress, - _actorRef: ActorRef, - conf: SparkConf) = { - this(defaultAddress, () => _actorRef, conf, true) - } - - lazy val actorRef = _actorRef() - - override lazy val address: RpcAddress = { - val akkaAddress = actorRef.path.address - RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), - akkaAddress.port.getOrElse(defaultAddress.port)) - } - - override lazy val name: String = actorRef.path.name - - private[akka] def init(): Unit = { - // Initialize the lazy vals - actorRef - address - name - } - - if (initInConstructor) { - init() - } - - override def send(message: Any): Unit = { - actorRef ! AkkaMessage(message, false) - } - - override def ask[T: ClassTag](message: Any, timeout: RpcTimeout): Future[T] = { - actorRef.ask(AkkaMessage(message, true))(timeout.duration).flatMap { - // The function will run in the calling thread, so it should be short and never block. - case msg @ AkkaMessage(message, reply) => - if (reply) { - logError(s"Receive $msg but the sender cannot reply") - Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) - } else { - Future.successful(message) - } - case AkkaFailure(e) => - Future.failed(e) - }(ThreadUtils.sameThread).mapTo[T]. - recover(timeout.addMessageIfTimeout)(ThreadUtils.sameThread) - } - - override def toString: String = s"${getClass.getSimpleName}($actorRef)" - - final override def equals(that: Any): Boolean = that match { - case other: AkkaRpcEndpointRef => actorRef == other.actorRef - case _ => false - } - - final override def hashCode(): Int = if (actorRef == null) 0 else actorRef.hashCode() -} - -/** - * A wrapper to `message` so that the receiver knows if the sender expects a reply. - * @param message - * @param needReply if the sender expects a reply message - */ -private[akka] case class AkkaMessage(message: Any, needReply: Boolean) - -/** - * A reply with the failure error from the receiver to the sender - */ -private[akka] case class AkkaFailure(e: Throwable) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 533c9847661b..19259e0e800c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,14 +17,14 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap, LinkedBlockingQueue, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.concurrent.Promise import scala.util.control.NonFatal -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ import org.apache.spark.util.ThreadUtils diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala index de3db6ba624f..ef876b1d8c15 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala @@ -257,9 +257,6 @@ private[netty] class NettyRpcEnv( dispatcher.getRpcEndpointRef(endpoint) } - override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = - new RpcEndpointAddress(address, endpointName).toString - override def shutdown(): Unit = { cleanup() } @@ -363,15 +360,14 @@ private[netty] class NettyRpcEnv( } override def read(dst: ByteBuffer): Int = { - val result = if (error == null) { - Try(source.read(dst)) - } else { - Failure(error) - } - - result match { + Try(source.read(dst)) match { case Success(bytesRead) => bytesRead - case Failure(error) => throw error + case Failure(readErr) => + if (error != null) { + throw error + } else { + throw readErr + } } } @@ -397,7 +393,7 @@ private[netty] class NettyRpcEnv( } override def onFailure(streamId: String, cause: Throwable): Unit = { - logError(s"Error downloading stream $streamId.", cause) + logDebug(s"Error downloading stream $streamId.", cause) source.setError(cause) sink.close() } @@ -428,7 +424,7 @@ private[netty] object NettyRpcEnv extends Logging { } -private[netty] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { +private[rpc] class NettyRpcEnvFactory extends RpcEnvFactory with Logging { def create(config: RpcEnvConfig): RpcEnv = { val sparkConf = config.conf @@ -549,10 +545,6 @@ private[netty] class NettyRpcHandler( nettyEnv: NettyRpcEnv, streamManager: StreamManager) extends RpcHandler with Logging { - // TODO: Can we add connection callback (channel registered) to the underlying framework? - // A variable to track whether we should dispatch the RemoteProcessConnected message. - private val clients = new ConcurrentHashMap[TransportClient, JBoolean]() - // A variable to track the remote RpcEnv addresses of all clients private val remoteAddresses = new ConcurrentHashMap[RpcAddress, RpcAddress]() @@ -575,9 +567,6 @@ private[netty] class NettyRpcHandler( val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] assert(addr != null) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) - if (clients.putIfAbsent(client, JBoolean.TRUE) == null) { - dispatcher.postToAll(RemoteProcessConnected(clientAddr)) - } val requestMessage = nettyEnv.deserialize[RequestMessage](client, message) if (requestMessage.senderAddress == null) { // Create a new message with the socket address of the client as the sender. @@ -614,10 +603,16 @@ private[netty] class NettyRpcHandler( } } - override def connectionTerminated(client: TransportClient): Unit = { + override def channelActive(client: TransportClient): Unit = { + val addr = client.getChannel().remoteAddress().asInstanceOf[InetSocketAddress] + assert(addr != null) + val clientAddr = RpcAddress(addr.getHostName, addr.getPort) + dispatcher.postToAll(RemoteProcessConnected(clientAddr)) + } + + override def channelInactive(client: TransportClient): Unit = { val addr = client.getChannel.remoteAddress().asInstanceOf[InetSocketAddress] if (addr != null) { - clients.remove(client) val clientAddr = RpcAddress(addr.getHostName, addr.getPort) nettyEnv.removeOutbox(clientAddr) dispatcher.postToAll(RemoteProcessDisconnected(clientAddr)) diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala index ecd96972455d..afcb023a99da 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/NettyStreamManager.scala @@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc.RpcEnvFileServer +import org.apache.spark.util.Utils /** * StreamManager implementation for serving files from a NettyRpcEnv. @@ -57,20 +58,23 @@ private[netty] class NettyStreamManager(rpcEnv: NettyRpcEnv) new File(dir, fname) } - require(file != null && file.isFile(), s"File not found: $streamId") - new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + if (file != null && file.isFile()) { + new FileSegmentManagedBuffer(rpcEnv.transportConf, file, 0, file.length()) + } else { + null + } } override def addFile(file: File): String = { require(files.putIfAbsent(file.getName(), file) == null, s"File ${file.getName()} already registered.") - s"${rpcEnv.address.toSparkURL}/files/${file.getName()}" + s"${rpcEnv.address.toSparkURL}/files/${Utils.encodeFileNameToURIRawPath(file.getName())}" } override def addJar(file: File): String = { require(jars.putIfAbsent(file.getName(), file) == null, s"JAR ${file.getName()} already registered.") - s"${rpcEnv.address.toSparkURL}/jars/${file.getName()}" + s"${rpcEnv.address.toSparkURL}/jars/${Utils.encodeFileNameToURIRawPath(file.getName())}" } override def addDirectory(baseUri: String, path: File): String = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b128ed50cad5..6b01a10fc136 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -40,8 +40,8 @@ import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcTimeout import org.apache.spark.storage._ -import org.apache.spark.util._ import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat +import org.apache.spark.util._ /** * The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of @@ -747,7 +747,7 @@ class DAGScheduler( } /** - * Check for waiting or failed stages which are now eligible for resubmission. + * Check for waiting stages which are now eligible for resubmission. * Ordinarily run on every iteration of the event loop. */ private def submitWaitingStages() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index eaa07acc5132..aa607c5a2df9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.fs.permission.FsPermission import org.json4s.JsonAST.JValue import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{Logging, SparkConf, SPARK_VERSION} +import org.apache.spark.{Logging, SPARK_VERSION, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} @@ -77,14 +77,6 @@ private[spark] class EventLoggingListener( // Only defined if the file system scheme is not local private var hadoopDataStream: Option[FSDataOutputStream] = None - // The Hadoop APIs have changed over time, so we use reflection to figure out - // the correct method to use to flush a hadoop data stream. See SPARK-1518 - // for details. - private val hadoopFlushMethod = { - val cls = classOf[FSDataOutputStream] - scala.util.Try(cls.getMethod("hflush")).getOrElse(cls.getMethod("sync")) - } - private var writer: Option[PrintWriter] = None // For testing. Keep track of all JSON serialized events that have been logged. @@ -97,7 +89,7 @@ private[spark] class EventLoggingListener( * Creates the log file in the configured log directory. */ def start() { - if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDir) { + if (!fileSystem.getFileStatus(new Path(logBaseDir)).isDirectory) { throw new IllegalArgumentException(s"Log directory $logBaseDir does not exist.") } @@ -147,7 +139,7 @@ private[spark] class EventLoggingListener( // scalastyle:on println if (flushLogger) { writer.foreach(_.flush()) - hadoopDataStream.foreach(hadoopFlushMethod.invoke(_)) + hadoopDataStream.foreach(_.hflush()) } if (testing) { loggedEvents += eventJson diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 0e438ab4366d..8235b1024537 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -103,7 +103,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl val instance: org.apache.hadoop.mapreduce.InputFormat[_, _] = ReflectionUtils.newInstance(inputFormatClazz.asInstanceOf[Class[_]], conf).asInstanceOf[ org.apache.hadoop.mapreduce.InputFormat[_, _]] - val job = new Job(conf) + val job = Job.getInstance(conf) val retval = new ArrayBuffer[SplitInfo]() val list = instance.getSplits(job) diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala deleted file mode 100644 index f96eb8ca0ae0..000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ /dev/null @@ -1,277 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -import java.io.{File, FileNotFoundException, IOException, PrintWriter} -import java.text.SimpleDateFormat -import java.util.{Date, Properties} - -import scala.collection.mutable.HashMap - -import org.apache.spark._ -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.executor.TaskMetrics - -/** - * :: DeveloperApi :: - * A logger class to record runtime information for jobs in Spark. This class outputs one log file - * for each Spark job, containing tasks start/stop and shuffle information. JobLogger is a subclass - * of SparkListener, use addSparkListener to add JobLogger to a SparkContext after the SparkContext - * is created. Note that each JobLogger only works for one SparkContext - * - * NOTE: The functionality of this class is heavily stripped down to accommodate for a general - * refactor of the SparkListener interface. In its place, the EventLoggingListener is introduced - * to log application information as SparkListenerEvents. To enable this functionality, set - * spark.eventLog.enabled to true. - */ -@DeveloperApi -@deprecated("Log application information by setting spark.eventLog.enabled.", "1.0.0") -class JobLogger(val user: String, val logDirName: String) extends SparkListener with Logging { - - def this() = this(System.getProperty("user.name", ""), - String.valueOf(System.currentTimeMillis())) - - private val logDir = - if (System.getenv("SPARK_LOG_DIR") != null) { - System.getenv("SPARK_LOG_DIR") - } else { - "/tmp/spark-%s".format(user) - } - - private val jobIdToPrintWriter = new HashMap[Int, PrintWriter] - private val stageIdToJobId = new HashMap[Int, Int] - private val jobIdToStageIds = new HashMap[Int, Seq[Int]] - private val dateFormat = new ThreadLocal[SimpleDateFormat]() { - override def initialValue(): SimpleDateFormat = new SimpleDateFormat("yyyy/MM/dd HH:mm:ss") - } - - createLogDir() - - /** Create a folder for log files, the folder's name is the creation time of jobLogger */ - protected def createLogDir() { - val dir = new File(logDir + "/" + logDirName + "/") - if (dir.exists()) { - return - } - if (!dir.mkdirs()) { - // JobLogger should throw a exception rather than continue to construct this object. - throw new IOException("create log directory error:" + logDir + "/" + logDirName + "/") - } - } - - /** - * Create a log file for one job - * @param jobId ID of the job - * @throws FileNotFoundException Fail to create log file - */ - protected def createLogWriter(jobId: Int) { - try { - val fileWriter = new PrintWriter(logDir + "/" + logDirName + "/" + jobId) - jobIdToPrintWriter += (jobId -> fileWriter) - } catch { - case e: FileNotFoundException => e.printStackTrace() - } - } - - /** - * Close log file, and clean the stage relationship in stageIdToJobId - * @param jobId ID of the job - */ - protected def closeLogWriter(jobId: Int) { - jobIdToPrintWriter.get(jobId).foreach { fileWriter => - fileWriter.close() - jobIdToStageIds.get(jobId).foreach(_.foreach { stageId => - stageIdToJobId -= stageId - }) - jobIdToPrintWriter -= jobId - jobIdToStageIds -= jobId - } - } - - /** - * Build up the maps that represent stage-job relationships - * @param jobId ID of the job - * @param stageIds IDs of the associated stages - */ - protected def buildJobStageDependencies(jobId: Int, stageIds: Seq[Int]) = { - jobIdToStageIds(jobId) = stageIds - stageIds.foreach { stageId => stageIdToJobId(stageId) = jobId } - } - - /** - * Write info into log file - * @param jobId ID of the job - * @param info Info to be recorded - * @param withTime Controls whether to record time stamp before the info, default is true - */ - protected def jobLogInfo(jobId: Int, info: String, withTime: Boolean = true) { - var writeInfo = info - if (withTime) { - val date = new Date(System.currentTimeMillis()) - writeInfo = dateFormat.get.format(date) + ": " + info - } - // scalastyle:off println - jobIdToPrintWriter.get(jobId).foreach(_.println(writeInfo)) - // scalastyle:on println - } - - /** - * Write info into log file - * @param stageId ID of the stage - * @param info Info to be recorded - * @param withTime Controls whether to record time stamp before the info, default is true - */ - protected def stageLogInfo(stageId: Int, info: String, withTime: Boolean = true) { - stageIdToJobId.get(stageId).foreach(jobId => jobLogInfo(jobId, info, withTime)) - } - - /** - * Record task metrics into job log files, including execution info and shuffle metrics - * @param stageId Stage ID of the task - * @param status Status info of the task - * @param taskInfo Task description info - * @param taskMetrics Task running metrics - */ - protected def recordTaskMetrics(stageId: Int, status: String, - taskInfo: TaskInfo, taskMetrics: TaskMetrics) { - val info = " TID=" + taskInfo.taskId + " STAGE_ID=" + stageId + - " START_TIME=" + taskInfo.launchTime + " FINISH_TIME=" + taskInfo.finishTime + - " EXECUTOR_ID=" + taskInfo.executorId + " HOST=" + taskMetrics.hostname - val executorRunTime = " EXECUTOR_RUN_TIME=" + taskMetrics.executorRunTime - val gcTime = " GC_TIME=" + taskMetrics.jvmGCTime - val inputMetrics = taskMetrics.inputMetrics match { - case Some(metrics) => - " READ_METHOD=" + metrics.readMethod.toString + - " INPUT_BYTES=" + metrics.bytesRead - case None => "" - } - val outputMetrics = taskMetrics.outputMetrics match { - case Some(metrics) => - " OUTPUT_BYTES=" + metrics.bytesWritten - case None => "" - } - val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match { - case Some(metrics) => - " BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched + - " BLOCK_FETCHED_LOCAL=" + metrics.localBlocksFetched + - " BLOCK_FETCHED_REMOTE=" + metrics.remoteBlocksFetched + - " REMOTE_FETCH_WAIT_TIME=" + metrics.fetchWaitTime + - " REMOTE_BYTES_READ=" + metrics.remoteBytesRead + - " LOCAL_BYTES_READ=" + metrics.localBytesRead - case None => "" - } - val writeMetrics = taskMetrics.shuffleWriteMetrics match { - case Some(metrics) => - " SHUFFLE_BYTES_WRITTEN=" + metrics.shuffleBytesWritten + - " SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime - case None => "" - } - stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics + - shuffleReadMetrics + writeMetrics) - } - - /** - * When stage is submitted, record stage submit info - * @param stageSubmitted Stage submitted event - */ - override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted) { - val stageInfo = stageSubmitted.stageInfo - stageLogInfo(stageInfo.stageId, "STAGE_ID=%d STATUS=SUBMITTED TASK_SIZE=%d".format( - stageInfo.stageId, stageInfo.numTasks)) - } - - /** - * When stage is completed, record stage completion status - * @param stageCompleted Stage completed event - */ - override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { - val stageId = stageCompleted.stageInfo.stageId - if (stageCompleted.stageInfo.failureReason.isEmpty) { - stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=COMPLETED") - } else { - stageLogInfo(stageId, s"STAGE_ID=$stageId STATUS=FAILED") - } - } - - /** - * When task ends, record task completion status and metrics - * @param taskEnd Task end event - */ - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - val taskInfo = taskEnd.taskInfo - var taskStatus = "TASK_TYPE=%s".format(taskEnd.taskType) - val taskMetrics = if (taskEnd.taskMetrics != null) taskEnd.taskMetrics else TaskMetrics.empty - taskEnd.reason match { - case Success => taskStatus += " STATUS=SUCCESS" - recordTaskMetrics(taskEnd.stageId, taskStatus, taskInfo, taskMetrics) - case Resubmitted => - taskStatus += " STATUS=RESUBMITTED TID=" + taskInfo.taskId + - " STAGE_ID=" + taskEnd.stageId - stageLogInfo(taskEnd.stageId, taskStatus) - case FetchFailed(bmAddress, shuffleId, mapId, reduceId, message) => - taskStatus += " STATUS=FETCHFAILED TID=" + taskInfo.taskId + " STAGE_ID=" + - taskEnd.stageId + " SHUFFLE_ID=" + shuffleId + " MAP_ID=" + - mapId + " REDUCE_ID=" + reduceId - stageLogInfo(taskEnd.stageId, taskStatus) - case _ => - } - } - - /** - * When job ends, recording job completion status and close log file - * @param jobEnd Job end event - */ - override def onJobEnd(jobEnd: SparkListenerJobEnd) { - val jobId = jobEnd.jobId - var info = "JOB_ID=" + jobId - jobEnd.jobResult match { - case JobSucceeded => info += " STATUS=SUCCESS" - case JobFailed(exception) => - info += " STATUS=FAILED REASON=" - exception.getMessage.split("\\s+").foreach(info += _ + "_") - case _ => - } - jobLogInfo(jobId, info.substring(0, info.length - 1).toUpperCase) - closeLogWriter(jobId) - } - - /** - * Record job properties into job log file - * @param jobId ID of the job - * @param properties Properties of the job - */ - protected def recordJobProperties(jobId: Int, properties: Properties) { - if (properties != null) { - val description = properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION, "") - jobLogInfo(jobId, description, withTime = false) - } - } - - /** - * When job starts, record job property and stage graph - * @param jobStart Job start event - */ - override def onJobStart(jobStart: SparkListenerJobStart) { - val jobId = jobStart.jobId - val properties = jobStart.properties - createLogWriter(jobId) - recordJobProperties(jobId, properties) - buildJobStageDependencies(jobId, jobStart.stageIds) - jobLogInfo(jobId, "JOB_ID=" + jobId + " STATUS=STARTED") - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 4d146678174f..3e3ab15d8a24 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import scala.collection.mutable import org.apache.spark._ -import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} private sealed trait OutputCommitCoordinationMessage extends Serializable diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index fb693721a9cb..6590cf6ffd24 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,9 +17,8 @@ package org.apache.spark.scheduler -import java.nio.ByteBuffer - import java.io._ +import java.nio.ByteBuffer import org.apache.spark._ import org.apache.spark.broadcast.Broadcast diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index 075a7f13172d..3130a65240a9 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -29,8 +29,8 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.storage.{BlockManagerId, BlockUpdatedInfo} -import org.apache.spark.util.{Distribution, Utils} import org.apache.spark.ui.SparkUI +import org.apache.spark.util.{Distribution, Utils} @DeveloperApi @JsonTypeInfo(use = JsonTypeInfo.Id.CLASS, include = JsonTypeInfo.As.PROPERTY, property = "Event") diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 9f27eed626be..0379ca2af6ab 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -22,14 +22,13 @@ import java.nio.ByteBuffer import scala.collection.mutable.HashMap -import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.{Accumulator, SparkEnv, TaskContextImpl, TaskContext} +import org.apache.spark.{Accumulator, SparkEnv, TaskContext, TaskContextImpl} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager +import org.apache.spark.metrics.MetricsSystem import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Utils} - /** * A unit of execution. We have two kinds of Task's in Spark: * diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index f113c2b1b843..a42990addb9c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -95,9 +95,6 @@ class TaskInfo( } } - @deprecated("Use attemptNumber", "1.6.0") - def attempt: Int = attemptNumber - def id: String = s"$index.$attemptNumber" def duration: Long = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index cb9a3008107d..7c0b007db708 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -17,8 +17,8 @@ package org.apache.spark.scheduler -import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.storage.BlockManagerId /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index bdf19f9f277d..6e3ef0e54f0f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer -import java.util.{TimerTask, Timer} +import java.util.{Timer, TimerTask} import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicLong @@ -30,11 +30,11 @@ import scala.util.Random import org.apache.spark._ import org.apache.spark.TaskState.TaskState +import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.SchedulingMode.SchedulingMode import org.apache.spark.scheduler.TaskLocality.TaskLocality -import org.apache.spark.util.{ThreadUtils, Utils} -import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.{ThreadUtils, Utils} /** * Schedules tasks for multiple types of clusters by acting through a SchedulerBackend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a02f3017cb6e..aa39b59d8cce 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -25,7 +25,7 @@ import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.math.{min, max} +import scala.math.{max, min} import scala.util.control.NonFatal import org.apache.spark._ @@ -608,7 +608,7 @@ private[spark] class TaskSetManager( } /** - * Marks the task as successful and notifies the DAGScheduler that a task has ended. + * Marks a task as successful and notifies the DAGScheduler that the task has ended. */ def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]): Unit = { val info = taskInfos(tid) @@ -705,7 +705,7 @@ private[spark] class TaskSetManager( ef.exception case e: ExecutorLostFailure if !e.exitCausedByApp => - logInfo(s"Task $tid failed because while it was being computed, its executor" + + logInfo(s"Task $tid failed because while it was being computed, its executor " + "exited for a reason unrelated to the task. Not counting this failure towards the " + "maximum number of failures for the task.") None diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 7efe16749e59..b808993aa6cd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -22,15 +22,15 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import org.apache.spark.rpc._ import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} +import org.apache.spark.rpc._ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend.ENDPOINT_NAME -import org.apache.spark.util.{ThreadUtils, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{AkkaUtils, SerializableBuffer, ThreadUtils, Utils} /** - * A scheduler backend that waits for coarse grained executors to connect to it through Akka. + * A scheduler backend that waits for coarse-grained executors to connect. * This backend holds onto each executor for the duration of the Spark job rather than relinquishing * executors whenever a task is done and asking the scheduler to launch a new executor for * each new task. Executors may be launched in a variety of ways, such as Mesos tasks for the @@ -471,7 +471,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp /** * Request that the cluster manager kill the specified executors. - * @return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. If list to kill is empty, it will return + * false. */ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized { killExecutors(executorIds, replace = false, force = false) @@ -487,7 +488,8 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp * @param executorIds identifiers of executors to kill * @param replace whether to replace the killed executors with new ones * @param force whether to force kill busy executors - * @return whether the kill request is acknowledged. + * @return whether the kill request is acknowledged. If list to kill is empty, it will return + * false. */ final def killExecutors( executorIds: Seq[String], @@ -516,7 +518,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp numPendingExecutors += knownExecutors.size } - doKillExecutors(executorsToKill) + !executorsToKill.isEmpty && doKillExecutors(executorsToKill) } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index 626a2b7d69ab..b25a4bfb501f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler.cluster -import org.apache.spark.rpc.{RpcEndpointRef, RpcAddress} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 641638a77d5f..0a6f2c01c18d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -17,11 +17,11 @@ package org.apache.spark.scheduler.cluster -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.rpc.RpcAddress -import org.apache.spark.{Logging, SparkContext, SparkEnv} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.rpc.RpcEndpointAddress import org.apache.spark.scheduler.TaskSchedulerImpl private[spark] class SimrSchedulerBackend( @@ -39,9 +39,10 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, - RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val driverUrl = RpcEndpointAddress( + sc.conf.get("spark.driver.host"), + sc.conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString val conf = SparkHadoopUtil.get.newConfiguration(sc.conf) val fs = FileSystem.get(conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 5105475c760e..16f33163789a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -19,11 +19,11 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.Semaphore -import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress} import org.apache.spark.scheduler._ import org.apache.spark.util.Utils @@ -54,9 +54,10 @@ private[spark] class SparkDeploySchedulerBackend( launcherBackend.connect() // The endpoint for executors to talk to us - val driverUrl = rpcEnv.uriOf(SparkEnv.driverActorSystemName, - RpcAddress(sc.conf.get("spark.driver.host"), sc.conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val driverUrl = RpcEndpointAddress( + sc.conf.get("spark.driver.host"), + sc.conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString val args = Seq( "--driver-url", driverUrl, "--executor-id", "{{EXECUTOR_ID}}", diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 7d08eae0b487..58c30e7d9788 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -18,20 +18,20 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, List => JList} +import java.util.concurrent.locks.ReentrantLock import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import com.google.common.collect.HashBiMap -import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.mesos.{Scheduler => MScheduler, SchedulerDriver} +import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} import org.apache.spark.{SecurityManager, SparkContext, SparkEnv, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient -import org.apache.spark.rpc.RpcAddress +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress} import org.apache.spark.scheduler.{SlaveLost, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.util.Utils @@ -215,10 +215,10 @@ private[spark] class CoarseMesosSchedulerBackend( if (conf.contains("spark.testing")) { "driverURL" } else { - sc.env.rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + RpcEndpointAddress( + conf.get("spark.driver.host"), + conf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index a6d9374eb9e8..05fda0fded7f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -18,23 +18,22 @@ package org.apache.spark.scheduler.cluster.mesos import java.io.File -import java.util.concurrent.locks.ReentrantLock import java.util.{Collections, Date, List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import org.apache.mesos.{Scheduler, SchedulerDriver} +import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} import org.apache.mesos.Protos.Environment.Variable import org.apache.mesos.Protos.TaskStatus.Reason -import org.apache.mesos.Protos.{TaskState => MesosTaskState, _} -import org.apache.mesos.{Scheduler, SchedulerDriver} + +import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.deploy.rest.{CreateSubmissionResponse, KillSubmissionResponse, SubmissionStatusResponse} import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.Utils -import org.apache.spark.{SecurityManager, SparkConf, SparkException, TaskState} - /** * Tracks the current state of a Mesos Task that runs a Spark driver. @@ -126,7 +125,7 @@ private[spark] class MesosClusterScheduler( private val retainedDrivers = conf.getInt("spark.mesos.retainedDrivers", 200) private val maxRetryWaitTime = conf.getInt("spark.mesos.cluster.retry.wait.max", 60) // 1 minute private val schedulerState = engineFactory.createEngine("scheduler") - private val stateLock = new ReentrantLock() + private val stateLock = new Object() private val finishedDrivers = new mutable.ArrayBuffer[MesosClusterSubmissionState](retainedDrivers) private var frameworkId: String = null diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 281965a5981b..eaf0cb06d6c7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -26,6 +26,7 @@ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.{Scheduler => MScheduler, _} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} import org.apache.mesos.protobuf.ByteString + import org.apache.spark.{SparkContext, SparkException, TaskState} import org.apache.spark.executor.MesosExecutorBackend import org.apache.spark.scheduler._ diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 573355ba5813..010caff3e39b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -25,16 +25,16 @@ import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal import com.google.common.base.Splitter -import org.apache.mesos.{MesosSchedulerDriver, SchedulerDriver, Scheduler, Protos} +import org.apache.mesos.{MesosSchedulerDriver, Protos, Scheduler, SchedulerDriver} import org.apache.mesos.Protos._ import org.apache.mesos.protobuf.{ByteString, GeneratedMessage} -import org.apache.spark.{SparkException, SparkConf, Logging, SparkContext} -import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf, SparkContext, SparkException} +import org.apache.spark.util.Utils /** * Shared trait for implementing a Mesos Scheduler. This holds common state and helper - * methods the Mesos scheduler will use. + * methods and Mesos scheduler will use. */ private[mesos] trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered diff --git a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala index 8d6af9cae892..3d5b7105f0ca 100644 --- a/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/GenericAvroSerializer.scala @@ -29,7 +29,7 @@ import org.apache.avro.generic.{GenericData, GenericRecord} import org.apache.avro.io._ import org.apache.commons.io.IOUtils -import org.apache.spark.{SparkException, SparkEnv} +import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.io.CompressionCodec /** diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index cb2ac5ea167e..150ddc12e069 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{DataInput, DataOutput, EOFException, IOException, InputStream, OutputStream} +import java.io._ import java.nio.ByteBuffer import javax.annotation.Nullable @@ -25,21 +25,20 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializer} -import com.esotericsoftware.kryo.{Kryo, KryoException, Serializer => KryoClassSerializer} import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator} import org.apache.avro.generic.{GenericData, GenericRecord} import org.roaringbitmap.RoaringBitmap import org.apache.spark._ import org.apache.spark.api.python.PythonBroadcast -import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus} import org.apache.spark.storage._ -import org.apache.spark.util.collection.CompactBuffer import org.apache.spark.util.{BoundedPriorityQueue, SerializableConfiguration, SerializableJobConf, Utils} +import org.apache.spark.util.collection.CompactBuffer /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. @@ -107,7 +106,6 @@ class KryoSerializer(conf: SparkConf) kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer()) kryo.register(classOf[SerializableConfiguration], new KryoJavaSerializer()) kryo.register(classOf[SerializableJobConf], new KryoJavaSerializer()) - kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer()) kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer()) kryo.register(classOf[GenericRecord], new GenericAvroSerializer(avroSchemas)) @@ -380,18 +378,24 @@ private[serializer] object KryoSerializer { private val toRegisterSerializer = Map[Class[_], KryoClassSerializer[_]]( classOf[RoaringBitmap] -> new KryoClassSerializer[RoaringBitmap]() { override def write(kryo: Kryo, output: KryoOutput, bitmap: RoaringBitmap): Unit = { - bitmap.serialize(new KryoOutputDataOutputBridge(output)) + bitmap.serialize(new KryoOutputObjectOutputBridge(kryo, output)) } override def read(kryo: Kryo, input: KryoInput, cls: Class[RoaringBitmap]): RoaringBitmap = { val ret = new RoaringBitmap - ret.deserialize(new KryoInputDataInputBridge(input)) + ret.deserialize(new KryoInputObjectInputBridge(kryo, input)) ret } } ) } -private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends DataInput { +/** + * This is a bridge class to wrap KryoInput as an InputStream and ObjectInput. It forwards all + * methods of InputStream and ObjectInput to KryoInput. It's usually helpful when an API expects + * an InputStream or ObjectInput but you want to use Kryo. + */ +private[spark] class KryoInputObjectInputBridge( + kryo: Kryo, input: KryoInput) extends FilterInputStream(input) with ObjectInput { override def readLong(): Long = input.readLong() override def readChar(): Char = input.readChar() override def readFloat(): Float = input.readFloat() @@ -401,12 +405,7 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat override def readInt(): Int = input.readInt() override def readUnsignedShort(): Int = input.readShortUnsigned() override def skipBytes(n: Int): Int = { - var remaining: Long = n - while (remaining > 0) { - val skip = Math.min(Integer.MAX_VALUE, remaining).asInstanceOf[Int] - input.skip(skip) - remaining -= skip - } + input.skip(n) n } override def readFully(b: Array[Byte]): Unit = input.read(b) @@ -415,9 +414,16 @@ private[serializer] class KryoInputDataInputBridge(input: KryoInput) extends Dat override def readBoolean(): Boolean = input.readBoolean() override def readUnsignedByte(): Int = input.readByteUnsigned() override def readDouble(): Double = input.readDouble() + override def readObject(): AnyRef = kryo.readClassAndObject(input) } -private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends DataOutput { +/** + * This is a bridge class to wrap KryoOutput as an OutputStream and ObjectOutput. It forwards all + * methods of OutputStream and ObjectOutput to KryoOutput. It's usually helpful when an API expects + * an OutputStream or ObjectOutput but you want to use Kryo. + */ +private[spark] class KryoOutputObjectOutputBridge( + kryo: Kryo, output: KryoOutput) extends FilterOutputStream(output) with ObjectOutput { override def writeFloat(v: Float): Unit = output.writeFloat(v) // There is no "readChars" counterpart, except maybe "readLine", which is not supported override def writeChars(s: String): Unit = throw new UnsupportedOperationException("writeChars") @@ -433,6 +439,7 @@ private[serializer] class KryoOutputDataOutputBridge(output: KryoOutput) extends override def writeChar(v: Int): Unit = output.writeChar(v.toChar) override def writeLong(v: Long): Unit = output.writeLong(v) override def writeByte(v: Int): Unit = output.writeByte(v) + override def writeObject(obj: AnyRef): Unit = kryo.writeClassAndObject(output, obj) } /** diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index bd2704dc8187..90c0728557b9 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -25,7 +25,7 @@ import scala.reflect.ClassTag import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.annotation.{DeveloperApi, Private} -import org.apache.spark.util.{Utils, ByteBufferInputStream, NextIterator} +import org.apache.spark.util.{ByteBufferInputStream, NextIterator, Utils} /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala index b36c457d6d51..0a65bbf8ddab 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BaseShuffleHandle.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{ShuffleDependency, Aggregator, Partitioner} +import org.apache.spark.{Aggregator, Partitioner, ShuffleDependency} import org.apache.spark.serializer.Serializer /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index be184464e0ae..b2d050b218f5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -17,8 +17,8 @@ package org.apache.spark.shuffle -import org.apache.spark.storage.BlockManagerId import org.apache.spark.{FetchFailed, TaskEndReason} +import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index cc5f933393ad..294e16cde193 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -17,17 +17,17 @@ package org.apache.spark.shuffle -import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} import scala.collection.JavaConverters._ +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.serializer.Serializer import org.apache.spark.storage._ -import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils} -import org.apache.spark.{Logging, SparkConf, SparkEnv} +import org.apache.spark.util.Utils /** A group of writers for a ShuffleMapTask, one writer per reducer. */ private[spark] trait ShuffleWriterGroup { @@ -63,10 +63,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) val completedMapTasks = new ConcurrentLinkedQueue[Int]() } - private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] - - private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) + private val shuffleStates = new ConcurrentHashMap[ShuffleId, ShuffleState] /** * Get a ShuffleWriterGroup for the given map task, which will register it as complete @@ -75,9 +72,12 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { new ShuffleWriterGroup { - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) - private val shuffleState = shuffleStates(shuffleId) - + private val shuffleState: ShuffleState = { + // Note: we do _not_ want to just wrap this java ConcurrentHashMap into a Scala map and use + // .getOrElseUpdate() because that's actually NOT atomic. + shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) + shuffleStates.get(shuffleId) + } val openStartTime = System.nanoTime val serializerInstance = serializer.newInstance() val writers: Array[DiskBlockObjectWriter] = { @@ -114,7 +114,7 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) /** Remove all the blocks / files related to a particular shuffle. */ private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { - shuffleStates.get(shuffleId) match { + Option(shuffleStates.get(shuffleId)) match { case Some(state) => for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) @@ -131,11 +131,5 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf) } } - private def cleanup(cleanupTime: Long) { - shuffleStates.clearOldValues(cleanupTime, (shuffleId, state) => removeShuffleBlocks(shuffleId)) - } - - override def stop() { - metadataCleaner.cancel() - } + override def stop(): Unit = {} } diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala index fadb8fe7ed0a..68aba52fd7c6 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala @@ -21,12 +21,12 @@ import java.io._ import com.google.common.io.ByteStreams +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID import org.apache.spark.storage._ import org.apache.spark.util.Utils -import org.apache.spark.{SparkEnv, Logging, SparkConf} /** * Create and maintain the shuffle blocks' mapping between logic block and physical file location. diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala index 4342b0d598b1..81aea33ee41b 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockResolver.scala @@ -18,6 +18,7 @@ package org.apache.spark.shuffle import java.nio.ByteBuffer + import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.storage.ShuffleBlockId diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index a3444bf4daa3..76fd249fbd2d 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{TaskContext, ShuffleDependency} +import org.apache.spark.{ShuffleDependency, TaskContext} /** * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the driver diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 31b4dd7c0f42..341ae782362a 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -17,8 +17,8 @@ package org.apache.spark.status.api.v1 import java.util.{Arrays, Date, List => JList} -import javax.ws.rs.core.MediaType import javax.ws.rs.{GET, Produces, QueryParam} +import javax.ws.rs.core.MediaType import org.apache.spark.executor.{InputMetrics => InternalInputMetrics, OutputMetrics => InternalOutputMetrics, ShuffleReadMetrics => InternalShuffleReadMetrics, ShuffleWriteMetrics => InternalShuffleWriteMetrics, TaskMetrics => InternalTaskMetrics} import org.apache.spark.scheduler.{AccumulableInfo => InternalAccumulableInfo, StageInfo} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala index b5ef72649e29..d7e6a8b58995 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala @@ -16,8 +16,8 @@ */ package org.apache.spark.status.api.v1 +import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType -import javax.ws.rs.{Produces, PathParam, GET} @Produces(Array(MediaType.APPLICATION_JSON)) private[v1] class OneApplicationResource(uiRoot: UIRoot) { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala index 6d8a60d480ae..a0f6360bc5c7 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneJobResource.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.status.api.v1 -import javax.ws.rs.{PathParam, GET, Produces} +import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.JobExecutionStatus diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala index dfdc09c6caf3..237aeac18587 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/OneRDDResource.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.status.api.v1 -import javax.ws.rs.{PathParam, GET, Produces} +import javax.ws.rs.{GET, PathParam, Produces} import javax.ws.rs.core.MediaType import org.apache.spark.ui.SparkUI diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 6074fc58d70d..4479e6875a73 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -19,10 +19,12 @@ package org.apache.spark.storage import java.io._ import java.nio.{ByteBuffer, MappedByteBuffer} +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} -import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ import scala.util.Random import scala.util.control.NonFatal @@ -57,7 +59,7 @@ private[spark] class BlockResult( * Manager running on every node (driver and executors) which provides interfaces for putting and * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap). * - * Note that #initialize() must be called before the BlockManager is usable. + * Note that [[initialize()]] must be called before the BlockManager is usable. */ private[spark] class BlockManager( executorId: String, @@ -75,7 +77,7 @@ private[spark] class BlockManager( val diskBlockManager = new DiskBlockManager(this, conf) - private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo] + private val blockInfo = new ConcurrentHashMap[BlockId, BlockInfo] private val futureExecutionContext = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("block-manager-future", 128)) @@ -147,11 +149,6 @@ private[spark] class BlockManager( private var asyncReregisterTask: Future[Unit] = null private val asyncReregisterLock = new Object - private val metadataCleaner = new MetadataCleaner( - MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) - private val broadcastCleaner = new MetadataCleaner( - MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf) - // Field related to peer block managers that are necessary for block replication @volatile private var cachedPeers: Seq[BlockManagerId] = _ private val peerFetchLock = new Object @@ -232,7 +229,7 @@ private[spark] class BlockManager( */ private def reportAllBlocks(): Unit = { logInfo(s"Reporting ${blockInfo.size} blocks to the master.") - for ((blockId, info) <- blockInfo) { + for ((blockId, info) <- blockInfo.asScala) { val status = getCurrentBlockStatus(blockId, info) if (!tryToReportBlockStatus(blockId, info, status)) { logError(s"Failed to report $blockId to master; giving up.") @@ -313,7 +310,7 @@ private[spark] class BlockManager( * NOTE: This is mainly for testing, and it doesn't fetch information from external block store. */ def getStatus(blockId: BlockId): Option[BlockStatus] = { - blockInfo.get(blockId).map { info => + blockInfo.asScala.get(blockId).map { info => val memSize = if (memoryStore.contains(blockId)) memoryStore.getSize(blockId) else 0L val diskSize = if (diskStore.contains(blockId)) diskStore.getSize(blockId) else 0L // Assume that block is not in external block store @@ -327,7 +324,7 @@ private[spark] class BlockManager( * may not know of). */ def getMatchingBlockIds(filter: BlockId => Boolean): Seq[BlockId] = { - (blockInfo.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq + (blockInfo.asScala.keys ++ diskBlockManager.getAllBlocks()).filter(filter).toSeq } /** @@ -439,7 +436,7 @@ private[spark] class BlockManager( } private def doGetLocal(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { - val info = blockInfo.get(blockId).orNull + val info = blockInfo.get(blockId) if (info != null) { info.synchronized { // Double check to make sure the block is still there. There is a small chance that the @@ -447,7 +444,7 @@ private[spark] class BlockManager( // Note that this only checks metadata tracking. If user intentionally deleted the block // on disk or from off heap storage without using removeBlock, this conditional check will // still pass but eventually we will get an exception because we can't find the block. - if (blockInfo.get(blockId).isEmpty) { + if (blockInfo.asScala.get(blockId).isEmpty) { logWarning(s"Block $blockId had been removed") return None } @@ -578,9 +575,19 @@ private[spark] class BlockManager( doGetRemote(blockId, asBlockResult = false).asInstanceOf[Option[ByteBuffer]] } + /** + * Return a list of locations for the given block, prioritizing the local machine since + * multiple block managers can share the same host. + */ + private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { + val locs = Random.shuffle(master.getLocations(blockId)) + val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } + preferredLocs ++ otherLocs + } + private def doGetRemote(blockId: BlockId, asBlockResult: Boolean): Option[Any] = { require(blockId != null, "BlockId is null") - val locations = Random.shuffle(master.getLocations(blockId)) + val locations = getLocations(blockId) var numFetchFailures = 0 for (loc <- locations) { logDebug(s"Getting remote block $blockId from $loc") @@ -721,7 +728,7 @@ private[spark] class BlockManager( val putBlockInfo = { val tinfo = new BlockInfo(level, tellMaster) // Do atomically ! - val oldBlockOpt = blockInfo.putIfAbsent(blockId, tinfo) + val oldBlockOpt = Option(blockInfo.putIfAbsent(blockId, tinfo)) if (oldBlockOpt.isDefined) { if (oldBlockOpt.get.waitForReady()) { logWarning(s"Block $blockId already exists on this machine; not re-adding it") @@ -1022,7 +1029,7 @@ private[spark] class BlockManager( data: () => Either[Array[Any], ByteBuffer]): Option[BlockStatus] = { logInfo(s"Dropping block $blockId from memory") - val info = blockInfo.get(blockId).orNull + val info = blockInfo.get(blockId) // If the block has not already been dropped if (info != null) { @@ -1033,7 +1040,7 @@ private[spark] class BlockManager( // If we get here, the block write failed. logWarning(s"Block $blockId was marked as failure. Nothing to drop") return None - } else if (blockInfo.get(blockId).isEmpty) { + } else if (blockInfo.asScala.get(blockId).isEmpty) { logWarning(s"Block $blockId was already dropped.") return None } @@ -1085,7 +1092,7 @@ private[spark] class BlockManager( def removeRdd(rddId: Int): Int = { // TODO: Avoid a linear scan by creating another mapping of RDD.id to blocks. logInfo(s"Removing RDD $rddId") - val blocksToRemove = blockInfo.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) + val blocksToRemove = blockInfo.asScala.keys.flatMap(_.asRDDId).filter(_.rddId == rddId) blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster = false) } blocksToRemove.size } @@ -1095,7 +1102,7 @@ private[spark] class BlockManager( */ def removeBroadcast(broadcastId: Long, tellMaster: Boolean): Int = { logDebug(s"Removing broadcast $broadcastId") - val blocksToRemove = blockInfo.keys.collect { + val blocksToRemove = blockInfo.asScala.keys.collect { case bid @ BroadcastBlockId(`broadcastId`, _) => bid } blocksToRemove.foreach { blockId => removeBlock(blockId, tellMaster) } @@ -1107,7 +1114,7 @@ private[spark] class BlockManager( */ def removeBlock(blockId: BlockId, tellMaster: Boolean = true): Unit = { logDebug(s"Removing block $blockId") - val info = blockInfo.get(blockId).orNull + val info = blockInfo.get(blockId) if (info != null) { info.synchronized { // Removals are idempotent in disk store and memory store. At worst, we get a warning. @@ -1131,36 +1138,6 @@ private[spark] class BlockManager( } } - private def dropOldNonBroadcastBlocks(cleanupTime: Long): Unit = { - logInfo(s"Dropping non broadcast blocks older than $cleanupTime") - dropOldBlocks(cleanupTime, !_.isBroadcast) - } - - private def dropOldBroadcastBlocks(cleanupTime: Long): Unit = { - logInfo(s"Dropping broadcast blocks older than $cleanupTime") - dropOldBlocks(cleanupTime, _.isBroadcast) - } - - private def dropOldBlocks(cleanupTime: Long, shouldDrop: (BlockId => Boolean)): Unit = { - val iterator = blockInfo.getEntrySet.iterator - while (iterator.hasNext) { - val entry = iterator.next() - val (id, info, time) = (entry.getKey, entry.getValue.value, entry.getValue.timestamp) - if (time < cleanupTime && shouldDrop(id)) { - info.synchronized { - val level = info.level - if (level.useMemory) { memoryStore.remove(id) } - if (level.useDisk) { diskStore.remove(id) } - if (level.useOffHeap) { externalBlockStore.remove(id) } - iterator.remove() - logInfo(s"Dropped block $id") - } - val status = getCurrentBlockStatus(id, info) - reportBlockStatus(id, info, status) - } - } - } - private def shouldCompress(blockId: BlockId): Boolean = { blockId match { case _: ShuffleBlockId => compressShuffle @@ -1238,8 +1215,6 @@ private[spark] class BlockManager( if (externalBlockStoreInitialized) { externalBlockStore.clear() } - metadataCleaner.cancel() - broadcastCleaner.cancel() futureExecutionContext.shutdownNow() logInfo("BlockManager stopped") } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 440c4c18aadd..da1de11d605c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -21,10 +21,10 @@ import scala.collection.Iterable import scala.collection.generic.CanBuildFrom import scala.concurrent.{Await, Future} -import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ThreadUtils, RpcUtils} +import org.apache.spark.util.{RpcUtils, ThreadUtils} private[spark] class BlockManagerMaster( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 41892b4ffce5..4db400a3442c 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -23,9 +23,9 @@ import scala.collection.mutable import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcCallContext, ThreadSafeRpcEndpoint} import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{ThreadUtils, Utils} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index f7e84a2c2e14..4daf22f71415 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -17,10 +17,10 @@ package org.apache.spark.storage +import java.io.{File, IOException} import java.util.UUID -import java.io.{IOException, File} -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.executor.ExecutorExitCode import org.apache.spark.util.{ShutdownHookManager, Utils} diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index e2dd80f24393..e36a367323b2 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -17,12 +17,12 @@ package org.apache.spark.storage -import java.io.{BufferedOutputStream, FileOutputStream, File, OutputStream} +import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream} import java.nio.channels.FileChannel import org.apache.spark.Logging -import org.apache.spark.serializer.{SerializerInstance, SerializationStream} import org.apache.spark.executor.ShuffleWriteMetrics +import org.apache.spark.serializer.{SerializationStream, SerializerInstance} import org.apache.spark.util.Utils /** diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 6c4477184d5b..1f3f193f2ffa 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -17,7 +17,7 @@ package org.apache.spark.storage -import java.io.{IOException, File, FileOutputStream, RandomAccessFile} +import java.io.{File, FileOutputStream, IOException, RandomAccessFile} import java.nio.ByteBuffer import java.nio.channels.FileChannel.MapMode diff --git a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala index 94e8559bd2e9..673f7ad79def 100644 --- a/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala +++ b/core/src/main/scala/org/apache/spark/storage/RDDInfo.scala @@ -18,7 +18,7 @@ package org.apache.spark.storage import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.rdd.{RDDOperationScope, RDD} +import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.util.{CallSite, Utils} @DeveloperApi diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 0d0448feb5b0..037bec1d9c33 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -36,7 +36,7 @@ import org.apache.spark.util.Utils * This creates an iterator of (BlockID, InputStream) tuples so the caller can handle blocks * in a pipelined fashion as they are received. * - * The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid + * The implementation throttles the remote fetches so they don't exceed maxBytesInFlight to avoid * using too much memory. * * @param context [[TaskContext]], used for metrics update @@ -329,7 +329,7 @@ final class ShuffleBlockFetcherIterator( } /** - * Helper class that ensures a ManagedBuffer is release upon InputStream.close() + * Helper class that ensures a ManagedBuffer is released upon InputStream.close() */ private class BufferReleasingInputStream( private val delegate: InputStream, diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala index d14fe4613528..6aa7e1390177 100644 --- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala @@ -25,15 +25,17 @@ import java.util.{Date, Random} import scala.util.control.NonFatal import com.google.common.io.ByteStreams - -import tachyon.client.{ReadType, WriteType, TachyonFS, TachyonFile} +import tachyon.{Constants, TachyonURI} +import tachyon.client.ClientContext +import tachyon.client.file.{TachyonFile, TachyonFileSystem} +import tachyon.client.file.TachyonFileSystem.TachyonFileSystemFactory +import tachyon.client.file.options.DeleteOptions import tachyon.conf.TachyonConf -import tachyon.TachyonURI +import tachyon.exception.{FileAlreadyExistsException, FileDoesNotExistException} import org.apache.spark.Logging import org.apache.spark.executor.ExecutorExitCode -import org.apache.spark.util.{ShutdownHookManager, Utils} - +import org.apache.spark.util.Utils /** * Creates and maintains the logical mapping between logical blocks and tachyon fs locations. By @@ -44,15 +46,15 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log var rootDirs: String = _ var master: String = _ - var client: tachyon.client.TachyonFS = _ + var client: TachyonFileSystem = _ private var subDirsPerTachyonDir: Int = _ // Create one Tachyon directory for each path mentioned in spark.tachyonStore.folderName; // then, inside this directory, create multiple subdirectories that we will hash files into, // in order to avoid having really large inodes at the top level in Tachyon. private var tachyonDirs: Array[TachyonFile] = _ - private var subDirs: Array[Array[tachyon.client.TachyonFile]] = _ - + private var subDirs: Array[Array[TachyonFile]] = _ + private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() override def init(blockManager: BlockManager, executorId: String): Unit = { super.init(blockManager, executorId) @@ -62,7 +64,10 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log rootDirs = s"$storeDir/$appFolderName/$executorId" master = blockManager.conf.get(ExternalBlockStore.MASTER_URL, "tachyon://localhost:19998") client = if (master != null && master != "") { - TachyonFS.get(new TachyonURI(master), new TachyonConf()) + val tachyonConf = new TachyonConf() + tachyonConf.set(Constants.MASTER_ADDRESS, master) + ClientContext.reset(tachyonConf) + TachyonFileSystemFactory.get } else { null } @@ -80,7 +85,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log // in order to avoid having really large inodes at the top level in Tachyon. tachyonDirs = createTachyonDirs() subDirs = Array.fill(tachyonDirs.length)(new Array[TachyonFile](subDirsPerTachyonDir)) - tachyonDirs.foreach(tachyonDir => ShutdownHookManager.registerShutdownDeleteDir(tachyonDir)) + tachyonDirs.foreach(registerShutdownDeleteDir) } override def toString: String = {"ExternalBlockStore-Tachyon"} @@ -89,6 +94,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log val file = getFile(blockId) if (fileExists(file)) { removeFile(file) + true } else { false } @@ -101,7 +107,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def putBytes(blockId: BlockId, bytes: ByteBuffer): Unit = { val file = getFile(blockId) - val os = file.getOutStream(WriteType.TRY_CACHE) + val os = client.getOutStream(new TachyonURI(client.getInfo(file).getPath)) try { Utils.writeByteBuffer(bytes, os) } catch { @@ -115,7 +121,7 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def putValues(blockId: BlockId, values: Iterator[_]): Unit = { val file = getFile(blockId) - val os = file.getOutStream(WriteType.TRY_CACHE) + val os = client.getOutStream(new TachyonURI(client.getInfo(file).getPath)) try { blockManager.dataSerializeStream(blockId, os, values) } catch { @@ -129,12 +135,17 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def getBytes(blockId: BlockId): Option[ByteBuffer] = { val file = getFile(blockId) - if (file == null || file.getLocationHosts.size == 0) { + if (file == null) { return None } - val is = file.getInStream(ReadType.CACHE) + val is = try { + client.getInStream(file) + } catch { + case _: FileDoesNotExistException => + return None + } try { - val size = file.length + val size = client.getInfo(file).length val bs = new Array[Byte](size.asInstanceOf[Int]) ByteStreams.readFully(is, bs) Some(ByteBuffer.wrap(bs)) @@ -149,25 +160,37 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log override def getValues(blockId: BlockId): Option[Iterator[_]] = { val file = getFile(blockId) - if (file == null || file.getLocationHosts().size() == 0) { + if (file == null) { return None } - val is = file.getInStream(ReadType.CACHE) - Option(is).map { is => - blockManager.dataDeserializeStream(blockId, is) + val is = try { + client.getInStream(file) + } catch { + case _: FileDoesNotExistException => + return None + } + try { + Some(blockManager.dataDeserializeStream(blockId, is)) + } finally { + is.close() } } override def getSize(blockId: BlockId): Long = { - getFile(blockId.name).length + client.getInfo(getFile(blockId.name)).length } - def removeFile(file: TachyonFile): Boolean = { - client.delete(new TachyonURI(file.getPath()), false) + def removeFile(file: TachyonFile): Unit = { + client.delete(file) } def fileExists(file: TachyonFile): Boolean = { - client.exist(new TachyonURI(file.getPath())) + try { + client.getInfo(file) + true + } catch { + case _: FileDoesNotExistException => false + } } def getFile(filename: String): TachyonFile = { @@ -186,18 +209,18 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log } else { val path = new TachyonURI(s"${tachyonDirs(dirId)}/${"%02x".format(subDirId)}") client.mkdir(path) - val newDir = client.getFile(path) + val newDir = client.loadMetadata(path) subDirs(dirId)(subDirId) = newDir newDir } } } val filePath = new TachyonURI(s"$subDir/$filename") - if(!client.exist(filePath)) { - client.createFile(filePath) + try { + client.create(filePath) + } catch { + case _: FileAlreadyExistsException => client.loadMetadata(filePath) } - val file = client.getFile(filePath) - file } def getFile(blockId: BlockId): TachyonFile = getFile(blockId.name) @@ -217,9 +240,11 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log try { tachyonDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536)) val path = new TachyonURI(s"$rootDir/spark-tachyon-$tachyonDirId") - if (!client.exist(path)) { + try { foundLocalDir = client.mkdir(path) - tachyonDir = client.getFile(path) + tachyonDir = client.loadMetadata(path) + } catch { + case _: FileAlreadyExistsException => // continue } } catch { case NonFatal(e) => @@ -240,14 +265,60 @@ private[spark] class TachyonBlockManager() extends ExternalBlockManager with Log logDebug("Shutdown hook called") tachyonDirs.foreach { tachyonDir => try { - if (!ShutdownHookManager.hasRootAsShutdownDeleteDir(tachyonDir)) { - Utils.deleteRecursively(tachyonDir, client) + if (!hasRootAsShutdownDeleteDir(tachyonDir)) { + deleteRecursively(tachyonDir, client) } } catch { case NonFatal(e) => logError("Exception while deleting tachyon spark dir: " + tachyonDir, e) } } - client.close() } + + /** + * Delete a file or directory and its contents recursively. + */ + private def deleteRecursively(dir: TachyonFile, client: TachyonFileSystem) { + client.delete(dir, new DeleteOptions.Builder(ClientContext.getConf).setRecursive(true).build()) + } + + // Register the tachyon path to be deleted via shutdown hook + private def registerShutdownDeleteDir(file: TachyonFile) { + val absolutePath = client.getInfo(file).getPath + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths += absolutePath + } + } + + // Remove the tachyon path to be deleted via shutdown hook + private def removeShutdownDeleteDir(file: TachyonFile) { + val absolutePath = client.getInfo(file).getPath + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths -= absolutePath + } + } + + // Is the path already registered to be deleted via a shutdown hook ? + private def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { + val absolutePath = client.getInfo(file).getPath + shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.contains(absolutePath) + } + } + + // Note: if file is child of some registered path, while not equal to it, then return true; + // else false. This is to ensure that two shutdown hooks do not try to delete each others + // paths - resulting in Exception and incomplete cleanup. + private def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { + val absolutePath = client.getInfo(file).getPath + val hasRoot = shutdownDeleteTachyonPaths.synchronized { + shutdownDeleteTachyonPaths.exists( + path => !absolutePath.equals(path) && absolutePath.startsWith(path)) + } + if (hasRoot) { + logInfo(s"path = $absolutePath, already present as root for deletion.") + } + hasRoot + } + } diff --git a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala index 6e2375477a68..9b6ed8cbbef1 100644 --- a/core/src/main/scala/org/apache/spark/ui/PagedTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/PagedTable.scala @@ -17,8 +17,15 @@ package org.apache.spark.ui +import java.net.URLDecoder + +import scala.collection.JavaConverters._ import scala.xml.{Node, Unparsed} +import com.google.common.base.Splitter + +import org.apache.spark.util.Utils + /** * A data source that provides data for a page. * @@ -71,6 +78,12 @@ private[ui] trait PagedTable[T] { def tableCssClass: String + def pageSizeFormField: String + + def prevPageSizeFormField: String + + def pageNumberFormField: String + def dataSource: PagedDataSource[T] def headers: Seq[Node] @@ -95,7 +108,12 @@ private[ui] trait PagedTable[T] { val PageData(totalPages, _) = _dataSource.pageData(1)
{pageNavigation(1, _dataSource.pageSize, totalPages)} -
{e.getMessage}
+
+

Error while rendering table:

+
+              {Utils.exceptionString(e)}
+            
+
} } @@ -151,36 +169,56 @@ private[ui] trait PagedTable[T] { // The current page should be disabled so that it cannot be clicked.
  • {p}
  • } else { -
  • {p}
  • +
  • {p}
  • + } + } + + val hiddenFormFields = { + if (goButtonFormPath.contains('?')) { + val querystring = goButtonFormPath.split("\\?", 2)(1) + Splitter + .on('&') + .trimResults() + .withKeyValueSeparator("=") + .split(querystring) + .asScala + .filterKeys(_ != pageSizeFormField) + .filterKeys(_ != prevPageSizeFormField) + .filterKeys(_ != pageNumberFormField) + .mapValues(URLDecoder.decode(_, "UTF-8")) + .map { case (k, v) => + + } + } else { + Seq.empty } } - val (goButtonJsFuncName, goButtonJsFunc) = goButtonJavascriptFunction - // When clicking the "Go" button, it will call this javascript method and then call - // "goButtonJsFuncName" - val formJs = - s"""$$(function(){ - | $$( "#form-$tableId-page" ).submit(function(event) { - | var page = $$("#form-$tableId-page-no").val() - | var pageSize = $$("#form-$tableId-page-size").val() - | pageSize = pageSize ? pageSize: 100; - | if (page != "") { - | ${goButtonJsFuncName}(page, pageSize); - | } - | event.preventDefault(); - | }); - |}); - """.stripMargin
    + method="get" + action={Unparsed(goButtonFormPath)} + class="form-inline pull-right" + style="margin-bottom: 0px;"> + + {hiddenFormFields} - + + + id={s"form-$tableId-page-size"} + name={pageSizeFormField} + value={pageSize.toString} + class="span1" /> +
    @@ -189,7 +227,7 @@ private[ui] trait PagedTable[T] {
    - } } @@ -239,10 +272,7 @@ private[ui] trait PagedTable[T] { def pageLink(page: Int): String /** - * Only the implementation knows how to create the url with a page number and the page size, so we - * leave this one to the implementation. The implementation should create a JavaScript method that - * accepts a page number along with the page size and jumps to the page. The return value is this - * method name and its JavaScript codes. + * Returns the submission path for the "go to page #" form. */ - def goButtonJavascriptFunction: (String, String) + def goButtonFormPath: String } diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 8da6884a3853..e319937702f2 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -21,18 +21,18 @@ import java.util.{Date, ServiceLoader} import scala.collection.JavaConverters._ -import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, - UIRoot} -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext} import org.apache.spark.scheduler._ +import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, + UIRoot} import org.apache.spark.storage.StorageStatusListener import org.apache.spark.ui.JettyUtils._ import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab} import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab} -import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab} -import org.apache.spark.ui.storage.{StorageListener, StorageTab} +import org.apache.spark.ui.jobs.{JobProgressListener, JobsTab, StagesTab} import org.apache.spark.ui.scope.RDDOperationGraphListener +import org.apache.spark.ui.storage.{StorageListener, StorageTab} +import org.apache.spark.util.Utils /** * Top level user interface for a Spark application. diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 81a121fd441b..392523598472 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -26,9 +26,9 @@ import scala.xml.Node import org.eclipse.jetty.servlet.ServletContextHandler import org.json4s.JsonAST.{JNothing, JValue} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.ui.JettyUtils._ import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * The top level component of the UI hierarchy that contains the server. diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index 58575d154ce5..1a6f0fdd50df 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -21,7 +21,7 @@ import java.net.URLDecoder import javax.servlet.http.HttpServletRequest import scala.util.Try -import scala.xml.{Text, Node} +import scala.xml.{Node, Text} import org.apache.spark.ui.{UIUtils, WebUIPage} diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index a88fc4c37d3c..2d955a66601e 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui.exec import scala.collection.mutable.HashMap -import org.apache.spark.{Resubmitted, ExceptionFailure, SparkContext} +import org.apache.spark.{ExceptionFailure, Resubmitted, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ import org.apache.spark.storage.{StorageStatus, StorageStatusListener} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala index d467dd9e1f29..451cd83b51ae 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala @@ -24,8 +24,8 @@ import scala.collection.mutable.{HashMap, ListBuffer} import scala.xml._ import org.apache.spark.JobExecutionStatus -import org.apache.spark.ui.jobs.UIData.{ExecutorUIData, JobUIData} import org.apache.spark.ui.{ToolTips, UIUtils, WebUIPage} +import org.apache.spark.ui.jobs.UIData.{ExecutorUIData, JobUIData} /** Page showing list of all ongoing and recently finished jobs */ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { @@ -224,10 +224,10 @@ private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") { } val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown") val formattedSubmissionTime = job.submissionTime.map(UIUtils.formatDate).getOrElse("Unknown") - val jobDescription = UIUtils.makeDescription(lastStageDescription, parent.basePath) + val basePathUri = UIUtils.prependBaseUri(parent.basePath) + val jobDescription = UIUtils.makeDescription(lastStageDescription, basePathUri) - val detailUrl = - "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId) + val detailUrl = "%s/jobs/job?id=%s".format(basePathUri, job.jobId) {job.jobId} {job.jobGroup.map(id => s"($id)").getOrElse("")} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala index 5e52942b64f3..e75f1c57a69d 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala @@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.{Node, NodeSeq} import org.apache.spark.scheduler.Schedulable -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing list of all ongoing and recently finished stages and pools */ private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 1268f44596f8..1304efd8f2ec 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui.jobs import scala.collection.mutable -import scala.xml.{Unparsed, Node} +import scala.xml.{Node, Unparsed} import org.apache.spark.ui.{ToolTips, UIUtils} import org.apache.spark.ui.jobs.UIData.StageUIData diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index 2cad0a796913..654d988807f9 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -18,11 +18,10 @@ package org.apache.spark.ui.jobs import java.util.Date +import javax.servlet.http.HttpServletRequest import scala.collection.mutable.{Buffer, HashMap, ListBuffer} -import scala.xml.{NodeSeq, Node, Unparsed, Utility} - -import javax.servlet.http.HttpServletRequest +import scala.xml.{Node, NodeSeq, Unparsed, Utility} import org.apache.spark.JobExecutionStatus import org.apache.spark.scheduler.StageInfo diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala index f3e0b38523f3..fa30f2bda427 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala @@ -22,7 +22,7 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.spark.scheduler.StageInfo -import org.apache.spark.ui.{WebUIPage, UIUtils} +import org.apache.spark.ui.{UIUtils, WebUIPage} /** Page showing specific pool details */ private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") { diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 1b34ba9f03c4..2cc6c75a9ac1 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -31,7 +31,7 @@ import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo, TaskLocality} import org.apache.spark.ui._ import org.apache.spark.ui.jobs.UIData._ -import org.apache.spark.util.{Utils, Distribution} +import org.apache.spark.util.{Distribution, Utils} /** Page showing statistics and task list for a given stage */ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { @@ -97,11 +97,13 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val parameterTaskSortColumn = request.getParameter("task.sort") val parameterTaskSortDesc = request.getParameter("task.desc") val parameterTaskPageSize = request.getParameter("task.pageSize") + val parameterTaskPrevPageSize = request.getParameter("task.prevPageSize") val taskPage = Option(parameterTaskPage).map(_.toInt).getOrElse(1) val taskSortColumn = Option(parameterTaskSortColumn).getOrElse("Index") val taskSortDesc = Option(parameterTaskSortDesc).map(_.toBoolean).getOrElse(false) val taskPageSize = Option(parameterTaskPageSize).map(_.toInt).getOrElse(100) + val taskPrevPageSize = Option(parameterTaskPrevPageSize).map(_.toInt).getOrElse(taskPageSize) // If this is set, expand the dag visualization by default val expandDagVizParam = request.getParameter("expandDagViz") @@ -274,6 +276,15 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { accumulableRow, externalAccumulables.toSeq) + val page: Int = { + // If the user has changed to a larger page size, then go to page 1 in order to avoid + // IndexOutOfBoundsException. + if (taskPageSize <= taskPrevPageSize) { + taskPage + } else { + 1 + } + } val currentTime = System.currentTimeMillis() val (taskTable, taskTableHTML) = try { val _taskTable = new TaskPagedTable( @@ -292,10 +303,17 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { sortColumn = taskSortColumn, desc = taskSortDesc ) - (_taskTable, _taskTable.table(taskPage)) + (_taskTable, _taskTable.table(page)) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => - (null,
    {e.getMessage}
    ) + val errorMessage = +
    +

    Error while rendering stage table:

    +
    +                {Utils.exceptionString(e)}
    +              
    +
    + (null, errorMessage) } val jsForScrollingDownToTaskTable = @@ -1215,7 +1233,14 @@ private[ui] class TaskPagedTable( override def tableId: String = "task-table" - override def tableCssClass: String = "table table-bordered table-condensed table-striped" + override def tableCssClass: String = + "table table-bordered table-condensed table-striped table-head-clickable" + + override def pageSizeFormField: String = "task.pageSize" + + override def prevPageSizeFormField: String = "task.prevPageSize" + + override def pageNumberFormField: String = "task.page" override val dataSource: TaskDataSource = new TaskDataSource( data, @@ -1232,24 +1257,16 @@ private[ui] class TaskPagedTable( override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - s"${basePath}&task.page=$page&task.sort=${encodedSortColumn}&task.desc=${desc}" + - s"&task.pageSize=${pageSize}" + basePath + + s"&$pageNumberFormField=$page" + + s"&task.sort=$encodedSortColumn" + + s"&task.desc=$desc" + + s"&$pageSizeFormField=$pageSize" } - override def goButtonJavascriptFunction: (String, String) = { - val jsFuncName = "goToTaskPage" + override def goButtonFormPath: String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - val jsFunc = s""" - |currentTaskPageSize = ${pageSize} - |function goToTaskPage(page, pageSize) { - | // Set page to 1 if the page size changes - | page = pageSize == currentTaskPageSize ? page : 1; - | var url = "${basePath}&task.sort=${encodedSortColumn}&task.desc=${desc}" + - | "&task.page=" + page + "&task.pageSize=" + pageSize; - | window.location.href = url; - |} - """.stripMargin - (jsFuncName, jsFunc) + s"$basePath&task.sort=$encodedSortColumn&task.desc=$desc" } def headers: Seq[Node] = { @@ -1298,21 +1315,27 @@ private[ui] class TaskPagedTable( val headerRow: Seq[Node] = { taskHeadersAndCssClasses.map { case (header, cssClass) => if (header == sortColumn) { - val headerLink = - s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.desc=${!desc}" + - s"&task.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") + val headerLink = Unparsed( + basePath + + s"&task.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&task.desc=${!desc}" + + s"&task.pageSize=$pageSize") val arrow = if (desc) "▾" else "▴" // UP or DOWN - - {header} -  {Unparsed(arrow)} + +
    + {header} +  {Unparsed(arrow)} + } else { - val headerLink = - s"$basePath&task.sort=${URLEncoder.encode(header, "UTF-8")}&task.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") - - {header} + val headerLink = Unparsed( + basePath + + s"&task.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&task.pageSize=$pageSize") + + + {header} + } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala index f008d4018061..78165d7b743e 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/UIData.scala @@ -17,14 +17,14 @@ package org.apache.spark.ui.jobs +import scala.collection.mutable +import scala.collection.mutable.HashMap + import org.apache.spark.JobExecutionStatus import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo} import org.apache.spark.util.collection.OpenHashSet -import scala.collection.mutable -import scala.collection.mutable.HashMap - private[spark] object UIData { class ExecutorSummary { diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index e9c8a8e299cd..06da74f1b6b5 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui.scope import scala.collection.mutable -import scala.collection.mutable.{StringBuilder, ListBuffer} +import scala.collection.mutable.{ListBuffer, StringBuilder} import org.apache.spark.Logging import org.apache.spark.scheduler.StageInfo diff --git a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala index fd6cc3ed759b..606d15d599e8 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/RDDPage.scala @@ -38,11 +38,13 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { val parameterBlockSortColumn = request.getParameter("block.sort") val parameterBlockSortDesc = request.getParameter("block.desc") val parameterBlockPageSize = request.getParameter("block.pageSize") + val parameterBlockPrevPageSize = request.getParameter("block.prevPageSize") val blockPage = Option(parameterBlockPage).map(_.toInt).getOrElse(1) val blockSortColumn = Option(parameterBlockSortColumn).getOrElse("Block Name") val blockSortDesc = Option(parameterBlockSortDesc).map(_.toBoolean).getOrElse(false) val blockPageSize = Option(parameterBlockPageSize).map(_.toInt).getOrElse(100) + val blockPrevPageSize = Option(parameterBlockPrevPageSize).map(_.toInt).getOrElse(blockPageSize) val rddId = parameterId.toInt val rddStorageInfo = AllRDDResource.getRDDStorageInfo(rddId, listener, includeDetails = true) @@ -56,17 +58,26 @@ private[ui] class RDDPage(parent: StorageTab) extends WebUIPage("rdd") { rddStorageInfo.dataDistribution.get, id = Some("rdd-storage-by-worker-table")) // Block table - val (blockTable, blockTableHTML) = try { + val page: Int = { + // If the user has changed to a larger page size, then go to page 1 in order to avoid + // IndexOutOfBoundsException. + if (blockPageSize <= blockPrevPageSize) { + blockPage + } else { + 1 + } + } + val blockTableHTML = try { val _blockTable = new BlockPagedTable( UIUtils.prependBaseUri(parent.basePath) + s"/storage/rdd/?id=${rddId}", rddStorageInfo.partitions.get, blockPageSize, blockSortColumn, blockSortDesc) - (_blockTable, _blockTable.table(blockPage)) + _blockTable.table(page) } catch { case e @ (_ : IllegalArgumentException | _ : IndexOutOfBoundsException) => - (null,
    {e.getMessage}
    ) +
    {e.getMessage}
    } val jsForScrollingDownToBlockTable = @@ -226,7 +237,14 @@ private[ui] class BlockPagedTable( override def tableId: String = "rdd-storage-by-block-table" - override def tableCssClass: String = "table table-bordered table-condensed table-striped" + override def tableCssClass: String = + "table table-bordered table-condensed table-striped table-head-clickable" + + override def pageSizeFormField: String = "block.pageSize" + + override def prevPageSizeFormField: String = "block.prevPageSize" + + override def pageNumberFormField: String = "block.page" override val dataSource: BlockDataSource = new BlockDataSource( rddPartitions, @@ -236,24 +254,16 @@ private[ui] class BlockPagedTable( override def pageLink(page: Int): String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - s"${basePath}&block.page=$page&block.sort=${encodedSortColumn}&block.desc=${desc}" + - s"&block.pageSize=${pageSize}" + basePath + + s"&$pageNumberFormField=$page" + + s"&block.sort=$encodedSortColumn" + + s"&block.desc=$desc" + + s"&$pageSizeFormField=$pageSize" } - override def goButtonJavascriptFunction: (String, String) = { - val jsFuncName = "goToBlockPage" + override def goButtonFormPath: String = { val encodedSortColumn = URLEncoder.encode(sortColumn, "UTF-8") - val jsFunc = s""" - |currentBlockPageSize = ${pageSize} - |function goToBlockPage(page, pageSize) { - | // Set page to 1 if the page size changes - | page = pageSize == currentBlockPageSize ? page : 1; - | var url = "${basePath}&block.sort=${encodedSortColumn}&block.desc=${desc}" + - | "&block.page=" + page + "&block.pageSize=" + pageSize; - | window.location.href = url; - |} - """.stripMargin - (jsFuncName, jsFunc) + s"$basePath&block.sort=$encodedSortColumn&block.desc=$desc" } override def headers: Seq[Node] = { @@ -271,22 +281,27 @@ private[ui] class BlockPagedTable( val headerRow: Seq[Node] = { blockHeaders.map { header => if (header == sortColumn) { - val headerLink = - s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}&block.desc=${!desc}" + - s"&block.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") + val headerLink = Unparsed( + basePath + + s"&block.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&block.desc=${!desc}" + + s"&block.pageSize=$pageSize") val arrow = if (desc) "▾" else "▴" // UP or DOWN - - {header} -  {Unparsed(arrow)} + + + {header} +  {Unparsed(arrow)} + } else { - val headerLink = - s"$basePath&block.sort=${URLEncoder.encode(header, "UTF-8")}" + - s"&block.pageSize=${pageSize}" - val js = Unparsed(s"window.location.href='${headerLink}'") - - {header} + val headerLink = Unparsed( + basePath + + s"&block.sort=${URLEncoder.encode(header, "UTF-8")}" + + s"&block.pageSize=$pageSize") + + + {header} + } } diff --git a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala index 22e2993b3b5b..2d9b885c684b 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/StorageTab.scala @@ -20,9 +20,9 @@ package org.apache.spark.ui.storage import scala.collection.mutable import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ui._ import org.apache.spark.scheduler._ import org.apache.spark.storage._ +import org.apache.spark.ui._ /** Web UI showing storage status of all RDD's in the given SparkContext. */ private[ui] class StorageTab(parent: SparkUI) extends SparkUITab(parent, "storage") { diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala deleted file mode 100644 index 81a7cbde01ce..000000000000 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import akka.actor.Actor -import org.slf4j.Logger - -/** - * A trait to enable logging all Akka actor messages. Here's an example of using this: - * - * {{{ - * class BlockManagerMasterActor extends Actor with ActorLogReceive with Logging { - * ... - * override def receiveWithLogging = { - * case GetLocations(blockId) => - * sender ! getLocations(blockId) - * ... - * } - * ... - * } - * }}} - * - */ -private[spark] trait ActorLogReceive { - self: Actor => - - override def receive: Actor.Receive = new Actor.Receive { - - private val _receiveWithLogging = receiveWithLogging - - override def isDefinedAt(o: Any): Boolean = { - val handled = _receiveWithLogging.isDefinedAt(o) - if (!handled) { - log.debug(s"Received unexpected actor system event: $o") - } - handled - } - - override def apply(o: Any): Unit = { - if (log.isDebugEnabled) { - log.debug(s"[actor] received message $o from ${self.sender}") - } - val start = System.nanoTime - _receiveWithLogging.apply(o) - val timeTaken = (System.nanoTime - start).toDouble / 1000000 - if (log.isDebugEnabled) { - log.debug(s"[actor] handled message ($timeTaken ms) $o from ${self.sender}") - } - } - } - - def receiveWithLogging: Actor.Receive - - protected def log: Logger -} diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 1738258a0c79..f2d93edd4fd2 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -19,14 +19,11 @@ package org.apache.spark.util import scala.collection.JavaConverters._ -import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} -import akka.pattern.ask - +import akka.actor.{ActorSystem, ExtendedActorSystem} import com.typesafe.config.ConfigFactory import org.apache.log4j.{Level, Logger} -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv, SparkException} -import org.apache.spark.rpc.RpcTimeout +import org.apache.spark.{Logging, SecurityManager, SparkConf} /** * Various utility classes for working with Akka. @@ -139,104 +136,4 @@ private[spark] object AkkaUtils extends Logging { /** Space reserved for extra data in an Akka message besides serialized task or task result. */ val reservedSizeBytes = 200 * 1024 - /** - * Send a message to the given actor and get its result within a default timeout, or - * throw a SparkException if this fails. - */ - def askWithReply[T]( - message: Any, - actor: ActorRef, - timeout: RpcTimeout): T = { - askWithReply[T](message, actor, maxAttempts = 1, retryInterval = Int.MaxValue, timeout) - } - - /** - * Send a message to the given actor and get its result within a default timeout, or - * throw a SparkException if this fails even after the specified number of retries. - */ - def askWithReply[T]( - message: Any, - actor: ActorRef, - maxAttempts: Int, - retryInterval: Long, - timeout: RpcTimeout): T = { - // TODO: Consider removing multiple attempts - if (actor == null) { - throw new SparkException(s"Error sending message [message = $message]" + - " as actor is null ") - } - var attempts = 0 - var lastException: Exception = null - while (attempts < maxAttempts) { - attempts += 1 - try { - val future = actor.ask(message)(timeout.duration) - val result = timeout.awaitResult(future) - if (result == null) { - throw new SparkException("Actor returned null") - } - return result.asInstanceOf[T] - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning(s"Error sending message [message = $message] in $attempts attempts", e) - } - if (attempts < maxAttempts) { - Thread.sleep(retryInterval) - } - } - - throw new SparkException( - s"Error sending message [message = $message]", lastException) - } - - def makeDriverRef(name: String, conf: SparkConf, actorSystem: ActorSystem): ActorRef = { - val driverActorSystemName = SparkEnv.driverActorSystemName - val driverHost: String = conf.get("spark.driver.host", "localhost") - val driverPort: Int = conf.getInt("spark.driver.port", 7077) - Utils.checkHost(driverHost, "Expected hostname") - val url = address(protocol(actorSystem), driverActorSystemName, driverHost, driverPort, name) - val timeout = RpcUtils.lookupRpcTimeout(conf) - logInfo(s"Connecting to $name: $url") - timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) - } - - def makeExecutorRef( - name: String, - conf: SparkConf, - host: String, - port: Int, - actorSystem: ActorSystem): ActorRef = { - val executorActorSystemName = SparkEnv.executorActorSystemName - Utils.checkHost(host, "Expected hostname") - val url = address(protocol(actorSystem), executorActorSystemName, host, port, name) - val timeout = RpcUtils.lookupRpcTimeout(conf) - logInfo(s"Connecting to $name: $url") - timeout.awaitResult(actorSystem.actorSelection(url).resolveOne(timeout.duration)) - } - - def protocol(actorSystem: ActorSystem): String = { - val akkaConf = actorSystem.settings.config - val sslProp = "akka.remote.netty.tcp.enable-ssl" - protocol(akkaConf.hasPath(sslProp) && akkaConf.getBoolean(sslProp)) - } - - def protocol(ssl: Boolean = false): String = { - if (ssl) { - "akka.ssl.tcp" - } else { - "akka.tcp" - } - } - - def address( - protocol: String, - systemName: String, - host: String, - port: Int, - actorName: String): String = { - s"$protocol://$systemName@$host:$port/user/$actorName" - } - } diff --git a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala index 6c1fca71f228..f6b7ea2f3786 100644 --- a/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/util/AsynchronousListenerBus.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.util.concurrent._ import java.util.concurrent.atomic.AtomicBoolean + import scala.util.DynamicVariable import org.apache.spark.SparkContext diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala new file mode 100644 index 000000000000..457a1a05a1bf --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import scala.collection.mutable + +import org.apache.commons.lang3.SystemUtils + +/** + * Utility class to benchmark components. An example of how to use this is: + * val benchmark = new Benchmark("My Benchmark", valuesPerIteration) + * benchmark.addCase("V1")() + * benchmark.addCase("V2")() + * benchmark.run + * This will output the average time to run each function and the rate of each function. + * + * The benchmark function takes one argument that is the iteration that's being run. + * + * If outputPerIteration is true, the timing for each run will be printed to stdout. + */ +private[spark] class Benchmark( + name: String, valuesPerIteration: Long, + iters: Int = 5, + outputPerIteration: Boolean = false) { + val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case] + + def addCase(name: String)(f: Int => Unit): Unit = { + benchmarks += Benchmark.Case(name, f) + } + + /** + * Runs the benchmark and outputs the results to stdout. This should be copied and added as + * a comment with the benchmark. Although the results vary from machine to machine, it should + * provide some baseline. + */ + def run(): Unit = { + require(benchmarks.nonEmpty) + // scalastyle:off + println("Running benchmark: " + name) + + val results = benchmarks.map { c => + println(" Running case: " + c.name) + Benchmark.measure(valuesPerIteration, iters, outputPerIteration)(c.fn) + } + println + + val firstRate = results.head.avgRate + // The results are going to be processor specific so it is useful to include that. + println(Benchmark.getProcessorName()) + printf("%-24s %16s %16s %14s\n", name + ":", "Avg Time(ms)", "Avg Rate(M/s)", "Relative Rate") + println("-------------------------------------------------------------------------") + results.zip(benchmarks).foreach { r => + printf("%-24s %16s %16s %14s\n", + r._2.name, + "%10.2f" format r._1.avgMs, + "%10.2f" format r._1.avgRate, + "%6.2f X" format (r._1.avgRate / firstRate)) + } + println + // scalastyle:on + } +} + +private[spark] object Benchmark { + case class Case(name: String, fn: Int => Unit) + case class Result(avgMs: Double, avgRate: Double) + + /** + * This should return a user helpful processor information. Getting at this depends on the OS. + * This should return something like "Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz" + */ + def getProcessorName(): String = { + if (SystemUtils.IS_OS_MAC_OSX) { + Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string")) + } else if (SystemUtils.IS_OS_LINUX) { + Utils.executeAndGetOutput(Seq("/usr/bin/grep", "-m", "1", "\"model name\"", "/proc/cpuinfo")) + } else { + System.getenv("PROCESSOR_IDENTIFIER") + } + } + + /** + * Runs a single function `f` for iters, returning the average time the function took and + * the rate of the function. + */ + def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Int => Unit): Result = { + var totalTime = 0L + for (i <- 0 until iters + 1) { + val start = System.nanoTime() + + f(i) + + val end = System.nanoTime() + if (i != 0) totalTime += end - start + + if (outputPerIteration) { + // scalastyle:off + println(s"Iteration $i took ${(end - start) / 1000} microseconds") + // scalastyle:on + } + } + Result(totalTime.toDouble / 1000000 / iters, num * iters / (totalTime.toDouble / 1000)) + } +} + diff --git a/core/src/main/scala/org/apache/spark/util/EventLoop.scala b/core/src/main/scala/org/apache/spark/util/EventLoop.scala index e9b2b8d24b47..542c5fccf458 100644 --- a/core/src/main/scala/org/apache/spark/util/EventLoop.scala +++ b/core/src/main/scala/org/apache/spark/util/EventLoop.scala @@ -17,8 +17,8 @@ package org.apache.spark.util -import java.util.concurrent.atomic.AtomicBoolean import java.util.concurrent.{BlockingQueue, LinkedBlockingDeque} +import java.util.concurrent.atomic.AtomicBoolean import scala.util.control.NonFatal diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index cb0f1bf79f3d..a62fd2f33928 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -25,8 +25,8 @@ import scala.collection.Map import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.json4s.DefaultFormats -import org.json4s.JsonDSL._ import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark._ diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala deleted file mode 100644 index a8bbad086849..000000000000 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util.{Timer, TimerTask} - -import org.apache.spark.{Logging, SparkConf} - -/** - * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) - */ -private[spark] class MetadataCleaner( - cleanerType: MetadataCleanerType.MetadataCleanerType, - cleanupFunc: (Long) => Unit, - conf: SparkConf) - extends Logging -{ - val name = cleanerType.toString - - private val delaySeconds = MetadataCleaner.getDelaySeconds(conf, cleanerType) - private val periodSeconds = math.max(10, delaySeconds / 10) - private val timer = new Timer(name + " cleanup timer", true) - - - private val task = new TimerTask { - override def run() { - try { - cleanupFunc(System.currentTimeMillis() - (delaySeconds * 1000)) - logInfo("Ran metadata cleaner for " + name) - } catch { - case e: Exception => logError("Error running cleanup task for " + name, e) - } - } - } - - if (delaySeconds > 0) { - logDebug( - "Starting metadata cleaner for " + name + " with delay of " + delaySeconds + " seconds " + - "and period of " + periodSeconds + " secs") - timer.schedule(task, delaySeconds * 1000, periodSeconds * 1000) - } - - def cancel() { - timer.cancel() - } -} - -private[spark] object MetadataCleanerType extends Enumeration { - - val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, BLOCK_MANAGER, - SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value - - type MetadataCleanerType = Value - - def systemProperty(which: MetadataCleanerType.MetadataCleanerType): String = { - "spark.cleaner.ttl." + which.toString - } -} - -// TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the -// initialization of StreamingContext. It's okay for users trying to configure stuff themselves. -private[spark] object MetadataCleaner { - def getDelaySeconds(conf: SparkConf): Int = { - conf.getTimeAsSeconds("spark.cleaner.ttl", "-1").toInt - } - - def getDelaySeconds( - conf: SparkConf, - cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { - conf.get(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString).toInt - } - - def setDelaySeconds( - conf: SparkConf, - cleanerType: MetadataCleanerType.MetadataCleanerType, - delay: Int) { - conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) - } - - /** - * Set the default delay time (in seconds). - * @param conf SparkConf instance - * @param delay default delay time to set - * @param resetAll whether to reset all to default - */ - def setDelaySeconds(conf: SparkConf, delay: Int, resetAll: Boolean = true) { - conf.set("spark.cleaner.ttl", delay.toString) - if (resetAll) { - for (cleanerType <- MetadataCleanerType.values) { - System.clearProperty(MetadataCleanerType.systemProperty(cleanerType)) - } - } - } -} - diff --git a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala index 945217203be7..0a3180da8798 100644 --- a/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala +++ b/core/src/main/scala/org/apache/spark/util/MutableURLClassLoader.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import java.net.{URLClassLoader, URL} +import java.net.{URL, URLClassLoader} import java.util.Enumeration import java.util.concurrent.ConcurrentHashMap diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala index 7578a3b1d85f..b68936f5c9f0 100644 --- a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -17,23 +17,21 @@ package org.apache.spark.util -import scala.concurrent.duration.FiniteDuration import scala.language.postfixOps -import org.apache.spark.{SparkEnv, SparkConf} +import org.apache.spark.SparkConf import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, RpcTimeout} -object RpcUtils { +private[spark] object RpcUtils { /** * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. */ def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { - val driverActorSystemName = SparkEnv.driverActorSystemName val driverHost: String = conf.get("spark.driver.host", "localhost") val driverPort: Int = conf.getInt("spark.driver.port", 7077) Utils.checkHost(driverHost, "Expected hostname") - rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) + rpcEnv.setupEndpointRef(RpcAddress(driverHost, driverPort), name) } /** Returns the configured number of times to retry connecting */ @@ -47,22 +45,12 @@ object RpcUtils { } /** Returns the default Spark timeout to use for RPC ask operations. */ - private[spark] def askRpcTimeout(conf: SparkConf): RpcTimeout = { + def askRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq("spark.rpc.askTimeout", "spark.network.timeout"), "120s") } - @deprecated("use askRpcTimeout instead, this method was not intended to be public", "1.5.0") - def askTimeout(conf: SparkConf): FiniteDuration = { - askRpcTimeout(conf).duration - } - /** Returns the default Spark timeout to use for RPC remote endpoint lookup. */ - private[spark] def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { + def lookupRpcTimeout(conf: SparkConf): RpcTimeout = { RpcTimeout(conf, Seq("spark.rpc.lookupTimeout", "spark.network.timeout"), "120s") } - - @deprecated("use lookupRpcTimeout instead, this method was not intended to be public", "1.5.0") - def lookupTimeout(conf: SparkConf): FiniteDuration = { - lookupRpcTimeout(conf).duration - } } diff --git a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala index 1a0f3b477ba3..38523be791ce 100644 --- a/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala +++ b/core/src/main/scala/org/apache/spark/util/ShutdownHookManager.scala @@ -20,10 +20,10 @@ package org.apache.spark.util import java.io.File import java.util.PriorityQueue -import scala.util.{Failure, Success, Try} -import tachyon.client.TachyonFile +import scala.util.Try import org.apache.hadoop.fs.FileSystem + import org.apache.spark.Logging /** @@ -52,7 +52,6 @@ private[spark] object ShutdownHookManager extends Logging { } private val shutdownDeletePaths = new scala.collection.mutable.HashSet[String]() - private val shutdownDeleteTachyonPaths = new scala.collection.mutable.HashSet[String]() // Add a shutdown hook to delete the temp dirs when the JVM exits addShutdownHook(TEMP_DIR_SHUTDOWN_PRIORITY) { () => @@ -77,14 +76,6 @@ private[spark] object ShutdownHookManager extends Logging { } } - // Register the tachyon path to be deleted via shutdown hook - def registerShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths += absolutePath - } - } - // Remove the path to be deleted via shutdown hook def removeShutdownDeleteDir(file: File) { val absolutePath = file.getAbsolutePath() @@ -93,14 +84,6 @@ private[spark] object ShutdownHookManager extends Logging { } } - // Remove the tachyon path to be deleted via shutdown hook - def removeShutdownDeleteDir(tachyonfile: TachyonFile) { - val absolutePath = tachyonfile.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.remove(absolutePath) - } - } - // Is the path already registered to be deleted via a shutdown hook ? def hasShutdownDeleteDir(file: File): Boolean = { val absolutePath = file.getAbsolutePath() @@ -109,14 +92,6 @@ private[spark] object ShutdownHookManager extends Logging { } } - // Is the path already registered to be deleted via a shutdown hook ? - def hasShutdownDeleteTachyonDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.contains(absolutePath) - } - } - // Note: if file is child of some registered path, while not equal to it, then return true; // else false. This is to ensure that two shutdown hooks do not try to delete each others // paths - resulting in IOException and incomplete cleanup. @@ -133,22 +108,6 @@ private[spark] object ShutdownHookManager extends Logging { retval } - // Note: if file is child of some registered path, while not equal to it, then return true; - // else false. This is to ensure that two shutdown hooks do not try to delete each others - // paths - resulting in Exception and incomplete cleanup. - def hasRootAsShutdownDeleteDir(file: TachyonFile): Boolean = { - val absolutePath = file.getPath() - val retval = shutdownDeleteTachyonPaths.synchronized { - shutdownDeleteTachyonPaths.exists { path => - !absolutePath.equals(path) && absolutePath.startsWith(path) - } - } - if (retval) { - logInfo("path = " + file + ", already present as root for deletion.") - } - retval - } - /** * Detect whether this thread might be executing a shutdown hook. Will always return true if * the current thread is a running a shutdown hook but may spuriously return true otherwise (e.g. @@ -219,21 +178,8 @@ private [util] class SparkShutdownHookManager { val hookTask = new Runnable() { override def run(): Unit = runAll() } - Try(Utils.classForName("org.apache.hadoop.util.ShutdownHookManager")) match { - case Success(shmClass) => - val fsPriority = classOf[FileSystem] - .getField("SHUTDOWN_HOOK_PRIORITY") - .get(null) // static field, the value is not used - .asInstanceOf[Int] - val shm = shmClass.getMethod("get").invoke(null) - shm.getClass().getMethod("addShutdownHook", classOf[Runnable], classOf[Int]) - .invoke(shm, hookTask, Integer.valueOf(fsPriority + 30)) - - case Failure(_) => - // scalastyle:off runtimeaddshutdownhook - Runtime.getRuntime.addShutdownHook(new Thread(hookTask, "Spark Shutdown Hook")); - // scalastyle:on runtimeaddshutdownhook - } + org.apache.hadoop.util.ShutdownHookManager.get().addShutdownHook( + hookTask, FileSystem.SHUTDOWN_HOOK_PRIORITY + 30) } def runAll(): Unit = { diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 09864e3f8392..52587d218894 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -17,8 +17,6 @@ package org.apache.spark.util -import com.google.common.collect.MapMaker - import java.lang.management.ManagementFactory import java.lang.reflect.{Field, Modifier} import java.util.{IdentityHashMap, Random} @@ -27,6 +25,8 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable.ArrayBuffer import scala.runtime.ScalaRunTime +import com.google.common.collect.MapMaker + import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.OpenHashSet diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index d7e5143c3095..173302504106 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -17,8 +17,8 @@ package org.apache.spark.util -import java.util.Set import java.util.Map.Entry +import java.util.Set import java.util.concurrent.ConcurrentHashMap import scala.collection.JavaConverters._ diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala deleted file mode 100644 index 65efeb1f4c19..000000000000 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashSet.scala +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.JavaConverters._ -import scala.collection.mutable.Set - -private[spark] class TimeStampedHashSet[A] extends Set[A] { - val internalMap = new ConcurrentHashMap[A, Long]() - - def contains(key: A): Boolean = { - internalMap.contains(key) - } - - def iterator: Iterator[A] = { - val jIterator = internalMap.entrySet().iterator() - jIterator.asScala.map(_.getKey) - } - - override def + (elem: A): Set[A] = { - val newSet = new TimeStampedHashSet[A] - newSet ++= this - newSet += elem - newSet - } - - override def - (elem: A): Set[A] = { - val newSet = new TimeStampedHashSet[A] - newSet ++= this - newSet -= elem - newSet - } - - override def += (key: A): this.type = { - internalMap.put(key, currentTime) - this - } - - override def -= (key: A): this.type = { - internalMap.remove(key) - this - } - - override def empty: Set[A] = new TimeStampedHashSet[A]() - - override def size(): Int = internalMap.size() - - override def foreach[U](f: (A) => U): Unit = { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - f(iterator.next.getKey) - } - } - - /** - * Removes old values that have timestamp earlier than `threshTime` - */ - def clearOldValues(threshTime: Long) { - val iterator = internalMap.entrySet().iterator() - while(iterator.hasNext) { - val entry = iterator.next() - if (entry.getValue < threshTime) { - iterator.remove() - } - } - } - - private def currentTime: Long = System.currentTimeMillis() -} diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala deleted file mode 100644 index 310c0c109416..000000000000 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import java.lang.ref.WeakReference -import java.util.concurrent.atomic.AtomicInteger - -import scala.collection.mutable -import scala.language.implicitConversions - -import org.apache.spark.Logging - -/** - * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. - * - * If the value is garbage collected and the weak reference is null, get() will return a - * non-existent value. These entries are removed from the map periodically (every N inserts), as - * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are - * older than a particular threshold can be removed using the clearOldValues method. - * - * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it - * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap, - * so all operations on this HashMap are thread-safe. - * - * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. - */ -private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends mutable.Map[A, B]() with Logging { - - import TimeStampedWeakValueHashMap._ - - private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) - private val insertCount = new AtomicInteger(0) - - /** Return a map consisting only of entries whose values are still strongly reachable. */ - private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null } - - def get(key: A): Option[B] = internalMap.get(key) - - def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator - - override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { - val newMap = new TimeStampedWeakValueHashMap[A, B1] - val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]] - newMap.internalMap.putAll(oldMap.toMap) - newMap.internalMap += kv - newMap - } - - override def - (key: A): mutable.Map[A, B] = { - val newMap = new TimeStampedWeakValueHashMap[A, B] - newMap.internalMap.putAll(nonNullReferenceMap.toMap) - newMap.internalMap -= key - newMap - } - - override def += (kv: (A, B)): this.type = { - internalMap += kv - if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) { - clearNullValues() - } - this - } - - override def -= (key: A): this.type = { - internalMap -= key - this - } - - override def update(key: A, value: B): Unit = this += ((key, value)) - - override def apply(key: A): B = internalMap.apply(key) - - override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p) - - override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() - - override def size: Int = internalMap.size - - override def foreach[U](f: ((A, B)) => U): Unit = nonNullReferenceMap.foreach(f) - - def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) - - def toMap: Map[A, B] = iterator.toMap - - /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ - def clearOldValues(threshTime: Long): Unit = internalMap.clearOldValues(threshTime) - - /** Remove entries with values that are no longer strongly reachable. */ - def clearNullValues() { - val it = internalMap.getEntrySet.iterator - while (it.hasNext) { - val entry = it.next() - if (entry.getValue.value.get == null) { - logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.") - it.remove() - } - } - } - - // For testing - - def getTimestamp(key: A): Option[Long] = { - internalMap.getTimeStampedValue(key).map(_.timestamp) - } - - def getReference(key: A): Option[WeakReference[B]] = { - internalMap.getTimeStampedValue(key).map(_.value) - } -} - -/** - * Helper methods for converting to and from WeakReferences. - */ -private object TimeStampedWeakValueHashMap { - - // Number of inserts after which entries with null references are removed - val CLEAR_NULL_VALUES_INTERVAL = 100 - - /* Implicit conversion methods to WeakReferences. */ - - implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) - - implicit def toWeakReferenceTuple[K, V](kv: (K, V)): (K, WeakReference[V]) = { - kv match { case (k, v) => (k, toWeakReference(v)) } - } - - implicit def toWeakReferenceFunction[K, V, R](p: ((K, V)) => R): ((K, WeakReference[V])) => R = { - (kv: (K, WeakReference[V])) => p(kv) - } - - /* Implicit conversion methods from WeakReferences. */ - - implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get - - implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { - v match { - case Some(ref) => Option(fromWeakReference(ref)) - case None => None - } - } - - implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { - kv match { case (k, v) => (k, fromWeakReference(v)) } - } - - implicit def fromWeakReferenceIterator[K, V]( - it: Iterator[(K, WeakReference[V])]): Iterator[(K, V)] = { - it.map(fromWeakReferenceTuple) - } - - implicit def fromWeakReferenceMap[K, V]( - map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { - mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) - } -} diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 9dbe66e7eefb..9ecbffbf715c 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,8 +22,8 @@ import java.lang.management.ManagementFactory import java.net._ import java.nio.ByteBuffer import java.nio.channels.Channels -import java.util.concurrent._ import java.util.{Locale, Properties, Random, UUID} +import java.util.concurrent._ import javax.net.ssl.HttpsURLConnection import scala.collection.JavaConverters._ @@ -43,8 +43,7 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.log4j.PropertyConfigurator import org.eclipse.jetty.util.MultiException import org.json4s._ -import tachyon.TachyonURI -import tachyon.client.{TachyonFS, TachyonFile} +import org.slf4j.Logger import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil @@ -331,6 +330,30 @@ private[spark] object Utils extends Logging { } /** + * A file name may contain some invalid URI characters, such as " ". This method will convert the + * file name to a raw path accepted by `java.net.URI(String)`. + * + * Note: the file name must not contain "/" or "\" + */ + def encodeFileNameToURIRawPath(fileName: String): String = { + require(!fileName.contains("/") && !fileName.contains("\\")) + // `file` and `localhost` are not used. Just to prevent URI from parsing `fileName` as + // scheme or host. The prefix "/" is required because URI doesn't accept a relative path. + // We should remove it after we get the raw path. + new URI("file", null, "localhost", -1, "/" + fileName, null, null).getRawPath.substring(1) + } + + /** + * Get the file name from uri's raw path and decode it. If the raw path of uri ends with "/", + * return the name before the last "/". + */ + def decodeFileNameInURI(uri: URI): String = { + val rawPath = uri.getRawPath + val rawFileName = rawPath.split("/").last + new URI("file:///" + rawFileName).getPath.substring(1) + } + + /** * Download a file or directory to target directory. Supports fetching the file in a variety of * ways, including HTTP, Hadoop-compatible filesystems, and files on a standard filesystem, based * on the URL parameter. Fetching directories is only supported from Hadoop-compatible @@ -351,7 +374,7 @@ private[spark] object Utils extends Logging { hadoopConf: Configuration, timestamp: Long, useCache: Boolean) { - val fileName = url.split("/").last + val fileName = decodeFileNameInURI(new URI(url)) val targetFile = new File(targetDir, fileName) val fetchCacheEnabled = conf.getBoolean("spark.files.useFetchCache", defaultValue = true) if (useCache && fetchCacheEnabled) { @@ -639,9 +662,7 @@ private[spark] object Utils extends Logging { private[spark] def isRunningInYarnContainer(conf: SparkConf): Boolean = { // These environment variables are set by YARN. - // For Hadoop 0.23.X, we check for YARN_LOCAL_DIRS (we use this below in getYarnLocalDirs()) - // For Hadoop 2.X, we check for CONTAINER_ID. - conf.getenv("CONTAINER_ID") != null || conf.getenv("YARN_LOCAL_DIRS") != null + conf.getenv("CONTAINER_ID") != null } /** @@ -717,17 +738,12 @@ private[spark] object Utils extends Logging { logError(s"Failed to create local root dir in $root. Ignoring this directory.") None } - }.toArray + } } /** Get the Yarn approved local directories. */ private def getYarnLocalDirs(conf: SparkConf): String = { - // Hadoop 0.23 and 2.x have different Environment variable names for the - // local dirs, so lets check both. We assume one of the 2 is set. - // LOCAL_DIRS => 2.X, YARN_LOCAL_DIRS => 0.23.X - val localDirs = Option(conf.getenv("YARN_LOCAL_DIRS")) - .getOrElse(Option(conf.getenv("LOCAL_DIRS")) - .getOrElse("")) + val localDirs = Option(conf.getenv("LOCAL_DIRS")).getOrElse("") if (localDirs.isEmpty) { throw new Exception("Yarn Local dirs can't be empty") @@ -921,15 +937,6 @@ private[spark] object Utils extends Logging { } } - /** - * Delete a file or directory and its contents recursively. - */ - def deleteRecursively(dir: TachyonFile, client: TachyonFS) { - if (!client.delete(new TachyonURI(dir.getPath()), true)) { - throw new IOException("Failed to delete the tachyon dir: " + dir) - } - } - /** * Check to see if file is a symbolic link. */ @@ -1684,6 +1691,30 @@ private[spark] object Utils extends Logging { new File(path).getName } + /** + * Terminates a process waiting for at most the specified duration. Returns whether + * the process terminated. + */ + def terminateProcess(process: Process, timeoutMs: Long): Option[Int] = { + try { + // Java8 added a new API which will more forcibly kill the process. Use that if available. + val destroyMethod = process.getClass().getMethod("destroyForcibly"); + destroyMethod.setAccessible(true) + destroyMethod.invoke(process) + } catch { + case NonFatal(e) => + if (!e.isInstanceOf[NoSuchMethodException]) { + logWarning("Exception when attempting to kill process", e) + } + process.destroy() + } + if (waitForProcess(process, timeoutMs)) { + Option(process.exitValue()) + } else { + None + } + } + /** * Wait for a process to terminate for at most the specified duration. * Return whether the process actually terminated after the given timeout. @@ -2197,6 +2228,23 @@ private[spark] object Utils extends Logging { def tempFileWith(path: File): File = { new File(path.getAbsolutePath + "." + UUID.randomUUID()) } + + /** + * Returns the name of this JVM process. This is OS dependent but typically (OSX, Linux, Windows), + * this is formatted as PID@hostname. + */ + def getProcessName(): String = { + ManagementFactory.getRuntimeMXBean().getName() + } + + /** + * Utility function that should be called early in `main()` for daemons to set up some common + * diagnostic state. + */ + def initDaemon(log: Logger): Unit = { + log.info(s"Started daemon with process name: ${Utils.getProcessName()}") + SignalLogger.register(log) + } } /** diff --git a/core/src/main/scala/org/apache/spark/util/Vector.scala b/core/src/main/scala/org/apache/spark/util/Vector.scala deleted file mode 100644 index 6b3fa8491904..000000000000 --- a/core/src/main/scala/org/apache/spark/util/Vector.scala +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.language.implicitConversions -import scala.util.Random - -import org.apache.spark.util.random.XORShiftRandom - -@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") -class Vector(val elements: Array[Double]) extends Serializable { - def length: Int = elements.length - - def apply(index: Int): Double = elements(index) - - def + (other: Vector): Vector = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - Vector(length, i => this(i) + other(i)) - } - - def add(other: Vector): Vector = this + other - - def - (other: Vector): Vector = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - Vector(length, i => this(i) - other(i)) - } - - def subtract(other: Vector): Vector = this - other - - def dot(other: Vector): Double = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - var ans = 0.0 - var i = 0 - while (i < length) { - ans += this(i) * other(i) - i += 1 - } - ans - } - - /** - * return (this + plus) dot other, but without creating any intermediate storage - * @param plus - * @param other - * @return - */ - def plusDot(plus: Vector, other: Vector): Double = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - if (length != plus.length) { - throw new IllegalArgumentException("Vectors of different length") - } - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) + plus(i)) * other(i) - i += 1 - } - ans - } - - def += (other: Vector): Vector = { - if (length != other.length) { - throw new IllegalArgumentException("Vectors of different length") - } - var i = 0 - while (i < length) { - elements(i) += other(i) - i += 1 - } - this - } - - def addInPlace(other: Vector): Vector = this +=other - - def * (scale: Double): Vector = Vector(length, i => this(i) * scale) - - def multiply (d: Double): Vector = this * d - - def / (d: Double): Vector = this * (1 / d) - - def divide (d: Double): Vector = this / d - - def unary_- : Vector = this * -1 - - def sum: Double = elements.reduceLeft(_ + _) - - def squaredDist(other: Vector): Double = { - var ans = 0.0 - var i = 0 - while (i < length) { - ans += (this(i) - other(i)) * (this(i) - other(i)) - i += 1 - } - ans - } - - def dist(other: Vector): Double = math.sqrt(squaredDist(other)) - - override def toString: String = elements.mkString("(", ", ", ")") -} - -@deprecated("Use Vectors.dense from Spark's mllib.linalg package instead.", "1.0.0") -object Vector { - def apply(elements: Array[Double]): Vector = new Vector(elements) - - def apply(elements: Double*): Vector = new Vector(elements.toArray) - - def apply(length: Int, initializer: Int => Double): Vector = { - val elements: Array[Double] = Array.tabulate(length)(initializer) - new Vector(elements) - } - - def zeros(length: Int): Vector = new Vector(new Array[Double](length)) - - def ones(length: Int): Vector = Vector(length, _ => 1) - - /** - * Creates this [[org.apache.spark.util.Vector]] of given length containing random numbers - * between 0.0 and 1.0. Optional scala.util.Random number generator can be provided. - */ - def random(length: Int, random: Random = new XORShiftRandom()): Vector = - Vector(length, _ => random.nextDouble()) - - class Multiplier(num: Double) { - def * (vec: Vector): Vector = vec * num - } - - implicit def doubleToMultiplier(num: Double): Multiplier = new Multiplier(num) - - implicit object VectorAccumParam extends org.apache.spark.AccumulatorParam[Vector] { - def addInPlace(t1: Vector, t2: Vector): Vector = t1 + t2 - - def zero(initialValue: Vector): Vector = Vector.zeros(initialValue.length) - } - -} diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index f6d81ee5bf05..4a44481cf4e1 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -28,12 +28,12 @@ import com.google.common.io.ByteStreams import org.apache.spark.{Logging, SparkEnv, TaskContext} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer.{DeserializationStream, Serializer} import org.apache.spark.storage.{BlockId, BlockManager} import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator -import org.apache.spark.executor.ShuffleWriteMetrics /** * :: DeveloperApi :: diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 44b1d90667e6..63ba954a7fa7 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -20,15 +20,15 @@ package org.apache.spark.util.collection import java.io._ import java.util.Comparator -import scala.collection.mutable.ArrayBuffer import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer import com.google.common.io.ByteStreams import org.apache.spark._ +import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.serializer._ -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.storage.{BlockId, DiskBlockObjectWriter} /** diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 60bf4dd7469f..0f6a425e3db9 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -18,6 +18,7 @@ package org.apache.spark.util.collection import scala.reflect._ + import com.google.common.hash.Hashing import org.apache.spark.annotation.Private diff --git a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala index 3a48af82b1da..e1592184ca6d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/Spillable.scala @@ -17,8 +17,8 @@ package org.apache.spark.util.collection -import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} /** * Spills contents of an in-memory collection to disk when the memory threshold diff --git a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala index 14b6ba4af489..58c8560a3d04 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/FileAppender.scala @@ -29,7 +29,6 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi extends Logging { @volatile private var outputStream: FileOutputStream = null @volatile private var markedForStop = false // has the appender been asked to stopped - @volatile private var stopped = false // has the appender stopped // Thread that reads the input stream and writes to file private val writingThread = new Thread("File appending thread for " + file) { @@ -47,11 +46,7 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi * or because of any error in appending */ def awaitTermination() { - synchronized { - if (!stopped) { - wait() - } - } + writingThread.join() } /** Stop the appender */ @@ -77,10 +72,6 @@ private[spark] class FileAppender(inputStream: InputStream, file: File, bufferSi logError(s"Error writing stream to file $file", e) } finally { closeFile() - synchronized { - stopped = true - notifyAll() - } } } diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala index 1e8476c4a047..050ece12f172 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingFileAppender.scala @@ -20,8 +20,8 @@ package org.apache.spark.util.logging import java.io.{File, FileFilter, InputStream} import com.google.common.io.Files + import org.apache.spark.SparkConf -import RollingFileAppender._ /** * Continuously appends data from input stream into the given file, and rolls @@ -39,9 +39,11 @@ private[spark] class RollingFileAppender( activeFile: File, val rollingPolicy: RollingPolicy, conf: SparkConf, - bufferSize: Int = DEFAULT_BUFFER_SIZE + bufferSize: Int = RollingFileAppender.DEFAULT_BUFFER_SIZE ) extends FileAppender(inputStream, activeFile, bufferSize) { + import RollingFileAppender._ + private val maxRetainedFiles = conf.getInt(RETAINED_FILES_PROPERTY, -1) /** Stop the appender */ diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index c156b03cdb7c..1314217023d1 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -19,8 +19,8 @@ package org.apache.spark.util.random import java.util.Random -import scala.reflect.ClassTag import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag import org.apache.commons.math3.distribution.PoissonDistribution diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index 11f1248c24d3..44d5cac7c2de 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -21,7 +21,17 @@ import java.nio.channels.FileChannel; import java.nio.ByteBuffer; import java.net.URI; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; import java.util.concurrent.*; import scala.Tuple2; @@ -35,7 +45,6 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.base.Throwables; -import com.google.common.base.Optional; import com.google.common.base.Charsets; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; @@ -49,7 +58,12 @@ import org.junit.Before; import org.junit.Test; -import org.apache.spark.api.java.*; +import org.apache.spark.api.java.JavaDoubleRDD; +import org.apache.spark.api.java.JavaFutureAction; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.function.*; import org.apache.spark.input.PortableDataStream; import org.apache.spark.partial.BoundedDouble; @@ -687,13 +701,6 @@ public Boolean call(Integer i) { }).isEmpty()); } - @Test - public void toArray() { - JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3)); - List list = rdd.toArray(); - Assert.assertEquals(Arrays.asList(1, 2, 3), list); - } - @Test public void cartesian() { JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0)); @@ -1246,7 +1253,7 @@ public Tuple2 call(Tuple2 pair) { JavaPairRDD output = sc.newAPIHadoopFile(outputDir, org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat.class, - IntWritable.class, Text.class, new Job().getConfiguration()); + IntWritable.class, Text.class, Job.getInstance().getConfiguration()); Assert.assertEquals(pairs.toString(), output.map(new Function, String>() { @Override public String call(Tuple2 x) { @@ -1587,11 +1594,11 @@ public void countApproxDistinctByKey() { } double relativeSD = 0.001; JavaPairRDD pairRdd = sc.parallelizePairs(arrayData); - List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); - for (Tuple2 resItem : res) { - double count = (double)resItem._1(); - Long resCount = (Long)resItem._2(); - Double error = Math.abs((resCount - count) / count); + List> res = pairRdd.countApproxDistinctByKey(relativeSD, 8).collect(); + for (Tuple2 resItem : res) { + double count = resItem._1(); + long resCount = resItem._2(); + double error = Math.abs((resCount - count) / count); Assert.assertTrue(error < 0.1); } @@ -1640,12 +1647,12 @@ public Tuple2 call(Integer i) { fractions.put(0, 0.5); fractions.put(1, 1.0); JavaPairRDD wr = rdd2.sampleByKey(true, fractions, 1L); - Map wrCounts = (Map) (Object) wr.countByKey(); + Map wrCounts = wr.countByKey(); Assert.assertEquals(2, wrCounts.size()); Assert.assertTrue(wrCounts.get(0) > 0); Assert.assertTrue(wrCounts.get(1) > 0); JavaPairRDD wor = rdd2.sampleByKey(false, fractions, 1L); - Map worCounts = (Map) (Object) wor.countByKey(); + Map worCounts = wor.countByKey(); Assert.assertEquals(2, worCounts.size()); Assert.assertTrue(worCounts.get(0) > 0); Assert.assertTrue(worCounts.get(1) > 0); @@ -1666,12 +1673,12 @@ public Tuple2 call(Integer i) { fractions.put(0, 0.5); fractions.put(1, 1.0); JavaPairRDD wrExact = rdd2.sampleByKeyExact(true, fractions, 1L); - Map wrExactCounts = (Map) (Object) wrExact.countByKey(); + Map wrExactCounts = wrExact.countByKey(); Assert.assertEquals(2, wrExactCounts.size()); Assert.assertTrue(wrExactCounts.get(0) == 2); Assert.assertTrue(wrExactCounts.get(1) == 4); JavaPairRDD worExact = rdd2.sampleByKeyExact(false, fractions, 1L); - Map worExactCounts = (Map) (Object) worExact.countByKey(); + Map worExactCounts = worExact.countByKey(); Assert.assertEquals(2, worExactCounts.size()); Assert.assertTrue(worExactCounts.get(0) == 2); Assert.assertTrue(worExactCounts.get(1) == 4); @@ -1792,32 +1799,6 @@ public void testAsyncActionErrorWrapping() throws Exception { Assert.assertTrue(future.isDone()); } - - /** - * Test for SPARK-3647. This test needs to use the maven-built assembly to trigger the issue, - * since that's the only artifact where Guava classes have been relocated. - */ - @Test - public void testGuavaOptional() { - // Stop the context created in setUp() and start a local-cluster one, to force usage of the - // assembly. - sc.stop(); - JavaSparkContext localCluster = new JavaSparkContext("local-cluster[1,1,1024]", "JavaAPISuite"); - try { - JavaRDD rdd1 = localCluster.parallelize(Arrays.asList(1, 2, null), 3); - JavaRDD> rdd2 = rdd1.map( - new Function>() { - @Override - public Optional call(Integer i) { - return Optional.fromNullable(i); - } - }); - rdd2.collect(); - } finally { - localCluster.stop(); - } - } - static class Class1 {} static class Class2 {} diff --git a/core/src/test/java/org/apache/spark/api/java/OptionalSuite.java b/core/src/test/java/org/apache/spark/api/java/OptionalSuite.java new file mode 100644 index 000000000000..4b97c18198c1 --- /dev/null +++ b/core/src/test/java/org/apache/spark/api/java/OptionalSuite.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java; + +import org.junit.Assert; +import org.junit.Test; + +/** + * Tests {@link Optional}. + */ +public class OptionalSuite { + + @Test + public void testEmpty() { + Assert.assertFalse(Optional.empty().isPresent()); + Assert.assertNull(Optional.empty().orNull()); + Assert.assertEquals("foo", Optional.empty().or("foo")); + Assert.assertEquals("foo", Optional.empty().orElse("foo")); + } + + @Test(expected = NullPointerException.class) + public void testEmptyGet() { + Optional.empty().get(); + } + + @Test + public void testAbsent() { + Assert.assertFalse(Optional.absent().isPresent()); + Assert.assertNull(Optional.absent().orNull()); + Assert.assertEquals("foo", Optional.absent().or("foo")); + Assert.assertEquals("foo", Optional.absent().orElse("foo")); + } + + @Test(expected = NullPointerException.class) + public void testAbsentGet() { + Optional.absent().get(); + } + + @Test + public void testOf() { + Assert.assertTrue(Optional.of(1).isPresent()); + Assert.assertNotNull(Optional.of(1).orNull()); + Assert.assertEquals(Integer.valueOf(1), Optional.of(1).get()); + Assert.assertEquals(Integer.valueOf(1), Optional.of(1).or(2)); + Assert.assertEquals(Integer.valueOf(1), Optional.of(1).orElse(2)); + } + + @Test(expected = NullPointerException.class) + public void testOfWithNull() { + Optional.of(null); + } + + @Test + public void testOfNullable() { + Assert.assertTrue(Optional.ofNullable(1).isPresent()); + Assert.assertNotNull(Optional.ofNullable(1).orNull()); + Assert.assertEquals(Integer.valueOf(1), Optional.ofNullable(1).get()); + Assert.assertEquals(Integer.valueOf(1), Optional.ofNullable(1).or(2)); + Assert.assertEquals(Integer.valueOf(1), Optional.ofNullable(1).orElse(2)); + Assert.assertFalse(Optional.ofNullable(null).isPresent()); + Assert.assertNull(Optional.ofNullable(null).orNull()); + Assert.assertEquals(Integer.valueOf(2), Optional.ofNullable(null).or(2)); + Assert.assertEquals(Integer.valueOf(2), Optional.ofNullable(null).orElse(2)); + } + + @Test + public void testFromNullable() { + Assert.assertTrue(Optional.fromNullable(1).isPresent()); + Assert.assertNotNull(Optional.fromNullable(1).orNull()); + Assert.assertEquals(Integer.valueOf(1), Optional.fromNullable(1).get()); + Assert.assertEquals(Integer.valueOf(1), Optional.fromNullable(1).or(2)); + Assert.assertEquals(Integer.valueOf(1), Optional.fromNullable(1).orElse(2)); + Assert.assertFalse(Optional.fromNullable(null).isPresent()); + Assert.assertNull(Optional.fromNullable(null).orNull()); + Assert.assertEquals(Integer.valueOf(2), Optional.fromNullable(null).or(2)); + Assert.assertEquals(Integer.valueOf(2), Optional.fromNullable(null).orElse(2)); + } + +} diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala index 0b19861fc41e..f200ff36c7dd 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala +++ b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala @@ -42,6 +42,7 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa private val conf: SparkConf = new SparkConf(loadDefaults = false) override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() MockitoAnnotations.initMocks(this) @@ -55,7 +56,11 @@ class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEa } override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } } test("commit shuffle files multiple times") { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index e0ee281e98b7..32f5a1a7e6c5 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -369,6 +369,37 @@ public void forcedSpillingWithNotReadIterator() throws Exception { assertSpillFilesWereCleanedUp(); } + @Test + public void forcedSpillingWithoutComparator() throws Exception { + final UnsafeExternalSorter sorter = UnsafeExternalSorter.create( + taskMemoryManager, + blockManager, + taskContext, + null, + null, + /* initialSize */ 1024, + pageSizeBytes); + long[] record = new long[100]; + int recordSize = record.length * 8; + int n = (int) pageSizeBytes / recordSize * 3; + int batch = n / 4; + for (int i = 0; i < n; i++) { + record[0] = (long) i; + sorter.insertRecord(record, Platform.LONG_ARRAY_OFFSET, recordSize, 0); + if (i % batch == batch - 1) { + sorter.spill(); + } + } + UnsafeSorterIterator iter = sorter.getIterator(); + for (int i = 0; i < n; i++) { + iter.hasNext(); + iter.loadNext(); + assert(Platform.getLong(iter.getBaseObject(), iter.getBaseOffset()) == i); + } + sorter.cleanupResources(); + assertSpillFilesWereCleanedUp(); + } + @Test public void testPeakMemoryUsed() throws Exception { final long recordLengthBytes = 8; diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 553d46285ac0..390764ba242f 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -256,8 +256,11 @@ class CheckpointSuite extends SparkFunSuite with RDDCheckpointTester with LocalS } override def afterEach(): Unit = { - super.afterEach() - Utils.deleteRecursively(checkpointDir) + try { + Utils.deleteRecursively(checkpointDir) + } finally { + super.afterEach() + } } override def sparkContext: SparkContext = sc diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 0c14bef7befd..7b0238091730 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -24,18 +24,14 @@ import scala.language.existentials import scala.util.Random import org.scalatest.BeforeAndAfter -import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.PatienceConfiguration import org.scalatest.time.SpanSugar._ -import org.apache.spark.rdd.{ReliableRDDCheckpointData, RDD} -import org.apache.spark.storage._ +import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.BroadcastBlockId -import org.apache.spark.storage.RDDBlockId -import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.storage.ShuffleIndexBlockId +import org.apache.spark.storage._ /** * An abstract base class for context cleaner tests, which sets up a context with a config diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index fedfbd547b91..4e678fbac6a3 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark import scala.collection.mutable import org.scalatest.{BeforeAndAfter, PrivateMethodTester} + import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.ExecutorInfo diff --git a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala index 1c775bcb3d9c..eb3fb99747d1 100644 --- a/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExternalShuffleServiceSuite.scala @@ -35,6 +35,7 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { var rpcHandler: ExternalShuffleBlockHandler = _ override def beforeAll() { + super.beforeAll() val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores = 2) rpcHandler = new ExternalShuffleBlockHandler(transportConf, null) val transportContext = new TransportContext(transportConf, rpcHandler) @@ -46,7 +47,11 @@ class ExternalShuffleServiceSuite extends ShuffleSuite with BeforeAndAfterAll { } override def afterAll() { - server.close() + try { + server.close() + } finally { + super.afterAll() + } } // This test ensures that the external shuffle service is actually in use for the other tests. diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala index 203dab934ca1..3def8b0b1850 100644 --- a/core/src/test/scala/org/apache/spark/FailureSuite.scala +++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark -import org.apache.spark.util.NonSerializable - import java.io.{IOException, NotSerializableException, ObjectInputStream} +import org.apache.spark.util.NonSerializable + // Common state shared by FailureSuite-launched tasks. We use a global object // for this because any local variables used in the task closures will rightfully // be copied for each task, so there's no other way for them to share state. diff --git a/core/src/test/scala/org/apache/spark/FileServerSuite.scala b/core/src/test/scala/org/apache/spark/FileServerSuite.scala index 1255e71af6c0..bc7059b77fec 100644 --- a/core/src/test/scala/org/apache/spark/FileServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileServerSuite.scala @@ -27,10 +27,10 @@ import org.apache.commons.lang3.RandomUtils import org.apache.spark.util.Utils -import SSLSampleConfigs._ - class FileServerSuite extends SparkFunSuite with LocalSparkContext { + import SSLSampleConfigs._ + @transient var tmpDir: File = _ @transient var tmpFile: File = _ @transient var tmpJarUrl: String = _ @@ -75,8 +75,11 @@ class FileServerSuite extends SparkFunSuite with LocalSparkContext { } override def afterAll() { - super.afterAll() - Utils.deleteRecursively(tmpDir) + try { + Utils.deleteRecursively(tmpDir) + } finally { + super.afterAll() + } } test("Distributing files locally") { diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index fdb00aafc4a4..993834f8d7d4 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -19,20 +19,18 @@ package org.apache.spark import java.io.{File, FileWriter} -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.input.PortableDataStream -import org.apache.spark.storage.StorageLevel - import scala.io.Source import org.apache.hadoop.io._ import org.apache.hadoop.io.compress.DefaultCodec -import org.apache.hadoop.mapred.{JobConf, FileAlreadyExistsException, FileSplit, TextInputFormat, TextOutputFormat} +import org.apache.hadoop.mapred.{FileAlreadyExistsException, FileSplit, JobConf, TextInputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.rdd.{NewHadoopRDD, HadoopRDD} +import org.apache.spark.input.PortableDataStream +import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils class FileSuite extends SparkFunSuite with LocalSparkContext { @@ -44,8 +42,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } override def afterEach() { - super.afterEach() - Utils.deleteRecursively(tempDir) + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } } test("text files") { @@ -503,11 +504,11 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val randomRDD = sc.parallelize( Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) - val job = new Job(sc.hadoopConfiguration) + val job = Job.getInstance(sc.hadoopConfiguration) job.setOutputKeyClass(classOf[String]) job.setOutputValueClass(classOf[String]) job.setOutputFormatClass(classOf[NewTextOutputFormat[String, String]]) - val jobConfig = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val jobConfig = job.getConfiguration jobConfig.set("mapred.output.dir", tempDir.getPath + "/outputDataset_new") randomRDD.saveAsNewAPIHadoopDataset(jobConfig) assert(new File(tempDir.getPath + "/outputDataset_new/part-r-00000").exists() === true) diff --git a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala index 19180e88ebe0..10794235ed39 100644 --- a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala @@ -24,6 +24,7 @@ class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with hash-based shuffle. override def beforeAll() { + super.beforeAll() conf.set("spark.shuffle.manager", "hash") } } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 3cd80c0f7d17..18e53508406d 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -25,13 +25,13 @@ import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps -import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} -import org.mockito.Mockito.{mock, spy, verify, when} import org.mockito.Matchers import org.mockito.Matchers._ +import org.mockito.Mockito.{mock, spy, verify, when} +import org.scalatest.{BeforeAndAfterEach, PrivateMethodTester} import org.apache.spark.executor.TaskMetrics -import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEnv, RpcEndpointRef} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend @@ -66,6 +66,7 @@ class HeartbeatReceiverSuite * that uses a manual clock. */ override def beforeEach(): Unit = { + super.beforeEach() val conf = new SparkConf() .setMaster("local[2]") .setAppName("test") diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index 1168eb0b802f..e13a442463e8 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -38,8 +38,11 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft with LocalSparkContext { override def afterEach() { - super.afterEach() - resetSparkContext() + try { + resetSparkContext() + } finally { + super.afterEach() + } } test("local mode, FIFO scheduler") { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 8bf2e55defd0..e1a0bf7c933b 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -17,7 +17,7 @@ package org.apache.spark -import _root_.io.netty.util.internal.logging.{Slf4JLoggerFactory, InternalLoggerFactory} +import _root_.io.netty.util.internal.logging.{InternalLoggerFactory, Slf4JLoggerFactory} import org.scalatest.BeforeAndAfterAll import org.scalatest.BeforeAndAfterEach import org.scalatest.Suite @@ -28,13 +28,16 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self @transient var sc: SparkContext = _ override def beforeAll() { - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) super.beforeAll() + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) } override def afterEach() { - resetSparkContext() - super.afterEach() + try { + resetSparkContext() + } finally { + super.afterEach() + } } def resetSparkContext(): Unit = { diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index 7e70308bb360..3819c0a8f31d 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer -import org.mockito.Mockito._ import org.mockito.Matchers.{any, isA} +import org.mockito.Mockito._ -import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcCallContext, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} @@ -125,7 +125,7 @@ class MapOutputTrackerSuite extends SparkFunSuite { val slaveRpcEnv = createRpcEnv("spark-slave", hostname, 0, new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) + slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() diff --git a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala index 25b79bce6ab9..fa35819f55ac 100644 --- a/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/SSLOptionsSuite.scala @@ -21,9 +21,10 @@ import java.io.File import javax.net.ssl.SSLContext import com.google.common.io.Files -import org.apache.spark.util.Utils import org.scalatest.BeforeAndAfterAll +import org.apache.spark.util.Utils + class SSLOptionsSuite extends SparkFunSuite with BeforeAndAfterAll { test("test resolving property file as spark conf ") { diff --git a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala index 2d14249855c9..33270bec6247 100644 --- a/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala +++ b/core/src/test/scala/org/apache/spark/SSLSampleConfigs.scala @@ -41,7 +41,6 @@ object SSLSampleConfigs { def sparkSSLConfig(): SparkConf = { val conf = new SparkConf(loadDefaults = false) - conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", keyStorePath) conf.set("spark.ssl.keyStorePassword", "password") @@ -55,7 +54,6 @@ object SSLSampleConfigs { def sparkSSLConfigUntrusted(): SparkConf = { val conf = new SparkConf(loadDefaults = false) - conf.set("spark.rpc", "akka") conf.set("spark.ssl.enabled", "true") conf.set("spark.ssl.keyStore", untrustedKeyStorePath) conf.set("spark.ssl.keyStorePassword", "password") diff --git a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala index 26b95c06789f..e0226803bb1c 100644 --- a/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/SecurityManagerSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark import java.io.File -import org.apache.spark.util.{SparkConfWithEnv, Utils} +import org.apache.spark.util.{ResetSystemProperties, SparkConfWithEnv, Utils} -class SecurityManagerSuite extends SparkFunSuite { +class SecurityManagerSuite extends SparkFunSuite with ResetSystemProperties { test("set security with conf") { val conf = new SparkConf diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 3d2700b7e6be..858bc742e07c 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -30,13 +30,16 @@ trait SharedSparkContext extends BeforeAndAfterAll { self: Suite => var conf = new SparkConf(false) override def beforeAll() { - _sc = new SparkContext("local[4]", "test", conf) super.beforeAll() + _sc = new SparkContext("local[4]", "test", conf) } override def afterAll() { - LocalSparkContext.stop(_sc) - _sc = null - super.afterAll() + try { + LocalSparkContext.stop(_sc) + _sc = null + } finally { + super.afterAll() + } } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala index d78c99c2e1e0..73638d9b131e 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleNettySuite.scala @@ -24,6 +24,7 @@ class ShuffleNettySuite extends ShuffleSuite with BeforeAndAfterAll { // This test suite should run all tests in ShuffleSuite with Netty shuffle mode. override def beforeAll() { + super.beforeAll() conf.set("spark.shuffle.blockTransferService", "netty") } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 0de10ae48537..c45d81459e8e 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark -import java.util.concurrent.{Callable, Executors, ExecutorService, CyclicBarrier} +import java.util.concurrent.{Callable, CyclicBarrier, Executors, ExecutorService} import org.scalatest.Matchers import org.apache.spark.ShuffleSuite.NonJavaSerializableClass import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD} -import org.apache.spark.scheduler.{MyRDD, MapStatus, SparkListener, SparkListenerTaskEnd} +import org.apache.spark.scheduler.{MapStatus, MyRDD, SparkListener, SparkListenerTaskEnd} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.ShuffleWriter -import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId} +import org.apache.spark.storage.{ShuffleBlockId, ShuffleDataBlockId} import org.apache.spark.util.MutablePair abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala index 01694a6e6f74..9f0a1b4c25dd 100644 --- a/core/src/test/scala/org/apache/spark/Smuggle.scala +++ b/core/src/test/scala/org/apache/spark/Smuggle.scala @@ -21,6 +21,7 @@ import java.util.UUID import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable +import scala.language.implicitConversions /** * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala index b8ab227517cc..7a897c2b4698 100644 --- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala @@ -26,8 +26,8 @@ import org.apache.commons.io.filefilter.TrueFileFilter import org.scalatest.BeforeAndAfterAll import org.apache.spark.rdd.ShuffledRDD -import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.util.Utils class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { @@ -37,10 +37,12 @@ class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { private var tempDir: File = _ override def beforeAll() { + super.beforeAll() conf.set("spark.shuffle.manager", "sort") } override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() conf.set("spark.local.dir", tempDir.getAbsolutePath) } diff --git a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala index ff9a92cc0a42..2fe99e3f8194 100644 --- a/core/src/test/scala/org/apache/spark/SparkConfSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkConfSuite.scala @@ -17,17 +17,18 @@ package org.apache.spark -import java.util.concurrent.{TimeUnit, Executors} +import java.util.concurrent.{Executors, TimeUnit} import scala.collection.JavaConverters._ import scala.concurrent.duration._ import scala.language.postfixOps -import scala.util.{Try, Random} +import scala.util.{Random, Try} + +import com.esotericsoftware.kryo.Kryo import org.apache.spark.network.util.ByteUnit import org.apache.spark.serializer.{KryoRegistrator, KryoSerializer} -import org.apache.spark.util.{RpcUtils, ResetSystemProperties} -import com.esotericsoftware.kryo.Kryo +import org.apache.spark.util.{ResetSystemProperties, RpcUtils} class SparkConfSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { test("Test byteString conversion") { diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index 2bdbd70c638a..3706455c3fac 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark import org.scalatest.Assertions + import org.apache.spark.storage.StorageLevel class SparkContextInfoSuite extends SparkFunSuite with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala index d18e0782c039..52919c1ec0b1 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark import org.scalatest.PrivateMethodTester -import org.apache.spark.util.Utils import org.apache.spark.scheduler.{SchedulerBackend, TaskScheduler, TaskSchedulerImpl} import org.apache.spark.scheduler.cluster.{SimrSchedulerBackend, SparkDeploySchedulerBackend} import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} import org.apache.spark.scheduler.local.LocalBackend +import org.apache.spark.util.Utils class SparkContextSchedulerCreationSuite extends SparkFunSuite with LocalSparkContext with PrivateMethodTester with Logging { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala index d4f2ea87650a..556afd08bbfe 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextSuite.scala @@ -20,18 +20,18 @@ package org.apache.spark import java.io.File import java.util.concurrent.TimeUnit +import scala.concurrent.Await +import scala.concurrent.duration.Duration + import com.google.common.base.Charsets._ import com.google.common.io.Files - import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce.lib.input.{TextInputFormat => NewTextInputFormat} -import org.apache.spark.util.Utils - -import scala.concurrent.Await -import scala.concurrent.duration.Duration import org.scalatest.Matchers._ +import org.apache.spark.util.Utils + class SparkContextSuite extends SparkFunSuite with LocalSparkContext { test("Only one SparkContext may be active at a time") { @@ -274,6 +274,31 @@ class SparkContextSuite extends SparkFunSuite with LocalSparkContext { } } + test("Default path for file based RDDs is properly set (SPARK-12517)") { + sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) + + // Test filetextFile, wholeTextFiles, binaryFiles, hadoopFile and + // newAPIHadoopFile for setting the default path as the RDD name + val mockPath = "default/path/for/" + + var targetPath = mockPath + "textFile" + assert(sc.textFile(targetPath).name === targetPath) + + targetPath = mockPath + "wholeTextFiles" + assert(sc.wholeTextFiles(targetPath).name === targetPath) + + targetPath = mockPath + "binaryFiles" + assert(sc.binaryFiles(targetPath).name === targetPath) + + targetPath = mockPath + "hadoopFile" + assert(sc.hadoopFile(targetPath).name === targetPath) + + targetPath = mockPath + "newAPIHadoopFile" + assert(sc.newAPIHadoopFile(targetPath).name === targetPath) + + sc.stop() + } + test("calling multiple sc.stop() must not throw any exception") { noException should be thrownBy { sc = new SparkContext(new SparkConf().setAppName("test").setMaster("local")) diff --git a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala index 54c131cdae36..fc31b784c7ae 100644 --- a/core/src/test/scala/org/apache/spark/ThreadingSuite.scala +++ b/core/src/test/scala/org/apache/spark/ThreadingSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark -import java.util.concurrent.{TimeUnit, Semaphore} -import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{Semaphore, TimeUnit} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} import org.apache.spark.scheduler._ diff --git a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala index 135c56bf5bc9..b38a3667abee 100644 --- a/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/api/python/PythonBroadcastSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.api.python -import scala.io.Source +import java.io.{File, PrintWriter} -import java.io.{PrintWriter, File} +import scala.io.Source import org.scalatest.Matchers diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala index ba21075ce6be..88fdbbdaec90 100644 --- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala +++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala @@ -45,39 +45,8 @@ class DummyBroadcastClass(rdd: RDD[Int]) extends Serializable { class BroadcastSuite extends SparkFunSuite with LocalSparkContext { - private val httpConf = broadcastConf("HttpBroadcastFactory") - private val torrentConf = broadcastConf("TorrentBroadcastFactory") - - test("Using HttpBroadcast locally") { - sc = new SparkContext("local", "test", httpConf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === Set((1, 10), (2, 10))) - } - - test("Accessing HttpBroadcast variables from multiple threads") { - sc = new SparkContext("local[10]", "test", httpConf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to 10).map(x => (x, 10)).toSet) - } - - test("Accessing HttpBroadcast variables in a local cluster") { - val numSlaves = 4 - val conf = httpConf.clone - conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - conf.set("spark.broadcast.compress", "true") - sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) - val list = List[Int](1, 2, 3, 4) - val broadcast = sc.broadcast(list) - val results = sc.parallelize(1 to numSlaves).map(x => (x, broadcast.value.sum)) - assert(results.collect().toSet === (1 to numSlaves).map(x => (x, 10)).toSet) - } - test("Using TorrentBroadcast locally") { - sc = new SparkContext("local", "test", torrentConf) + sc = new SparkContext("local", "test") val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to 2).map(x => (x, broadcast.value.sum)) @@ -85,7 +54,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { } test("Accessing TorrentBroadcast variables from multiple threads") { - sc = new SparkContext("local[10]", "test", torrentConf) + sc = new SparkContext("local[10]", "test") val list = List[Int](1, 2, 3, 4) val broadcast = sc.broadcast(list) val results = sc.parallelize(1 to 10).map(x => (x, broadcast.value.sum)) @@ -94,7 +63,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Accessing TorrentBroadcast variables in a local cluster") { val numSlaves = 4 - val conf = torrentConf.clone + val conf = new SparkConf conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") conf.set("spark.broadcast.compress", "true") sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) @@ -124,31 +93,13 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { test("Test Lazy Broadcast variables with TorrentBroadcast") { val numSlaves = 2 - val conf = torrentConf.clone - sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", conf) + sc = new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") val rdd = sc.parallelize(1 to numSlaves) - val results = new DummyBroadcastClass(rdd).doSomething() assert(results.toSet === (1 to numSlaves).map(x => (x, false)).toSet) } - test("Unpersisting HttpBroadcast on executors only in local mode") { - testUnpersistHttpBroadcast(distributed = false, removeFromDriver = false) - } - - test("Unpersisting HttpBroadcast on executors and driver in local mode") { - testUnpersistHttpBroadcast(distributed = false, removeFromDriver = true) - } - - test("Unpersisting HttpBroadcast on executors only in distributed mode") { - testUnpersistHttpBroadcast(distributed = true, removeFromDriver = false) - } - - test("Unpersisting HttpBroadcast on executors and driver in distributed mode") { - testUnpersistHttpBroadcast(distributed = true, removeFromDriver = true) - } - test("Unpersisting TorrentBroadcast on executors only in local mode") { testUnpersistTorrentBroadcast(distributed = false, removeFromDriver = false) } @@ -179,66 +130,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(thrown.getMessage.toLowerCase.contains("stopped")) } - /** - * Verify the persistence of state associated with an HttpBroadcast in either local mode or - * local-cluster mode (when distributed = true). - * - * This test creates a broadcast variable, uses it on all executors, and then unpersists it. - * In between each step, this test verifies that the broadcast blocks and the broadcast file - * are present only on the expected nodes. - */ - private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) { - val numSlaves = if (distributed) 2 else 0 - - // Verify that the broadcast file is created, and blocks are persisted only on the driver - def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === 1) - statuses.head match { case (bm, status) => - assert(bm.isDriver, "Block should only be on the driver") - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store on the driver") - assert(status.diskSize === 0, "Block should not be in disk store on the driver") - } - if (distributed) { - // this file is only generated in distributed mode - assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!") - } - } - - // Verify that blocks are persisted in both the executors and the driver - def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - assert(statuses.size === numSlaves + 1) - statuses.foreach { case (_, status) => - assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK) - assert(status.memSize > 0, "Block should be in memory store") - assert(status.diskSize === 0, "Block should not be in disk store") - } - } - - // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver - // is true. In the latter case, also verify that the broadcast file is deleted on the driver. - def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) { - val blockId = BroadcastBlockId(broadcastId) - val statuses = bmm.getBlockStatus(blockId, askSlaves = true) - val expectedNumBlocks = if (removeFromDriver) 0 else 1 - val possiblyNot = if (removeFromDriver) "" else " not" - assert(statuses.size === expectedNumBlocks, - "Block should%s be unpersisted on the driver".format(possiblyNot)) - if (distributed && removeFromDriver) { - // this file is only generated in distributed mode - assert(!HttpBroadcast.getFile(blockId.broadcastId).exists, - "Broadcast file should%s be deleted".format(possiblyNot)) - } - } - - testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation, - afterUsingBroadcast, afterUnpersist, removeFromDriver) - } - /** * Verify the persistence of state associated with an TorrentBroadcast in a local-cluster. * @@ -284,7 +175,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(statuses.size === expectedNumBlocks) } - testUnpersistBroadcast(distributed, numSlaves, torrentConf, afterCreation, + testUnpersistBroadcast(distributed, numSlaves, afterCreation, afterUsingBroadcast, afterUnpersist, removeFromDriver) } @@ -300,7 +191,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { private def testUnpersistBroadcast( distributed: Boolean, numSlaves: Int, // used only when distributed = true - broadcastConf: SparkConf, afterCreation: (Long, BlockManagerMaster) => Unit, afterUsingBroadcast: (Long, BlockManagerMaster) => Unit, afterUnpersist: (Long, BlockManagerMaster) => Unit, @@ -308,7 +198,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { sc = if (distributed) { val _sc = - new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test", broadcastConf) + new SparkContext("local-cluster[%d, 1, 1024]".format(numSlaves), "test") // Wait until all salves are up try { _sc.jobProgressListener.waitUntilExecutorsUp(numSlaves, 60000) @@ -319,7 +209,7 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { throw e } } else { - new SparkContext("local", "test", broadcastConf) + new SparkContext("local", "test") } val blockManagerMaster = sc.env.blockManager.master val list = List[Int](1, 2, 3, 4) @@ -356,13 +246,6 @@ class BroadcastSuite extends SparkFunSuite with LocalSparkContext { assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet) } } - - /** Helper method to create a SparkConf that uses the given broadcast factory. */ - private def broadcastConf(factoryName: String): SparkConf = { - val conf = new SparkConf - conf.set("spark.broadcast.factory", "org.apache.spark.broadcast.%s".format(factoryName)) - conf - } } package object testPackage extends Assertions { diff --git a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala index 3164760b08a7..86455a13d0fe 100644 --- a/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/DeployTestUtils.scala @@ -20,9 +20,9 @@ package org.apache.spark.deploy import java.io.File import java.util.Date +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.master.{ApplicationInfo, DriverInfo, WorkerInfo} import org.apache.spark.deploy.worker.{DriverRunner, ExecutorRunner} -import org.apache.spark.{SecurityManager, SparkConf} private[deploy] object DeployTestUtils { def createAppDesc(): ApplicationDescription = { diff --git a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala index d93febcfd23f..9ecf49b59898 100644 --- a/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala +++ b/core/src/test/scala/org/apache/spark/deploy/IvyTestUtils.scala @@ -24,10 +24,8 @@ import java.util.jar.Manifest import scala.collection.mutable.ArrayBuffer -import com.google.common.io.{Files, ByteStreams} - +import com.google.common.io.{ByteStreams, Files} import org.apache.commons.io.FileUtils - import org.apache.ivy.core.settings.IvySettings import org.apache.spark.TestUtils.{createCompiledClass, JavaSourceFromString} diff --git a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala index 0a9f128a3a6b..2d48e75cfbd9 100644 --- a/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/JsonProtocolSuite.scala @@ -23,10 +23,10 @@ import com.fasterxml.jackson.core.JsonParseException import org.json4s._ import org.json4s.jackson.JsonMethods +import org.apache.spark.{JsonTestUtils, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.{MasterStateResponse, WorkerStateResponse} import org.apache.spark.deploy.master.{ApplicationInfo, RecoveryState} import org.apache.spark.deploy.worker.ExecutorRunner -import org.apache.spark.{JsonTestUtils, SparkFunSuite} class JsonProtocolSuite extends SparkFunSuite with JsonTestUtils { diff --git a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala index 8dd31b4b6fdd..f416ace5c2b7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/LogUrlsStandaloneSuite.scala @@ -22,9 +22,9 @@ import java.net.URL import scala.collection.mutable import scala.io.Source -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.scheduler.{SparkListenerExecutorAdded, SparkListener} import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.{SparkListener, SparkListenerExecutorAdded} +import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.util.SparkConfWithEnv class LogUrlsStandaloneSuite extends SparkFunSuite with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala index cc30ba223e1c..13cba94578a6 100644 --- a/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/RPackageUtilsSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.deploy -import java.io.{PrintStream, OutputStream, File} +import java.io.{File, OutputStream, PrintStream} import java.net.URI -import java.util.jar.Attributes.Name import java.util.jar.{JarFile, Manifest} +import java.util.jar.Attributes.Name import java.util.zip.ZipFile import scala.collection.JavaConverters._ diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala index 63c346c1b890..4877710c1237 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitUtilsSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.deploy -import java.io.{File, PrintStream, OutputStream} +import java.io.{File, OutputStream, PrintStream} import scala.collection.mutable.ArrayBuffer -import org.scalatest.BeforeAndAfterAll import org.apache.ivy.core.module.descriptor.MDArtifact import org.apache.ivy.core.settings.IvySettings import org.apache.ivy.plugins.resolver.{AbstractResolver, FileSystemResolver, IBiblioResolver} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.deploy.SparkSubmitUtils.MavenCoordinate @@ -171,7 +171,7 @@ class SparkSubmitUtilsSuite extends SparkFunSuite with BeforeAndAfterAll { } test("neglects Spark and Spark's dependencies") { - val components = Seq("bagel_", "catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", + val components = Seq("catalyst_", "core_", "graphx_", "hive_", "mllib_", "repl_", "sql_", "streaming_", "yarn_", "network-common_", "network-shuffle_", "network-yarn_") val coordinates = diff --git a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala index 2fa795f84666..ab3d4cafebef 100644 --- a/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/StandaloneDynamicAllocationSuite.scala @@ -71,15 +71,18 @@ class StandaloneDynamicAllocationSuite } override def afterAll(): Unit = { - masterRpcEnv.shutdown() - workerRpcEnvs.foreach(_.shutdown()) - master.stop() - workers.foreach(_.stop()) - masterRpcEnv = null - workerRpcEnvs = null - master = null - workers = null - super.afterAll() + try { + masterRpcEnv.shutdown() + workerRpcEnvs.foreach(_.shutdown()) + master.stop() + workers.foreach(_.stop()) + masterRpcEnv = null + workerRpcEnvs = null + master = null + workers = null + } finally { + super.afterAll() + } } test("dynamic allocation default behavior") { @@ -365,7 +368,7 @@ class StandaloneDynamicAllocationSuite val executors = getExecutorIds(sc) assert(executors.size === 2) assert(sc.killExecutor(executors.head)) - assert(sc.killExecutor(executors.head)) + assert(!sc.killExecutor(executors.head)) val apps = getApplications() assert(apps.head.executors.size === 1) // The limit should not be lowered twice @@ -386,23 +389,28 @@ class StandaloneDynamicAllocationSuite // the driver refuses to kill executors it does not know about syncExecutors(sc) val executors = getExecutorIds(sc) + val executorIdsBefore = executors.toSet assert(executors.size === 2) - // kill executor 1, and replace it + // kill and replace an executor assert(sc.killAndReplaceExecutor(executors.head)) eventually(timeout(10.seconds), interval(10.millis)) { val apps = getApplications() assert(apps.head.executors.size === 2) + val executorIdsAfter = getExecutorIds(sc).toSet + // make sure the executor was killed and replaced + assert(executorIdsBefore != executorIdsAfter) } - var apps = getApplications() - // kill executor 1 - assert(sc.killExecutor(executors.head)) - apps = getApplications() - assert(apps.head.executors.size === 2) - assert(apps.head.getExecutorLimit === 2) - // kill executor 2 - assert(sc.killExecutor(executors(1))) - apps = getApplications() + // kill old executor (which is killedAndReplaced) should fail + assert(!sc.killExecutor(executors.head)) + + // refresh executors list + val newExecutors = getExecutorIds(sc) + syncExecutors(sc) + + // kill newly created executor and do not replace it + assert(sc.killExecutor(newExecutors(1))) + val apps = getApplications() assert(apps.head.executors.size === 1) assert(apps.head.getExecutorLimit === 1) } @@ -430,7 +438,7 @@ class StandaloneDynamicAllocationSuite val executorIdToTaskCount = taskScheduler invokePrivate getMap() executorIdToTaskCount(executors.head) = 1 // kill the busy executor without force; this should fail - assert(killExecutor(sc, executors.head, force = false)) + assert(!killExecutor(sc, executors.head, force = false)) apps = getApplications() assert(apps.head.executors.size === 2) @@ -466,7 +474,7 @@ class StandaloneDynamicAllocationSuite (0 until numWorkers).map { i => val rpcEnv = workerRpcEnvs(i) val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), - Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + Worker.ENDPOINT_NAME, null, conf, securityManager) rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) worker } diff --git a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala index 1e5c05a73f8a..eb794b6739d5 100644 --- a/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/client/AppClientSuite.scala @@ -63,15 +63,18 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd } override def afterAll(): Unit = { - workerRpcEnvs.foreach(_.shutdown()) - masterRpcEnv.shutdown() - workers.foreach(_.stop()) - master.stop() - workerRpcEnvs = null - masterRpcEnv = null - workers = null - master = null - super.afterAll() + try { + workerRpcEnvs.foreach(_.shutdown()) + masterRpcEnv.shutdown() + workers.foreach(_.stop()) + master.stop() + workerRpcEnvs = null + masterRpcEnv = null + workers = null + master = null + } finally { + super.afterAll() + } } test("interface methods of AppClient using local Master") { @@ -144,7 +147,7 @@ class AppClientSuite extends SparkFunSuite with LocalSparkContext with BeforeAnd (0 until numWorkers).map { i => val rpcEnv = workerRpcEnvs(i) val worker = new Worker(rpcEnv, 0, cores, memory, Array(masterRpcEnv.address), - Worker.SYSTEM_NAME + i, Worker.ENDPOINT_NAME, null, conf, securityManager) + Worker.ENDPOINT_NAME, null, conf, securityManager) rpcEnv.setupEndpoint(Worker.ENDPOINT_NAME, worker) worker } diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 5cab17f8a38f..6cbf911395a8 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -23,8 +23,8 @@ import java.net.URI import java.util.concurrent.TimeUnit import java.util.zip.{ZipInputStream, ZipOutputStream} -import scala.io.Source import scala.concurrent.duration._ +import scala.io.Source import scala.language.postfixOps import com.google.common.base.Charsets diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 4b7fd4f13b69..18659fc0c18d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -30,6 +30,7 @@ import org.scalatest.mock.MockitoSugar import org.apache.spark.{JsonTestUtils, SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.ui.{SparkUI, UIUtils} +import org.apache.spark.util.ResetSystemProperties /** * A collection of tests against the historyserver, including comparing responses from the json @@ -43,7 +44,7 @@ import org.apache.spark.ui.{SparkUI, UIUtils} * are considered part of Spark's public api. */ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers with MockitoSugar - with JsonTestUtils { + with JsonTestUtils with ResetSystemProperties { private val logDir = new File("src/test/resources/spark-events") private val expRoot = new File("src/test/resources/HistoryServerExpectations/") diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 242bf4b5566e..10e33a32ba4c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -98,7 +98,7 @@ class MasterSuite extends SparkFunSuite with Matchers with Eventually with Priva Master.startRpcEnvAndEndpoint("127.0.0.1", 0, 0, conf) try { - rpcEnv.setupEndpointRef(Master.SYSTEM_NAME, rpcEnv.address, Master.ENDPOINT_NAME) + rpcEnv.setupEndpointRef(rpcEnv.address, Master.ENDPOINT_NAME) CustomPersistenceEngine.lastInstance.isDefined shouldBe true val persistenceEngine = CustomPersistenceEngine.lastInstance.get diff --git a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala index 7a4472867568..b4deed7f877e 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/PersistenceEngineSuite.scala @@ -25,7 +25,7 @@ import org.apache.curator.test.TestingServer import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.rpc.{RpcEndpoint, RpcEnv} -import org.apache.spark.serializer.{Serializer, JavaSerializer} +import org.apache.spark.serializer.{JavaSerializer, Serializer} import org.apache.spark.util.Utils class PersistenceEngineSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala index fba835f054f8..0c9382a92bca 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -23,11 +23,11 @@ import scala.io.Source import scala.language.postfixOps import org.json4s.jackson.JsonMethods._ -import org.json4s.JsonAST.{JNothing, JString, JInt} +import org.json4s.JsonAST.{JInt, JNothing, JString} import org.mockito.Mockito.{mock, when} import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SecurityManager, SparkFunSuite} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.DeployMessages.MasterStateResponse import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 9693e32bf6af..ee889bf14454 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -24,16 +24,16 @@ import javax.servlet.http.HttpServletResponse import scala.collection.mutable import com.google.common.base.Charsets -import org.scalatest.BeforeAndAfterEach import org.json4s.JsonAST._ import org.json4s.jackson.JsonMethods._ +import org.scalatest.BeforeAndAfterEach import org.apache.spark._ -import org.apache.spark.rpc._ -import org.apache.spark.util.Utils -import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.{SparkSubmit, SparkSubmitArguments} +import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.DriverState._ +import org.apache.spark.rpc._ +import org.apache.spark.util.Utils /** * Tests for the REST application submission protocol used in standalone cluster mode. @@ -43,8 +43,12 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { private var server: Option[RestSubmissionServer] = None override def afterEach() { - rpcEnv.foreach(_.shutdown()) - server.foreach(_.stop()) + try { + rpcEnv.foreach(_.shutdown()) + server.foreach(_.stop()) + } finally { + super.afterEach() + } } test("construct submit request") { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala index 7101cb9978df..607c0a4fac46 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/CommandUtilsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.deploy.worker +import org.scalatest.{Matchers, PrivateMethodTester} + import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command import org.apache.spark.util.Utils -import org.scalatest.{Matchers, PrivateMethodTester} class CommandUtilsSuite extends SparkFunSuite with Matchers with PrivateMethodTester { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala index 6258c18d177f..bd8b0655f4bb 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/DriverRunnerTest.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker import java.io.File -import org.mockito.Mockito._ import org.mockito.Matchers._ +import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 98664dc1101e..0240bf8aed4c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.worker import java.io.File -import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} class ExecutorRunnerTest extends SparkFunSuite { test("command includes appId") { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala index faed4bdc6844..101a44edd8ee 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerSuite.scala @@ -19,11 +19,11 @@ package org.apache.spark.deploy.worker import org.scalatest.Matchers +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.{Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.{DriverStateChanged, ExecutorStateChanged} import org.apache.spark.deploy.master.DriverState -import org.apache.spark.deploy.{Command, ExecutorState} import org.apache.spark.rpc.{RpcAddress, RpcEnv} -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} class WorkerSuite extends SparkFunSuite with Matchers { @@ -67,7 +67,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedExecutors", 2.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 5) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -93,7 +93,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedExecutors", 30.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 50) { worker.executors += s"app1/$i" -> createExecutorRunner(i) @@ -128,7 +128,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedDrivers", 2.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 5) { val driverId = s"driverId-$i" @@ -154,7 +154,7 @@ class WorkerSuite extends SparkFunSuite with Matchers { conf.set("spark.worker.ui.retainedDrivers", 30.toString) val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) val worker = new Worker(rpcEnv, 50000, 20, 1234 * 5, Array.fill(1)(RpcAddress("1.2.3.4", 1234)), - "sparkWorker1", "Worker", "/tmp", conf, new SecurityManager(conf)) + "Worker", "/tmp", conf, new SecurityManager(conf)) // initialize workers for (i <- 0 until 50) { val driverId = s"driverId-$i" diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 40c24bdecc6c..31bea3293ae7 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.deploy.worker -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.SecurityManager -import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointAddress, RpcEnv} class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher shuts down on valid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val targetWorkerUrl = RpcEndpointAddress(RpcAddress("1.2.3.4", 1234), "Worker").toString val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) workerWatcher.onDisconnected(RpcAddress("1.2.3.4", 1234)) @@ -36,7 +35,7 @@ class WorkerWatcherSuite extends SparkFunSuite { test("WorkerWatcher stays alive on invalid disassociation") { val conf = new SparkConf() val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) - val targetWorkerUrl = rpcEnv.uriOf("test", RpcAddress("1.2.3.4", 1234), "Worker") + val targetWorkerUrl = RpcEndpointAddress(RpcAddress("1.2.3.4", 1234), "Worker").toString val otherRpcAddress = RpcAddress("4.3.2.1", 1234) val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl, isTesting = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) diff --git a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala index 8a199459c1dd..d852255a4fd2 100644 --- a/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala +++ b/core/src/test/scala/org/apache/spark/input/WholeTextFileRecordReaderSuite.scala @@ -23,13 +23,12 @@ import java.io.FileOutputStream import scala.collection.immutable.IndexedSeq -import org.scalatest.BeforeAndAfterAll - import org.apache.hadoop.io.Text +import org.apache.hadoop.io.compress.{CompressionCodecFactory, DefaultCodec, GzipCodec} +import org.scalatest.BeforeAndAfterAll import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.util.Utils -import org.apache.hadoop.io.compress.{DefaultCodec, CompressionCodecFactory, GzipCodec} /** * Tests the correctness of @@ -47,6 +46,7 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl // hard-to-reproduce test failures, since any suites that were run after this one would inherit // the new value of "fs.local.block.size" (see SPARK-5227 and SPARK-5679). To work around this, // we disable FileSystem caching in this suite. + super.beforeAll() val conf = new SparkConf().set("spark.hadoop.fs.file.impl.disable.cache", "true") sc = new SparkContext("local", "test", conf) @@ -59,7 +59,11 @@ class WholeTextFileRecordReaderSuite extends SparkFunSuite with BeforeAndAfterAl } override def afterAll() { - sc.stop() + try { + sc.stop() + } finally { + super.afterAll() + } } private def createNativeFile(inputDir: File, fileName: String, contents: Array[Byte], diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index 1553ab60bdda..9e9c2b0165e1 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -46,7 +46,7 @@ class CompressionCodecSuite extends SparkFunSuite { test("default compression codec") { val codec = CompressionCodec.createCodec(conf) - assert(codec.getClass === classOf[SnappyCompressionCodec]) + assert(codec.getClass === classOf[LZ4CompressionCodec]) testCodec(codec) } @@ -62,12 +62,10 @@ class CompressionCodecSuite extends SparkFunSuite { testCodec(codec) } - test("lz4 does not support concatenation of serialized streams") { + test("lz4 supports concatenation of serialized streams") { val codec = CompressionCodec.createCodec(conf, classOf[LZ4CompressionCodec].getName) assert(codec.getClass === classOf[LZ4CompressionCodec]) - intercept[Exception] { - testConcatenationOfSerializedStreams(codec) - } + testConcatenationOfSerializedStreams(codec) } test("lzf compression codec") { diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 555b640cb424..f2924a6a5c05 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.memory import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import scala.concurrent.duration.Duration import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.Duration import org.mockito.Matchers.{any, anyLong} import org.mockito.Mockito.{mock, when, RETURNS_SMART_NULLS} diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 4b4c3b031132..0e60cc8e7787 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.memory -import org.apache.spark.{SparkEnv, TaskContextImpl, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} /** * Helper methods for mocking out memory-management-related classes in tests. diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index 0706a6e45de8..4a1e49b45df4 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -20,7 +20,7 @@ package org.apache.spark.memory import scala.collection.mutable import org.apache.spark.SparkConf -import org.apache.spark.storage.{BlockStatus, BlockId} +import org.apache.spark.storage.{BlockId, BlockStatus} class TestMemoryManager(conf: SparkConf) extends MemoryManager(conf, numCores = 1, Long.MaxValue, Long.MaxValue) { diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 44eb5a046912..aaf62e0f9106 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -25,17 +25,17 @@ import org.apache.commons.lang3.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, + JobConf, LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, + Reporter, TextInputFormat => OldTextInputFormat} import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, CombineFileRecordReader => OldCombineFileRecordReader, CombineFileSplit => OldCombineFileSplit} -import org.apache.hadoop.mapred.{JobConf, Reporter, FileSplit => OldFileSplit, - InputSplit => OldInputSplit, LineRecordReader => OldLineRecordReader, - RecordReader => OldRecordReader, TextInputFormat => OldTextInputFormat} +import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader, + TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.hadoop.mapreduce.{TaskAttemptContext, InputSplit => NewInputSplit, - RecordReader => NewRecordReader} import org.scalatest.BeforeAndAfter import org.apache.spark.{SharedSparkContext, SparkFunSuite} diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala index 41f2ff725a17..b24f5d732f29 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsConfigSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.metrics -import org.apache.spark.SparkConf - import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkConf import org.apache.spark.SparkFunSuite class MetricsConfigSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala index 9c389c76bf3b..5d8554229dbe 100644 --- a/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/MetricsSystemSuite.scala @@ -17,16 +17,15 @@ package org.apache.spark.metrics +import scala.collection.mutable.ArrayBuffer + +import com.codahale.metrics.MetricRegistry import org.scalatest.{BeforeAndAfter, PrivateMethodTester} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.master.MasterSource import org.apache.spark.metrics.source.Source -import com.codahale.metrics.MetricRegistry - -import scala.collection.mutable.ArrayBuffer - class MetricsSystemSuite extends SparkFunSuite with BeforeAndAfter with PrivateMethodTester{ var filePath: String = _ var conf: SparkConf = null diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala index 98da94139f7f..47dbcb8fc0ea 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferSecuritySuite.scala @@ -22,20 +22,21 @@ import java.nio._ import java.nio.charset.Charset import java.util.concurrent.TimeUnit -import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} +import scala.concurrent.duration._ import scala.util.{Failure, Success, Try} import com.google.common.io.CharStreams -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.BlockFetchingListener -import org.apache.spark.network.{BlockDataManager, BlockTransferService} -import org.apache.spark.storage.{BlockId, ShuffleBlockId} -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito._ import org.scalatest.mock.MockitoSugar import org.scalatest.ShouldMatchers +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.{BlockDataManager, BlockTransferService} +import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.BlockFetchingListener +import org.apache.spark.storage.{BlockId, ShuffleBlockId} + class NettyBlockTransferSecuritySuite extends SparkFunSuite with MockitoSugar with ShouldMatchers { test("security default off") { val conf = new SparkConf() diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index 6f8e8a7ac603..cc1a9e028708 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark.network.netty -import org.apache.spark.network.BlockDataManager -import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.mockito.Mockito.mock import org.scalatest._ +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.network.BlockDataManager + class NettyBlockTransferServiceSuite extends SparkFunSuite with BeforeAndAfterEach @@ -31,14 +32,18 @@ class NettyBlockTransferServiceSuite private var service1: NettyBlockTransferService = _ override def afterEach() { - if (service0 != null) { - service0.close() - service0 = null - } + try { + if (service0 != null) { + service0.close() + service0 = null + } - if (service1 != null) { - service1.close() - service1 = null + if (service1 != null) { + service1.close() + service1 = null + } + } finally { + super.afterEach() } } diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index de015ebd5d23..d18bde790b40 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -34,12 +34,17 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim @transient private var sc: SparkContext = _ override def beforeAll() { + super.beforeAll() sc = new SparkContext("local[2]", "test") } override def afterAll() { - LocalSparkContext.stop(sc) - sc = null + try { + LocalSparkContext.stop(sc) + sc = null + } finally { + super.afterAll() + } } lazy val zeroPartRdd = new EmptyRDD[Int](sc) diff --git a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala index 5103eb74b245..e694f5e5e7ad 100644 --- a/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/LocalCheckpointSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.rdd -import org.apache.spark.{SparkException, SparkContext, LocalSparkContext, SparkFunSuite} - import org.mockito.Mockito.spy + +import org.apache.spark.{LocalSparkContext, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.storage.{RDDBlockId, StorageLevel} /** @@ -29,6 +29,7 @@ import org.apache.spark.storage.{RDDBlockId, StorageLevel} class LocalCheckpointSuite extends SparkFunSuite with LocalSparkContext { override def beforeEach(): Unit = { + super.beforeEach() sc = new SparkContext("local[2]", "test") } diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 7d2cfcca9436..16e2d2e636c1 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -17,18 +17,18 @@ package org.apache.spark.rdd -import org.apache.commons.math3.distribution.{PoissonDistribution, BinomialDistribution} -import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.mapred._ -import org.apache.hadoop.util.Progressable - import scala.collection.mutable.{ArrayBuffer, HashSet} import scala.util.Random +import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistribution} import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, -OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, -TaskAttemptContext => NewTaskAttempContext} +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.mapred._ +import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, + OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, + RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} +import org.apache.hadoop.util.Progressable + import org.apache.spark.{Partitioner, SharedSparkContext, SparkFunSuite} import org.apache.spark.util.Utils diff --git a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala index 5f73ec867596..1eebc924a534 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PipedRDDSuite.scala @@ -19,15 +19,15 @@ package org.apache.spark.rdd import java.io.File -import org.apache.hadoop.fs.Path -import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} - import scala.collection.Map import scala.language.postfixOps import scala.sys.process._ import scala.util.Try +import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.{FileSplit, JobConf, TextInputFormat} + import org.apache.spark._ import org.apache.spark.util.Utils diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 007a71f87cf1..ef2ed445005d 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -17,14 +17,14 @@ package org.apache.spark.rdd -import java.io.{ObjectInputStream, ObjectOutputStream, IOException} +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} -import com.esotericsoftware.kryo.KryoException - -import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.collection.JavaConverters._ +import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.reflect.ClassTag +import com.esotericsoftware.kryo.KryoException + import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDDSuiteUtils._ @@ -441,66 +441,6 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(prunedData(0) === 10) } - test("mapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - @deprecated("suppress compile time deprecation warning", "1.0.0") - val randoms = ones.mapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextDouble * t}.collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(2) === prn42_3) - assert(randoms(5) === prn43_3) - } - - test("flatMapWith") { - import java.util.Random - val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) - @deprecated("suppress compile time deprecation warning", "1.0.0") - val randoms = ones.flatMapWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => - val random = prng.nextDouble() - Seq(random * t, random * t * 10)}. - collect() - val prn42_3 = { - val prng42 = new Random(42) - prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() - } - val prn43_3 = { - val prng43 = new Random(43) - prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() - } - assert(randoms(5) === prn42_3 * 10) - assert(randoms(11) === prn43_3 * 10) - } - - test("filterWith") { - import java.util.Random - val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) - @deprecated("suppress compile time deprecation warning", "1.0.0") - val sample = ints.filterWith( - (index: Int) => new Random(index + 42)) - {(t: Int, prng: Random) => prng.nextInt(3) == 0}. - collect() - val checkSample = { - val prng42 = new Random(42) - val prng43 = new Random(43) - Array(1, 2, 3, 4, 5, 6).filter{i => - if (i < 4) 0 == prng42.nextInt(3) else 0 == prng43.nextInt(3) - } - } - assert(sample.size === checkSample.size) - for (i <- 0 until sample.size) assert(sample(i) === checkSample(i)) - } - test("collect large number of empty partitions") { // Regression test for SPARK-4019 assert(sc.makeRDD(0 until 10, 1000).repartition(2001).collect().toSet === (0 until 10).toSet) @@ -542,6 +482,10 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(nums.take(501) === (1 to 501).toArray) assert(nums.take(999) === (1 to 999).toArray) assert(nums.take(1000) === (1 to 999).toArray) + + nums = sc.parallelize(1 to 2, 2) + assert(nums.take(2147483638).size === 2) + assert(nums.takeAsync(2147483638).get.size === 2) } test("top with predefined ordering") { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 6d153eb04e04..64e486d791cd 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.rpc import java.io.{File, NotSerializableException} -import java.util.UUID import java.nio.charset.StandardCharsets.UTF_8 -import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} +import java.util.UUID +import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} import scala.collection.mutable import scala.concurrent.Await @@ -44,6 +44,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { var env: RpcEnv = _ override def beforeAll(): Unit = { + super.beforeAll() val conf = new SparkConf() env = createRpcEnv(conf, "local", 0) @@ -53,10 +54,14 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } override def afterAll(): Unit = { - if (env != null) { - env.shutdown() + try { + if (env != null) { + env.shutdown() + } + SparkEnv.set(null) + } finally { + super.afterAll() } - SparkEnv.set(null) } def createRpcEnv(conf: SparkConf, name: String, port: Int, clientMode: Boolean = false): RpcEnv @@ -89,7 +94,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "send-remotely") try { rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(10 millis)) { @@ -143,7 +148,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-remotely") try { val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) @@ -171,7 +176,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { conf.set("spark.rpc.numRetries", "1") val anotherEnv = createRpcEnv(conf, "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "ask-timeout") try { // Any exception thrown in askWithRetry is wrapped with a SparkException and set as the cause val e = intercept[SparkException] { @@ -430,7 +435,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely") try { val f = rpcEndpointRef.ask[String]("hello") val ack = Await.result(f, 5 seconds) @@ -470,8 +475,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "sendWithReply-remotely-error") + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely-error") try { val f = rpcEndpointRef.ask[String]("hello") val e = intercept[SparkException] { @@ -484,10 +488,16 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } - test("network events") { + /** + * Setup an [[RpcEndpoint]] to collect all network events. + * @return the [[RpcEndpointRef]] and an `Seq` that contains network events. + */ + private def setupNetworkEndpoint( + _env: RpcEnv, + name: String): (RpcEndpointRef, Seq[(Any, Any)]) = { val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupEndpoint("network-events", new ThreadSafeRpcEndpoint { - override val rpcEnv = env + val ref = _env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { + override val rpcEnv = _env override def receive: PartialFunction[Any, Unit] = { case "hello" => @@ -507,83 +517,94 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } }) + (ref, events) + } - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) - // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "network-events") - val remoteAddress = anotherEnv.address - rpcEndpointRef.send("hello") - eventually(timeout(5 seconds), interval(5 millis)) { - // anotherEnv is connected in client mode, so the remote address may be unknown depending on - // the implementation. Account for that when doing checks. - if (remoteAddress != null) { - assert(events === List(("onConnected", remoteAddress))) - } else { - assert(events.size === 1) - assert(events(0)._1 === "onConnected") + test("network events in sever RpcEnv when another RpcEnv is in server mode") { + val serverEnv1 = createRpcEnv(new SparkConf(), "server1", 0, clientMode = false) + val serverEnv2 = createRpcEnv(new SparkConf(), "server2", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(serverEnv1, "network-events") + val (serverRef2, _) = setupNetworkEndpoint(serverEnv2, "network-events") + try { + val serverRefInServer2 = serverEnv1.setupEndpointRef(serverRef2.address, serverRef2.name) + // Send a message to set up the connection + serverRefInServer2.send("hello") + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv2.address))) } - } - anotherEnv.shutdown() - anotherEnv.awaitTermination() - eventually(timeout(5 seconds), interval(5 millis)) { - // Account for anotherEnv not having an address due to running in client mode. - if (remoteAddress != null) { - assert(events === List( - ("onConnected", remoteAddress), - ("onNetworkError", remoteAddress), - ("onDisconnected", remoteAddress)) || - events === List( - ("onConnected", remoteAddress), - ("onDisconnected", remoteAddress))) - } else { - val eventNames = events.map(_._1) - assert(eventNames === List("onConnected", "onNetworkError", "onDisconnected") || - eventNames === List("onConnected", "onDisconnected")) + serverEnv2.shutdown() + serverEnv2.awaitTermination() + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv2.address))) + assert(events.contains(("onDisconnected", serverEnv2.address))) } + } finally { + serverEnv1.shutdown() + serverEnv2.shutdown() + serverEnv1.awaitTermination() + serverEnv2.awaitTermination() } } - test("network events between non-client-mode RpcEnvs") { - val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] - env.setupEndpoint("network-events-non-client", new ThreadSafeRpcEndpoint { - override val rpcEnv = env + test("network events in sever RpcEnv when another RpcEnv is in client mode") { + val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val (serverRef, events) = setupNetworkEndpoint(serverEnv, "network-events") + val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + try { + val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") - override def receive: PartialFunction[Any, Unit] = { - case "hello" => - case m => events += "receive" -> m + eventually(timeout(5 seconds), interval(5 millis)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.map(_._1).contains("onConnected")) } - override def onConnected(remoteAddress: RpcAddress): Unit = { - events += "onConnected" -> remoteAddress - } + clientEnv.shutdown() + clientEnv.awaitTermination() - override def onDisconnected(remoteAddress: RpcAddress): Unit = { - events += "onDisconnected" -> remoteAddress + eventually(timeout(5 seconds), interval(5 millis)) { + // We don't know the exact client address but at least we can verify the message type + assert(events.map(_._1).contains("onConnected")) + assert(events.map(_._1).contains("onDisconnected")) } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() + } + } - override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - events += "onNetworkError" -> remoteAddress - } + test("network events in client RpcEnv when another RpcEnv is in server mode") { + val clientEnv = createRpcEnv(new SparkConf(), "client", 0, clientMode = true) + val serverEnv = createRpcEnv(new SparkConf(), "server", 0, clientMode = false) + val (_, events) = setupNetworkEndpoint(clientEnv, "network-events") + val (serverRef, _) = setupNetworkEndpoint(serverEnv, "network-events") + try { + val serverRefInClient = clientEnv.setupEndpointRef(serverRef.address, serverRef.name) + // Send a message to set up the connection + serverRefInClient.send("hello") - }) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv.address))) + } - val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = false) - // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "network-events-non-client") - val remoteAddress = anotherEnv.address - rpcEndpointRef.send("hello") - eventually(timeout(5 seconds), interval(5 millis)) { - assert(events.contains(("onConnected", remoteAddress))) - } + serverEnv.shutdown() + serverEnv.awaitTermination() - anotherEnv.shutdown() - anotherEnv.awaitTermination() - eventually(timeout(5 seconds), interval(5 millis)) { - assert(events.contains(("onConnected", remoteAddress))) - assert(events.contains(("onDisconnected", remoteAddress))) + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.contains(("onConnected", serverEnv.address))) + assert(events.contains(("onDisconnected", serverEnv.address))) + } + } finally { + clientEnv.shutdown() + serverEnv.shutdown() + clientEnv.awaitTermination() + serverEnv.awaitTermination() } } @@ -598,8 +619,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val anotherEnv = createRpcEnv(new SparkConf(), "remote", 0, clientMode = true) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef( - "local", env.address, "sendWithReply-unserializable-error") + val rpcEndpointRef = + anotherEnv.setupEndpointRef(env.address, "sendWithReply-unserializable-error") try { val f = rpcEndpointRef.ask[String]("hello") val e = intercept[Exception] { @@ -636,8 +657,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { case msg: String => message = msg } }) - val rpcEndpointRef = - remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "send-authentication") + val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "send-authentication") rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(10 millis)) { assert("hello" === message) @@ -668,8 +688,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { } } }) - val rpcEndpointRef = - remoteEnv.setupEndpointRef("authentication-local", localEnv.address, "ask-authentication") + val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication") val reply = rpcEndpointRef.askWithRetry[String]("hello") assert("hello" === reply) } finally { @@ -771,6 +790,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val tempDir = Utils.createTempDir() val file = new File(tempDir, "file") Files.write(UUID.randomUUID().toString(), file, UTF_8) + val fileWithSpecialChars = new File(tempDir, "file name") + Files.write(UUID.randomUUID().toString(), fileWithSpecialChars, UTF_8) val empty = new File(tempDir, "empty") Files.write("", empty, UTF_8); val jar = new File(tempDir, "jar") @@ -787,6 +808,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { Files.write(UUID.randomUUID().toString(), subFile2, UTF_8) val fileUri = env.fileServer.addFile(file) + val fileWithSpecialCharsUri = env.fileServer.addFile(fileWithSpecialChars) val emptyUri = env.fileServer.addFile(empty) val jarUri = env.fileServer.addJar(jar) val dir1Uri = env.fileServer.addDirectory("/dir1", dir1) @@ -805,6 +827,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val files = Seq( (file, fileUri), + (fileWithSpecialChars, fileWithSpecialCharsUri), (empty, emptyUri), (jar, jarUri), (subFile1, dir1Uri + "/file1"), diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala deleted file mode 100644 index 7aac02775e1b..000000000000 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc.akka - -import org.apache.spark.rpc._ -import org.apache.spark.{SSLSampleConfigs, SecurityManager, SparkConf} - -class AkkaRpcEnvSuite extends RpcEnvSuite { - - override def createRpcEnv(conf: SparkConf, - name: String, - port: Int, - clientMode: Boolean = false): RpcEnv = { - new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf), clientMode)) - } - - test("setupEndpointRef: systemName, address, endpointName") { - val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { - override val rpcEnv = env - - override def receive = { - case _ => - } - }) - val conf = new SparkConf() - val newRpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 0, new SecurityManager(conf), false)) - try { - val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") - assert(s"akka.tcp://local@${env.address}/user/test_endpoint" === - newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) - } finally { - newRpcEnv.shutdown() - } - } - - test("uriOf") { - val uri = env.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") - assert("akka.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) - } - - test("uriOf: ssl") { - val conf = SSLSampleConfigs.sparkSSLConfig() - val securityManager = new SecurityManager(conf) - val rpcEnv = new AkkaRpcEnvFactory().create( - RpcEnvConfig(conf, "test", "localhost", 0, securityManager, false)) - try { - val uri = rpcEnv.uriOf("local", RpcAddress("1.2.3.4", 12345), "test_endpoint") - assert("akka.ssl.tcp://local@1.2.3.4:12345/user/test_endpoint" === uri) - } finally { - rpcEnv.shutdown() - } - } - -} diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index 2136795b1881..12113be75c23 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -23,7 +23,7 @@ import java.util.concurrent.atomic.AtomicInteger import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite -import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcAddress, TestRpcEndpoint} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv, TestRpcEndpoint} class InboxSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala index 56743ba650b4..4fcdb619f930 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcAddressSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.rpc.netty import org.apache.spark.SparkFunSuite +import org.apache.spark.rpc.RpcEndpointAddress class NettyRpcAddressSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index ce83087ec04d..994a58836bd0 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -33,9 +33,9 @@ class NettyRpcEnvSuite extends RpcEnvSuite { } test("non-existent endpoint") { - val uri = env.uriOf("test", env.address, "nonexist-endpoint") + val uri = RpcEndpointAddress(env.address, "nonexist-endpoint").toString val e = intercept[RpcEndpointNotFoundException] { - env.setupEndpointRef("test", env.address, "nonexist-endpoint") + env.setupEndpointRef(env.address, "nonexist-endpoint") } assert(e.getMessage.contains(uri)) } diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala index ebd6f700710b..0c156fef0ae0 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcHandlerSuite.scala @@ -21,11 +21,11 @@ import java.net.InetSocketAddress import java.nio.ByteBuffer import io.netty.channel.Channel -import org.mockito.Mockito._ import org.mockito.Matchers._ +import org.mockito.Mockito._ import org.apache.spark.SparkFunSuite -import org.apache.spark.network.client.{TransportResponseHandler, TransportClient} +import org.apache.spark.network.client.{TransportClient, TransportResponseHandler} import org.apache.spark.network.server.StreamManager import org.apache.spark.rpc._ @@ -43,7 +43,7 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) + nettyRpcHandler.channelActive(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) } @@ -55,10 +55,10 @@ class NettyRpcHandlerSuite extends SparkFunSuite { val channel = mock(classOf[Channel]) val client = new TransportClient(channel, mock(classOf[TransportResponseHandler])) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.receive(client, null, null) + nettyRpcHandler.channelActive(client) when(channel.remoteAddress()).thenReturn(new InetSocketAddress("localhost", 40000)) - nettyRpcHandler.connectionTerminated(client) + nettyRpcHandler.channelInactive(client) verify(dispatcher, times(1)).postToAll(RemoteProcessConnected(RpcAddress("localhost", 40000))) verify(dispatcher, times(1)).postToAll( diff --git a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala index eef6aafa624e..70f40fb26c2f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/CoarseGrainedSchedulerBackendSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.scheduler import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} -import org.apache.spark.util.{SerializableBuffer, AkkaUtils} +import org.apache.spark.util.{AkkaUtils, SerializableBuffer} class CoarseGrainedSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 2869f0fde4c5..370a284d2950 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import scala.collection.mutable.{ArrayBuffer, HashSet, HashMap, Map} +import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.language.reflectiveCalls import scala.util.control.NonFatal diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 5cb2d4225d28..43da6fc5b547 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -67,11 +67,11 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit val logPath = new Path(eventLogger.logPath + EventLoggingListener.IN_PROGRESS) assert(fileSystem.exists(logPath)) val logStatus = fileSystem.getFileStatus(logPath) - assert(!logStatus.isDir) + assert(!logStatus.isDirectory) // Verify log is renamed after stop() eventLogger.stop() - assert(!fileSystem.getFileStatus(new Path(eventLogger.logPath)).isDir) + assert(!fileSystem.getFileStatus(new Path(eventLogger.logPath)).isDirectory) } test("Basic event logging") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala index 15c8de61b824..56e0f01b3b41 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/MapStatusSuite.scala @@ -17,13 +17,13 @@ package org.apache.spark.scheduler -import org.apache.spark.storage.BlockManagerId +import scala.util.Random -import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.serializer.JavaSerializer import org.roaringbitmap.RoaringBitmap -import scala.util.Random +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.storage.BlockManagerId class MapStatusSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index f33324792495..1dca4bd89fd9 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.io.{ObjectInputStream, ObjectOutputStream, IOException} +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import org.apache.spark.TaskContext diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala index 1ae5b030f083..9f41aca8a1e1 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorIntegrationSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.scheduler import org.apache.hadoop.mapred.{FileOutputCommitter, TaskAttemptContext} import org.scalatest.concurrent.Timeouts -import org.scalatest.time.{Span, Seconds} +import org.scalatest.time.{Seconds, Span} -import org.apache.spark.{SparkConf, SparkContext, LocalSparkContext, SparkFunSuite, TaskContext} +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite, TaskContext} import org.apache.spark.util.Utils /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 7345508bfe99..c461da65bdc4 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -20,22 +20,21 @@ package org.apache.spark.scheduler import java.io.File import java.util.concurrent.TimeoutException +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.hadoop.mapred.{JobConf, OutputCommitter, TaskAttemptContext, TaskAttemptID} import org.mockito.Matchers import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfter -import org.apache.hadoop.mapred.{TaskAttemptID, JobConf, TaskAttemptContext, OutputCommitter} - import org.apache.spark._ -import org.apache.spark.rdd.{RDD, FakeOutputCommitter} +import org.apache.spark.rdd.{FakeOutputCommitter, RDD} import org.apache.spark.util.Utils -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.language.postfixOps - /** * Unit tests for the output commit coordination functionality. * diff --git a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala index 103fc19369c9..761e82e6cf1c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ReplayListenerSuite.scala @@ -23,7 +23,6 @@ import java.net.URI import org.json4s.jackson.JsonMethods._ import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkConf, SparkContext, SPARK_VERSION} import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec @@ -115,7 +114,7 @@ class ReplayListenerSuite extends SparkFunSuite with BeforeAndAfter { val applications = fileSystem.listStatus(logDirPath) assert(applications != null && applications.size > 0) val eventLog = applications.sortBy(_.getModificationTime).last - assert(!eventLog.isDir) + assert(!eventLog.isDirectory) // Replay events val logData = EventLoggingListener.openEventLog(eventLog.getPath(), fileSystem) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index f20d5be7c0ee..dc15f5932d6f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -24,10 +24,9 @@ import scala.collection.JavaConverters._ import org.scalatest.Matchers -import org.apache.spark.SparkException +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.ResetSystemProperties -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Matchers with ResetSystemProperties { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index d83d0aee4225..e5ec44a9f3b6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -17,17 +17,15 @@ package org.apache.spark.scheduler -import org.mockito.Mockito._ import org.mockito.Matchers.any - +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark._ +import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.util.{TaskCompletionListener, TaskCompletionListenerException} -import org.apache.spark.metrics.source.JvmSource - class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -99,14 +97,6 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark }.collect() assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } - - test("TaskContext.attemptId returns taskAttemptId for backwards-compatibility (SPARK-4014)") { - sc = new SparkContext("local", "test") - val attemptIds = sc.parallelize(Seq(1, 2, 3, 4), 4).mapPartitions { iter => - Seq(TaskContext.get().attemptId).iterator - }.collect() - assert(attemptIds.toSet === Set(0, 1, 2, 3)) - } } private object TaskContextSuite { diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala index 525ee0d3bdc5..a4110d2d462d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackendSuite.scala @@ -20,17 +20,17 @@ package org.apache.spark.scheduler.cluster.mesos import java.util import java.util.Collections -import org.apache.mesos.Protos.Value.Scalar -import org.apache.mesos.Protos._ import org.apache.mesos.{Protos, Scheduler, SchedulerDriver} +import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.Scalar +import org.mockito.Matchers import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.mockito.Matchers import org.scalatest.mock.MockitoSugar import org.scalatest.BeforeAndAfter +import org.apache.spark.{LocalSparkContext, SecurityManager, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.scheduler.TaskSchedulerImpl -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SecurityManager, SparkFunSuite} class CoarseMesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala index c4dc56003120..504e5780f3d8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendSuite.scala @@ -26,19 +26,19 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.mesos.Protos.Value.Scalar import org.apache.mesos.Protos._ +import org.apache.mesos.Protos.Value.Scalar import org.apache.mesos.SchedulerDriver +import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Matchers._ import org.mockito.Mockito._ -import org.mockito.{ArgumentCaptor, Matchers} import org.scalatest.mock.MockitoSugar +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.executor.MesosExecutorBackend -import org.apache.spark.scheduler.cluster.ExecutorInfo import org.apache.spark.scheduler.{LiveListenerBus, SparkListenerExecutorAdded, TaskDescription, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.cluster.ExecutorInfo class MesosSchedulerBackendSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { diff --git a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala index f5cef1caaf1a..98fdc58786ec 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/mesos/MesosClusterSchedulerSuite.scala @@ -21,11 +21,10 @@ import java.util.Date import org.scalatest.mock.MockitoSugar +import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.deploy.Command import org.apache.spark.deploy.mesos.MesosDriverDescription import org.apache.spark.scheduler.cluster.mesos._ -import org.apache.spark.{LocalSparkContext, SparkConf, SparkFunSuite} - class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar { diff --git a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala index 87f25e7245e1..3734f1cb408f 100644 --- a/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/GenericAvroSerializerSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.serializer import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import java.nio.ByteBuffer -import com.esotericsoftware.kryo.io.{Output, Input} -import org.apache.avro.{SchemaBuilder, Schema} +import com.esotericsoftware.kryo.io.{Input, Output} +import org.apache.avro.{Schema, SchemaBuilder} import org.apache.avro.generic.GenericData.Record -import org.apache.spark.{SparkFunSuite, SharedSparkContext} +import org.apache.spark.{SharedSparkContext, SparkFunSuite} class GenericAvroSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") diff --git a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala index 20f45670bc2b..6a6ea42797fb 100644 --- a/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/JavaSerializerSuite.scala @@ -23,13 +23,18 @@ class JavaSerializerSuite extends SparkFunSuite { test("JavaSerializer instances are serializable") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() - instance.deserialize[JavaSerializer](instance.serialize(serializer)) + val obj = instance.deserialize[JavaSerializer](instance.serialize(serializer)) + // enforce class cast + obj.getClass } test("Deserialize object containing a primitive Class as attribute") { val serializer = new JavaSerializer(new SparkConf()) val instance = serializer.newInstance() - instance.deserialize[JavaSerializer](instance.serialize(new ContainsPrimitiveClass())) + val obj = instance.deserialize[ContainsPrimitiveClass](instance.serialize( + new ContainsPrimitiveClass())) + // enforce class cast + obj.getClass } } diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala index 935a091f14f9..a0483f648388 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerDistributedSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.serializer -import org.apache.spark.util.Utils - import com.esotericsoftware.kryo.Kryo import org.apache.spark._ import org.apache.spark.serializer.KryoDistributedTest._ +import org.apache.spark.util.Utils class KryoSerializerDistributedSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala index a9b209ccfc76..21251f0b9376 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerResizableOutputSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.serializer import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.SparkContext import org.apache.spark.LocalSparkContext +import org.apache.spark.SparkContext import org.apache.spark.SparkException - class KryoSerializerResizableOutputSuite extends SparkFunSuite { // trial and error showed this will not serialize with 1mb buffer diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index 9fcc22b608c6..f869bcd70861 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.serializer -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileOutputStream, FileInputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, FileInputStream, FileOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -25,14 +25,13 @@ import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} - import org.roaringbitmap.RoaringBitmap import org.apache.spark.{SharedSparkContext, SparkConf, SparkFunSuite} import org.apache.spark.scheduler.HighlyCompressedMapStatus import org.apache.spark.serializer.KryoTest._ -import org.apache.spark.util.Utils import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") @@ -363,19 +362,35 @@ class KryoSerializerSuite extends SparkFunSuite with SharedSparkContext { bitmap.add(1) bitmap.add(3) bitmap.add(5) - bitmap.serialize(new KryoOutputDataOutputBridge(output)) + // Ignore Kryo because it doesn't use writeObject + bitmap.serialize(new KryoOutputObjectOutputBridge(null, output)) output.flush() output.close() val inStream = new FileInputStream(tmpfile) val input = new KryoInput(inStream) val ret = new RoaringBitmap - ret.deserialize(new KryoInputDataInputBridge(input)) + // Ignore Kryo because it doesn't use readObject + ret.deserialize(new KryoInputObjectInputBridge(null, input)) input.close() assert(ret == bitmap) Utils.deleteRecursively(dir) } + test("KryoOutputObjectOutputBridge.writeObject and KryoInputObjectInputBridge.readObject") { + val kryo = new KryoSerializer(conf).newKryo() + + val bytesOutput = new ByteArrayOutputStream() + val objectOutput = new KryoOutputObjectOutputBridge(kryo, new KryoOutput(bytesOutput)) + objectOutput.writeObject("test") + objectOutput.close() + + val bytesInput = new ByteArrayInputStream(bytesOutput.toByteArray) + val objectInput = new KryoInputObjectInputBridge(kryo, new KryoInput(bytesInput)) + assert(objectInput.readObject() === "test") + objectInput.close() + } + test("getAutoReset") { val ser = new KryoSerializer(new SparkConf).newInstance().asInstanceOf[KryoSerializerInstance] assert(ser.getAutoReset) diff --git a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala index 2d5e9d66b2e1..683aaa3aab1b 100644 --- a/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/SerializationDebuggerSuite.scala @@ -29,6 +29,7 @@ class SerializationDebuggerSuite extends SparkFunSuite with BeforeAndAfterEach { import SerializationDebugger.find override def beforeEach(): Unit = { + super.beforeEach() SerializationDebugger.enableDebugging = true } diff --git a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala index c1e0a29a34bb..17037870f7a1 100644 --- a/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala +++ b/core/src/test/scala/org/apache/spark/serializer/TestSerializer.scala @@ -17,12 +17,11 @@ package org.apache.spark.serializer -import java.io.{EOFException, OutputStream, InputStream} +import java.io.{EOFException, InputStream, OutputStream} import java.nio.ByteBuffer import scala.reflect.ClassTag - /** * A serializer implementation that always returns two elements in a deserialization stream. */ diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index d3b1b2b620b4..e33408b94e2c 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -23,8 +23,8 @@ import java.util.UUID import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.{Mock, MockitoAnnotations} +import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.Matchers._ import org.mockito.Mockito._ import org.mockito.invocation.InvocationOnMock @@ -32,9 +32,9 @@ import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterEach import org.apache.spark._ -import org.apache.spark.executor.{TaskMetrics, ShuffleWriteMetrics} -import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} +import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -55,6 +55,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) taskMetrics = new TaskMetrics @@ -119,9 +120,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte } override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) - blockIdToFileMap.clear() - temporaryFilesCreated.clear() + try { + Utils.deleteRecursively(tempDir) + blockIdToFileMap.clear() + temporaryFilesCreated.clear() + } finally { + super.afterEach() + } } test("write empty iterator") { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index 6e3f500e15dc..3fd6fb456046 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -26,11 +26,11 @@ import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.memory.StaticMemoryManager import org.apache.spark.network.BlockTransferService +import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 53991d8a1aed..67210e5d4c50 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -25,23 +25,23 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps +import org.mockito.{Matchers => mc} import org.mockito.Mockito.{mock, when} import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.concurrent.Timeouts._ -import org.apache.spark.network.netty.NettyBlockTransferService -import org.apache.spark.rpc.RpcEnv import org.apache.spark._ import org.apache.spark.executor.DataReadMethod import org.apache.spark.memory.StaticMemoryManager +import org.apache.spark.network.netty.NettyBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ - class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterEach with PrivateMethodTester with ResetSystemProperties { @@ -66,7 +66,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE private def makeBlockManager( maxMem: Long, - name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { + name: String = SparkContext.DRIVER_IDENTIFIER, + master: BlockManagerMaster = this.master): BlockManager = { val transfer = new NettyBlockTransferService(conf, securityMgr, numCores = 1) val memManager = new StaticMemoryManager(conf, Long.MaxValue, maxMem, numCores = 1) val blockManager = new BlockManager(name, rpcEnv, master, serializer, conf, @@ -77,6 +78,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } override def beforeEach(): Unit = { + super.beforeEach() rpcEnv = RpcEnv.create("test", "localhost", 0, conf, securityMgr) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case @@ -95,22 +97,26 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } override def afterEach(): Unit = { - if (store != null) { - store.stop() - store = null - } - if (store2 != null) { - store2.stop() - store2 = null - } - if (store3 != null) { - store3.stop() - store3 = null + try { + if (store != null) { + store.stop() + store = null + } + if (store2 != null) { + store2.stop() + store2 = null + } + if (store3 != null) { + store3.stop() + store3 = null + } + rpcEnv.shutdown() + rpcEnv.awaitTermination() + rpcEnv = null + master = null + } finally { + super.afterEach() } - rpcEnv.shutdown() - rpcEnv.awaitTermination() - rpcEnv = null - master = null } test("StorageLevel object caching") { @@ -451,6 +457,21 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE assert(list2DiskGet.get.readMethod === DataReadMethod.Disk) } + test("optimize a location order of blocks") { + val localHost = Utils.localHostName() + val otherHost = "otherHost" + val bmMaster = mock(classOf[BlockManagerMaster]) + val bmId1 = BlockManagerId("id1", localHost, 1) + val bmId2 = BlockManagerId("id2", localHost, 2) + val bmId3 = BlockManagerId("id3", otherHost, 3) + when(bmMaster.getLocations(mc.any[BlockId])).thenReturn(Seq(bmId1, bmId2, bmId3)) + + val blockManager = makeBlockManager(128, "exec", bmMaster) + val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) + val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) + } + test("SPARK-9591: getRemoteBytes from another location when Exception throw") { val origTimeoutOpt = conf.getOption("spark.network.timeout") try { @@ -484,38 +505,27 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } test("in-memory LRU storage") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.getSingle("a1") === None, "a1 was in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY) - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3") === None, "a3 was in store") + testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY) } test("in-memory LRU storage with serialization") { + testInMemoryLRUStorage(StorageLevel.MEMORY_ONLY_SER) + } + + private def testInMemoryLRUStorage(storageLevel: StorageLevel): Unit = { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_ONLY_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a1", a1, storageLevel) + store.putSingle("a2", a2, storageLevel) + store.putSingle("a3", a3, storageLevel) assert(store.getSingle("a2").isDefined, "a2 was not in store") assert(store.getSingle("a3").isDefined, "a3 was not in store") assert(store.getSingle("a1") === None, "a1 was in store") assert(store.getSingle("a2").isDefined, "a2 was not in store") // At this point a2 was gotten last, so LRU will getSingle rid of a3 - store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY_SER) + store.putSingle("a1", a1, storageLevel) assert(store.getSingle("a1").isDefined, "a1 was not in store") assert(store.getSingle("a2").isDefined, "a2 was not in store") assert(store.getSingle("a3") === None, "a3 was in store") @@ -597,62 +607,35 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE } test("disk and memory storage") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getSingle) } test("disk and memory storage with getLocalBytes") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK) - assert(store.getLocalBytes("a2").isDefined, "a2 was not in store") - assert(store.getLocalBytes("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK, _.getLocalBytes) } test("disk and memory storage with serialization") { - store = makeBlockManager(12000) - val a1 = new Array[Byte](4000) - val a2 = new Array[Byte](4000) - val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getSingle("a2").isDefined, "a2 was not in store") - assert(store.getSingle("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getSingle("a1").isDefined, "a1 was not in store") - assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getSingle) } test("disk and memory storage with serialization and getLocalBytes") { + testDiskAndMemoryStorage(StorageLevel.MEMORY_AND_DISK_SER, _.getLocalBytes) + } + + def testDiskAndMemoryStorage( + storageLevel: StorageLevel, + accessMethod: BlockManager => BlockId => Option[_]): Unit = { store = makeBlockManager(12000) val a1 = new Array[Byte](4000) val a2 = new Array[Byte](4000) val a3 = new Array[Byte](4000) - store.putSingle("a1", a1, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a2", a2, StorageLevel.MEMORY_AND_DISK_SER) - store.putSingle("a3", a3, StorageLevel.MEMORY_AND_DISK_SER) - assert(store.getLocalBytes("a2").isDefined, "a2 was not in store") - assert(store.getLocalBytes("a3").isDefined, "a3 was not in store") - assert(store.memoryStore.getValues("a1") == None, "a1 was in memory store") - assert(store.getLocalBytes("a1").isDefined, "a1 was not in store") + store.putSingle("a1", a1, storageLevel) + store.putSingle("a2", a2, storageLevel) + store.putSingle("a3", a3, storageLevel) + assert(accessMethod(store)("a2").isDefined, "a2 was not in store") + assert(accessMethod(store)("a3").isDefined, "a3 was not in store") + assert(store.memoryStore.getValues("a1").isEmpty, "a1 was in memory store") + assert(accessMethod(store)("a1").isDefined, "a1 was not in store") assert(store.memoryStore.getValues("a1").isDefined, "a1 was not in memory store") } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 688f56f4665f..69e17461df75 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -45,19 +45,27 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B } override def afterAll() { - super.afterAll() - Utils.deleteRecursively(rootDir0) - Utils.deleteRecursively(rootDir1) + try { + Utils.deleteRecursively(rootDir0) + Utils.deleteRecursively(rootDir1) + } finally { + super.afterAll() + } } override def beforeEach() { + super.beforeEach() val conf = testConf.clone conf.set("spark.local.dir", rootDirs) diskBlockManager = new DiskBlockManager(blockManager, conf) } override def afterEach() { - diskBlockManager.stop() + try { + diskBlockManager.stop() + } finally { + super.afterEach() + } } test("basic block creation") { diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 7c19531c1880..5d36617cfc44 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -30,11 +30,16 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { var tempDir: File = _ override def beforeEach(): Unit = { + super.beforeEach() tempDir = Utils.createTempDir() } override def afterEach(): Unit = { - Utils.deleteRecursively(tempDir) + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } } test("verify write metrics") { diff --git a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala index cc50289c7b3e..c7074078d8fd 100644 --- a/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/LocalDirsSuite.scala @@ -19,11 +19,10 @@ package org.apache.spark.storage import java.io.File -import org.apache.spark.util.Utils import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.SparkConfWithEnv +import org.apache.spark.util.{SparkConfWithEnv, Utils} /** * Tests for the spark.local.dir and SPARK_LOCAL_DIRS configuration options. diff --git a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala index cc76c141c53c..74eeca282882 100644 --- a/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/PagedTableSuite.scala @@ -64,7 +64,13 @@ class PagedTableSuite extends SparkFunSuite { override def row(t: Int): Seq[Node] = Nil - override def goButtonJavascriptFunction: (String, String) = ("", "") + override def pageSizeFormField: String = "pageSize" + + override def prevPageSizeFormField: String = "prevPageSize" + + override def pageNumberFormField: String = "page" + + override def goButtonFormPath: String = "" } assert(pagedTable.pageNavigation(1, 10, 1) === Nil) diff --git a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala index ceecfd665bf8..aa22f3ba2b4d 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISeleniumSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.ui import java.net.{HttpURLConnection, URL} -import javax.servlet.http.{HttpServletResponse, HttpServletRequest} +import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.io.Source import scala.xml.Node @@ -26,16 +26,16 @@ import scala.xml.Node import com.gargoylesoftware.htmlunit.DefaultCssErrorHandler import org.json4s._ import org.json4s.jackson.JsonMethods -import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.openqa.selenium.{By, WebDriver} +import org.openqa.selenium.htmlunit.HtmlUnitDriver import org.scalatest._ import org.scalatest.concurrent.Eventually._ import org.scalatest.selenium.WebBrowser import org.scalatest.time.SpanSugar._ import org.w3c.css.sac.CSSParseException -import org.apache.spark.LocalSparkContext._ import org.apache.spark._ +import org.apache.spark.LocalSparkContext._ import org.apache.spark.api.java.StorageLevels import org.apache.spark.deploy.history.HistoryServerSuite import org.apache.spark.shuffle.FetchFailedException @@ -76,14 +76,19 @@ class UISeleniumSuite extends SparkFunSuite with WebBrowser with Matchers with B override def beforeAll(): Unit = { + super.beforeAll() webDriver = new HtmlUnitDriver { getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) } } override def afterAll(): Unit = { - if (webDriver != null) { - webDriver.quit() + try { + if (webDriver != null) { + webDriver.quit() + } + } finally { + super.afterAll() } } diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 8f9502b5673d..2d28b67ef23f 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -26,8 +26,8 @@ import org.eclipse.jetty.servlet.ServletContextHandler import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ -import org.apache.spark.LocalSparkContext._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.LocalSparkContext._ class UISuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala index 86b078851851..3fb78da0c747 100644 --- a/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.ui.scope import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.scheduler._ -import org.apache.spark.scheduler.SparkListenerStageSubmitted -import org.apache.spark.scheduler.SparkListenerStageCompleted -import org.apache.spark.scheduler.SparkListenerJobStart /** * Tests that this listener populates and cleans up its data structures properly. diff --git a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala index 37e2670de968..4b838a8ab133 100644 --- a/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/storage/StorageTabSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.ui.storage import org.scalatest.BeforeAndAfter + import org.apache.spark.{SparkFunSuite, Success} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala deleted file mode 100644 index 0af4b6098bb0..000000000000 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ /dev/null @@ -1,360 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.collection.mutable.ArrayBuffer - -import java.util.concurrent.TimeoutException - -import akka.actor.ActorNotFound - -import org.apache.spark._ -import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId} -import org.apache.spark.SSLSampleConfigs._ - - -/** - * Test the AkkaUtils with various security settings. - */ -class AkkaUtilsSuite extends SparkFunSuite with LocalSparkContext with ResetSystemProperties { - - test("remote fetch security bad password") { - val conf = new SparkConf - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - - val securityManager = new SecurityManager(conf) - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.rpc", "akka") - badconf.set("spark.authenticate", "true") - badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf) - - assert(securityManagerBad.isAuthenticationEnabled() === true) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security off") { - val conf = new SparkConf - conf.set("spark.authenticate", "false") - conf.set("spark.authenticate.secret", "bad") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.authenticate", "false") - badconf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(badconf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security off - assert(slaveTracker.getMapSizesByExecutorId(10, 0).toSeq === - Seq((BlockManagerId("a", "hostA", 1000), - ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security pass") { - val conf = new SparkConf - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val goodconf = new SparkConf - goodconf.set("spark.authenticate", "true") - goodconf.set("spark.authenticate.secret", "good") - val securityManagerGood = new SecurityManager(goodconf) - - assert(securityManagerGood.isAuthenticationEnabled() === true) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, goodconf, securityManagerGood) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, MapStatus( - BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security on and passwords match - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), - ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch security off client") { - val conf = new SparkConf - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val badconf = new SparkConf - badconf.set("spark.rpc", "akka") - badconf.set("spark.authenticate", "false") - badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, badconf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - test("remote fetch ssl on") { - val conf = sparkSSLConfig() - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slaves", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === false) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - // this should succeed since security off - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on and security enabled") { - val conf = sparkSSLConfig() - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - slaveConf.set("spark.authenticate", "true") - slaveConf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - - assert(securityManagerBad.isAuthenticationEnabled() === true) - - masterTracker.registerShuffle(10, 1) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) - masterTracker.registerMapOutput(10, 0, - MapStatus(BlockManagerId("a", "hostA", 1000), Array(1000L))) - masterTracker.incrementEpoch() - slaveTracker.updateEpoch(masterTracker.getEpoch) - - assert(slaveTracker.getMapSizesByExecutorId(10, 0) === - Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000))))) - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on and security enabled - bad credentials") { - val conf = sparkSSLConfig() - conf.set("spark.rpc", "akka") - conf.set("spark.authenticate", "true") - conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === true) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - slaveConf.set("spark.rpc", "akka") - slaveConf.set("spark.authenticate", "true") - slaveConf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - val slaveTracker = new MapOutputTrackerWorker(conf) - intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerEndpoint = - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - - - test("remote fetch ssl on - untrusted server") { - val conf = sparkSSLConfigUntrusted() - val securityManager = new SecurityManager(conf) - - val hostname = "localhost" - val rpcEnv = RpcEnv.create("spark", hostname, 0, conf, securityManager) - System.setProperty("spark.hostPort", rpcEnv.address.hostPort) - - assert(securityManager.isAuthenticationEnabled() === false) - - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerEndpoint = rpcEnv.setupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint(rpcEnv, masterTracker, conf)) - - val slaveConf = sparkSSLConfig() - .set("spark.rpc.askTimeout", "5s") - .set("spark.rpc.lookupTimeout", "5s") - val securityManagerBad = new SecurityManager(slaveConf) - - val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, slaveConf, securityManagerBad) - try { - slaveRpcEnv.setupEndpointRef("spark", rpcEnv.address, MapOutputTracker.ENDPOINT_NAME) - fail("should receive either ActorNotFound or TimeoutException") - } catch { - case e: ActorNotFound => - case e: TimeoutException => - } - - rpcEnv.shutdown() - slaveRpcEnv.shutdown() - } - -} diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala index 480722a5ac18..932704c1a365 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala @@ -18,10 +18,9 @@ package org.apache.spark.util import java.io.NotSerializableException -import java.util.Random -import org.apache.spark.LocalSparkContext._ import org.apache.spark.{SparkContext, SparkException, SparkFunSuite, TaskContext} +import org.apache.spark.LocalSparkContext._ import org.apache.spark.partial.CountEvaluator import org.apache.spark.rdd.RDD @@ -91,11 +90,6 @@ class ClosureCleanerSuite extends SparkFunSuite { expectCorrectException { TestUserClosuresActuallyCleaned.testKeyBy(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitions(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithIndex(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testMapPartitionsWithContext(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testFlatMapWith(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testFilterWith(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testForEachWith(rdd) } - expectCorrectException { TestUserClosuresActuallyCleaned.testMapWith(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions2(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions3(rdd) } expectCorrectException { TestUserClosuresActuallyCleaned.testZipPartitions4(rdd) } @@ -269,21 +263,6 @@ private object TestUserClosuresActuallyCleaned { def testMapPartitionsWithIndex(rdd: RDD[Int]): Unit = { rdd.mapPartitionsWithIndex { (_, it) => return; it }.count() } - def testFlatMapWith(rdd: RDD[Int]): Unit = { - rdd.flatMapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; Seq() }.count() - } - def testMapWith(rdd: RDD[Int]): Unit = { - rdd.mapWith ((index: Int) => new Random(index + 42)){ (_, it) => return; 0 }.count() - } - def testFilterWith(rdd: RDD[Int]): Unit = { - rdd.filterWith ((index: Int) => new Random(index + 42)){ (_, it) => return; true }.count() - } - def testForEachWith(rdd: RDD[Int]): Unit = { - rdd.foreachWith ((index: Int) => new Random(index + 42)){ (_, it) => return } - } - def testMapPartitionsWithContext(rdd: RDD[Int]): Unit = { - rdd.mapPartitionsWithContext { (_, it) => return; it }.count() - } def testZipPartitions2(rdd: RDD[Int]): Unit = { rdd.zipPartitions(rdd) { case (it1, it2) => return; it1 }.count() } diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala index a829b099025e..934385fbcad1 100644 --- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala +++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite2.scala @@ -38,14 +38,19 @@ class ClosureCleanerSuite2 extends SparkFunSuite with BeforeAndAfterAll with Pri private var closureSerializer: SerializerInstance = null override def beforeAll(): Unit = { + super.beforeAll() sc = new SparkContext("local", "test") closureSerializer = sc.env.closureSerializer.newInstance() } override def afterAll(): Unit = { - sc.stop() - sc = null - closureSerializer = null + try { + sc.stop() + sc = null + closureSerializer = null + } finally { + super.afterAll() + } } // Some fields and methods to reference in inner closures later diff --git a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala index 2b76ae1f8a24..98d1b28d5a16 100644 --- a/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/FileAppenderSuite.scala @@ -22,13 +22,12 @@ import java.io._ import scala.collection.mutable.HashSet import scala.reflect._ -import org.scalatest.BeforeAndAfter - import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.scalatest.BeforeAndAfter import org.apache.spark.{Logging, SparkConf, SparkFunSuite} -import org.apache.spark.util.logging.{RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy, FileAppender} +import org.apache.spark.util.logging.{FileAppender, RollingFileAppender, SizeBasedRollingPolicy, TimeBasedRollingPolicy} class FileAppenderSuite extends SparkFunSuite with BeforeAndAfter with Logging { diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index 1939ce5c743b..6566400e6379 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -19,9 +19,6 @@ package org.apache.spark.util import java.util.Properties -import org.apache.spark.scheduler.cluster.ExecutorInfo -import org.apache.spark.shuffle.MetadataFetchFailedException - import scala.collection.Map import org.json4s.jackson.JsonMethods._ @@ -30,6 +27,8 @@ import org.apache.spark._ import org.apache.spark.executor._ import org.apache.spark.rdd.RDDOperationScope import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster.ExecutorInfo +import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage._ class JsonProtocolSuite extends SparkFunSuite { diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 101610e38014..49088aa0a53b 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.util import scala.collection.mutable.ArrayBuffer -import org.scalatest.{BeforeAndAfterEach, BeforeAndAfterAll, PrivateMethodTester} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, PrivateMethodTester} import org.apache.spark.SparkFunSuite @@ -79,6 +79,10 @@ class SizeEstimatorSuite System.setProperty("spark.test.useCompressedOops", "true") } + override def afterEach(): Unit = { + super.afterEach() + } + test("simple classes") { assertResult(16)(SizeEstimator.estimate(new DummyClass1)) assertResult(16)(SizeEstimator.estimate(new DummyClass2)) diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 92ae03896752..6652a41b6990 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.duration._ import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ import scala.util.Random import org.scalatest.concurrent.Eventually._ diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala index 9b3169026cda..25fc15dd54d0 100644 --- a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -17,8 +17,6 @@ package org.apache.spark.util -import java.lang.ref.WeakReference - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.Random @@ -34,10 +32,6 @@ class TimeStampedHashMapSuite extends SparkFunSuite { testMap(new TimeStampedHashMap[String, String]()) testMapThreadSafety(new TimeStampedHashMap[String, String]()) - // Test TimeStampedWeakValueHashMap basic functionality - testMap(new TimeStampedWeakValueHashMap[String, String]()) - testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]()) - test("TimeStampedHashMap - clearing by timestamp") { // clearing by insertion time val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false) @@ -68,86 +62,6 @@ class TimeStampedHashMapSuite extends SparkFunSuite { assert(map1.get("k2").isDefined) } - test("TimeStampedWeakValueHashMap - clearing by timestamp") { - // clearing by insertion time - val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false) - map("k1") = "v1" - assert(map("k1") === "v1") - Thread.sleep(10) - val threshTime = System.currentTimeMillis - assert(map.getTimestamp("k1").isDefined) - assert(map.getTimestamp("k1").get < threshTime) - map.clearOldValues(threshTime) - assert(map.get("k1") === None) - - // clearing by modification time - val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true) - map1("k1") = "v1" - map1("k2") = "v2" - assert(map1("k1") === "v1") - Thread.sleep(10) - val threshTime1 = System.currentTimeMillis - Thread.sleep(10) - assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime - assert(map1.getTimestamp("k1").isDefined) - assert(map1.getTimestamp("k1").get < threshTime1) - assert(map1.getTimestamp("k2").isDefined) - assert(map1.getTimestamp("k2").get >= threshTime1) - map1.clearOldValues(threshTime1) // should only clear k1 - assert(map1.get("k1") === None) - assert(map1.get("k2").isDefined) - } - - test("TimeStampedWeakValueHashMap - clearing weak references") { - var strongRef = new Object - val weakRef = new WeakReference(strongRef) - val map = new TimeStampedWeakValueHashMap[String, Object] - map("k1") = strongRef - map("k2") = "v2" - map("k3") = "v3" - val isEquals = map("k1") == strongRef - assert(isEquals) - - // clear strong reference to "k1" - strongRef = null - val startTime = System.currentTimeMillis - System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. - System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. - while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { - System.gc() - System.runFinalization() - Thread.sleep(100) - } - assert(map.getReference("k1").isDefined) - val ref = map.getReference("k1").get - assert(ref.get === null) - assert(map.get("k1") === None) - - // operations should only display non-null entries - assert(map.iterator.forall { case (k, v) => k != "k1" }) - assert(map.filter { case (k, v) => k != "k2" }.size === 1) - assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3") - assert(map.toMap.size === 2) - assert(map.toMap.forall { case (k, v) => k != "k1" }) - val buffer = new ArrayBuffer[String] - map.foreach { case (k, v) => buffer += v.toString } - assert(buffer.size === 2) - assert(buffer.forall(_ != "k1")) - val plusMap = map + (("k4", "v4")) - assert(plusMap.size === 3) - assert(plusMap.forall { case (k, v) => k != "k1" }) - val minusMap = map - "k2" - assert(minusMap.size === 1) - assert(minusMap.head._1 == "k3") - - // clear null values - should only clear k1 - map.clearNullValues() - assert(map.getReference("k1") === None) - assert(map.get("k1") === None) - assert(map.get("k2").isDefined) - assert(map.get("k2").get === "v2") - } - /** Test basic operations of a Scala mutable Map. */ def testMap(hashMapConstructor: => mutable.Map[String, String]) { def newMap() = hashMapConstructor diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 68b0da76bc13..bc926c280c7c 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -17,26 +17,25 @@ package org.apache.spark.util -import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, FileOutputStream} import java.lang.{Double => JDouble, Float => JFloat} import java.net.{BindException, ServerSocket, URI} import java.nio.{ByteBuffer, ByteOrder} import java.text.DecimalFormatSymbols -import java.util.concurrent.TimeUnit import java.util.Locale +import java.util.concurrent.TimeUnit import scala.collection.mutable.ListBuffer import scala.util.Random import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files - +import org.apache.commons.lang3.SystemUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.network.util.ByteUnit -import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.SparkConf class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -734,4 +733,88 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { conf.set("spark.executor.instances", "0")) === true) } + test("encodeFileNameToURIRawPath") { + assert(Utils.encodeFileNameToURIRawPath("abc") === "abc") + assert(Utils.encodeFileNameToURIRawPath("abc xyz") === "abc%20xyz") + assert(Utils.encodeFileNameToURIRawPath("abc:xyz") === "abc:xyz") + } + + test("decodeFileNameInURI") { + assert(Utils.decodeFileNameInURI(new URI("files:///abc/xyz")) === "xyz") + assert(Utils.decodeFileNameInURI(new URI("files:///abc")) === "abc") + assert(Utils.decodeFileNameInURI(new URI("files:///abc%20xyz")) === "abc xyz") + } + + test("Kill process") { + // Verify that we can terminate a process even if it is in a bad state. This is only run + // on UNIX since it does some OS specific things to verify the correct behavior. + if (SystemUtils.IS_OS_UNIX) { + def getPid(p: Process): Int = { + val f = p.getClass().getDeclaredField("pid") + f.setAccessible(true) + f.get(p).asInstanceOf[Int] + } + + def pidExists(pid: Int): Boolean = { + val p = Runtime.getRuntime.exec(s"kill -0 $pid") + p.waitFor() + p.exitValue() == 0 + } + + def signal(pid: Int, s: String): Unit = { + val p = Runtime.getRuntime.exec(s"kill -$s $pid") + p.waitFor() + } + + // Start up a process that runs 'sleep 10'. Terminate the process and assert it takes + // less time and the process is no longer there. + val startTimeMs = System.currentTimeMillis() + val process = new ProcessBuilder("sleep", "10").start() + val pid = getPid(process) + try { + assert(pidExists(pid)) + val terminated = Utils.terminateProcess(process, 5000) + assert(terminated.isDefined) + Utils.waitForProcess(process, 5000) + val durationMs = System.currentTimeMillis() - startTimeMs + assert(durationMs < 5000) + assert(!pidExists(pid)) + } finally { + // Forcibly kill the test process just in case. + signal(pid, "SIGKILL") + } + + val v: String = System.getProperty("java.version") + if (v >= "1.8.0") { + // Java8 added a way to forcibly terminate a process. We'll make sure that works by + // creating a very misbehaving process. It ignores SIGTERM and has been SIGSTOPed. On + // older versions of java, this will *not* terminate. + val file = File.createTempFile("temp-file-name", ".tmp") + val cmd = + s""" + |#!/bin/bash + |trap "" SIGTERM + |sleep 10 + """.stripMargin + Files.write(cmd.getBytes(), file) + file.getAbsoluteFile.setExecutable(true) + + val process = new ProcessBuilder(file.getAbsolutePath).start() + val pid = getPid(process) + assert(pidExists(pid)) + try { + signal(pid, "SIGSTOP") + val start = System.currentTimeMillis() + val terminated = Utils.terminateProcess(process, 5000) + assert(terminated.isDefined) + Utils.waitForProcess(process, 5000) + val duration = System.currentTimeMillis() - start + assert(duration < 5000) + assert(!pidExists(pid)) + } finally { + signal(pid, "SIGKILL") + } + } + } + } } diff --git a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala b/core/src/test/scala/org/apache/spark/util/VectorSuite.scala deleted file mode 100644 index 11194cd22a41..000000000000 --- a/core/src/test/scala/org/apache/spark/util/VectorSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.util - -import scala.util.Random - -import org.apache.spark.SparkFunSuite - -/** - * Tests org.apache.spark.util.Vector functionality - */ -@deprecated("suppress compile time deprecation warning", "1.0.0") -class VectorSuite extends SparkFunSuite { - - def verifyVector(vector: Vector, expectedLength: Int): Unit = { - assert(vector.length == expectedLength) - assert(vector.elements.min > 0.0) - assert(vector.elements.max < 1.0) - } - - test("random with default random number generator") { - val vector100 = Vector.random(100) - verifyVector(vector100, 100) - } - - test("random with given random number generator") { - val vector100 = Vector.random(100, new Random(100)) - verifyVector(vector100, 100) - } -} diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala index d7b2d07a4005..a62adf1c2c54 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala @@ -17,15 +17,13 @@ package org.apache.spark.util.collection -import org.apache.spark.memory.MemoryTestingUtils - import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark._ +import org.apache.spark.memory.MemoryTestingUtils import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} - class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext { import TestUtils.{assertNotSpilled, assertSpilled} diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala index 0326ed70b5ed..c12f78447197 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.util.collection.unsafe.sort import com.google.common.primitives.UnsignedBytes import org.scalatest.prop.PropertyChecks + import org.apache.spark.SparkFunSuite import org.apache.spark.unsafe.types.UTF8String diff --git a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala index d6af0aebde73..791491daf081 100644 --- a/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/RandomSamplerSuite.scala @@ -18,9 +18,10 @@ package org.apache.spark.util.random import java.util.Random + import scala.collection.mutable.ArrayBuffer -import org.apache.commons.math3.distribution.PoissonDistribution +import org.apache.commons.math3.distribution.PoissonDistribution import org.scalatest.Matchers import org.apache.spark.SparkFunSuite diff --git a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala index a5b50fce5c0a..853503bbc2bb 100644 --- a/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/random/XORShiftRandomSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.util.random -import org.scalatest.Matchers +import scala.language.reflectiveCalls import org.apache.commons.math3.stat.inference.ChiSquareTest +import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.util.Utils.times -import scala.language.reflectiveCalls - class XORShiftRandomSuite extends SparkFunSuite with Matchers { private def fixture = new { diff --git a/data/streaming/AFINN-111.txt b/data/streaming/AFINN-111.txt new file mode 100644 index 000000000000..0f6fb8ebaa0b --- /dev/null +++ b/data/streaming/AFINN-111.txt @@ -0,0 +1,2477 @@ +abandon -2 +abandoned -2 +abandons -2 +abducted -2 +abduction -2 +abductions -2 +abhor -3 +abhorred -3 +abhorrent -3 +abhors -3 +abilities 2 +ability 2 +aboard 1 +absentee -1 +absentees -1 +absolve 2 +absolved 2 +absolves 2 +absolving 2 +absorbed 1 +abuse -3 +abused -3 +abuses -3 +abusive -3 +accept 1 +accepted 1 +accepting 1 +accepts 1 +accident -2 +accidental -2 +accidentally -2 +accidents -2 +accomplish 2 +accomplished 2 +accomplishes 2 +accusation -2 +accusations -2 +accuse -2 +accused -2 +accuses -2 +accusing -2 +ache -2 +achievable 1 +aching -2 +acquit 2 +acquits 2 +acquitted 2 +acquitting 2 +acrimonious -3 +active 1 +adequate 1 +admire 3 +admired 3 +admires 3 +admiring 3 +admit -1 +admits -1 +admitted -1 +admonish -2 +admonished -2 +adopt 1 +adopts 1 +adorable 3 +adore 3 +adored 3 +adores 3 +advanced 1 +advantage 2 +advantages 2 +adventure 2 +adventures 2 +adventurous 2 +affected -1 +affection 3 +affectionate 3 +afflicted -1 +affronted -1 +afraid -2 +aggravate -2 +aggravated -2 +aggravates -2 +aggravating -2 +aggression -2 +aggressions -2 +aggressive -2 +aghast -2 +agog 2 +agonise -3 +agonised -3 +agonises -3 +agonising -3 +agonize -3 +agonized -3 +agonizes -3 +agonizing -3 +agree 1 +agreeable 2 +agreed 1 +agreement 1 +agrees 1 +alarm -2 +alarmed -2 +alarmist -2 +alarmists -2 +alas -1 +alert -1 +alienation -2 +alive 1 +allergic -2 +allow 1 +alone -2 +amaze 2 +amazed 2 +amazes 2 +amazing 4 +ambitious 2 +ambivalent -1 +amuse 3 +amused 3 +amusement 3 +amusements 3 +anger -3 +angers -3 +angry -3 +anguish -3 +anguished -3 +animosity -2 +annoy -2 +annoyance -2 +annoyed -2 +annoying -2 +annoys -2 +antagonistic -2 +anti -1 +anticipation 1 +anxiety -2 +anxious -2 +apathetic -3 +apathy -3 +apeshit -3 +apocalyptic -2 +apologise -1 +apologised -1 +apologises -1 +apologising -1 +apologize -1 +apologized -1 +apologizes -1 +apologizing -1 +apology -1 +appalled -2 +appalling -2 +appease 2 +appeased 2 +appeases 2 +appeasing 2 +applaud 2 +applauded 2 +applauding 2 +applauds 2 +applause 2 +appreciate 2 +appreciated 2 +appreciates 2 +appreciating 2 +appreciation 2 +apprehensive -2 +approval 2 +approved 2 +approves 2 +ardent 1 +arrest -2 +arrested -3 +arrests -2 +arrogant -2 +ashame -2 +ashamed -2 +ass -4 +assassination -3 +assassinations -3 +asset 2 +assets 2 +assfucking -4 +asshole -4 +astonished 2 +astound 3 +astounded 3 +astounding 3 +astoundingly 3 +astounds 3 +attack -1 +attacked -1 +attacking -1 +attacks -1 +attract 1 +attracted 1 +attracting 2 +attraction 2 +attractions 2 +attracts 1 +audacious 3 +authority 1 +avert -1 +averted -1 +averts -1 +avid 2 +avoid -1 +avoided -1 +avoids -1 +await -1 +awaited -1 +awaits -1 +award 3 +awarded 3 +awards 3 +awesome 4 +awful -3 +awkward -2 +axe -1 +axed -1 +backed 1 +backing 2 +backs 1 +bad -3 +badass -3 +badly -3 +bailout -2 +bamboozle -2 +bamboozled -2 +bamboozles -2 +ban -2 +banish -1 +bankrupt -3 +bankster -3 +banned -2 +bargain 2 +barrier -2 +bastard -5 +bastards -5 +battle -1 +battles -1 +beaten -2 +beatific 3 +beating -1 +beauties 3 +beautiful 3 +beautifully 3 +beautify 3 +belittle -2 +belittled -2 +beloved 3 +benefit 2 +benefits 2 +benefitted 2 +benefitting 2 +bereave -2 +bereaved -2 +bereaves -2 +bereaving -2 +best 3 +betray -3 +betrayal -3 +betrayed -3 +betraying -3 +betrays -3 +better 2 +bias -1 +biased -2 +big 1 +bitch -5 +bitches -5 +bitter -2 +bitterly -2 +bizarre -2 +blah -2 +blame -2 +blamed -2 +blames -2 +blaming -2 +bless 2 +blesses 2 +blessing 3 +blind -1 +bliss 3 +blissful 3 +blithe 2 +block -1 +blockbuster 3 +blocked -1 +blocking -1 +blocks -1 +bloody -3 +blurry -2 +boastful -2 +bold 2 +boldly 2 +bomb -1 +boost 1 +boosted 1 +boosting 1 +boosts 1 +bore -2 +bored -2 +boring -3 +bother -2 +bothered -2 +bothers -2 +bothersome -2 +boycott -2 +boycotted -2 +boycotting -2 +boycotts -2 +brainwashing -3 +brave 2 +breakthrough 3 +breathtaking 5 +bribe -3 +bright 1 +brightest 2 +brightness 1 +brilliant 4 +brisk 2 +broke -1 +broken -1 +brooding -2 +bullied -2 +bullshit -4 +bully -2 +bullying -2 +bummer -2 +buoyant 2 +burden -2 +burdened -2 +burdening -2 +burdens -2 +calm 2 +calmed 2 +calming 2 +calms 2 +can't stand -3 +cancel -1 +cancelled -1 +cancelling -1 +cancels -1 +cancer -1 +capable 1 +captivated 3 +care 2 +carefree 1 +careful 2 +carefully 2 +careless -2 +cares 2 +cashing in -2 +casualty -2 +catastrophe -3 +catastrophic -4 +cautious -1 +celebrate 3 +celebrated 3 +celebrates 3 +celebrating 3 +censor -2 +censored -2 +censors -2 +certain 1 +chagrin -2 +chagrined -2 +challenge -1 +chance 2 +chances 2 +chaos -2 +chaotic -2 +charged -3 +charges -2 +charm 3 +charming 3 +charmless -3 +chastise -3 +chastised -3 +chastises -3 +chastising -3 +cheat -3 +cheated -3 +cheater -3 +cheaters -3 +cheats -3 +cheer 2 +cheered 2 +cheerful 2 +cheering 2 +cheerless -2 +cheers 2 +cheery 3 +cherish 2 +cherished 2 +cherishes 2 +cherishing 2 +chic 2 +childish -2 +chilling -1 +choke -2 +choked -2 +chokes -2 +choking -2 +clarifies 2 +clarity 2 +clash -2 +classy 3 +clean 2 +cleaner 2 +clear 1 +cleared 1 +clearly 1 +clears 1 +clever 2 +clouded -1 +clueless -2 +cock -5 +cocksucker -5 +cocksuckers -5 +cocky -2 +coerced -2 +collapse -2 +collapsed -2 +collapses -2 +collapsing -2 +collide -1 +collides -1 +colliding -1 +collision -2 +collisions -2 +colluding -3 +combat -1 +combats -1 +comedy 1 +comfort 2 +comfortable 2 +comforting 2 +comforts 2 +commend 2 +commended 2 +commit 1 +commitment 2 +commits 1 +committed 1 +committing 1 +compassionate 2 +compelled 1 +competent 2 +competitive 2 +complacent -2 +complain -2 +complained -2 +complains -2 +comprehensive 2 +conciliate 2 +conciliated 2 +conciliates 2 +conciliating 2 +condemn -2 +condemnation -2 +condemned -2 +condemns -2 +confidence 2 +confident 2 +conflict -2 +conflicting -2 +conflictive -2 +conflicts -2 +confuse -2 +confused -2 +confusing -2 +congrats 2 +congratulate 2 +congratulation 2 +congratulations 2 +consent 2 +consents 2 +consolable 2 +conspiracy -3 +constrained -2 +contagion -2 +contagions -2 +contagious -1 +contempt -2 +contemptuous -2 +contemptuously -2 +contend -1 +contender -1 +contending -1 +contentious -2 +contestable -2 +controversial -2 +controversially -2 +convince 1 +convinced 1 +convinces 1 +convivial 2 +cool 1 +cool stuff 3 +cornered -2 +corpse -1 +costly -2 +courage 2 +courageous 2 +courteous 2 +courtesy 2 +cover-up -3 +coward -2 +cowardly -2 +coziness 2 +cramp -1 +crap -3 +crash -2 +crazier -2 +craziest -2 +crazy -2 +creative 2 +crestfallen -2 +cried -2 +cries -2 +crime -3 +criminal -3 +criminals -3 +crisis -3 +critic -2 +criticism -2 +criticize -2 +criticized -2 +criticizes -2 +criticizing -2 +critics -2 +cruel -3 +cruelty -3 +crush -1 +crushed -2 +crushes -1 +crushing -1 +cry -1 +crying -2 +cunt -5 +curious 1 +curse -1 +cut -1 +cute 2 +cuts -1 +cutting -1 +cynic -2 +cynical -2 +cynicism -2 +damage -3 +damages -3 +damn -4 +damned -4 +damnit -4 +danger -2 +daredevil 2 +daring 2 +darkest -2 +darkness -1 +dauntless 2 +dead -3 +deadlock -2 +deafening -1 +dear 2 +dearly 3 +death -2 +debonair 2 +debt -2 +deceit -3 +deceitful -3 +deceive -3 +deceived -3 +deceives -3 +deceiving -3 +deception -3 +decisive 1 +dedicated 2 +defeated -2 +defect -3 +defects -3 +defender 2 +defenders 2 +defenseless -2 +defer -1 +deferring -1 +defiant -1 +deficit -2 +degrade -2 +degraded -2 +degrades -2 +dehumanize -2 +dehumanized -2 +dehumanizes -2 +dehumanizing -2 +deject -2 +dejected -2 +dejecting -2 +dejects -2 +delay -1 +delayed -1 +delight 3 +delighted 3 +delighting 3 +delights 3 +demand -1 +demanded -1 +demanding -1 +demands -1 +demonstration -1 +demoralized -2 +denied -2 +denier -2 +deniers -2 +denies -2 +denounce -2 +denounces -2 +deny -2 +denying -2 +depressed -2 +depressing -2 +derail -2 +derailed -2 +derails -2 +deride -2 +derided -2 +derides -2 +deriding -2 +derision -2 +desirable 2 +desire 1 +desired 2 +desirous 2 +despair -3 +despairing -3 +despairs -3 +desperate -3 +desperately -3 +despondent -3 +destroy -3 +destroyed -3 +destroying -3 +destroys -3 +destruction -3 +destructive -3 +detached -1 +detain -2 +detained -2 +detention -2 +determined 2 +devastate -2 +devastated -2 +devastating -2 +devoted 3 +diamond 1 +dick -4 +dickhead -4 +die -3 +died -3 +difficult -1 +diffident -2 +dilemma -1 +dipshit -3 +dire -3 +direful -3 +dirt -2 +dirtier -2 +dirtiest -2 +dirty -2 +disabling -1 +disadvantage -2 +disadvantaged -2 +disappear -1 +disappeared -1 +disappears -1 +disappoint -2 +disappointed -2 +disappointing -2 +disappointment -2 +disappointments -2 +disappoints -2 +disaster -2 +disasters -2 +disastrous -3 +disbelieve -2 +discard -1 +discarded -1 +discarding -1 +discards -1 +disconsolate -2 +disconsolation -2 +discontented -2 +discord -2 +discounted -1 +discouraged -2 +discredited -2 +disdain -2 +disgrace -2 +disgraced -2 +disguise -1 +disguised -1 +disguises -1 +disguising -1 +disgust -3 +disgusted -3 +disgusting -3 +disheartened -2 +dishonest -2 +disillusioned -2 +disinclined -2 +disjointed -2 +dislike -2 +dismal -2 +dismayed -2 +disorder -2 +disorganized -2 +disoriented -2 +disparage -2 +disparaged -2 +disparages -2 +disparaging -2 +displeased -2 +dispute -2 +disputed -2 +disputes -2 +disputing -2 +disqualified -2 +disquiet -2 +disregard -2 +disregarded -2 +disregarding -2 +disregards -2 +disrespect -2 +disrespected -2 +disruption -2 +disruptions -2 +disruptive -2 +dissatisfied -2 +distort -2 +distorted -2 +distorting -2 +distorts -2 +distract -2 +distracted -2 +distraction -2 +distracts -2 +distress -2 +distressed -2 +distresses -2 +distressing -2 +distrust -3 +distrustful -3 +disturb -2 +disturbed -2 +disturbing -2 +disturbs -2 +dithering -2 +dizzy -1 +dodging -2 +dodgy -2 +does not work -3 +dolorous -2 +dont like -2 +doom -2 +doomed -2 +doubt -1 +doubted -1 +doubtful -1 +doubting -1 +doubts -1 +douche -3 +douchebag -3 +downcast -2 +downhearted -2 +downside -2 +drag -1 +dragged -1 +drags -1 +drained -2 +dread -2 +dreaded -2 +dreadful -3 +dreading -2 +dream 1 +dreams 1 +dreary -2 +droopy -2 +drop -1 +drown -2 +drowned -2 +drowns -2 +drunk -2 +dubious -2 +dud -2 +dull -2 +dumb -3 +dumbass -3 +dump -1 +dumped -2 +dumps -1 +dupe -2 +duped -2 +dysfunction -2 +eager 2 +earnest 2 +ease 2 +easy 1 +ecstatic 4 +eerie -2 +eery -2 +effective 2 +effectively 2 +elated 3 +elation 3 +elegant 2 +elegantly 2 +embarrass -2 +embarrassed -2 +embarrasses -2 +embarrassing -2 +embarrassment -2 +embittered -2 +embrace 1 +emergency -2 +empathetic 2 +emptiness -1 +empty -1 +enchanted 2 +encourage 2 +encouraged 2 +encouragement 2 +encourages 2 +endorse 2 +endorsed 2 +endorsement 2 +endorses 2 +enemies -2 +enemy -2 +energetic 2 +engage 1 +engages 1 +engrossed 1 +enjoy 2 +enjoying 2 +enjoys 2 +enlighten 2 +enlightened 2 +enlightening 2 +enlightens 2 +ennui -2 +enrage -2 +enraged -2 +enrages -2 +enraging -2 +enrapture 3 +enslave -2 +enslaved -2 +enslaves -2 +ensure 1 +ensuring 1 +enterprising 1 +entertaining 2 +enthral 3 +enthusiastic 3 +entitled 1 +entrusted 2 +envies -1 +envious -2 +envy -1 +envying -1 +erroneous -2 +error -2 +errors -2 +escape -1 +escapes -1 +escaping -1 +esteemed 2 +ethical 2 +euphoria 3 +euphoric 4 +eviction -1 +evil -3 +exaggerate -2 +exaggerated -2 +exaggerates -2 +exaggerating -2 +exasperated 2 +excellence 3 +excellent 3 +excite 3 +excited 3 +excitement 3 +exciting 3 +exclude -1 +excluded -2 +exclusion -1 +exclusive 2 +excuse -1 +exempt -1 +exhausted -2 +exhilarated 3 +exhilarates 3 +exhilarating 3 +exonerate 2 +exonerated 2 +exonerates 2 +exonerating 2 +expand 1 +expands 1 +expel -2 +expelled -2 +expelling -2 +expels -2 +exploit -2 +exploited -2 +exploiting -2 +exploits -2 +exploration 1 +explorations 1 +expose -1 +exposed -1 +exposes -1 +exposing -1 +extend 1 +extends 1 +exuberant 4 +exultant 3 +exultantly 3 +fabulous 4 +fad -2 +fag -3 +faggot -3 +faggots -3 +fail -2 +failed -2 +failing -2 +fails -2 +failure -2 +failures -2 +fainthearted -2 +fair 2 +faith 1 +faithful 3 +fake -3 +fakes -3 +faking -3 +fallen -2 +falling -1 +falsified -3 +falsify -3 +fame 1 +fan 3 +fantastic 4 +farce -1 +fascinate 3 +fascinated 3 +fascinates 3 +fascinating 3 +fascist -2 +fascists -2 +fatalities -3 +fatality -3 +fatigue -2 +fatigued -2 +fatigues -2 +fatiguing -2 +favor 2 +favored 2 +favorite 2 +favorited 2 +favorites 2 +favors 2 +fear -2 +fearful -2 +fearing -2 +fearless 2 +fearsome -2 +fed up -3 +feeble -2 +feeling 1 +felonies -3 +felony -3 +fervent 2 +fervid 2 +festive 2 +fiasco -3 +fidgety -2 +fight -1 +fine 2 +fire -2 +fired -2 +firing -2 +fit 1 +fitness 1 +flagship 2 +flees -1 +flop -2 +flops -2 +flu -2 +flustered -2 +focused 2 +fond 2 +fondness 2 +fool -2 +foolish -2 +fools -2 +forced -1 +foreclosure -2 +foreclosures -2 +forget -1 +forgetful -2 +forgive 1 +forgiving 1 +forgotten -1 +fortunate 2 +frantic -1 +fraud -4 +frauds -4 +fraudster -4 +fraudsters -4 +fraudulence -4 +fraudulent -4 +free 1 +freedom 2 +frenzy -3 +fresh 1 +friendly 2 +fright -2 +frightened -2 +frightening -3 +frikin -2 +frisky 2 +frowning -1 +frustrate -2 +frustrated -2 +frustrates -2 +frustrating -2 +frustration -2 +ftw 3 +fuck -4 +fucked -4 +fucker -4 +fuckers -4 +fuckface -4 +fuckhead -4 +fucking -4 +fucktard -4 +fud -3 +fuked -4 +fuking -4 +fulfill 2 +fulfilled 2 +fulfills 2 +fuming -2 +fun 4 +funeral -1 +funerals -1 +funky 2 +funnier 4 +funny 4 +furious -3 +futile 2 +gag -2 +gagged -2 +gain 2 +gained 2 +gaining 2 +gains 2 +gallant 3 +gallantly 3 +gallantry 3 +generous 2 +genial 3 +ghost -1 +giddy -2 +gift 2 +glad 3 +glamorous 3 +glamourous 3 +glee 3 +gleeful 3 +gloom -1 +gloomy -2 +glorious 2 +glory 2 +glum -2 +god 1 +goddamn -3 +godsend 4 +good 3 +goodness 3 +grace 1 +gracious 3 +grand 3 +grant 1 +granted 1 +granting 1 +grants 1 +grateful 3 +gratification 2 +grave -2 +gray -1 +great 3 +greater 3 +greatest 3 +greed -3 +greedy -2 +green wash -3 +green washing -3 +greenwash -3 +greenwasher -3 +greenwashers -3 +greenwashing -3 +greet 1 +greeted 1 +greeting 1 +greetings 2 +greets 1 +grey -1 +grief -2 +grieved -2 +gross -2 +growing 1 +growth 2 +guarantee 1 +guilt -3 +guilty -3 +gullibility -2 +gullible -2 +gun -1 +ha 2 +hacked -1 +haha 3 +hahaha 3 +hahahah 3 +hail 2 +hailed 2 +hapless -2 +haplessness -2 +happiness 3 +happy 3 +hard -1 +hardier 2 +hardship -2 +hardy 2 +harm -2 +harmed -2 +harmful -2 +harming -2 +harms -2 +harried -2 +harsh -2 +harsher -2 +harshest -2 +hate -3 +hated -3 +haters -3 +hates -3 +hating -3 +haunt -1 +haunted -2 +haunting 1 +haunts -1 +havoc -2 +healthy 2 +heartbreaking -3 +heartbroken -3 +heartfelt 3 +heaven 2 +heavenly 4 +heavyhearted -2 +hell -4 +help 2 +helpful 2 +helping 2 +helpless -2 +helps 2 +hero 2 +heroes 2 +heroic 3 +hesitant -2 +hesitate -2 +hid -1 +hide -1 +hides -1 +hiding -1 +highlight 2 +hilarious 2 +hindrance -2 +hoax -2 +homesick -2 +honest 2 +honor 2 +honored 2 +honoring 2 +honour 2 +honoured 2 +honouring 2 +hooligan -2 +hooliganism -2 +hooligans -2 +hope 2 +hopeful 2 +hopefully 2 +hopeless -2 +hopelessness -2 +hopes 2 +hoping 2 +horrendous -3 +horrible -3 +horrific -3 +horrified -3 +hostile -2 +huckster -2 +hug 2 +huge 1 +hugs 2 +humerous 3 +humiliated -3 +humiliation -3 +humor 2 +humorous 2 +humour 2 +humourous 2 +hunger -2 +hurrah 5 +hurt -2 +hurting -2 +hurts -2 +hypocritical -2 +hysteria -3 +hysterical -3 +hysterics -3 +idiot -3 +idiotic -3 +ignorance -2 +ignorant -2 +ignore -1 +ignored -2 +ignores -1 +ill -2 +illegal -3 +illiteracy -2 +illness -2 +illnesses -2 +imbecile -3 +immobilized -1 +immortal 2 +immune 1 +impatient -2 +imperfect -2 +importance 2 +important 2 +impose -1 +imposed -1 +imposes -1 +imposing -1 +impotent -2 +impress 3 +impressed 3 +impresses 3 +impressive 3 +imprisoned -2 +improve 2 +improved 2 +improvement 2 +improves 2 +improving 2 +inability -2 +inaction -2 +inadequate -2 +incapable -2 +incapacitated -2 +incensed -2 +incompetence -2 +incompetent -2 +inconsiderate -2 +inconvenience -2 +inconvenient -2 +increase 1 +increased 1 +indecisive -2 +indestructible 2 +indifference -2 +indifferent -2 +indignant -2 +indignation -2 +indoctrinate -2 +indoctrinated -2 +indoctrinates -2 +indoctrinating -2 +ineffective -2 +ineffectively -2 +infatuated 2 +infatuation 2 +infected -2 +inferior -2 +inflamed -2 +influential 2 +infringement -2 +infuriate -2 +infuriated -2 +infuriates -2 +infuriating -2 +inhibit -1 +injured -2 +injury -2 +injustice -2 +innovate 1 +innovates 1 +innovation 1 +innovative 2 +inquisition -2 +inquisitive 2 +insane -2 +insanity -2 +insecure -2 +insensitive -2 +insensitivity -2 +insignificant -2 +insipid -2 +inspiration 2 +inspirational 2 +inspire 2 +inspired 2 +inspires 2 +inspiring 3 +insult -2 +insulted -2 +insulting -2 +insults -2 +intact 2 +integrity 2 +intelligent 2 +intense 1 +interest 1 +interested 2 +interesting 2 +interests 1 +interrogated -2 +interrupt -2 +interrupted -2 +interrupting -2 +interruption -2 +interrupts -2 +intimidate -2 +intimidated -2 +intimidates -2 +intimidating -2 +intimidation -2 +intricate 2 +intrigues 1 +invincible 2 +invite 1 +inviting 1 +invulnerable 2 +irate -3 +ironic -1 +irony -1 +irrational -1 +irresistible 2 +irresolute -2 +irresponsible 2 +irreversible -1 +irritate -3 +irritated -3 +irritating -3 +isolated -1 +itchy -2 +jackass -4 +jackasses -4 +jailed -2 +jaunty 2 +jealous -2 +jeopardy -2 +jerk -3 +jesus 1 +jewel 1 +jewels 1 +jocular 2 +join 1 +joke 2 +jokes 2 +jolly 2 +jovial 2 +joy 3 +joyful 3 +joyfully 3 +joyless -2 +joyous 3 +jubilant 3 +jumpy -1 +justice 2 +justifiably 2 +justified 2 +keen 1 +kill -3 +killed -3 +killing -3 +kills -3 +kind 2 +kinder 2 +kiss 2 +kudos 3 +lack -2 +lackadaisical -2 +lag -1 +lagged -2 +lagging -2 +lags -2 +lame -2 +landmark 2 +laugh 1 +laughed 1 +laughing 1 +laughs 1 +laughting 1 +launched 1 +lawl 3 +lawsuit -2 +lawsuits -2 +lazy -1 +leak -1 +leaked -1 +leave -1 +legal 1 +legally 1 +lenient 1 +lethargic -2 +lethargy -2 +liar -3 +liars -3 +libelous -2 +lied -2 +lifesaver 4 +lighthearted 1 +like 2 +liked 2 +likes 2 +limitation -1 +limited -1 +limits -1 +litigation -1 +litigious -2 +lively 2 +livid -2 +lmao 4 +lmfao 4 +loathe -3 +loathed -3 +loathes -3 +loathing -3 +lobby -2 +lobbying -2 +lol 3 +lonely -2 +lonesome -2 +longing -1 +loom -1 +loomed -1 +looming -1 +looms -1 +loose -3 +looses -3 +loser -3 +losing -3 +loss -3 +lost -3 +lovable 3 +love 3 +loved 3 +lovelies 3 +lovely 3 +loving 2 +lowest -1 +loyal 3 +loyalty 3 +luck 3 +luckily 3 +lucky 3 +lugubrious -2 +lunatic -3 +lunatics -3 +lurk -1 +lurking -1 +lurks -1 +mad -3 +maddening -3 +made-up -1 +madly -3 +madness -3 +mandatory -1 +manipulated -1 +manipulating -1 +manipulation -1 +marvel 3 +marvelous 3 +marvels 3 +masterpiece 4 +masterpieces 4 +matter 1 +matters 1 +mature 2 +meaningful 2 +meaningless -2 +medal 3 +mediocrity -3 +meditative 1 +melancholy -2 +menace -2 +menaced -2 +mercy 2 +merry 3 +mess -2 +messed -2 +messing up -2 +methodical 2 +mindless -2 +miracle 4 +mirth 3 +mirthful 3 +mirthfully 3 +misbehave -2 +misbehaved -2 +misbehaves -2 +misbehaving -2 +mischief -1 +mischiefs -1 +miserable -3 +misery -2 +misgiving -2 +misinformation -2 +misinformed -2 +misinterpreted -2 +misleading -3 +misread -1 +misreporting -2 +misrepresentation -2 +miss -2 +missed -2 +missing -2 +mistake -2 +mistaken -2 +mistakes -2 +mistaking -2 +misunderstand -2 +misunderstanding -2 +misunderstands -2 +misunderstood -2 +moan -2 +moaned -2 +moaning -2 +moans -2 +mock -2 +mocked -2 +mocking -2 +mocks -2 +mongering -2 +monopolize -2 +monopolized -2 +monopolizes -2 +monopolizing -2 +moody -1 +mope -1 +moping -1 +moron -3 +motherfucker -5 +motherfucking -5 +motivate 1 +motivated 2 +motivating 2 +motivation 1 +mourn -2 +mourned -2 +mournful -2 +mourning -2 +mourns -2 +mumpish -2 +murder -2 +murderer -2 +murdering -3 +murderous -3 +murders -2 +myth -1 +n00b -2 +naive -2 +nasty -3 +natural 1 +naïve -2 +needy -2 +negative -2 +negativity -2 +neglect -2 +neglected -2 +neglecting -2 +neglects -2 +nerves -1 +nervous -2 +nervously -2 +nice 3 +nifty 2 +niggas -5 +nigger -5 +no -1 +no fun -3 +noble 2 +noisy -1 +nonsense -2 +noob -2 +nosey -2 +not good -2 +not working -3 +notorious -2 +novel 2 +numb -1 +nuts -3 +obliterate -2 +obliterated -2 +obnoxious -3 +obscene -2 +obsessed 2 +obsolete -2 +obstacle -2 +obstacles -2 +obstinate -2 +odd -2 +offend -2 +offended -2 +offender -2 +offending -2 +offends -2 +offline -1 +oks 2 +ominous 3 +once-in-a-lifetime 3 +opportunities 2 +opportunity 2 +oppressed -2 +oppressive -2 +optimism 2 +optimistic 2 +optionless -2 +outcry -2 +outmaneuvered -2 +outrage -3 +outraged -3 +outreach 2 +outstanding 5 +overjoyed 4 +overload -1 +overlooked -1 +overreact -2 +overreacted -2 +overreaction -2 +overreacts -2 +oversell -2 +overselling -2 +oversells -2 +oversimplification -2 +oversimplified -2 +oversimplifies -2 +oversimplify -2 +overstatement -2 +overstatements -2 +overweight -1 +oxymoron -1 +pain -2 +pained -2 +panic -3 +panicked -3 +panics -3 +paradise 3 +paradox -1 +pardon 2 +pardoned 2 +pardoning 2 +pardons 2 +parley -1 +passionate 2 +passive -1 +passively -1 +pathetic -2 +pay -1 +peace 2 +peaceful 2 +peacefully 2 +penalty -2 +pensive -1 +perfect 3 +perfected 2 +perfectly 3 +perfects 2 +peril -2 +perjury -3 +perpetrator -2 +perpetrators -2 +perplexed -2 +persecute -2 +persecuted -2 +persecutes -2 +persecuting -2 +perturbed -2 +pesky -2 +pessimism -2 +pessimistic -2 +petrified -2 +phobic -2 +picturesque 2 +pileup -1 +pique -2 +piqued -2 +piss -4 +pissed -4 +pissing -3 +piteous -2 +pitied -1 +pity -2 +playful 2 +pleasant 3 +please 1 +pleased 3 +pleasure 3 +poised -2 +poison -2 +poisoned -2 +poisons -2 +pollute -2 +polluted -2 +polluter -2 +polluters -2 +pollutes -2 +poor -2 +poorer -2 +poorest -2 +popular 3 +positive 2 +positively 2 +possessive -2 +postpone -1 +postponed -1 +postpones -1 +postponing -1 +poverty -1 +powerful 2 +powerless -2 +praise 3 +praised 3 +praises 3 +praising 3 +pray 1 +praying 1 +prays 1 +prblm -2 +prblms -2 +prepared 1 +pressure -1 +pressured -2 +pretend -1 +pretending -1 +pretends -1 +pretty 1 +prevent -1 +prevented -1 +preventing -1 +prevents -1 +prick -5 +prison -2 +prisoner -2 +prisoners -2 +privileged 2 +proactive 2 +problem -2 +problems -2 +profiteer -2 +progress 2 +prominent 2 +promise 1 +promised 1 +promises 1 +promote 1 +promoted 1 +promotes 1 +promoting 1 +propaganda -2 +prosecute -1 +prosecuted -2 +prosecutes -1 +prosecution -1 +prospect 1 +prospects 1 +prosperous 3 +protect 1 +protected 1 +protects 1 +protest -2 +protesters -2 +protesting -2 +protests -2 +proud 2 +proudly 2 +provoke -1 +provoked -1 +provokes -1 +provoking -1 +pseudoscience -3 +punish -2 +punished -2 +punishes -2 +punitive -2 +pushy -1 +puzzled -2 +quaking -2 +questionable -2 +questioned -1 +questioning -1 +racism -3 +racist -3 +racists -3 +rage -2 +rageful -2 +rainy -1 +rant -3 +ranter -3 +ranters -3 +rants -3 +rape -4 +rapist -4 +rapture 2 +raptured 2 +raptures 2 +rapturous 4 +rash -2 +ratified 2 +reach 1 +reached 1 +reaches 1 +reaching 1 +reassure 1 +reassured 1 +reassures 1 +reassuring 2 +rebellion -2 +recession -2 +reckless -2 +recommend 2 +recommended 2 +recommends 2 +redeemed 2 +refuse -2 +refused -2 +refusing -2 +regret -2 +regretful -2 +regrets -2 +regretted -2 +regretting -2 +reject -1 +rejected -1 +rejecting -1 +rejects -1 +rejoice 4 +rejoiced 4 +rejoices 4 +rejoicing 4 +relaxed 2 +relentless -1 +reliant 2 +relieve 1 +relieved 2 +relieves 1 +relieving 2 +relishing 2 +remarkable 2 +remorse -2 +repulse -1 +repulsed -2 +rescue 2 +rescued 2 +rescues 2 +resentful -2 +resign -1 +resigned -1 +resigning -1 +resigns -1 +resolute 2 +resolve 2 +resolved 2 +resolves 2 +resolving 2 +respected 2 +responsible 2 +responsive 2 +restful 2 +restless -2 +restore 1 +restored 1 +restores 1 +restoring 1 +restrict -2 +restricted -2 +restricting -2 +restriction -2 +restricts -2 +retained -1 +retard -2 +retarded -2 +retreat -1 +revenge -2 +revengeful -2 +revered 2 +revive 2 +revives 2 +reward 2 +rewarded 2 +rewarding 2 +rewards 2 +rich 2 +ridiculous -3 +rig -1 +rigged -1 +right direction 3 +rigorous 3 +rigorously 3 +riot -2 +riots -2 +risk -2 +risks -2 +rob -2 +robber -2 +robed -2 +robing -2 +robs -2 +robust 2 +rofl 4 +roflcopter 4 +roflmao 4 +romance 2 +rotfl 4 +rotflmfao 4 +rotflol 4 +ruin -2 +ruined -2 +ruining -2 +ruins -2 +sabotage -2 +sad -2 +sadden -2 +saddened -2 +sadly -2 +safe 1 +safely 1 +safety 1 +salient 1 +sappy -1 +sarcastic -2 +satisfied 2 +save 2 +saved 2 +scam -2 +scams -2 +scandal -3 +scandalous -3 +scandals -3 +scapegoat -2 +scapegoats -2 +scare -2 +scared -2 +scary -2 +sceptical -2 +scold -2 +scoop 3 +scorn -2 +scornful -2 +scream -2 +screamed -2 +screaming -2 +screams -2 +screwed -2 +screwed up -3 +scumbag -4 +secure 2 +secured 2 +secures 2 +sedition -2 +seditious -2 +seduced -1 +self-confident 2 +self-deluded -2 +selfish -3 +selfishness -3 +sentence -2 +sentenced -2 +sentences -2 +sentencing -2 +serene 2 +severe -2 +sexy 3 +shaky -2 +shame -2 +shamed -2 +shameful -2 +share 1 +shared 1 +shares 1 +shattered -2 +shit -4 +shithead -4 +shitty -3 +shock -2 +shocked -2 +shocking -2 +shocks -2 +shoot -1 +short-sighted -2 +short-sightedness -2 +shortage -2 +shortages -2 +shrew -4 +shy -1 +sick -2 +sigh -2 +significance 1 +significant 1 +silencing -1 +silly -1 +sincere 2 +sincerely 2 +sincerest 2 +sincerity 2 +sinful -3 +singleminded -2 +skeptic -2 +skeptical -2 +skepticism -2 +skeptics -2 +slam -2 +slash -2 +slashed -2 +slashes -2 +slashing -2 +slavery -3 +sleeplessness -2 +slick 2 +slicker 2 +slickest 2 +sluggish -2 +slut -5 +smart 1 +smarter 2 +smartest 2 +smear -2 +smile 2 +smiled 2 +smiles 2 +smiling 2 +smog -2 +sneaky -1 +snub -2 +snubbed -2 +snubbing -2 +snubs -2 +sobering 1 +solemn -1 +solid 2 +solidarity 2 +solution 1 +solutions 1 +solve 1 +solved 1 +solves 1 +solving 1 +somber -2 +some kind 0 +son-of-a-bitch -5 +soothe 3 +soothed 3 +soothing 3 +sophisticated 2 +sore -1 +sorrow -2 +sorrowful -2 +sorry -1 +spam -2 +spammer -3 +spammers -3 +spamming -2 +spark 1 +sparkle 3 +sparkles 3 +sparkling 3 +speculative -2 +spirit 1 +spirited 2 +spiritless -2 +spiteful -2 +splendid 3 +sprightly 2 +squelched -1 +stab -2 +stabbed -2 +stable 2 +stabs -2 +stall -2 +stalled -2 +stalling -2 +stamina 2 +stampede -2 +startled -2 +starve -2 +starved -2 +starves -2 +starving -2 +steadfast 2 +steal -2 +steals -2 +stereotype -2 +stereotyped -2 +stifled -1 +stimulate 1 +stimulated 1 +stimulates 1 +stimulating 2 +stingy -2 +stolen -2 +stop -1 +stopped -1 +stopping -1 +stops -1 +stout 2 +straight 1 +strange -1 +strangely -1 +strangled -2 +strength 2 +strengthen 2 +strengthened 2 +strengthening 2 +strengthens 2 +stressed -2 +stressor -2 +stressors -2 +stricken -2 +strike -1 +strikers -2 +strikes -1 +strong 2 +stronger 2 +strongest 2 +struck -1 +struggle -2 +struggled -2 +struggles -2 +struggling -2 +stubborn -2 +stuck -2 +stunned -2 +stunning 4 +stupid -2 +stupidly -2 +suave 2 +substantial 1 +substantially 1 +subversive -2 +success 2 +successful 3 +suck -3 +sucks -3 +suffer -2 +suffering -2 +suffers -2 +suicidal -2 +suicide -2 +suing -2 +sulking -2 +sulky -2 +sullen -2 +sunshine 2 +super 3 +superb 5 +superior 2 +support 2 +supported 2 +supporter 1 +supporters 1 +supporting 1 +supportive 2 +supports 2 +survived 2 +surviving 2 +survivor 2 +suspect -1 +suspected -1 +suspecting -1 +suspects -1 +suspend -1 +suspended -1 +suspicious -2 +swear -2 +swearing -2 +swears -2 +sweet 2 +swift 2 +swiftly 2 +swindle -3 +swindles -3 +swindling -3 +sympathetic 2 +sympathy 2 +tard -2 +tears -2 +tender 2 +tense -2 +tension -1 +terrible -3 +terribly -3 +terrific 4 +terrified -3 +terror -3 +terrorize -3 +terrorized -3 +terrorizes -3 +thank 2 +thankful 2 +thanks 2 +thorny -2 +thoughtful 2 +thoughtless -2 +threat -2 +threaten -2 +threatened -2 +threatening -2 +threatens -2 +threats -2 +thrilled 5 +thwart -2 +thwarted -2 +thwarting -2 +thwarts -2 +timid -2 +timorous -2 +tired -2 +tits -2 +tolerant 2 +toothless -2 +top 2 +tops 2 +torn -2 +torture -4 +tortured -4 +tortures -4 +torturing -4 +totalitarian -2 +totalitarianism -2 +tout -2 +touted -2 +touting -2 +touts -2 +tragedy -2 +tragic -2 +tranquil 2 +trap -1 +trapped -2 +trauma -3 +traumatic -3 +travesty -2 +treason -3 +treasonous -3 +treasure 2 +treasures 2 +trembling -2 +tremulous -2 +tricked -2 +trickery -2 +triumph 4 +triumphant 4 +trouble -2 +troubled -2 +troubles -2 +true 2 +trust 1 +trusted 2 +tumor -2 +twat -5 +ugly -3 +unacceptable -2 +unappreciated -2 +unapproved -2 +unaware -2 +unbelievable -1 +unbelieving -1 +unbiased 2 +uncertain -1 +unclear -1 +uncomfortable -2 +unconcerned -2 +unconfirmed -1 +unconvinced -1 +uncredited -1 +undecided -1 +underestimate -1 +underestimated -1 +underestimates -1 +underestimating -1 +undermine -2 +undermined -2 +undermines -2 +undermining -2 +undeserving -2 +undesirable -2 +uneasy -2 +unemployment -2 +unequal -1 +unequaled 2 +unethical -2 +unfair -2 +unfocused -2 +unfulfilled -2 +unhappy -2 +unhealthy -2 +unified 1 +unimpressed -2 +unintelligent -2 +united 1 +unjust -2 +unlovable -2 +unloved -2 +unmatched 1 +unmotivated -2 +unprofessional -2 +unresearched -2 +unsatisfied -2 +unsecured -2 +unsettled -1 +unsophisticated -2 +unstable -2 +unstoppable 2 +unsupported -2 +unsure -1 +untarnished 2 +unwanted -2 +unworthy -2 +upset -2 +upsets -2 +upsetting -2 +uptight -2 +urgent -1 +useful 2 +usefulness 2 +useless -2 +uselessness -2 +vague -2 +validate 1 +validated 1 +validates 1 +validating 1 +verdict -1 +verdicts -1 +vested 1 +vexation -2 +vexing -2 +vibrant 3 +vicious -2 +victim -3 +victimize -3 +victimized -3 +victimizes -3 +victimizing -3 +victims -3 +vigilant 3 +vile -3 +vindicate 2 +vindicated 2 +vindicates 2 +vindicating 2 +violate -2 +violated -2 +violates -2 +violating -2 +violence -3 +violent -3 +virtuous 2 +virulent -2 +vision 1 +visionary 3 +visioning 1 +visions 1 +vitality 3 +vitamin 1 +vitriolic -3 +vivacious 3 +vociferous -1 +vulnerability -2 +vulnerable -2 +walkout -2 +walkouts -2 +wanker -3 +want 1 +war -2 +warfare -2 +warm 1 +warmth 2 +warn -2 +warned -2 +warning -3 +warnings -3 +warns -2 +waste -1 +wasted -2 +wasting -2 +wavering -1 +weak -2 +weakness -2 +wealth 3 +wealthy 2 +weary -2 +weep -2 +weeping -2 +weird -2 +welcome 2 +welcomed 2 +welcomes 2 +whimsical 1 +whitewash -3 +whore -4 +wicked -2 +widowed -1 +willingness 2 +win 4 +winner 4 +winning 4 +wins 4 +winwin 3 +wish 1 +wishes 1 +wishing 1 +withdrawal -3 +woebegone -2 +woeful -3 +won 3 +wonderful 4 +woo 3 +woohoo 3 +wooo 4 +woow 4 +worn -1 +worried -3 +worry -3 +worrying -3 +worse -3 +worsen -3 +worsened -3 +worsening -3 +worsens -3 +worshiped 3 +worst -3 +worth 2 +worthless -2 +worthy 2 +wow 4 +wowow 4 +wowww 4 +wrathful -3 +wreck -2 +wrong -2 +wronged -2 +wtf -4 +yeah 1 +yearning 1 +yeees 2 +yes 1 +youthful 2 +yucky -2 +yummy 3 +zealot -2 +zealots -2 +zealous 2 \ No newline at end of file diff --git a/dev/audit-release/audit_release.py b/dev/audit-release/audit_release.py index 27d1dd784ce2..972be30da1eb 100755 --- a/dev/audit-release/audit_release.py +++ b/dev/audit-release/audit_release.py @@ -115,7 +115,7 @@ def ensure_path_not_present(path): # maven that links against them. This will catch issues with messed up # dependencies within those projects. modules = [ - "spark-core", "spark-bagel", "spark-mllib", "spark-streaming", "spark-repl", + "spark-core", "spark-mllib", "spark-streaming", "spark-repl", "spark-graphx", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-mqtt", "spark-streaming-twitter", "spark-streaming-zeromq", "spark-catalyst", "spark-sql", "spark-hive", "spark-streaming-kinesis-asl" diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index cb79e9eba06e..b1895b16b1b6 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -166,9 +166,6 @@ if [[ "$1" == "package" ]]; then # We increment the Zinc port each time to avoid OOM's and other craziness if multiple builds # share the same Zinc server. - make_binary_release "hadoop1" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver" "3030" & - make_binary_release "hadoop1-scala2.11" "-Psparkr -Phadoop-1 -Phive -Dscala-2.11" "3031" & - make_binary_release "cdh4" "-Psparkr -Phadoop-1 -Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" "3032" & make_binary_release "hadoop2.3" "-Psparkr -Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" "3033" & make_binary_release "hadoop2.4" "-Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" "3034" & make_binary_release "hadoop2.6" "-Psparkr -Phadoop-2.6 -Phive -Phive-thriftserver -Pyarn" "3034" & diff --git a/dev/create-release/release-tag.sh b/dev/create-release/release-tag.sh index b0a3374becc6..d404939d1cae 100755 --- a/dev/create-release/release-tag.sh +++ b/dev/create-release/release-tag.sh @@ -64,9 +64,6 @@ git commit -a -m "Preparing Spark release $RELEASE_TAG" echo "Creating tag $RELEASE_TAG at the head of $GIT_BRANCH" git tag $RELEASE_TAG -# TODO: It would be nice to do some verifications here -# i.e. check whether ec2 scripts have the new version - # Create next version $MVN versions:set -DnewVersion=$NEXT_VERSION | grep -v "no value" # silence logs git commit -a -m "Preparing development version $NEXT_VERSION" diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py index 7f152b7f5355..5d0ac16b3b0a 100755 --- a/dev/create-release/releaseutils.py +++ b/dev/create-release/releaseutils.py @@ -159,7 +159,6 @@ def get_commits(tag): "build": CORE_COMPONENT, "deploy": CORE_COMPONENT, "documentation": CORE_COMPONENT, - "ec2": "EC2", "examples": CORE_COMPONENT, "graphx": "GraphX", "input/output": CORE_COMPONENT, diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 new file mode 100644 index 000000000000..e4373f79f792 --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -0,0 +1,191 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.jar +akka-actor_2.10-2.3.11.jar +akka-remote_2.10-2.3.11.jar +akka-slf4j_2.10-2.3.11.jar +antlr-runtime-3.5.2.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7-tests.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.10-0.11.2.jar +breeze_2.10-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.5.0.jar +chill_2.10-0.5.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math-2.1.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +config-1.2.1.jar +core-1.1.2.jar +curator-client-2.4.0.jar +curator-framework-2.4.0.jar +curator-recipes-2.4.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +gmbal-api-only-3.0.0-b023.jar +grizzly-framework-2.1.2.jar +grizzly-http-2.1.2.jar +grizzly-http-server-2.1.2.jar +grizzly-http-servlet-2.1.2.jar +grizzly-rcm-2.1.2.jar +groovy-all-2.1.6.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.2.0.jar +hadoop-auth-2.2.0.jar +hadoop-client-2.2.0.jar +hadoop-common-2.2.0.jar +hadoop-hdfs-2.2.0.jar +hadoop-mapreduce-client-app-2.2.0.jar +hadoop-mapreduce-client-common-2.2.0.jar +hadoop-mapreduce-client-core-2.2.0.jar +hadoop-mapreduce-client-jobclient-2.2.0.jar +hadoop-mapreduce-client-shuffle-2.2.0.jar +hadoop-yarn-api-2.2.0.jar +hadoop-yarn-client-2.2.0.jar +hadoop-yarn-common-2.2.0.jar +hadoop-yarn-server-common-2.2.0.jar +hadoop-yarn-server-web-proxy-2.2.0.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.4.4.jar +jackson-core-2.4.4.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.4.4.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.10-2.4.4.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +jansi-1.4.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javax.servlet-3.1.jar +javax.servlet-api-3.0.1.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.10.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-grizzly2-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jersey-test-framework-core-1.9.jar +jersey-test-framework-grizzly2-1.9.jar +jets3t-0.7.1.jar +jettison-1.1.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.10.5.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.10-3.2.10.jar +json4s-core_2.10-3.2.10.jar +json4s-jackson_2.10-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.10.jar +kryo-2.21.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.1.jar +management-api-3.0.0-b012.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-1.2.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.jar +pyrolite-4.9.jar +quasiquotes_2.10-2.0.0-M8.jar +reflectasm-1.07-shaded.jar +scala-compiler-2.10.5.jar +scala-library-2.10.5.jar +scala-reflect-2.10.5.jar +scalap-2.10.5.jar +servlet-api-2.5.jar +slf4j-api-1.7.10.jar +slf4j-log4j12-1.7.10.jar +snappy-0.2.jar +snappy-java-1.1.2.jar +spire-macros_2.10-0.7.4.jar +spire_2.10-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +super-csv-2.2.0.jar +tachyon-client-0.8.2.jar +tachyon-underfs-hdfs-0.8.2.jar +tachyon-underfs-local-0.8.2.jar +tachyon-underfs-s3-0.8.2.jar +uncommons-maths-1.2.2a.jar +unused-1.0.0.jar +xbean-asm5-shaded-4.4.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.5.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 new file mode 100644 index 000000000000..7478181406d0 --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -0,0 +1,182 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +akka-actor_2.10-2.3.11.jar +akka-remote_2.10-2.3.11.jar +akka-slf4j_2.10-2.3.11.jar +antlr-runtime-3.5.2.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7-tests.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.10-0.11.2.jar +breeze_2.10-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.5.0.jar +chill_2.10-0.5.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +config-1.2.1.jar +core-1.1.2.jar +curator-client-2.4.0.jar +curator-framework-2.4.0.jar +curator-recipes-2.4.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +groovy-all-2.1.6.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.3.0.jar +hadoop-auth-2.3.0.jar +hadoop-client-2.3.0.jar +hadoop-common-2.3.0.jar +hadoop-hdfs-2.3.0.jar +hadoop-mapreduce-client-app-2.3.0.jar +hadoop-mapreduce-client-common-2.3.0.jar +hadoop-mapreduce-client-core-2.3.0.jar +hadoop-mapreduce-client-jobclient-2.3.0.jar +hadoop-mapreduce-client-shuffle-2.3.0.jar +hadoop-yarn-api-2.3.0.jar +hadoop-yarn-client-2.3.0.jar +hadoop-yarn-common-2.3.0.jar +hadoop-yarn-server-common-2.3.0.jar +hadoop-yarn-server-web-proxy-2.3.0.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.4.4.jar +jackson-core-2.4.4.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.4.4.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.10-2.4.4.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +jansi-1.4.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.10.jar +jdo-api-3.0.1.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.10.5.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.10-3.2.10.jar +json4s-core_2.10-3.2.10.jar +json4s-jackson_2.10-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.10.jar +kryo-2.21.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.2.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-1.2.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.jar +pyrolite-4.9.jar +quasiquotes_2.10-2.0.0-M8.jar +reflectasm-1.07-shaded.jar +scala-compiler-2.10.5.jar +scala-library-2.10.5.jar +scala-reflect-2.10.5.jar +scalap-2.10.5.jar +servlet-api-2.5.jar +slf4j-api-1.7.10.jar +slf4j-log4j12-1.7.10.jar +snappy-0.2.jar +snappy-java-1.1.2.jar +spire-macros_2.10-0.7.4.jar +spire_2.10-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +super-csv-2.2.0.jar +tachyon-client-0.8.2.jar +tachyon-underfs-hdfs-0.8.2.jar +tachyon-underfs-local-0.8.2.jar +tachyon-underfs-s3-0.8.2.jar +uncommons-maths-1.2.2a.jar +unused-1.0.0.jar +xbean-asm5-shaded-4.4.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.5.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 new file mode 100644 index 000000000000..faffb8bf398a --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -0,0 +1,183 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +akka-actor_2.10-2.3.11.jar +akka-remote_2.10-2.3.11.jar +akka-slf4j_2.10-2.3.11.jar +antlr-runtime-3.5.2.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7-tests.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.10-0.11.2.jar +breeze_2.10-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.5.0.jar +chill_2.10-0.5.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +config-1.2.1.jar +core-1.1.2.jar +curator-client-2.4.0.jar +curator-framework-2.4.0.jar +curator-recipes-2.4.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +groovy-all-2.1.6.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.4.0.jar +hadoop-auth-2.4.0.jar +hadoop-client-2.4.0.jar +hadoop-common-2.4.0.jar +hadoop-hdfs-2.4.0.jar +hadoop-mapreduce-client-app-2.4.0.jar +hadoop-mapreduce-client-common-2.4.0.jar +hadoop-mapreduce-client-core-2.4.0.jar +hadoop-mapreduce-client-jobclient-2.4.0.jar +hadoop-mapreduce-client-shuffle-2.4.0.jar +hadoop-yarn-api-2.4.0.jar +hadoop-yarn-client-2.4.0.jar +hadoop-yarn-common-2.4.0.jar +hadoop-yarn-server-common-2.4.0.jar +hadoop-yarn-server-web-proxy-2.4.0.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.4.4.jar +jackson-core-2.4.4.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.4.4.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.10-2.4.4.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +jansi-1.4.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.10.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.10.5.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.10-3.2.10.jar +json4s-core_2.10-3.2.10.jar +json4s-jackson_2.10-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.10.jar +kryo-2.21.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.2.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-1.2.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.jar +pyrolite-4.9.jar +quasiquotes_2.10-2.0.0-M8.jar +reflectasm-1.07-shaded.jar +scala-compiler-2.10.5.jar +scala-library-2.10.5.jar +scala-reflect-2.10.5.jar +scalap-2.10.5.jar +servlet-api-2.5.jar +slf4j-api-1.7.10.jar +slf4j-log4j12-1.7.10.jar +snappy-0.2.jar +snappy-java-1.1.2.jar +spire-macros_2.10-0.7.4.jar +spire_2.10-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +super-csv-2.2.0.jar +tachyon-client-0.8.2.jar +tachyon-underfs-hdfs-0.8.2.jar +tachyon-underfs-local-0.8.2.jar +tachyon-underfs-s3-0.8.2.jar +uncommons-maths-1.2.2a.jar +unused-1.0.0.jar +xbean-asm5-shaded-4.4.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.5.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 new file mode 100644 index 000000000000..e703c7acd387 --- /dev/null +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -0,0 +1,190 @@ +JavaEWAH-0.3.2.jar +RoaringBitmap-0.5.11.jar +ST4-4.0.4.jar +activation-1.1.1.jar +akka-actor_2.10-2.3.11.jar +akka-remote_2.10-2.3.11.jar +akka-slf4j_2.10-2.3.11.jar +antlr-runtime-3.5.2.jar +aopalliance-1.0.jar +apache-log4j-extras-1.2.17.jar +apacheds-i18n-2.0.0-M15.jar +apacheds-kerberos-codec-2.0.0-M15.jar +api-asn1-api-1.0.0-M20.jar +api-util-1.0.0-M20.jar +arpack_combined_all-0.1.jar +asm-3.1.jar +asm-commons-3.1.jar +asm-tree-3.1.jar +avro-1.7.7.jar +avro-ipc-1.7.7-tests.jar +avro-ipc-1.7.7.jar +avro-mapred-1.7.7-hadoop2.jar +base64-2.3.8.jar +bcprov-jdk15on-1.51.jar +bonecp-0.8.0.RELEASE.jar +breeze-macros_2.10-0.11.2.jar +breeze_2.10-0.11.2.jar +calcite-avatica-1.2.0-incubating.jar +calcite-core-1.2.0-incubating.jar +calcite-linq4j-1.2.0-incubating.jar +chill-java-0.5.0.jar +chill_2.10-0.5.0.jar +commons-beanutils-1.7.0.jar +commons-beanutils-core-1.8.0.jar +commons-cli-1.2.jar +commons-codec-1.10.jar +commons-collections-3.2.2.jar +commons-compiler-2.7.6.jar +commons-compress-1.4.1.jar +commons-configuration-1.6.jar +commons-dbcp-1.4.jar +commons-digester-1.8.jar +commons-httpclient-3.1.jar +commons-io-2.4.jar +commons-lang-2.6.jar +commons-lang3-3.3.2.jar +commons-logging-1.1.3.jar +commons-math3-3.4.1.jar +commons-net-2.2.jar +commons-pool-1.5.4.jar +compress-lzf-1.0.3.jar +config-1.2.1.jar +core-1.1.2.jar +curator-client-2.6.0.jar +curator-framework-2.6.0.jar +curator-recipes-2.6.0.jar +datanucleus-api-jdo-3.2.6.jar +datanucleus-core-3.2.10.jar +datanucleus-rdbms-3.2.9.jar +derby-10.10.1.1.jar +eigenbase-properties-1.1.5.jar +geronimo-annotation_1.0_spec-1.1.1.jar +geronimo-jaspic_1.0_spec-1.0.jar +geronimo-jta_1.1_spec-1.1.1.jar +groovy-all-2.1.6.jar +gson-2.2.4.jar +guice-3.0.jar +guice-servlet-3.0.jar +hadoop-annotations-2.6.0.jar +hadoop-auth-2.6.0.jar +hadoop-client-2.6.0.jar +hadoop-common-2.6.0.jar +hadoop-hdfs-2.6.0.jar +hadoop-mapreduce-client-app-2.6.0.jar +hadoop-mapreduce-client-common-2.6.0.jar +hadoop-mapreduce-client-core-2.6.0.jar +hadoop-mapreduce-client-jobclient-2.6.0.jar +hadoop-mapreduce-client-shuffle-2.6.0.jar +hadoop-yarn-api-2.6.0.jar +hadoop-yarn-client-2.6.0.jar +hadoop-yarn-common-2.6.0.jar +hadoop-yarn-server-common-2.6.0.jar +hadoop-yarn-server-web-proxy-2.6.0.jar +htrace-core-3.0.4.jar +httpclient-4.3.2.jar +httpcore-4.3.2.jar +ivy-2.4.0.jar +jackson-annotations-2.4.4.jar +jackson-core-2.4.4.jar +jackson-core-asl-1.9.13.jar +jackson-databind-2.4.4.jar +jackson-jaxrs-1.9.13.jar +jackson-mapper-asl-1.9.13.jar +jackson-module-scala_2.10-2.4.4.jar +jackson-xc-1.9.13.jar +janino-2.7.8.jar +jansi-1.4.jar +java-xmlbuilder-1.0.jar +javax.inject-1.jar +javax.servlet-3.0.0.v201112011016.jar +javolution-5.5.1.jar +jaxb-api-2.2.2.jar +jaxb-impl-2.2.3-1.jar +jcl-over-slf4j-1.7.10.jar +jdo-api-3.0.1.jar +jersey-client-1.9.jar +jersey-core-1.9.jar +jersey-guice-1.9.jar +jersey-json-1.9.jar +jersey-server-1.9.jar +jets3t-0.9.3.jar +jettison-1.1.jar +jetty-6.1.26.jar +jetty-all-7.6.0.v20120127.jar +jetty-util-6.1.26.jar +jline-2.10.5.jar +jline-2.12.jar +joda-time-2.9.jar +jodd-core-3.5.2.jar +jpam-1.1.jar +json-20090211.jar +json4s-ast_2.10-3.2.10.jar +json4s-core_2.10-3.2.10.jar +json4s-jackson_2.10-3.2.10.jar +jsr305-1.3.9.jar +jta-1.1.jar +jtransforms-2.4.0.jar +jul-to-slf4j-1.7.10.jar +kryo-2.21.jar +leveldbjni-all-1.8.jar +libfb303-0.9.2.jar +libthrift-0.9.2.jar +log4j-1.2.17.jar +lz4-1.3.0.jar +mail-1.4.7.jar +mesos-0.21.1-shaded-protobuf.jar +metrics-core-3.1.2.jar +metrics-graphite-3.1.2.jar +metrics-json-3.1.2.jar +metrics-jvm-3.1.2.jar +minlog-1.2.jar +mx4j-3.0.2.jar +netty-3.8.0.Final.jar +netty-all-4.0.29.Final.jar +objenesis-1.2.jar +opencsv-2.3.jar +oro-2.0.8.jar +paranamer-2.6.jar +parquet-column-1.7.0.jar +parquet-common-1.7.0.jar +parquet-encoding-1.7.0.jar +parquet-format-2.3.0-incubating.jar +parquet-generator-1.7.0.jar +parquet-hadoop-1.7.0.jar +parquet-hadoop-bundle-1.6.0.jar +parquet-jackson-1.7.0.jar +pmml-agent-1.2.7.jar +pmml-model-1.2.7.jar +pmml-schema-1.2.7.jar +protobuf-java-2.5.0.jar +py4j-0.9.jar +pyrolite-4.9.jar +quasiquotes_2.10-2.0.0-M8.jar +reflectasm-1.07-shaded.jar +scala-compiler-2.10.5.jar +scala-library-2.10.5.jar +scala-reflect-2.10.5.jar +scalap-2.10.5.jar +servlet-api-2.5.jar +slf4j-api-1.7.10.jar +slf4j-log4j12-1.7.10.jar +snappy-0.2.jar +snappy-java-1.1.2.jar +spire-macros_2.10-0.7.4.jar +spire_2.10-0.7.4.jar +stax-api-1.0-2.jar +stax-api-1.0.1.jar +stream-2.7.0.jar +super-csv-2.2.0.jar +tachyon-client-0.8.2.jar +tachyon-underfs-hdfs-0.8.2.jar +tachyon-underfs-local-0.8.2.jar +tachyon-underfs-s3-0.8.2.jar +uncommons-maths-1.2.2a.jar +unused-1.0.0.jar +xbean-asm5-shaded-4.4.jar +xercesImpl-2.9.1.jar +xmlenc-0.52.jar +xz-1.0.jar +zookeeper-3.4.6.jar diff --git a/dev/lint-python b/dev/lint-python index 0b97213ae3df..1765a07d2f22 100755 --- a/dev/lint-python +++ b/dev/lint-python @@ -19,7 +19,7 @@ SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )" SPARK_ROOT_DIR="$(dirname "$SCRIPT_DIR")" -PATHS_TO_CHECK="./python/pyspark/ ./ec2/spark_ec2.py ./examples/src/main/python/ ./dev/sparktestsupport" +PATHS_TO_CHECK="./python/pyspark/ ./examples/src/main/python/ ./dev/sparktestsupport" PATHS_TO_CHECK="$PATHS_TO_CHECK ./dev/run-tests.py ./python/run-tests.py ./dev/run-tests-jenkins.py" PEP8_REPORT_PATH="$SPARK_ROOT_DIR/dev/pep8-report.txt" PYLINT_REPORT_PATH="$SPARK_ROOT_DIR/dev/pylint-report.txt" diff --git a/dev/run-tests-jenkins.py b/dev/run-tests-jenkins.py index 7aecea25b209..c44e522c0475 100755 --- a/dev/run-tests-jenkins.py +++ b/dev/run-tests-jenkins.py @@ -124,6 +124,7 @@ def run_tests(tests_timeout): ERROR_CODES["BLOCK_R_STYLE"]: 'R style tests', ERROR_CODES["BLOCK_DOCUMENTATION"]: 'to generate documentation', ERROR_CODES["BLOCK_BUILD"]: 'to build', + ERROR_CODES["BLOCK_BUILD_TESTS"]: 'build dependency tests', ERROR_CODES["BLOCK_MIMA"]: 'MiMa tests', ERROR_CODES["BLOCK_SPARK_UNIT_TESTS"]: 'Spark unit tests', ERROR_CODES["BLOCK_PYSPARK_UNIT_TESTS"]: 'PySpark unit tests', @@ -163,14 +164,14 @@ def main(): if "test-maven" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_TOOL"] = "maven" # Switch the Hadoop profile based on the PR title: - if "test-hadoop1.0" in ghprb_pull_title: - os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop1.0" - if "test-hadoop2.0" in ghprb_pull_title: - os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.0" if "test-hadoop2.2" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.2" if "test-hadoop2.3" in ghprb_pull_title: os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.3" + if "test-hadoop2.4" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.4" + if "test-hadoop2.6" in ghprb_pull_title: + os.environ["AMPLAB_JENKINS_BUILD_PROFILE"] = "hadoop2.6" build_display_name = os.environ["BUILD_DISPLAY_NAME"] build_url = os.environ["BUILD_URL"] @@ -197,7 +198,6 @@ def main(): pr_tests = [ "pr_merge_ability", "pr_public_classes" - # DISABLED (pwendell) "pr_new_dependencies" ] # `bind_message_base` returns a function to generate messages for Github posting diff --git a/dev/run-tests.py b/dev/run-tests.py index e7e10f1d8c72..8726889cbc77 100755 --- a/dev/run-tests.py +++ b/dev/run-tests.py @@ -148,7 +148,7 @@ def determine_java_executable(): return java_exe if java_exe else which("java") -JavaVersion = namedtuple('JavaVersion', ['major', 'minor', 'patch', 'update']) +JavaVersion = namedtuple('JavaVersion', ['major', 'minor', 'patch']) def determine_java_version(java_exe): @@ -164,14 +164,13 @@ def determine_java_version(java_exe): # find raw version string, eg 'java version "1.8.0_25"' raw_version_str = next(x for x in raw_output_lines if " version " in x) - match = re.search('(\d+)\.(\d+)\.(\d+)_(\d+)', raw_version_str) + match = re.search('(\d+)\.(\d+)\.(\d+)', raw_version_str) major = int(match.group(1)) minor = int(match.group(2)) patch = int(match.group(3)) - update = int(match.group(4)) - return JavaVersion(major, minor, patch, update) + return JavaVersion(major, minor, patch) # ------------------------------------------------------------------------------------------------- # Functions for running the other build and test scripts @@ -296,15 +295,14 @@ def exec_sbt(sbt_args=()): def get_hadoop_profiles(hadoop_version): """ - For the given Hadoop version tag, return a list of SBT profile flags for + For the given Hadoop version tag, return a list of Maven/SBT profile flags for building and testing against that Hadoop version. """ sbt_maven_hadoop_profiles = { - "hadoop1.0": ["-Phadoop-1", "-Dhadoop.version=1.2.1"], - "hadoop2.0": ["-Phadoop-1", "-Dhadoop.version=2.0.0-mr1-cdh4.1.1"], "hadoop2.2": ["-Pyarn", "-Phadoop-2.2"], - "hadoop2.3": ["-Pyarn", "-Phadoop-2.3", "-Dhadoop.version=2.3.0"], + "hadoop2.3": ["-Pyarn", "-Phadoop-2.3"], + "hadoop2.4": ["-Pyarn", "-Phadoop-2.4"], "hadoop2.6": ["-Pyarn", "-Phadoop-2.6"], } @@ -420,6 +418,12 @@ def run_python_tests(test_modules, parallelism): run_cmd(command) +def run_build_tests(): + set_title_and_block("Running build tests", "BLOCK_BUILD_TESTS") + run_cmd([os.path.join(SPARK_HOME, "dev", "test-dependencies.sh")]) + pass + + def run_sparkr_tests(): set_title_and_block("Running SparkR tests", "BLOCK_SPARKR_UNIT_TESTS") @@ -528,7 +532,8 @@ def main(): if not changed_files or any(f.endswith(".scala") for f in changed_files): run_scala_style_checks() if not changed_files or any(f.endswith(".java") for f in changed_files): - run_java_style_checks() + # run_java_style_checks() + pass if not changed_files or any(f.endswith(".py") for f in changed_files): run_python_style_checks() if not changed_files or any(f.endswith(".R") for f in changed_files): @@ -539,6 +544,9 @@ def main(): # if "DOCS" in changed_modules and test_env == "amplab_jenkins": # build_spark_documentation() + if any(m.should_run_build_tests for m in test_modules): + run_build_tests() + # spark build build_apache_spark(build_tool, hadoop_version) diff --git a/dev/sparktestsupport/__init__.py b/dev/sparktestsupport/__init__.py index 0e8032d13341..89015f8c4fb9 100644 --- a/dev/sparktestsupport/__init__.py +++ b/dev/sparktestsupport/__init__.py @@ -32,5 +32,6 @@ "BLOCK_PYSPARK_UNIT_TESTS": 19, "BLOCK_SPARKR_UNIT_TESTS": 20, "BLOCK_JAVA_STYLE": 21, + "BLOCK_BUILD_TESTS": 22, "BLOCK_TIMEOUT": 124 } diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index d65547e04db4..1fc659616412 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -31,7 +31,7 @@ class Module(object): def __init__(self, name, dependencies, source_file_regexes, build_profile_flags=(), environ={}, sbt_test_goals=(), python_test_goals=(), blacklisted_python_implementations=(), - test_tags=(), should_run_r_tests=False): + test_tags=(), should_run_r_tests=False, should_run_build_tests=False): """ Define a new module. @@ -53,6 +53,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= :param test_tags A set of tags that will be excluded when running unit tests if the module is not explicitly changed. :param should_run_r_tests: If true, changes in this module will trigger all R tests. + :param should_run_build_tests: If true, changes in this module will trigger build tests. """ self.name = name self.dependencies = dependencies @@ -64,6 +65,7 @@ def __init__(self, name, dependencies, source_file_regexes, build_profile_flags= self.blacklisted_python_implementations = blacklisted_python_implementations self.test_tags = test_tags self.should_run_r_tests = should_run_r_tests + self.should_run_build_tests = should_run_build_tests self.dependent_modules = set() for dep in dependencies: @@ -394,16 +396,16 @@ def contains_file(self, filename): ] ) - -ec2 = Module( - name="ec2", +build = Module( + name="build", dependencies=[], source_file_regexes=[ - "ec2/", - ] + ".*pom.xml", + "dev/test-dependencies.sh", + ], + should_run_build_tests=True ) - yarn = Module( name="yarn", dependencies=[], @@ -433,5 +435,6 @@ def contains_file(self, filename): "test", ], python_test_goals=list(itertools.chain.from_iterable(m.python_test_goals for m in all_modules)), - should_run_r_tests=True + should_run_r_tests=True, + should_run_build_tests=True ) diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh new file mode 100755 index 000000000000..424ce6ad7663 --- /dev/null +++ b/dev/test-dependencies.sh @@ -0,0 +1,120 @@ +#!/usr/bin/env bash + +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +set -e + +FWDIR="$(cd "`dirname $0`"/..; pwd)" +cd "$FWDIR" + +# Explicitly set locale in order to make `sort` output consistent across machines. +# See https://stackoverflow.com/questions/28881 for more details. +export LC_ALL=C + +# TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. + +# NOTE: These should match those in the release publishing script +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pyarn -Phive" +MVN="build/mvn --force" +HADOOP_PROFILES=( + hadoop-2.2 + hadoop-2.3 + hadoop-2.4 + hadoop-2.6 +) + +# We'll switch the version to a temp. one, publish POMs using that new version, then switch back to +# the old version. We need to do this because the `dependency:build-classpath` task needs to +# resolve Spark's internal submodule dependencies. + +# From http://stackoverflow.com/a/26514030 +set +e +OLD_VERSION=$($MVN -q \ + -Dexec.executable="echo" \ + -Dexec.args='${project.version}' \ + --non-recursive \ + org.codehaus.mojo:exec-maven-plugin:1.3.1:exec) +if [ $? != 0 ]; then + echo -e "Error while getting version string from Maven:\n$OLD_VERSION" + exit 1 +fi +set -e +TEMP_VERSION="spark-$(python -S -c "import random; print(random.randrange(100000, 999999))")" + +function reset_version { + # Delete the temporary POMs that we wrote to the local Maven repo: + find "$HOME/.m2/" | grep "$TEMP_VERSION" | xargs rm -rf + + # Restore the original version number: + $MVN -q versions:set -DnewVersion=$OLD_VERSION -DgenerateBackupPoms=false > /dev/null +} +trap reset_version EXIT + +$MVN -q versions:set -DnewVersion=$TEMP_VERSION -DgenerateBackupPoms=false > /dev/null + +# Generate manifests for each Hadoop profile: +for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do + echo "Performing Maven install for $HADOOP_PROFILE" + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE jar:jar install:install -q \ + -pl '!assembly' \ + -pl '!examples' \ + -pl '!external/flume-assembly' \ + -pl '!external/kafka-assembly' \ + -pl '!external/twitter' \ + -pl '!external/flume' \ + -pl '!external/mqtt' \ + -pl '!external/mqtt-assembly' \ + -pl '!external/zeromq' \ + -pl '!external/kafka' \ + -pl '!tags' \ + -DskipTests + + echo "Generating dependency manifest for $HADOOP_PROFILE" + mkdir -p dev/pr-deps + $MVN $HADOOP2_MODULE_PROFILES -P$HADOOP_PROFILE dependency:build-classpath -pl assembly \ + | grep "Building Spark Project Assembly" -A 5 \ + | tail -n 1 | tr ":" "\n" | rev | cut -d "/" -f 1 | rev | sort \ + | grep -v spark > dev/pr-deps/spark-deps-$HADOOP_PROFILE +done + +if [[ $@ == **replace-manifest** ]]; then + echo "Replacing manifests and creating new files at dev/deps" + rm -rf dev/deps + mv dev/pr-deps dev/deps + exit 0 +fi + +for HADOOP_PROFILE in "${HADOOP_PROFILES[@]}"; do + set +e + dep_diff="$( + git diff \ + --no-index \ + dev/deps/spark-deps-$HADOOP_PROFILE \ + dev/pr-deps/spark-deps-$HADOOP_PROFILE \ + )" + set -e + if [ "$dep_diff" != "" ]; then + echo "Spark's published dependencies DO NOT MATCH the manifest file (dev/spark-deps)." + echo "To update the manifest file, run './dev/test-dependencies.sh --replace-manifest'." + echo "$dep_diff" + rm -rf dev/pr-deps + exit 1 + fi +done + +exit 0 diff --git a/dev/tests/pr_new_dependencies.sh b/dev/tests/pr_new_dependencies.sh deleted file mode 100755 index fdfb3c62aff5..000000000000 --- a/dev/tests/pr_new_dependencies.sh +++ /dev/null @@ -1,117 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# -# This script follows the base format for testing pull requests against -# another branch and returning results to be published. More details can be -# found at dev/run-tests-jenkins. -# -# Arg1: The Github Pull Request Actual Commit -#+ known as `ghprbActualCommit` in `run-tests-jenkins` -# Arg2: The SHA1 hash -#+ known as `sha1` in `run-tests-jenkins` -# Arg3: Current PR Commit Hash -#+ the PR hash for the current commit -# - -ghprbActualCommit="$1" -sha1="$2" -current_pr_head="$3" - -MVN_BIN="build/mvn" -CURR_CP_FILE="my-classpath.txt" -MASTER_CP_FILE="master-classpath.txt" - -# First switch over to the master branch -git checkout -f master -# Find and copy all pom.xml files into a *.gate file that we can check -# against through various `git` changes -find -name "pom.xml" -exec cp {} {}.gate \; -# Switch back to the current PR -git checkout -f "${current_pr_head}" - -# Check if any *.pom files from the current branch are different from the master -difference_q="" -for p in $(find -name "pom.xml"); do - [[ -f "${p}" && -f "${p}.gate" ]] && \ - difference_q="${difference_q}$(diff $p.gate $p)" -done - -# If no pom files were changed we can easily say no new dependencies were added -if [ -z "${difference_q}" ]; then - echo " * This patch does not change any dependencies." -else - # Else we need to manually build spark to determine what, if any, dependencies - # were added into the Spark assembly jar - ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ - sed -n -e '/Building Spark Project Assembly/,$p' | \ - grep --context=1 -m 2 "Dependencies classpath:" | \ - head -n 3 | \ - tail -n 1 | \ - tr ":" "\n" | \ - rev | \ - cut -d "/" -f 1 | \ - rev | \ - sort > ${CURR_CP_FILE} - - # Checkout the master branch to compare against - git checkout -f master - - ${MVN_BIN} clean package dependency:build-classpath -DskipTests 2>/dev/null | \ - sed -n -e '/Building Spark Project Assembly/,$p' | \ - grep --context=1 -m 2 "Dependencies classpath:" | \ - head -n 3 | \ - tail -n 1 | \ - tr ":" "\n" | \ - rev | \ - cut -d "/" -f 1 | \ - rev | \ - sort > ${MASTER_CP_FILE} - - DIFF_RESULTS="`diff ${CURR_CP_FILE} ${MASTER_CP_FILE}`" - - if [ -z "${DIFF_RESULTS}" ]; then - echo " * This patch does not change any dependencies." - else - # Pretty print the new dependencies - added_deps=$(echo "${DIFF_RESULTS}" | grep "<" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') - removed_deps=$(echo "${DIFF_RESULTS}" | grep ">" | cut -d' ' -f2 | awk '{printf " * \`"$1"\`\\n"}') - added_deps_text=" * This patch **adds the following new dependencies:**\n${added_deps}" - removed_deps_text=" * This patch **removes the following dependencies:**\n${removed_deps}" - - # Construct the final returned message with proper - return_mssg="" - [ -n "${added_deps}" ] && return_mssg="${added_deps_text}" - if [ -n "${removed_deps}" ]; then - if [ -n "${return_mssg}" ]; then - return_mssg="${return_mssg}\n${removed_deps_text}" - else - return_mssg="${removed_deps_text}" - fi - fi - echo "${return_mssg}" - fi - - # Remove the files we've left over - [ -f "${CURR_CP_FILE}" ] && rm -f "${CURR_CP_FILE}" - [ -f "${MASTER_CP_FILE}" ] && rm -f "${MASTER_CP_FILE}" - - # Clean up our mess from the Maven builds just in case - ${MVN_BIN} clean &>/dev/null -fi diff --git a/dev/tests/pr_public_classes.sh b/dev/tests/pr_public_classes.sh index 927295b88c96..41c5d3ee8cb3 100755 --- a/dev/tests/pr_public_classes.sh +++ b/dev/tests/pr_public_classes.sh @@ -24,36 +24,44 @@ # # Arg1: The Github Pull Request Actual Commit #+ known as `ghprbActualCommit` in `run-tests-jenkins` -# Arg2: The SHA1 hash -#+ known as `sha1` in `run-tests-jenkins` -# - -# We diff master...$ghprbActualCommit because that gets us changes introduced in the PR -#+ and not anything else added to master since the PR was branched. ghprbActualCommit="$1" -sha1="$2" + +# $ghprbActualCommit is an automatic merge commit generated by GitHub; its parents are some Spark +# master commit and the tip of the pull request branch. + +# By diffing$ghprbActualCommit^...$ghprbActualCommit and filtering to examine the diffs of only +# non-test files, we can gets us changes introduced in the PR and not anything else added to master +# since the PR was branched. + +# Handle differences between GNU and BSD sed +if [[ $(uname) == "Darwin" ]]; then + SED='sed -E' +else + SED='sed -r' +fi source_files=$( - git diff master...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ + git diff $ghprbActualCommit^...$ghprbActualCommit --name-only `# diff patch against master from branch point` \ | grep -v -e "\/test" `# ignore files in test directories` \ | grep -e "\.py$" -e "\.java$" -e "\.scala$" `# include only code files` \ | tr "\n" " " ) + new_public_classes=$( - git diff master...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ + git diff $ghprbActualCommit^...$ghprbActualCommit ${source_files} `# diff patch against master from branch point` \ | grep "^\+" `# filter in only added lines` \ - | sed -r -e "s/^\+//g" `# remove the leading +` \ + | $SED -e "s/^\+//g" `# remove the leading +` \ | grep -e "trait " -e "class " `# filter in lines with these key words` \ | grep -e "{" -e "(" `# filter in lines with these key words, too` \ | grep -v -e "\@\@" -e "private" `# exclude lines with these words` \ | grep -v -e "^// " -e "^/\*" -e "^ \* " `# exclude comment lines` \ - | sed -r -e "s/\{.*//g" `# remove from the { onwards` \ - | sed -r -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ - | sed -r -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ - | sed -r -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ - | sed -r -e "s/^/ \* /g" `# prepend ' *' to start of line` \ - | sed -r -e "s/$/\\\n/g" `# append newline to end of line` \ + | $SED -e "s/\{.*//g" `# remove from the { onwards` \ + | $SED -e "s/\}//g" `# just in case, remove }; they mess the JSON` \ + | $SED -e "s/\"/\\\\\"/g" `# escape double quotes; they mess the JSON` \ + | $SED -e "s/^(.*)$/\`\1\`/g" `# surround with backticks for style` \ + | $SED -e "s/^/ \* /g" `# prepend ' *' to start of line` \ + | $SED -e "s/$/\\\n/g" `# append newline to end of line` \ | tr -d "\n" `# remove actual LF characters` ) @@ -61,5 +69,5 @@ if [ -z "$new_public_classes" ]; then echo " * This patch adds no public classes." else public_classes_note=" * This patch adds the following public classes _(experimental)_:" - echo "${public_classes_note}\n${new_public_classes}" + echo -e "${public_classes_note}\n${new_public_classes}" fi diff --git a/docker-integration-tests/pom.xml b/docker-integration-tests/pom.xml index 39d3f344615e..78b638ecfa63 100644 --- a/docker-integration-tests/pom.xml +++ b/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala index c503c4a13b48..f73231fc80a0 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/DockerJDBCIntegrationSuite.scala @@ -30,8 +30,8 @@ import org.scalatest.concurrent.Eventually import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite -import org.apache.spark.util.DockerUtils import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.DockerUtils abstract class DatabaseOnDocker { /** diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala index 6eb6b3391a4a..559dc1fed163 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/PostgresIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.sql.Connection import java.util.Properties import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.{Literal, If} +import org.apache.spark.sql.catalyst.expressions.{If, Literal} import org.apache.spark.tags.DockerTest @DockerTest diff --git a/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala index 87271776d856..fda377e03235 100644 --- a/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala +++ b/docker-integration-tests/src/test/scala/org/apache/spark/util/DockerUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.util -import java.net.{Inet4Address, NetworkInterface, InetAddress} +import java.net.{Inet4Address, InetAddress, NetworkInterface} import scala.collection.JavaConverters._ import scala.sys.process._ diff --git a/docs/_config.yml b/docs/_config.yml index 2c70b76be8b7..dc25ff2c16c5 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -14,8 +14,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala, and Mesos. -SPARK_VERSION: 1.6.0-SNAPSHOT -SPARK_VERSION_SHORT: 1.6.0 +SPARK_VERSION: 2.0.0-SNAPSHOT +SPARK_VERSION_SHORT: 2.0.0 SCALA_BINARY_VERSION: "2.10" SCALA_VERSION: "2.10.5" MESOS_VERSION: 0.21.0 diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html index 3089474c1338..d493f62f0e57 100755 --- a/docs/_layouts/global.html +++ b/docs/_layouts/global.html @@ -75,7 +75,6 @@
  • DataFrames, Datasets and SQL
  • MLlib (Machine Learning)
  • GraphX (Graph Processing)
  • -
  • Bagel (Pregel on Spark)
  • SparkR (R on Spark)
  • @@ -99,8 +98,6 @@
  • Spark Standalone
  • Mesos
  • YARN
  • -
  • -
  • Amazon EC2
  • diff --git a/docs/bagel-programming-guide.md b/docs/bagel-programming-guide.md deleted file mode 100644 index 347ca4a7af98..000000000000 --- a/docs/bagel-programming-guide.md +++ /dev/null @@ -1,159 +0,0 @@ ---- -layout: global -displayTitle: Bagel Programming Guide -title: Bagel ---- - -**Bagel is deprecated, and superseded by [GraphX](graphx-programming-guide.html).** - -Bagel is a Spark implementation of Google's [Pregel](http://portal.acm.org/citation.cfm?id=1807184) graph processing framework. Bagel currently supports basic graph computation, combiners, and aggregators. - -In the Pregel programming model, jobs run as a sequence of iterations called _supersteps_. In each superstep, each vertex in the graph runs a user-specified function that can update state associated with the vertex and send messages to other vertices for use in the *next* iteration. - -This guide shows the programming model and features of Bagel by walking through an example implementation of PageRank on Bagel. - -# Linking with Bagel - -To use Bagel in your program, add the following SBT or Maven dependency: - - groupId = org.apache.spark - artifactId = spark-bagel_{{site.SCALA_BINARY_VERSION}} - version = {{site.SPARK_VERSION}} - -# Programming Model - -Bagel operates on a graph represented as a [distributed dataset](programming-guide.html) of (K, V) pairs, where keys are vertex IDs and values are vertices plus their associated state. In each superstep, Bagel runs a user-specified compute function on each vertex that takes as input the current vertex state and a list of messages sent to that vertex during the previous superstep, and returns the new vertex state and a list of outgoing messages. - -For example, we can use Bagel to implement PageRank. Here, vertices represent pages, edges represent links between pages, and messages represent shares of PageRank sent to the pages that a particular page links to. - -We first extend the default `Vertex` class to store a `Double` -representing the current PageRank of the vertex, and similarly extend -the `Message` and `Edge` classes. Note that these need to be marked `@serializable` to allow Spark to transfer them across machines. We also import the Bagel types and implicit conversions. - -{% highlight scala %} -import org.apache.spark.bagel._ -import org.apache.spark.bagel.Bagel._ - -@serializable class PREdge(val targetId: String) extends Edge - -@serializable class PRVertex( - val id: String, val rank: Double, val outEdges: Seq[Edge], - val active: Boolean) extends Vertex - -@serializable class PRMessage( - val targetId: String, val rankShare: Double) extends Message -{% endhighlight %} - -Next, we load a sample graph from a text file as a distributed dataset and package it into `PRVertex` objects. We also cache the distributed dataset because Bagel will use it multiple times and we'd like to avoid recomputing it. - -{% highlight scala %} -val input = sc.textFile("data/mllib/pagerank_data.txt") - -val numVerts = input.count() - -val verts = input.map(line => { - val fields = line.split('\t') - val (id, linksStr) = (fields(0), fields(1)) - val links = linksStr.split(',').map(new PREdge(_)) - (id, new PRVertex(id, 1.0 / numVerts, links, true)) -}).cache -{% endhighlight %} - -We run the Bagel job, passing in `verts`, an empty distributed dataset of messages, and a custom compute function that runs PageRank for 10 iterations. - -{% highlight scala %} -val emptyMsgs = sc.parallelize(List[(String, PRMessage)]()) - -def compute(self: PRVertex, msgs: Option[Seq[PRMessage]], superstep: Int) -: (PRVertex, Iterable[PRMessage]) = { - val msgSum = msgs.getOrElse(List()).map(_.rankShare).sum - val newRank = - if (msgSum != 0) - 0.15 / numVerts + 0.85 * msgSum - else - self.rank - val halt = superstep >= 10 - val msgsOut = - if (!halt) - self.outEdges.map(edge => - new PRMessage(edge.targetId, newRank / self.outEdges.size)) - else - List() - (new PRVertex(self.id, newRank, self.outEdges, !halt), msgsOut) -} -{% endhighlight %} - -val result = Bagel.run(sc, verts, emptyMsgs)()(compute) - -Finally, we print the results. - -{% highlight scala %} -println(result.map(v => "%s\t%s\n".format(v.id, v.rank)).collect.mkString) -{% endhighlight %} - -## Combiners - -Sending a message to another vertex generally involves expensive communication over the network. For certain algorithms, it's possible to reduce the amount of communication using _combiners_. For example, if the compute function receives integer messages and only uses their sum, it's possible for Bagel to combine multiple messages to the same vertex by summing them. - -For combiner support, Bagel can optionally take a set of combiner functions that convert messages to their combined form. - -_Example: PageRank with combiners_ - -## Aggregators - -Aggregators perform a reduce across all vertices after each superstep, and provide the result to each vertex in the next superstep. - -For aggregator support, Bagel can optionally take an aggregator function that reduces across each vertex. - -_Example_ - -## Operations - -Here are the actions and types in the Bagel API. See [Bagel.scala](https://github.com/apache/spark/blob/master/bagel/src/main/scala/org/apache/spark/bagel/Bagel.scala) for details. - -### Actions - -{% highlight scala %} -/*** Full form ***/ - -Bagel.run(sc, vertices, messages, combiner, aggregator, partitioner, numSplits)(compute) -// where compute takes (vertex: V, combinedMessages: Option[C], aggregated: Option[A], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) - -/*** Abbreviated forms ***/ - -Bagel.run(sc, vertices, messages, combiner, partitioner, numSplits)(compute) -// where compute takes (vertex: V, combinedMessages: Option[C], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) - -Bagel.run(sc, vertices, messages, combiner, numSplits)(compute) -// where compute takes (vertex: V, combinedMessages: Option[C], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) - -Bagel.run(sc, vertices, messages, numSplits)(compute) -// where compute takes (vertex: V, messages: Option[Array[M]], superstep: Int) -// and returns (newVertex: V, outMessages: Array[M]) -{% endhighlight %} - -### Types - -{% highlight scala %} -trait Combiner[M, C] { - def createCombiner(msg: M): C - def mergeMsg(combiner: C, msg: M): C - def mergeCombiners(a: C, b: C): C -} - -trait Aggregator[V, A] { - def createAggregator(vert: V): A - def mergeAggregators(a: A, b: A): A -} - -trait Vertex { - def active: Boolean -} - -trait Message[K] { - def targetId: K -} -{% endhighlight %} diff --git a/docs/building-spark.md b/docs/building-spark.md index 3d38edbdad4b..785988902da8 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -33,13 +33,13 @@ to the `sharedSettings` val. See also [this PR](https://github.com/apache/spark/ # Building a Runnable Distribution -To create a Spark distribution like those distributed by the -[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as -to be runnable, use `make-distribution.sh` in the project root directory. It can be configured +To create a Spark distribution like those distributed by the +[Spark Downloads](http://spark.apache.org/downloads.html) page, and that is laid out so as +to be runnable, use `make-distribution.sh` in the project root directory. It can be configured with Maven profile settings and so on like the direct Maven build. Example: ./make-distribution.sh --name custom-spark --tgz -Psparkr -Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn - + For more information on usage, run `./make-distribution.sh --help` # Setting up Maven's Memory Usage @@ -74,7 +74,6 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro Hadoop versionProfile required - 1.x to 2.1.xhadoop-1 2.2.xhadoop-2.2 2.3.xhadoop-2.3 2.4.xhadoop-2.4 @@ -82,15 +81,6 @@ Because HDFS is not protocol-compatible across versions, if you want to read fro -For Apache Hadoop versions 1.x, Cloudera CDH "mr1" distributions, and other Hadoop versions without YARN, use: - -{% highlight bash %} -# Apache Hadoop 1.2.1 -mvn -Dhadoop.version=1.2.1 -Phadoop-1 -DskipTests clean package - -# Cloudera CDH 4.2.0 with MapReduce v1 -mvn -Dhadoop.version=2.0.0-mr1-cdh4.2.0 -Phadoop-1 -DskipTests clean package -{% endhighlight %} You can enable the `yarn` profile and optionally set the `yarn.version` property if it is different from `hadoop.version`. Spark only supports YARN versions 2.2.0 and later. diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index faaf154d243f..2810112f5294 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -53,8 +53,6 @@ The system currently supports three cluster managers: and service applications. * [Hadoop YARN](running-on-yarn.html) -- the resource manager in Hadoop 2. -In addition, Spark's [EC2 launch scripts](ec2-scripts.html) make it easy to launch a standalone -cluster on Amazon EC2. # Submitting Applications diff --git a/docs/configuration.md b/docs/configuration.md index 38d3d059f9d3..08392c39187b 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -595,7 +595,7 @@ Apart from these, the following properties are also available, and may be useful spark.io.compression.codec - snappy + lz4 The codec used to compress internal data such as RDD partitions, broadcast variables and shuffle outputs. By default, Spark provides three codecs: lz4, lzf, @@ -687,9 +687,10 @@ Apart from these, the following properties are also available, and may be useful spark.rdd.compress false - Whether to compress serialized RDD partitions (e.g. for - StorageLevel.MEMORY_ONLY_SER). Can save substantial space at the cost of some - extra CPU time. + Whether to compress serialized RDD partitions (e.g. for + StorageLevel.MEMORY_ONLY_SER in Java + and Scala or StorageLevel.MEMORY_ONLY in Python). + Can save substantial space at the cost of some extra CPU time. @@ -749,7 +750,7 @@ Apart from these, the following properties are also available, and may be useful spark.memory.offHeap.enabled - true + false If true, Spark will attempt to use off-heap memory for certain operations. If off-heap memory use is enabled, then spark.memory.offHeap.size must be positive. @@ -822,24 +823,6 @@ Apart from these, the following properties are also available, and may be useful too small, BlockManager might take a performance hit. - - spark.broadcast.factory - org.apache.spark.broadcast.
    TorrentBroadcastFactory - - Which broadcast implementation to use. - - - - spark.cleaner.ttl - (infinite) - - Duration (seconds) of how long Spark will remember any metadata (stages generated, tasks - generated, etc.). Periodic cleanups will ensure that metadata older than this duration will be - forgotten. This is useful for running Spark for many hours / days (for example, running 24/7 in - case of Spark Streaming applications). Note that any RDD that persists in memory for more than - this duration will be cleared as well. - - spark.executor.cores 1 in YARN mode, all the available cores on the worker in standalone mode. @@ -1016,14 +999,6 @@ Apart from these, the following properties are also available, and may be useful Port for all block managers to listen on. These exist on both the driver and the executors. - - spark.broadcast.port - (random) - - Port for the driver's HTTP broadcast server to listen on. - This is not relevant for torrent broadcast. - - spark.driver.host (local hostname) @@ -1443,8 +1418,8 @@ Apart from these, the following properties are also available, and may be useful

    Use spark.ssl.YYY.XXX settings to overwrite the global configuration for particular protocol denoted by YYY. Currently YYY can be - either akka for Akka based connections or fs for broadcast and - file server.

    + either akka for Akka based connections or fs for file + server.

    @@ -1599,6 +1574,24 @@ Apart from these, the following properties are also available, and may be useful How many batches the Spark Streaming UI and status APIs remember before garbage collecting. + + spark.streaming.driver.writeAheadLog.closeFileAfterWrite + false + + Whether to close the file after writing a write ahead log record on the driver. Set this to 'true' + when you want to use S3 (or any file system that does not support flushing) for the metadata WAL + on the driver. + + + + spark.streaming.receiver.writeAheadLog.closeFileAfterWrite + false + + Whether to close the file after writing a write ahead log record on the receivers. Set this to 'true' + when you want to use S3 (or any file system that does not support flushing) for the data WAL + on the receivers. + + #### SparkR diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md deleted file mode 100644 index 7f60f82b966f..000000000000 --- a/docs/ec2-scripts.md +++ /dev/null @@ -1,192 +0,0 @@ ---- -layout: global -title: Running Spark on EC2 ---- - -The `spark-ec2` script, located in Spark's `ec2` directory, allows you -to launch, manage and shut down Spark clusters on Amazon EC2. It automatically -sets up Spark and HDFS on the cluster for you. This guide describes -how to use `spark-ec2` to launch clusters, how to run jobs on them, and how -to shut them down. It assumes you've already signed up for an EC2 account -on the [Amazon Web Services site](http://aws.amazon.com/). - -`spark-ec2` is designed to manage multiple named clusters. You can -launch a new cluster (telling the script its size and giving it a name), -shutdown an existing cluster, or log into a cluster. Each cluster is -identified by placing its machines into EC2 security groups whose names -are derived from the name of the cluster. For example, a cluster named -`test` will contain a master node in a security group called -`test-master`, and a number of slave nodes in a security group called -`test-slaves`. The `spark-ec2` script will create these security groups -for you based on the cluster name you request. You can also use them to -identify machines belonging to each cluster in the Amazon EC2 Console. - - -# Before You Start - -- Create an Amazon EC2 key pair for yourself. This can be done by - logging into your Amazon Web Services account through the [AWS - console](http://aws.amazon.com/console/), clicking Key Pairs on the - left sidebar, and creating and downloading a key. Make sure that you - set the permissions for the private key file to `600` (i.e. only you - can read and write it) so that `ssh` will work. -- Whenever you want to use the `spark-ec2` script, set the environment - variables `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` to your - Amazon EC2 access key ID and secret access key. These can be - obtained from the [AWS homepage](http://aws.amazon.com/) by clicking - Account \> Security Credentials \> Access Credentials. - -# Launching a Cluster - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run - `./spark-ec2 -k -i -s launch `, - where `` is the name of your EC2 key pair (that you gave it - when you created it), `` is the private key file for your - key pair, `` is the number of slave nodes to launch (try - 1 at first), and `` is the name to give to your - cluster. - - For example: - - ```bash - export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU -export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a launch my-spark-cluster - ``` - -- After everything launches, check that the cluster scheduler is up and sees - all the slaves by going to its web UI, which will be printed at the end of - the script (typically `http://:8080`). - -You can also run `./spark-ec2 --help` to see more usage options. The -following options are worth pointing out: - -- `--instance-type=` can be used to specify an EC2 -instance type to use. For now, the script only supports 64-bit instance -types, and the default type is `m1.large` (which has 2 cores and 7.5 GB -RAM). Refer to the Amazon pages about [EC2 instance -types](http://aws.amazon.com/ec2/instance-types) and [EC2 -pricing](http://aws.amazon.com/ec2/#pricing) for information about other -instance types. -- `--region=` specifies an EC2 region in which to launch -instances. The default region is `us-east-1`. -- `--zone=` can be used to specify an EC2 availability zone -to launch instances in. Sometimes, you will get an error because there -is not enough capacity in one zone, and you should try to launch in -another. -- `--ebs-vol-size=` will attach an EBS volume with a given amount - of space to each node so that you can have a persistent HDFS cluster - on your nodes across cluster restarts (see below). -- `--spot-price=` will launch the worker nodes as - [Spot Instances](http://aws.amazon.com/ec2/spot-instances/), - bidding for the given maximum price (in dollars). -- `--spark-version=` will pre-load the cluster with the - specified version of Spark. The `` can be a version number - (e.g. "0.7.3") or a specific git hash. By default, a recent - version will be used. -- `--spark-git-repo=` will let you run a custom version of - Spark that is built from the given git repository. By default, the - [Apache Github mirror](https://github.com/apache/spark) will be used. - When using a custom Spark version, `--spark-version` must be set to git - commit hash, such as 317e114, instead of a version number. -- If one of your launches fails due to e.g. not having the right -permissions on your private key file, you can run `launch` with the -`--resume` option to restart the setup process on an existing cluster. - -# Launching a Cluster in a VPC - -- Run - `./spark-ec2 -k -i -s --vpc-id= --subnet-id= launch `, - where `` is the name of your EC2 key pair (that you gave it - when you created it), `` is the private key file for your - key pair, `` is the number of slave nodes to launch (try - 1 at first), `` is the name of your VPC, `` is the - name of your subnet, and `` is the name to give to your - cluster. - - For example: - - ```bash - export AWS_SECRET_ACCESS_KEY=AaBbCcDdEeFGgHhIiJjKkLlMmNnOoPpQqRrSsTtU -export AWS_ACCESS_KEY_ID=ABCDEFG1234567890123 -./spark-ec2 --key-pair=awskey --identity-file=awskey.pem --region=us-west-1 --zone=us-west-1a --vpc-id=vpc-a28d24c7 --subnet-id=subnet-4eb27b39 --spark-version=1.1.0 launch my-spark-cluster - ``` - -# Running Applications - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run `./spark-ec2 -k -i login ` to - SSH into the cluster, where `` and `` are as - above. (This is just for convenience; you could also use - the EC2 console.) -- To deploy code or data within your cluster, you can log in and use the - provided script `~/spark-ec2/copy-dir`, which, - given a directory path, RSYNCs it to the same location on all the slaves. -- If your application needs to access large datasets, the fastest way to do - that is to load them from Amazon S3 or an Amazon EBS device into an - instance of the Hadoop Distributed File System (HDFS) on your nodes. - The `spark-ec2` script already sets up a HDFS instance for you. It's - installed in `/root/ephemeral-hdfs`, and can be accessed using the - `bin/hadoop` script in that directory. Note that the data in this - HDFS goes away when you stop and restart a machine. -- There is also a *persistent HDFS* instance in - `/root/persistent-hdfs` that will keep data across cluster restarts. - Typically each node has relatively little space of persistent data - (about 3 GB), but you can use the `--ebs-vol-size` option to - `spark-ec2` to attach a persistent EBS volume to each node for - storing the persistent HDFS. -- Finally, if you get errors while running your application, look at the slave's logs - for that application inside of the scheduler work directory (/root/spark/work). You can - also view the status of the cluster using the web UI: `http://:8080`. - -# Configuration - -You can edit `/root/spark/conf/spark-env.sh` on each machine to set Spark configuration options, such -as JVM options. This file needs to be copied to **every machine** to reflect the change. The easiest way to -do this is to use a script we provide called `copy-dir`. First edit your `spark-env.sh` file on the master, -then run `~/spark-ec2/copy-dir /root/spark/conf` to RSYNC it to all the workers. - -The [configuration guide](configuration.html) describes the available configuration options. - -# Terminating a Cluster - -***Note that there is no way to recover data on EC2 nodes after shutting -them down! Make sure you have copied everything important off the nodes -before stopping them.*** - -- Go into the `ec2` directory in the release of Spark you downloaded. -- Run `./spark-ec2 destroy `. - -# Pausing and Restarting Clusters - -The `spark-ec2` script also supports pausing a cluster. In this case, -the VMs are stopped but not terminated, so they -***lose all data on ephemeral disks*** but keep the data in their -root partitions and their `persistent-hdfs`. Stopped machines will not -cost you any EC2 cycles, but ***will*** continue to cost money for EBS -storage. - -- To stop one of your clusters, go into the `ec2` directory and run -`./spark-ec2 --region= stop `. -- To restart it later, run -`./spark-ec2 -i --region= start `. -- To ultimately destroy the cluster and stop consuming EBS space, run -`./spark-ec2 --region= destroy ` as described in the previous -section. - -# Limitations - -- Support for "cluster compute" nodes is limited -- there's no way to specify a - locality group. However, you can launch slave nodes in your - `-slaves` group manually and then use `spark-ec2 launch - --resume` to start a cluster with them. - -If you have a patch or suggestion for one of these limitations, feel free to -[contribute](contributing-to-spark.html) it! - -# Accessing Data in S3 - -Spark's file interface allows it to process data in Amazon S3 using the same URI formats that are supported for Hadoop. You can specify a path in S3 as input through a URI of the form `s3n:///path`. To provide AWS credentials for S3 access, launch the Spark cluster with the option `--copy-aws-credentials`. Full instructions on S3 access using the Hadoop input libraries can be found on the [Hadoop S3 page](http://wiki.apache.org/hadoop/AmazonS3). - -In addition to using a single input file, you can also use a directory of files as input by simply giving the path to the directory. diff --git a/docs/index.md b/docs/index.md index ae26f97c86c2..9dfc52a2bdc9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -64,7 +64,7 @@ To run Spark interactively in a R interpreter, use `bin/sparkR`: ./bin/sparkR --master local[2] Example applications are also provided in R. For example, - + ./bin/spark-submit examples/src/main/r/dataframe.R # Launching on a Cluster @@ -73,7 +73,6 @@ The Spark [cluster mode overview](cluster-overview.html) explains the key concep Spark can run both by itself, or over several existing cluster managers. It currently provides several options for deployment: -* [Amazon EC2](ec2-scripts.html): our EC2 scripts let you launch a cluster in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): simplest way to deploy Spark on a private cluster * [Apache Mesos](running-on-mesos.html) * [Hadoop YARN](running-on-yarn.html) @@ -103,7 +102,7 @@ options for deployment: * [Cluster Overview](cluster-overview.html): overview of concepts and components when running on a cluster * [Submitting Applications](submitting-applications.html): packaging and deploying applications * Deployment modes: - * [Amazon EC2](ec2-scripts.html): scripts that let you launch a cluster on EC2 in about 5 minutes + * [Amazon EC2](https://github.com/amplab/spark-ec2): scripts that let you launch a cluster on EC2 in about 5 minutes * [Standalone Deploy Mode](spark-standalone.html): launch a standalone cluster quickly without a third-party cluster manager * [Mesos](running-on-mesos.html): deploy a private cluster using [Apache Mesos](http://mesos.apache.org) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 36327c6efeaf..6c587b3f0d8d 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -91,7 +91,7 @@ pre-packaged distribution. 2. Add this jar to the classpath of all `NodeManager`s in your cluster. 3. In the `yarn-site.xml` on each node, add `spark_shuffle` to `yarn.nodemanager.aux-services`, then set `yarn.nodemanager.aux-services.spark_shuffle.class` to -`org.apache.spark.network.yarn.YarnShuffleService` and `spark.shuffle.service.enabled` to true. +`org.apache.spark.network.yarn.YarnShuffleService`. 4. Restart all `NodeManager`s in your cluster. All other relevant configurations are optional and under the `spark.dynamicAllocation.*` and diff --git a/docs/ml-classification-regression.md b/docs/ml-classification-regression.md index d63438bf74c1..8ffc997b4bf5 100644 --- a/docs/ml-classification-regression.md +++ b/docs/ml-classification-regression.md @@ -535,7 +535,9 @@ The main differences between this API and the [original MLlib Decision Tree API] * use of DataFrame metadata to distinguish continuous and categorical features -The Pipelines API for Decision Trees offers a bit more functionality than the original API. In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities). +The Pipelines API for Decision Trees offers a bit more functionality than the original API. +In particular, for classification, users can get the predicted probability of each class (a.k.a. class conditional probabilities); +for regression, users can get the biased sample variance of prediction. Ensembles of trees (Random Forests and Gradient-Boosted Trees) are described below in the [Tree ensembles section](#tree-ensembles). @@ -605,6 +607,13 @@ All output columns are optional; to exclude an output column, set its correspond Vector of length # classes equal to rawPrediction normalized to a multinomial distribution Classification only + + varianceCol + Double + + The biased sample variance of prediction + Regression only + diff --git a/docs/ml-guide.md b/docs/ml-guide.md index 44a316a07dfe..1343753bce24 100644 --- a/docs/ml-guide.md +++ b/docs/ml-guide.md @@ -628,7 +628,7 @@ Currently, `spark.ml` supports model selection using the [`CrossValidator`](api/ The `Evaluator` can be a [`RegressionEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.RegressionEvaluator) for regression problems, a [`BinaryClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.BinaryClassificationEvaluator) for binary data, or a [`MultiClassClassificationEvaluator`](api/scala/index.html#org.apache.spark.ml.evaluation.MultiClassClassificationEvaluator) -for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetric` +for multiclass problems. The default metric used to choose the best `ParamMap` can be overriden by the `setMetricName` method in each of these evaluators. The `ParamMap` which produces the best evaluation metric (averaged over the `$k$` folds) is selected as the best model. @@ -951,4 +951,4 @@ model.transform(test) {% endhighlight %} - \ No newline at end of file + diff --git a/docs/programming-guide.md b/docs/programming-guide.md index f823b89a4b5e..bad25e63e89e 100644 --- a/docs/programming-guide.md +++ b/docs/programming-guide.md @@ -806,7 +806,7 @@ However, in `cluster` mode, what happens is more complicated, and the above may What is happening here is that the variables within the closure sent to each executor are now copies and thus, when **counter** is referenced within the `foreach` function, it's no longer the **counter** on the driver node. There is still a **counter** in the memory of the driver node but this is no longer visible to the executors! The executors only see the copy from the serialized closure. Thus, the final value of **counter** will still be zero since all operations on **counter** were referencing the value within the serialized closure. -To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#AccumLink). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. +To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. @@ -1091,7 +1091,7 @@ for details. foreach(func) - Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems. + Run a function func on each element of the dataset. This is usually done for side effects such as updating an Accumulator or interacting with external storage systems.
    Note: modifying variables other than Accumulators outside of the foreach() may result in undefined behavior. See Understanding closures for more details. @@ -1196,14 +1196,14 @@ storage levels is: partitions that don't fit on disk, and read them from there when they're needed. - MEMORY_ONLY_SER + MEMORY_ONLY_SER
    (Java and Scala) Store RDD as serialized Java objects (one byte array per partition). This is generally more space-efficient than deserialized objects, especially when using a fast serializer, but more CPU-intensive to read. - MEMORY_AND_DISK_SER + MEMORY_AND_DISK_SER
    (Java and Scala) Similar to MEMORY_ONLY_SER, but spill partitions that don't fit in memory to disk instead of recomputing them on the fly each time they're needed. @@ -1230,7 +1230,9 @@ storage levels is: -**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, so it does not matter whether you choose a serialized level.* +**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, +so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, +`MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, `DISK_ONLY_2` and `OFF_HEAP`.* Spark also automatically persists some intermediate data in shuffle operations (e.g. `reduceByKey`), even without users calling `persist`. This is done to avoid recomputing the entire input if a node fails during the shuffle. We still recommend users call `persist` on the resulting RDD if they plan to reuse it. @@ -1243,7 +1245,7 @@ efficiency. We recommend going through the following process to select one: This is the most CPU-efficient option, allowing operations on the RDDs to run as fast as possible. * If not, try using `MEMORY_ONLY_SER` and [selecting a fast serialization library](tuning.html) to -make the objects much more space-efficient, but still reasonably fast to access. +make the objects much more space-efficient, but still reasonably fast to access. (Java and Scala) * Don't spill to disk unless the functions that computed your datasets are expensive, or they filter a large amount of the data. Otherwise, recomputing a partition may be as fast as reading it from @@ -1336,7 +1338,7 @@ run on the cluster so that `v` is not shipped to the nodes more than once. In ad `v` should not be modified after it is broadcast in order to ensure that all nodes get the same value of the broadcast variable (e.g. if the variable is shipped to a new node later). -## Accumulators +## Accumulators Accumulators are variables that are only "added" to through an associative operation and can therefore be efficiently supported in parallel. They can be used to implement counters (as in diff --git a/docs/security.md b/docs/security.md index 0bfc791c5744..1b7741d4dd93 100644 --- a/docs/security.md +++ b/docs/security.md @@ -23,7 +23,7 @@ If your applications are using event logging, the directory where the event logs ## Encryption -Spark supports SSL for Akka and HTTP (for broadcast and file server) protocols. SASL encryption is +Spark supports SSL for Akka and HTTP (for file server) protocols. SASL encryption is supported for the block transfer service. Encryption is not yet supported for the WebUI. Encryption is not yet supported for data stored by Spark in temporary local storage, such as shuffle @@ -32,7 +32,7 @@ to configure your cluster manager to store application data on encrypted disks. ### SSL Configuration -Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for broadcast and file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). +Configuration for SSL is organized hierarchically. The user can configure the default SSL settings which will be used for all the supported communication protocols unless they are overwritten by protocol-specific settings. This way the user can easily provide the common settings for all the protocols without disabling the ability to configure each one individually. The common SSL settings are at `spark.ssl` namespace in Spark configuration, while Akka SSL configuration is at `spark.ssl.akka` and HTTP for file server SSL configuration is at `spark.ssl.fs`. The full breakdown can be found on the [configuration page](configuration.html). SSL must be configured on each node and configured for each component involved in communication using the particular protocol. @@ -160,15 +160,6 @@ configure those ports. spark.fileserver.port Jetty-based. Only used if Akka RPC backend is configured. - - Executor - Driver - (random) - HTTP Broadcast - spark.broadcast.port - Jetty-based. Not used by TorrentBroadcast, which sends data through the block manager - instead. - Executor / Driver Executor / Driver diff --git a/docs/sparkr.md b/docs/sparkr.md index 9ddd2eda3fe8..ea81532c611e 100644 --- a/docs/sparkr.md +++ b/docs/sparkr.md @@ -385,12 +385,12 @@ The following functions are masked by the SparkR package: Since part of SparkR is modeled on the `dplyr` package, certain functions in SparkR share the same names with those in `dplyr`. Depending on the load order of the two packages, some functions from the package loaded first are masked by those in the package loaded after. In such case, prefix such calls with the package name, for instance, `SparkR::cume_dist(x)` or `dplyr::cume_dist(x)`. - + You can inspect the search path in R with [`search()`](https://stat.ethz.ch/R-manual/R-devel/library/base/html/search.html) # Migration Guide -## Upgrading From SparkR 1.6 to 1.7 +## Upgrading From SparkR 1.5.x to 1.6 - - Until Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.7 to `error` to match the Scala API. + - Before Spark 1.6, the default mode for writes was `append`. It was changed in Spark 1.6.0 to `error` to match the Scala API. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 3f9a831eddc8..b05883361643 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1895,9 +1895,7 @@ the Data Sources API. The following options are supported: driver - The class name of the JDBC driver needed to connect to this URL. This class will be loaded - on the master and workers before running an JDBC commands to allow the driver to - register itself with the JDBC subsystem. + The class name of the JDBC driver to use to connect to this URL. diff --git a/docs/streaming-custom-receivers.md b/docs/streaming-custom-receivers.md index a75587a92adc..97db865daa37 100644 --- a/docs/streaming-custom-receivers.md +++ b/docs/streaming-custom-receivers.md @@ -257,9 +257,9 @@ The following table summarizes the characteristics of both types of receivers ## Implementing and Using a Custom Actor-based Receiver -Custom [Akka Actors](http://doc.akka.io/docs/akka/2.2.4/scala/actors.html) can also be used to +Custom [Akka Actors](http://doc.akka.io/docs/akka/2.3.11/scala/actors.html) can also be used to receive data. The [`ActorHelper`](api/scala/index.html#org.apache.spark.streaming.receiver.ActorHelper) -trait can be applied on any Akka actor, which allows received data to be stored in Spark using +trait can be mixed in to any Akka actor, which allows received data to be stored in Spark using `store(...)` methods. The supervisor strategy of this actor can be configured to handle failures, etc. {% highlight scala %} @@ -273,8 +273,8 @@ class CustomActor extends Actor with ActorHelper { And a new input stream can be created with this custom actor as {% highlight scala %} -// Assuming ssc is the StreamingContext -val lines = ssc.actorStream[String](Props(new CustomActor()), "CustomReceiver") +val ssc: StreamingContext = ... +val lines = ssc.actorStream[String](Props[CustomActor], "CustomReceiver") {% endhighlight %} See [ActorWordCount.scala](https://github.com/apache/spark/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala) diff --git a/docs/streaming-kafka-integration.md b/docs/streaming-kafka-integration.md index 5be73c42560f..9454714eeb9c 100644 --- a/docs/streaming-kafka-integration.md +++ b/docs/streaming-kafka-integration.md @@ -104,6 +104,7 @@ Next, we discuss how to use this approach in your streaming application. [key class], [value class], [key decoder class], [value decoder class] ]( streamingContext, [map of Kafka parameters], [set of topics to consume]) + You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. See the [API docs](api/scala/index.html#org.apache.spark.streaming.kafka.KafkaUtils$) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/DirectKafkaWordCount.scala). @@ -115,6 +116,7 @@ Next, we discuss how to use this approach in your streaming application. [key class], [value class], [key decoder class], [value decoder class], [map of Kafka parameters], [set of topics to consume]); + You can also pass a `messageHandler` to `createDirectStream` to access `MessageAndMetadata` that contains metadata about the current message and transform it to any desired type. See the [API docs](api/java/index.html?org/apache/spark/streaming/kafka/KafkaUtils.html) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaDirectKafkaWordCount.java). @@ -123,6 +125,7 @@ Next, we discuss how to use this approach in your streaming application. from pyspark.streaming.kafka import KafkaUtils directKafkaStream = KafkaUtils.createDirectStream(ssc, [topic], {"metadata.broker.list": brokers}) + You can also pass a `messageHandler` to `createDirectStream` to access `KafkaMessageAndMetadata` that contains metadata about the current message and transform it to any desired type. By default, the Python API will decode Kafka data as UTF8 encoded strings. You can specify your custom decoding function to decode the byte arrays in Kafka records to any arbitrary data type. See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kafka.KafkaUtils) and the [example]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/direct_kafka_wordcount.py). diff --git a/docs/streaming-kinesis-integration.md b/docs/streaming-kinesis-integration.md index 238a911a9199..07194b0a6b75 100644 --- a/docs/streaming-kinesis-integration.md +++ b/docs/streaming-kinesis-integration.md @@ -23,7 +23,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m **Note that by linking to this library, you will include [ASL](https://aws.amazon.com/asl/)-licensed code in your application.** -2. **Programming:** In the streaming application code, import `KinesisUtils` and create the input DStream as follows: +2. **Programming:** In the streaming application code, import `KinesisUtils` and create the input DStream of byte array as follows:
    @@ -36,7 +36,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) See the [API docs](api/scala/index.html#org.apache.spark.streaming.kinesis.KinesisUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the Running the Example section for instructions on how to run the example. + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/scala/org/apache/spark/examples/streaming/KinesisWordCountASL.scala). Refer to the [Running the Example](#running-the-example) subsection for instructions on how to run the example.
    @@ -49,7 +49,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2); See the [API docs](api/java/index.html?org/apache/spark/streaming/kinesis/KinesisUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the next subsection for instructions to run the example. + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/java/org/apache/spark/examples/streaming/JavaKinesisWordCountASL.java). Refer to the [Running the Example](#running-the-example) subsection for instructions to run the example.
    @@ -60,18 +60,47 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2) See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.kinesis.KinesisUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the next subsection for instructions to run the example. + and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/extras/kinesis-asl/src/main/python/examples/streaming/kinesis_wordcount_asl.py). Refer to the [Running the Example](#running-the-example) subsection for instructions to run the example.
    - - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream + You may also provide a "message handler function" that takes a Kinesis `Record` and returns a generic object `T`, in case you would like to use other data included in a `Record` such as partition key. This is currently only supported in Scala and Java. - - `[Kineiss app name]`: The application name that will be used to checkpoint the Kinesis +
    +
    + + import org.apache.spark.streaming.Duration + import org.apache.spark.streaming.kinesis._ + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream + + val kinesisStream = KinesisUtils.createStream[T]( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2, + [message handler]) + +
    +
    + + import org.apache.spark.streaming.Duration; + import org.apache.spark.streaming.kinesis.*; + import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; + + JavaReceiverInputDStream kinesisStream = KinesisUtils.createStream( + streamingContext, [Kinesis app name], [Kinesis stream name], [endpoint URL], + [region name], [initial position], [checkpoint interval], StorageLevel.MEMORY_AND_DISK_2, + [message handler], [class T]); + +
    +
    + + - `streamingContext`: StreamingContext containg an application name used by Kinesis to tie this Kinesis application to the Kinesis stream + + - `[Kinesis app name]`: The application name that will be used to checkpoint the Kinesis sequence numbers in DynamoDB table. - The application name must be unique for a given account and region. - If the table exists but has incorrect checkpoint information (for a different stream, or - old expired sequenced numbers), then there may be temporary errors. + old expired sequenced numbers), then there may be temporary errors. - `[Kinesis stream name]`: The Kinesis stream that this streaming application will pull data from. @@ -83,6 +112,8 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m - `[initial position]`: Can be either `InitialPositionInStream.TRIM_HORIZON` or `InitialPositionInStream.LATEST` (see Kinesis Checkpointing section and Amazon Kinesis API documentation for more details). + - `[message handler]`: A function that takes a Kinesis `Record` and outputs generic `T`. + In other versions of the API, you can also specify the AWS access key and secret key directly. 3. **Deploying:** Package `spark-streaming-kinesis-asl_{{site.SCALA_BINARY_VERSION}}` and its dependencies (except `spark-core_{{site.SCALA_BINARY_VERSION}}` and `spark-streaming_{{site.SCALA_BINARY_VERSION}}` which are provided by `spark-submit`) into the application JAR. Then use `spark-submit` to launch your application (see [Deploying section](streaming-programming-guide.html#deploying-applications) in the main programming guide). @@ -99,7 +130,7 @@ A Kinesis stream can be set up at one of the valid Kinesis endpoints with 1 or m Spark Streaming Kinesis Architecture

    @@ -165,11 +196,16 @@ To run the example, This will push 1000 lines per second of 10 random numbers per line to the Kinesis stream. This data should then be received and processed by the running example. +#### Record De-aggregation + +When data is generated using the [Kinesis Producer Library (KPL)](http://docs.aws.amazon.com/kinesis/latest/dev/developing-producers-with-kpl.html), messages may be aggregated for cost savings. Spark Streaming will automatically +de-aggregate records during consumption. + #### Kinesis Checkpointing - Each Kinesis input DStream periodically stores the current position of the stream in the backing DynamoDB table. This allows the system to recover from failures and continue processing where the DStream left off. - Checkpointing too frequently will cause excess load on the AWS checkpoint storage layer and may lead to AWS throttling. The provided example handles this throttling with a random-backoff-retry strategy. - If no Kinesis checkpoint info exists when the input DStream starts, it will start either from the oldest record available (InitialPositionInStream.TRIM_HORIZON) or from the latest tip (InitialPostitionInStream.LATEST). This is configurable. -- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). +- InitialPositionInStream.LATEST could lead to missed records if data is added to the stream while no input DStreams are running (and no checkpoint info is being stored). - InitialPositionInStream.TRIM_HORIZON may lead to duplicate processing of records where the impact is dependent on checkpoint frequency and processing idempotency. diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md index ed6b28c28213..8fd075d02b78 100644 --- a/docs/streaming-programming-guide.md +++ b/docs/streaming-programming-guide.md @@ -881,7 +881,6 @@ Scala code, take a look at the example
    {% highlight java %} -import com.google.common.base.Optional; Function2, Optional, Optional> updateFunction = new Function2, Optional, Optional>() { @Override public Optional call(List values, Optional state) { @@ -1415,6 +1414,171 @@ Note that the connections in the pool should be lazily created on demand and tim *** +## Accumulators and Broadcast Variables + +[Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) cannot be recovered from checkpoint in Spark Streaming. If you enable checkpointing and use [Accumulators](programming-guide.html#accumulators) or [Broadcast variables](programming-guide.html#broadcast-variables) as well, you'll have to create lazily instantiated singleton instances for [Accumulators](programming-guide.html#accumulators) and [Broadcast variables](programming-guide.html#broadcast-variables) so that they can be re-instantiated after the driver restarts on failure. This is shown in the following example. + +
    +
    +{% highlight scala %} + +object WordBlacklist { + + @volatile private var instance: Broadcast[Seq[String]] = null + + def getInstance(sc: SparkContext): Broadcast[Seq[String]] = { + if (instance == null) { + synchronized { + if (instance == null) { + val wordBlacklist = Seq("a", "b", "c") + instance = sc.broadcast(wordBlacklist) + } + } + } + instance + } +} + +object DroppedWordsCounter { + + @volatile private var instance: Accumulator[Long] = null + + def getInstance(sc: SparkContext): Accumulator[Long] = { + if (instance == null) { + synchronized { + if (instance == null) { + instance = sc.accumulator(0L, "WordsInBlacklistCounter") + } + } + } + instance + } +} + +wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { + // Get or register the blacklist Broadcast + val blacklist = WordBlacklist.getInstance(rdd.sparkContext) + // Get or register the droppedWordsCounter Accumulator + val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) + // Use blacklist to drop words and use droppedWordsCounter to count them + val counts = rdd.filter { case (word, count) => + if (blacklist.value.contains(word)) { + droppedWordsCounter += count + false + } else { + true + } + }.collect() + val output = "Counts at time " + time + " " + counts +}) + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala). +
    +
    +{% highlight java %} + +class JavaWordBlacklist { + + private static volatile Broadcast> instance = null; + + public static Broadcast> getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaWordBlacklist.class) { + if (instance == null) { + List wordBlacklist = Arrays.asList("a", "b", "c"); + instance = jsc.broadcast(wordBlacklist); + } + } + } + return instance; + } +} + +class JavaDroppedWordsCounter { + + private static volatile Accumulator instance = null; + + public static Accumulator getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaDroppedWordsCounter.class) { + if (instance == null) { + instance = jsc.accumulator(0, "WordsInBlacklistCounter"); + } + } + } + return instance; + } +} + +wordCounts.foreachRDD(new Function2, Time, Void>() { + @Override + public Void call(JavaPairRDD rdd, Time time) throws IOException { + // Get or register the blacklist Broadcast + final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the droppedWordsCounter Accumulator + final Accumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); + // Use blacklist to drop words and use droppedWordsCounter to count them + String counts = rdd.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 wordCount) throws Exception { + if (blacklist.value().contains(wordCount._1())) { + droppedWordsCounter.add(wordCount._2()); + return false; + } else { + return true; + } + } + }).collect().toString(); + String output = "Counts at time " + time + " " + counts; + } +} + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java). +
    +
    +{% highlight python %} + +def getWordBlacklist(sparkContext): + if ('wordBlacklist' not in globals()): + globals()['wordBlacklist'] = sparkContext.broadcast(["a", "b", "c"]) + return globals()['wordBlacklist'] + +def getDroppedWordsCounter(sparkContext): + if ('droppedWordsCounter' not in globals()): + globals()['droppedWordsCounter'] = sparkContext.accumulator(0) + return globals()['droppedWordsCounter'] + +def echo(time, rdd): + # Get or register the blacklist Broadcast + blacklist = getWordBlacklist(rdd.context) + # Get or register the droppedWordsCounter Accumulator + droppedWordsCounter = getDroppedWordsCounter(rdd.context) + + # Use blacklist to drop words and use droppedWordsCounter to count them + def filterFunc(wordCount): + if wordCount[0] in blacklist.value: + droppedWordsCounter.add(wordCount[1]) + False + else: + True + + counts = "Counts at time %s %s" % (time, rdd.filter(filterFunc).collect()) + +wordCounts.foreachRDD(echo) + +{% endhighlight %} + +See the full [source code]({{site.SPARK_GITHUB_URL}}/blob/master/examples/src/main/python/streaming/recoverable_network_wordcount.py). + +
    +
    + +*** + ## DataFrame and SQL Operations You can easily use [DataFrames and SQL](sql-programming-guide.html) operations on streaming data. You have to create a SQLContext using the SparkContext that the StreamingContext is using. Furthermore this has to done such that it can be restarted on driver failures. This is done by creating a lazily instantiated singleton instance of SQLContext. This is shown in the following example. It modifies the earlier [word count example](#a-quick-example) to generate word counts using DataFrames and SQL. Each RDD is converted to a DataFrame, registered as a temporary table and then queried using SQL. @@ -1820,7 +1984,11 @@ To run a Spark Streaming applications, you need to have the following. to increase aggregate throughput. Additionally, it is recommended that the replication of the received data within Spark be disabled when the write ahead log is enabled as the log is already stored in a replicated storage system. This can be done by setting the storage level for the - input stream to `StorageLevel.MEMORY_AND_DISK_SER`. + input stream to `StorageLevel.MEMORY_AND_DISK_SER`. While using S3 (or any file system that + does not support flushing) for _write ahead logs_, please remember to enable + `spark.streaming.driver.writeAheadLog.closeFileAfterWrite` and + `spark.streaming.receiver.writeAheadLog.closeFileAfterWrite`. See + [Spark Streaming Configuration](configuration.html#spark-streaming) for more details. - *Setting the max receiving rate* - If the cluster resources is not large enough for the streaming application to process data as fast as it is being received, the receivers can be rate limited @@ -1858,12 +2026,6 @@ contains serialized Scala/Java/Python objects and trying to deserialize objects modified classes may lead to errors. In this case, either start the upgraded app with a different checkpoint directory, or delete the previous checkpoint directory. -### Other Considerations -{:.no_toc} -If the data is being received by the receivers faster than what can be processed, -you can limit the rate by setting the [configuration parameter](configuration.html#spark-streaming) -`spark.streaming.receiver.maxRate`. - *** ## Monitoring Applications diff --git a/ec2/README b/ec2/README deleted file mode 100644 index 72434f24bf98..000000000000 --- a/ec2/README +++ /dev/null @@ -1,4 +0,0 @@ -This folder contains a script, spark-ec2, for launching Spark clusters on -Amazon EC2. Usage instructions are available online at: - -http://spark.apache.org/docs/latest/ec2-scripts.html diff --git a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh b/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh deleted file mode 100644 index 4f3e8da809f7..000000000000 --- a/ec2/deploy.generic/root/spark-ec2/ec2-variables.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/usr/bin/env bash - -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# These variables are automatically filled in by the spark-ec2 script. -export MASTERS="{{master_list}}" -export SLAVES="{{slave_list}}" -export HDFS_DATA_DIRS="{{hdfs_data_dirs}}" -export MAPRED_LOCAL_DIRS="{{mapred_local_dirs}}" -export SPARK_LOCAL_DIRS="{{spark_local_dirs}}" -export MODULES="{{modules}}" -export SPARK_VERSION="{{spark_version}}" -export TACHYON_VERSION="{{tachyon_version}}" -export HADOOP_MAJOR_VERSION="{{hadoop_major_version}}" -export SWAP_MB="{{swap}}" -export SPARK_WORKER_INSTANCES="{{spark_worker_instances}}" -export SPARK_MASTER_OPTS="{{spark_master_opts}}" -export AWS_ACCESS_KEY_ID="{{aws_access_key_id}}" -export AWS_SECRET_ACCESS_KEY="{{aws_secret_access_key}}" diff --git a/ec2/spark-ec2 b/ec2/spark-ec2 deleted file mode 100755 index 26e7d2265569..000000000000 --- a/ec2/spark-ec2 +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/sh - -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -# Preserve the user's CWD so that relative paths are passed correctly to -#+ the underlying Python script. -SPARK_EC2_DIR="$(dirname "$0")" - -python -Wdefault "${SPARK_EC2_DIR}/spark_ec2.py" "$@" diff --git a/ec2/spark_ec2.py b/ec2/spark_ec2.py deleted file mode 100755 index 19d5980560fe..000000000000 --- a/ec2/spark_ec2.py +++ /dev/null @@ -1,1530 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- - -# -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -from __future__ import division, print_function, with_statement - -import codecs -import hashlib -import itertools -import logging -import os -import os.path -import pipes -import random -import shutil -import string -from stat import S_IRUSR -import subprocess -import sys -import tarfile -import tempfile -import textwrap -import time -import warnings -from datetime import datetime -from optparse import OptionParser -from sys import stderr - -if sys.version < "3": - from urllib2 import urlopen, Request, HTTPError -else: - from urllib.request import urlopen, Request - from urllib.error import HTTPError - raw_input = input - xrange = range - -SPARK_EC2_VERSION = "1.6.0" -SPARK_EC2_DIR = os.path.dirname(os.path.realpath(__file__)) - -VALID_SPARK_VERSIONS = set([ - "0.7.3", - "0.8.0", - "0.8.1", - "0.9.0", - "0.9.1", - "0.9.2", - "1.0.0", - "1.0.1", - "1.0.2", - "1.1.0", - "1.1.1", - "1.2.0", - "1.2.1", - "1.3.0", - "1.3.1", - "1.4.0", - "1.4.1", - "1.5.0", - "1.5.1", - "1.5.2", - "1.6.0", -]) - -SPARK_TACHYON_MAP = { - "1.0.0": "0.4.1", - "1.0.1": "0.4.1", - "1.0.2": "0.4.1", - "1.1.0": "0.5.0", - "1.1.1": "0.5.0", - "1.2.0": "0.5.0", - "1.2.1": "0.5.0", - "1.3.0": "0.5.0", - "1.3.1": "0.5.0", - "1.4.0": "0.6.4", - "1.4.1": "0.6.4", - "1.5.0": "0.7.1", - "1.5.1": "0.7.1", - "1.5.2": "0.7.1", - "1.6.0": "0.8.2", -} - -DEFAULT_SPARK_VERSION = SPARK_EC2_VERSION -DEFAULT_SPARK_GITHUB_REPO = "https://github.com/apache/spark" - -# Default location to get the spark-ec2 scripts (and ami-list) from -DEFAULT_SPARK_EC2_GITHUB_REPO = "https://github.com/amplab/spark-ec2" -DEFAULT_SPARK_EC2_BRANCH = "branch-1.5" - - -def setup_external_libs(libs): - """ - Download external libraries from PyPI to SPARK_EC2_DIR/lib/ and prepend them to our PATH. - """ - PYPI_URL_PREFIX = "https://pypi.python.org/packages/source" - SPARK_EC2_LIB_DIR = os.path.join(SPARK_EC2_DIR, "lib") - - if not os.path.exists(SPARK_EC2_LIB_DIR): - print("Downloading external libraries that spark-ec2 needs from PyPI to {path}...".format( - path=SPARK_EC2_LIB_DIR - )) - print("This should be a one-time operation.") - os.mkdir(SPARK_EC2_LIB_DIR) - - for lib in libs: - versioned_lib_name = "{n}-{v}".format(n=lib["name"], v=lib["version"]) - lib_dir = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name) - - if not os.path.isdir(lib_dir): - tgz_file_path = os.path.join(SPARK_EC2_LIB_DIR, versioned_lib_name + ".tar.gz") - print(" - Downloading {lib}...".format(lib=lib["name"])) - download_stream = urlopen( - "{prefix}/{first_letter}/{lib_name}/{lib_name}-{lib_version}.tar.gz".format( - prefix=PYPI_URL_PREFIX, - first_letter=lib["name"][:1], - lib_name=lib["name"], - lib_version=lib["version"] - ) - ) - with open(tgz_file_path, "wb") as tgz_file: - tgz_file.write(download_stream.read()) - with open(tgz_file_path, "rb") as tar: - if hashlib.md5(tar.read()).hexdigest() != lib["md5"]: - print("ERROR: Got wrong md5sum for {lib}.".format(lib=lib["name"]), file=stderr) - sys.exit(1) - tar = tarfile.open(tgz_file_path) - tar.extractall(path=SPARK_EC2_LIB_DIR) - tar.close() - os.remove(tgz_file_path) - print(" - Finished downloading {lib}.".format(lib=lib["name"])) - sys.path.insert(1, lib_dir) - - -# Only PyPI libraries are supported. -external_libs = [ - { - "name": "boto", - "version": "2.34.0", - "md5": "5556223d2d0cc4d06dd4829e671dcecd" - } -] - -setup_external_libs(external_libs) - -import boto -from boto.ec2.blockdevicemapping import BlockDeviceMapping, BlockDeviceType, EBSBlockDeviceType -from boto import ec2 - - -class UsageError(Exception): - pass - - -# Configure and parse our command-line arguments -def parse_args(): - parser = OptionParser( - prog="spark-ec2", - version="%prog {v}".format(v=SPARK_EC2_VERSION), - usage="%prog [options] \n\n" - + " can be: launch, destroy, login, stop, start, get-master, reboot-slaves") - - parser.add_option( - "-s", "--slaves", type="int", default=1, - help="Number of slaves to launch (default: %default)") - parser.add_option( - "-w", "--wait", type="int", - help="DEPRECATED (no longer necessary) - Seconds to wait for nodes to start") - parser.add_option( - "-k", "--key-pair", - help="Key pair to use on instances") - parser.add_option( - "-i", "--identity-file", - help="SSH private key file to use for logging into instances") - parser.add_option( - "-p", "--profile", default=None, - help="If you have multiple profiles (AWS or boto config), you can configure " + - "additional, named profiles by using this option (default: %default)") - parser.add_option( - "-t", "--instance-type", default="m1.large", - help="Type of instance to launch (default: %default). " + - "WARNING: must be 64-bit; small instances won't work") - parser.add_option( - "-m", "--master-instance-type", default="", - help="Master instance type (leave empty for same as instance-type)") - parser.add_option( - "-r", "--region", default="us-east-1", - help="EC2 region used to launch instances in, or to find them in (default: %default)") - parser.add_option( - "-z", "--zone", default="", - help="Availability zone to launch instances in, or 'all' to spread " + - "slaves across multiple (an additional $0.01/Gb for bandwidth" + - "between zones applies) (default: a single zone chosen at random)") - parser.add_option( - "-a", "--ami", - help="Amazon Machine Image ID to use") - parser.add_option( - "-v", "--spark-version", default=DEFAULT_SPARK_VERSION, - help="Version of Spark to use: 'X.Y.Z' or a specific git hash (default: %default)") - parser.add_option( - "--spark-git-repo", - default=DEFAULT_SPARK_GITHUB_REPO, - help="Github repo from which to checkout supplied commit hash (default: %default)") - parser.add_option( - "--spark-ec2-git-repo", - default=DEFAULT_SPARK_EC2_GITHUB_REPO, - help="Github repo from which to checkout spark-ec2 (default: %default)") - parser.add_option( - "--spark-ec2-git-branch", - default=DEFAULT_SPARK_EC2_BRANCH, - help="Github repo branch of spark-ec2 to use (default: %default)") - parser.add_option( - "--deploy-root-dir", - default=None, - help="A directory to copy into / on the first master. " + - "Must be absolute. Note that a trailing slash is handled as per rsync: " + - "If you omit it, the last directory of the --deploy-root-dir path will be created " + - "in / before copying its contents. If you append the trailing slash, " + - "the directory is not created and its contents are copied directly into /. " + - "(default: %default).") - parser.add_option( - "--hadoop-major-version", default="1", - help="Major version of Hadoop. Valid options are 1 (Hadoop 1.0.4), 2 (CDH 4.2.0), yarn " + - "(Hadoop 2.4.0) (default: %default)") - parser.add_option( - "-D", metavar="[ADDRESS:]PORT", dest="proxy_port", - help="Use SSH dynamic port forwarding to create a SOCKS proxy at " + - "the given local address (for use with login)") - parser.add_option( - "--resume", action="store_true", default=False, - help="Resume installation on a previously launched cluster " + - "(for debugging)") - parser.add_option( - "--ebs-vol-size", metavar="SIZE", type="int", default=0, - help="Size (in GB) of each EBS volume.") - parser.add_option( - "--ebs-vol-type", default="standard", - help="EBS volume type (e.g. 'gp2', 'standard').") - parser.add_option( - "--ebs-vol-num", type="int", default=1, - help="Number of EBS volumes to attach to each node as /vol[x]. " + - "The volumes will be deleted when the instances terminate. " + - "Only possible on EBS-backed AMIs. " + - "EBS volumes are only attached if --ebs-vol-size > 0. " + - "Only support up to 8 EBS volumes.") - parser.add_option( - "--placement-group", type="string", default=None, - help="Which placement group to try and launch " + - "instances into. Assumes placement group is already " + - "created.") - parser.add_option( - "--swap", metavar="SWAP", type="int", default=1024, - help="Swap space to set up per node, in MB (default: %default)") - parser.add_option( - "--spot-price", metavar="PRICE", type="float", - help="If specified, launch slaves as spot instances with the given " + - "maximum price (in dollars)") - parser.add_option( - "--ganglia", action="store_true", default=True, - help="Setup Ganglia monitoring on cluster (default: %default). NOTE: " + - "the Ganglia page will be publicly accessible") - parser.add_option( - "--no-ganglia", action="store_false", dest="ganglia", - help="Disable Ganglia monitoring for the cluster") - parser.add_option( - "-u", "--user", default="root", - help="The SSH user you want to connect as (default: %default)") - parser.add_option( - "--delete-groups", action="store_true", default=False, - help="When destroying a cluster, delete the security groups that were created") - parser.add_option( - "--use-existing-master", action="store_true", default=False, - help="Launch fresh slaves, but use an existing stopped master if possible") - parser.add_option( - "--worker-instances", type="int", default=1, - help="Number of instances per worker: variable SPARK_WORKER_INSTANCES. Not used if YARN " + - "is used as Hadoop major version (default: %default)") - parser.add_option( - "--master-opts", type="string", default="", - help="Extra options to give to master through SPARK_MASTER_OPTS variable " + - "(e.g -Dspark.worker.timeout=180)") - parser.add_option( - "--user-data", type="string", default="", - help="Path to a user-data file (most AMIs interpret this as an initialization script)") - parser.add_option( - "--authorized-address", type="string", default="0.0.0.0/0", - help="Address to authorize on created security groups (default: %default)") - parser.add_option( - "--additional-security-group", type="string", default="", - help="Additional security group to place the machines in") - parser.add_option( - "--additional-tags", type="string", default="", - help="Additional tags to set on the machines; tags are comma-separated, while name and " + - "value are colon separated; ex: \"Task:MySparkProject,Env:production\"") - parser.add_option( - "--copy-aws-credentials", action="store_true", default=False, - help="Add AWS credentials to hadoop configuration to allow Spark to access S3") - parser.add_option( - "--subnet-id", default=None, - help="VPC subnet to launch instances in") - parser.add_option( - "--vpc-id", default=None, - help="VPC to launch instances in") - parser.add_option( - "--private-ips", action="store_true", default=False, - help="Use private IPs for instances rather than public if VPC/subnet " + - "requires that.") - parser.add_option( - "--instance-initiated-shutdown-behavior", default="stop", - choices=["stop", "terminate"], - help="Whether instances should terminate when shut down or just stop") - parser.add_option( - "--instance-profile-name", default=None, - help="IAM profile name to launch instances under") - - (opts, args) = parser.parse_args() - if len(args) != 2: - parser.print_help() - sys.exit(1) - (action, cluster_name) = args - - # Boto config check - # http://boto.cloudhackers.com/en/latest/boto_config_tut.html - home_dir = os.getenv('HOME') - if home_dir is None or not os.path.isfile(home_dir + '/.boto'): - if not os.path.isfile('/etc/boto.cfg'): - # If there is no boto config, check aws credentials - if not os.path.isfile(home_dir + '/.aws/credentials'): - if os.getenv('AWS_ACCESS_KEY_ID') is None: - print("ERROR: The environment variable AWS_ACCESS_KEY_ID must be set", - file=stderr) - sys.exit(1) - if os.getenv('AWS_SECRET_ACCESS_KEY') is None: - print("ERROR: The environment variable AWS_SECRET_ACCESS_KEY must be set", - file=stderr) - sys.exit(1) - return (opts, action, cluster_name) - - -# Get the EC2 security group of the given name, creating it if it doesn't exist -def get_or_make_group(conn, name, vpc_id): - groups = conn.get_all_security_groups() - group = [g for g in groups if g.name == name] - if len(group) > 0: - return group[0] - else: - print("Creating security group " + name) - return conn.create_security_group(name, "Spark EC2 group", vpc_id) - - -def get_validate_spark_version(version, repo): - if "." in version: - version = version.replace("v", "") - if version not in VALID_SPARK_VERSIONS: - print("Don't know about Spark version: {v}".format(v=version), file=stderr) - sys.exit(1) - return version - else: - github_commit_url = "{repo}/commit/{commit_hash}".format(repo=repo, commit_hash=version) - request = Request(github_commit_url) - request.get_method = lambda: 'HEAD' - try: - response = urlopen(request) - except HTTPError as e: - print("Couldn't validate Spark commit: {url}".format(url=github_commit_url), - file=stderr) - print("Received HTTP response code of {code}.".format(code=e.code), file=stderr) - sys.exit(1) - return version - - -# Source: http://aws.amazon.com/amazon-linux-ami/instance-type-matrix/ -# Last Updated: 2015-06-19 -# For easy maintainability, please keep this manually-inputted dictionary sorted by key. -EC2_INSTANCE_TYPES = { - "c1.medium": "pvm", - "c1.xlarge": "pvm", - "c3.large": "pvm", - "c3.xlarge": "pvm", - "c3.2xlarge": "pvm", - "c3.4xlarge": "pvm", - "c3.8xlarge": "pvm", - "c4.large": "hvm", - "c4.xlarge": "hvm", - "c4.2xlarge": "hvm", - "c4.4xlarge": "hvm", - "c4.8xlarge": "hvm", - "cc1.4xlarge": "hvm", - "cc2.8xlarge": "hvm", - "cg1.4xlarge": "hvm", - "cr1.8xlarge": "hvm", - "d2.xlarge": "hvm", - "d2.2xlarge": "hvm", - "d2.4xlarge": "hvm", - "d2.8xlarge": "hvm", - "g2.2xlarge": "hvm", - "g2.8xlarge": "hvm", - "hi1.4xlarge": "pvm", - "hs1.8xlarge": "pvm", - "i2.xlarge": "hvm", - "i2.2xlarge": "hvm", - "i2.4xlarge": "hvm", - "i2.8xlarge": "hvm", - "m1.small": "pvm", - "m1.medium": "pvm", - "m1.large": "pvm", - "m1.xlarge": "pvm", - "m2.xlarge": "pvm", - "m2.2xlarge": "pvm", - "m2.4xlarge": "pvm", - "m3.medium": "hvm", - "m3.large": "hvm", - "m3.xlarge": "hvm", - "m3.2xlarge": "hvm", - "m4.large": "hvm", - "m4.xlarge": "hvm", - "m4.2xlarge": "hvm", - "m4.4xlarge": "hvm", - "m4.10xlarge": "hvm", - "r3.large": "hvm", - "r3.xlarge": "hvm", - "r3.2xlarge": "hvm", - "r3.4xlarge": "hvm", - "r3.8xlarge": "hvm", - "t1.micro": "pvm", - "t2.micro": "hvm", - "t2.small": "hvm", - "t2.medium": "hvm", - "t2.large": "hvm", -} - - -def get_tachyon_version(spark_version): - return SPARK_TACHYON_MAP.get(spark_version, "") - - -# Attempt to resolve an appropriate AMI given the architecture and region of the request. -def get_spark_ami(opts): - if opts.instance_type in EC2_INSTANCE_TYPES: - instance_type = EC2_INSTANCE_TYPES[opts.instance_type] - else: - instance_type = "pvm" - print("Don't recognize %s, assuming type is pvm" % opts.instance_type, file=stderr) - - # URL prefix from which to fetch AMI information - ami_prefix = "{r}/{b}/ami-list".format( - r=opts.spark_ec2_git_repo.replace("https://github.com", "https://raw.github.com", 1), - b=opts.spark_ec2_git_branch) - - ami_path = "%s/%s/%s" % (ami_prefix, opts.region, instance_type) - reader = codecs.getreader("ascii") - try: - ami = reader(urlopen(ami_path)).read().strip() - except: - print("Could not resolve AMI at: " + ami_path, file=stderr) - sys.exit(1) - - print("Spark AMI: " + ami) - return ami - - -# Launch a cluster of the given name, by setting up its security groups, -# and then starting new instances in them. -# Returns a tuple of EC2 reservation objects for the master and slaves -# Fails if there already instances running in the cluster's groups. -def launch_cluster(conn, opts, cluster_name): - if opts.identity_file is None: - print("ERROR: Must provide an identity file (-i) for ssh connections.", file=stderr) - sys.exit(1) - - if opts.key_pair is None: - print("ERROR: Must provide a key pair name (-k) to use on instances.", file=stderr) - sys.exit(1) - - user_data_content = None - if opts.user_data: - with open(opts.user_data) as user_data_file: - user_data_content = user_data_file.read() - - print("Setting up security groups...") - master_group = get_or_make_group(conn, cluster_name + "-master", opts.vpc_id) - slave_group = get_or_make_group(conn, cluster_name + "-slaves", opts.vpc_id) - authorized_address = opts.authorized_address - if master_group.rules == []: # Group was just now created - if opts.vpc_id is None: - master_group.authorize(src_group=master_group) - master_group.authorize(src_group=slave_group) - else: - master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=master_group) - master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=master_group) - master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=master_group) - master_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=slave_group) - master_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=slave_group) - master_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=slave_group) - master_group.authorize('tcp', 22, 22, authorized_address) - master_group.authorize('tcp', 8080, 8081, authorized_address) - master_group.authorize('tcp', 18080, 18080, authorized_address) - master_group.authorize('tcp', 19999, 19999, authorized_address) - master_group.authorize('tcp', 50030, 50030, authorized_address) - master_group.authorize('tcp', 50070, 50070, authorized_address) - master_group.authorize('tcp', 60070, 60070, authorized_address) - master_group.authorize('tcp', 4040, 4045, authorized_address) - # Rstudio (GUI for R) needs port 8787 for web access - master_group.authorize('tcp', 8787, 8787, authorized_address) - # HDFS NFS gateway requires 111,2049,4242 for tcp & udp - master_group.authorize('tcp', 111, 111, authorized_address) - master_group.authorize('udp', 111, 111, authorized_address) - master_group.authorize('tcp', 2049, 2049, authorized_address) - master_group.authorize('udp', 2049, 2049, authorized_address) - master_group.authorize('tcp', 4242, 4242, authorized_address) - master_group.authorize('udp', 4242, 4242, authorized_address) - # RM in YARN mode uses 8088 - master_group.authorize('tcp', 8088, 8088, authorized_address) - if opts.ganglia: - master_group.authorize('tcp', 5080, 5080, authorized_address) - if slave_group.rules == []: # Group was just now created - if opts.vpc_id is None: - slave_group.authorize(src_group=master_group) - slave_group.authorize(src_group=slave_group) - else: - slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=master_group) - slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=master_group) - slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=master_group) - slave_group.authorize(ip_protocol='icmp', from_port=-1, to_port=-1, - src_group=slave_group) - slave_group.authorize(ip_protocol='tcp', from_port=0, to_port=65535, - src_group=slave_group) - slave_group.authorize(ip_protocol='udp', from_port=0, to_port=65535, - src_group=slave_group) - slave_group.authorize('tcp', 22, 22, authorized_address) - slave_group.authorize('tcp', 8080, 8081, authorized_address) - slave_group.authorize('tcp', 50060, 50060, authorized_address) - slave_group.authorize('tcp', 50075, 50075, authorized_address) - slave_group.authorize('tcp', 60060, 60060, authorized_address) - slave_group.authorize('tcp', 60075, 60075, authorized_address) - - # Check if instances are already running in our groups - existing_masters, existing_slaves = get_existing_cluster(conn, opts, cluster_name, - die_on_error=False) - if existing_slaves or (existing_masters and not opts.use_existing_master): - print("ERROR: There are already instances running in group %s or %s" % - (master_group.name, slave_group.name), file=stderr) - sys.exit(1) - - # Figure out Spark AMI - if opts.ami is None: - opts.ami = get_spark_ami(opts) - - # we use group ids to work around https://github.com/boto/boto/issues/350 - additional_group_ids = [] - if opts.additional_security_group: - additional_group_ids = [sg.id - for sg in conn.get_all_security_groups() - if opts.additional_security_group in (sg.name, sg.id)] - print("Launching instances...") - - try: - image = conn.get_all_images(image_ids=[opts.ami])[0] - except: - print("Could not find AMI " + opts.ami, file=stderr) - sys.exit(1) - - # Create block device mapping so that we can add EBS volumes if asked to. - # The first drive is attached as /dev/sds, 2nd as /dev/sdt, ... /dev/sdz - block_map = BlockDeviceMapping() - if opts.ebs_vol_size > 0: - for i in range(opts.ebs_vol_num): - device = EBSBlockDeviceType() - device.size = opts.ebs_vol_size - device.volume_type = opts.ebs_vol_type - device.delete_on_termination = True - block_map["/dev/sd" + chr(ord('s') + i)] = device - - # AWS ignores the AMI-specified block device mapping for M3 (see SPARK-3342). - if opts.instance_type.startswith('m3.'): - for i in range(get_num_disks(opts.instance_type)): - dev = BlockDeviceType() - dev.ephemeral_name = 'ephemeral%d' % i - # The first ephemeral drive is /dev/sdb. - name = '/dev/sd' + string.ascii_letters[i + 1] - block_map[name] = dev - - # Launch slaves - if opts.spot_price is not None: - # Launch spot instances with the requested price - print("Requesting %d slaves as spot instances with price $%.3f" % - (opts.slaves, opts.spot_price)) - zones = get_zones(conn, opts) - num_zones = len(zones) - i = 0 - my_req_ids = [] - for zone in zones: - num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - slave_reqs = conn.request_spot_instances( - price=opts.spot_price, - image_id=opts.ami, - launch_group="launch-group-%s" % cluster_name, - placement=zone, - count=num_slaves_this_zone, - key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_profile_name=opts.instance_profile_name) - my_req_ids += [req.id for req in slave_reqs] - i += 1 - - print("Waiting for spot instances to be granted...") - try: - while True: - time.sleep(10) - reqs = conn.get_all_spot_instance_requests() - id_to_req = {} - for r in reqs: - id_to_req[r.id] = r - active_instance_ids = [] - for i in my_req_ids: - if i in id_to_req and id_to_req[i].state == "active": - active_instance_ids.append(id_to_req[i].instance_id) - if len(active_instance_ids) == opts.slaves: - print("All %d slaves granted" % opts.slaves) - reservations = conn.get_all_reservations(active_instance_ids) - slave_nodes = [] - for r in reservations: - slave_nodes += r.instances - break - else: - print("%d of %d slaves granted, waiting longer" % ( - len(active_instance_ids), opts.slaves)) - except: - print("Canceling spot instance requests") - conn.cancel_spot_instance_requests(my_req_ids) - # Log a warning if any of these requests actually launched instances: - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - running = len(master_nodes) + len(slave_nodes) - if running: - print(("WARNING: %d instances are still running" % running), file=stderr) - sys.exit(0) - else: - # Launch non-spot instances - zones = get_zones(conn, opts) - num_zones = len(zones) - i = 0 - slave_nodes = [] - for zone in zones: - num_slaves_this_zone = get_partition(opts.slaves, num_zones, i) - if num_slaves_this_zone > 0: - slave_res = image.run( - key_name=opts.key_pair, - security_group_ids=[slave_group.id] + additional_group_ids, - instance_type=opts.instance_type, - placement=zone, - min_count=num_slaves_this_zone, - max_count=num_slaves_this_zone, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, - instance_profile_name=opts.instance_profile_name) - slave_nodes += slave_res.instances - print("Launched {s} slave{plural_s} in {z}, regid = {r}".format( - s=num_slaves_this_zone, - plural_s=('' if num_slaves_this_zone == 1 else 's'), - z=zone, - r=slave_res.id)) - i += 1 - - # Launch or resume masters - if existing_masters: - print("Starting master...") - for inst in existing_masters: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - master_nodes = existing_masters - else: - master_type = opts.master_instance_type - if master_type == "": - master_type = opts.instance_type - if opts.zone == 'all': - opts.zone = random.choice(conn.get_all_zones()).name - master_res = image.run( - key_name=opts.key_pair, - security_group_ids=[master_group.id] + additional_group_ids, - instance_type=master_type, - placement=opts.zone, - min_count=1, - max_count=1, - block_device_map=block_map, - subnet_id=opts.subnet_id, - placement_group=opts.placement_group, - user_data=user_data_content, - instance_initiated_shutdown_behavior=opts.instance_initiated_shutdown_behavior, - instance_profile_name=opts.instance_profile_name) - - master_nodes = master_res.instances - print("Launched master in %s, regid = %s" % (zone, master_res.id)) - - # This wait time corresponds to SPARK-4983 - print("Waiting for AWS to propagate instance metadata...") - time.sleep(15) - - # Give the instances descriptive names and set additional tags - additional_tags = {} - if opts.additional_tags.strip(): - additional_tags = dict( - map(str.strip, tag.split(':', 1)) for tag in opts.additional_tags.split(',') - ) - - for master in master_nodes: - master.add_tags( - dict(additional_tags, Name='{cn}-master-{iid}'.format(cn=cluster_name, iid=master.id)) - ) - - for slave in slave_nodes: - slave.add_tags( - dict(additional_tags, Name='{cn}-slave-{iid}'.format(cn=cluster_name, iid=slave.id)) - ) - - # Return all the instances - return (master_nodes, slave_nodes) - - -def get_existing_cluster(conn, opts, cluster_name, die_on_error=True): - """ - Get the EC2 instances in an existing cluster if available. - Returns a tuple of lists of EC2 instance objects for the masters and slaves. - """ - print("Searching for existing cluster {c} in region {r}...".format( - c=cluster_name, r=opts.region)) - - def get_instances(group_names): - """ - Get all non-terminated instances that belong to any of the provided security groups. - - EC2 reservation filters and instance states are documented here: - http://docs.aws.amazon.com/cli/latest/reference/ec2/describe-instances.html#options - """ - reservations = conn.get_all_reservations( - filters={"instance.group-name": group_names}) - instances = itertools.chain.from_iterable(r.instances for r in reservations) - return [i for i in instances if i.state not in ["shutting-down", "terminated"]] - - master_instances = get_instances([cluster_name + "-master"]) - slave_instances = get_instances([cluster_name + "-slaves"]) - - if any((master_instances, slave_instances)): - print("Found {m} master{plural_m}, {s} slave{plural_s}.".format( - m=len(master_instances), - plural_m=('' if len(master_instances) == 1 else 's'), - s=len(slave_instances), - plural_s=('' if len(slave_instances) == 1 else 's'))) - - if not master_instances and die_on_error: - print("ERROR: Could not find a master for cluster {c} in region {r}.".format( - c=cluster_name, r=opts.region), file=sys.stderr) - sys.exit(1) - - return (master_instances, slave_instances) - - -# Deploy configuration files and run setup scripts on a newly launched -# or started EC2 cluster. -def setup_cluster(conn, master_nodes, slave_nodes, opts, deploy_ssh_key): - master = get_dns_name(master_nodes[0], opts.private_ips) - if deploy_ssh_key: - print("Generating cluster's SSH key on master...") - key_setup = """ - [ -f ~/.ssh/id_rsa ] || - (ssh-keygen -q -t rsa -N '' -f ~/.ssh/id_rsa && - cat ~/.ssh/id_rsa.pub >> ~/.ssh/authorized_keys) - """ - ssh(master, opts, key_setup) - dot_ssh_tar = ssh_read(master, opts, ['tar', 'c', '.ssh']) - print("Transferring cluster's SSH key to slaves...") - for slave in slave_nodes: - slave_address = get_dns_name(slave, opts.private_ips) - print(slave_address) - ssh_write(slave_address, opts, ['tar', 'x'], dot_ssh_tar) - - modules = ['spark', 'ephemeral-hdfs', 'persistent-hdfs', - 'mapreduce', 'spark-standalone', 'tachyon', 'rstudio'] - - if opts.hadoop_major_version == "1": - modules = list(filter(lambda x: x != "mapreduce", modules)) - - if opts.ganglia: - modules.append('ganglia') - - # Clear SPARK_WORKER_INSTANCES if running on YARN - if opts.hadoop_major_version == "yarn": - opts.worker_instances = "" - - # NOTE: We should clone the repository before running deploy_files to - # prevent ec2-variables.sh from being overwritten - print("Cloning spark-ec2 scripts from {r}/tree/{b} on master...".format( - r=opts.spark_ec2_git_repo, b=opts.spark_ec2_git_branch)) - ssh( - host=master, - opts=opts, - command="rm -rf spark-ec2" - + " && " - + "git clone {r} -b {b} spark-ec2".format(r=opts.spark_ec2_git_repo, - b=opts.spark_ec2_git_branch) - ) - - print("Deploying files to master...") - deploy_files( - conn=conn, - root_dir=SPARK_EC2_DIR + "/" + "deploy.generic", - opts=opts, - master_nodes=master_nodes, - slave_nodes=slave_nodes, - modules=modules - ) - - if opts.deploy_root_dir is not None: - print("Deploying {s} to master...".format(s=opts.deploy_root_dir)) - deploy_user_files( - root_dir=opts.deploy_root_dir, - opts=opts, - master_nodes=master_nodes - ) - - print("Running setup on master...") - setup_spark_cluster(master, opts) - print("Done!") - - -def setup_spark_cluster(master, opts): - ssh(master, opts, "chmod u+x spark-ec2/setup.sh") - ssh(master, opts, "spark-ec2/setup.sh") - print("Spark standalone cluster started at http://%s:8080" % master) - - if opts.ganglia: - print("Ganglia started at http://%s:5080/ganglia" % master) - - -def is_ssh_available(host, opts, print_ssh_output=True): - """ - Check if SSH is available on a host. - """ - s = subprocess.Popen( - ssh_command(opts) + ['-t', '-t', '-o', 'ConnectTimeout=3', - '%s@%s' % (opts.user, host), stringify_command('true')], - stdout=subprocess.PIPE, - stderr=subprocess.STDOUT # we pipe stderr through stdout to preserve output order - ) - cmd_output = s.communicate()[0] # [1] is stderr, which we redirected to stdout - - if s.returncode != 0 and print_ssh_output: - # extra leading newline is for spacing in wait_for_cluster_state() - print(textwrap.dedent("""\n - Warning: SSH connection error. (This could be temporary.) - Host: {h} - SSH return code: {r} - SSH output: {o} - """).format( - h=host, - r=s.returncode, - o=cmd_output.strip() - )) - - return s.returncode == 0 - - -def is_cluster_ssh_available(cluster_instances, opts): - """ - Check if SSH is available on all the instances in a cluster. - """ - for i in cluster_instances: - dns_name = get_dns_name(i, opts.private_ips) - if not is_ssh_available(host=dns_name, opts=opts): - return False - else: - return True - - -def wait_for_cluster_state(conn, opts, cluster_instances, cluster_state): - """ - Wait for all the instances in the cluster to reach a designated state. - - cluster_instances: a list of boto.ec2.instance.Instance - cluster_state: a string representing the desired state of all the instances in the cluster - value can be 'ssh-ready' or a valid value from boto.ec2.instance.InstanceState such as - 'running', 'terminated', etc. - (would be nice to replace this with a proper enum: http://stackoverflow.com/a/1695250) - """ - sys.stdout.write( - "Waiting for cluster to enter '{s}' state.".format(s=cluster_state) - ) - sys.stdout.flush() - - start_time = datetime.now() - num_attempts = 0 - - while True: - time.sleep(5 * num_attempts) # seconds - - for i in cluster_instances: - i.update() - - max_batch = 100 - statuses = [] - for j in xrange(0, len(cluster_instances), max_batch): - batch = [i.id for i in cluster_instances[j:j + max_batch]] - statuses.extend(conn.get_all_instance_status(instance_ids=batch)) - - if cluster_state == 'ssh-ready': - if all(i.state == 'running' for i in cluster_instances) and \ - all(s.system_status.status == 'ok' for s in statuses) and \ - all(s.instance_status.status == 'ok' for s in statuses) and \ - is_cluster_ssh_available(cluster_instances, opts): - break - else: - if all(i.state == cluster_state for i in cluster_instances): - break - - num_attempts += 1 - - sys.stdout.write(".") - sys.stdout.flush() - - sys.stdout.write("\n") - - end_time = datetime.now() - print("Cluster is now in '{s}' state. Waited {t} seconds.".format( - s=cluster_state, - t=(end_time - start_time).seconds - )) - - -# Get number of local disks available for a given EC2 instance type. -def get_num_disks(instance_type): - # Source: http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/InstanceStorage.html - # Last Updated: 2015-06-19 - # For easy maintainability, please keep this manually-inputted dictionary sorted by key. - disks_by_instance = { - "c1.medium": 1, - "c1.xlarge": 4, - "c3.large": 2, - "c3.xlarge": 2, - "c3.2xlarge": 2, - "c3.4xlarge": 2, - "c3.8xlarge": 2, - "c4.large": 0, - "c4.xlarge": 0, - "c4.2xlarge": 0, - "c4.4xlarge": 0, - "c4.8xlarge": 0, - "cc1.4xlarge": 2, - "cc2.8xlarge": 4, - "cg1.4xlarge": 2, - "cr1.8xlarge": 2, - "d2.xlarge": 3, - "d2.2xlarge": 6, - "d2.4xlarge": 12, - "d2.8xlarge": 24, - "g2.2xlarge": 1, - "g2.8xlarge": 2, - "hi1.4xlarge": 2, - "hs1.8xlarge": 24, - "i2.xlarge": 1, - "i2.2xlarge": 2, - "i2.4xlarge": 4, - "i2.8xlarge": 8, - "m1.small": 1, - "m1.medium": 1, - "m1.large": 2, - "m1.xlarge": 4, - "m2.xlarge": 1, - "m2.2xlarge": 1, - "m2.4xlarge": 2, - "m3.medium": 1, - "m3.large": 1, - "m3.xlarge": 2, - "m3.2xlarge": 2, - "m4.large": 0, - "m4.xlarge": 0, - "m4.2xlarge": 0, - "m4.4xlarge": 0, - "m4.10xlarge": 0, - "r3.large": 1, - "r3.xlarge": 1, - "r3.2xlarge": 1, - "r3.4xlarge": 1, - "r3.8xlarge": 2, - "t1.micro": 0, - "t2.micro": 0, - "t2.small": 0, - "t2.medium": 0, - "t2.large": 0, - } - if instance_type in disks_by_instance: - return disks_by_instance[instance_type] - else: - print("WARNING: Don't know number of disks on instance type %s; assuming 1" - % instance_type, file=stderr) - return 1 - - -# Deploy the configuration file templates in a given local directory to -# a cluster, filling in any template parameters with information about the -# cluster (e.g. lists of masters and slaves). Files are only deployed to -# the first master instance in the cluster, and we expect the setup -# script to be run on that instance to copy them to other nodes. -# -# root_dir should be an absolute path to the directory with the files we want to deploy. -def deploy_files(conn, root_dir, opts, master_nodes, slave_nodes, modules): - active_master = get_dns_name(master_nodes[0], opts.private_ips) - - num_disks = get_num_disks(opts.instance_type) - hdfs_data_dirs = "/mnt/ephemeral-hdfs/data" - mapred_local_dirs = "/mnt/hadoop/mrlocal" - spark_local_dirs = "/mnt/spark" - if num_disks > 1: - for i in range(2, num_disks + 1): - hdfs_data_dirs += ",/mnt%d/ephemeral-hdfs/data" % i - mapred_local_dirs += ",/mnt%d/hadoop/mrlocal" % i - spark_local_dirs += ",/mnt%d/spark" % i - - cluster_url = "%s:7077" % active_master - - if "." in opts.spark_version: - # Pre-built Spark deploy - spark_v = get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - tachyon_v = get_tachyon_version(spark_v) - else: - # Spark-only custom deploy - spark_v = "%s|%s" % (opts.spark_git_repo, opts.spark_version) - tachyon_v = "" - print("Deploying Spark via git hash; Tachyon won't be set up") - modules = filter(lambda x: x != "tachyon", modules) - - master_addresses = [get_dns_name(i, opts.private_ips) for i in master_nodes] - slave_addresses = [get_dns_name(i, opts.private_ips) for i in slave_nodes] - worker_instances_str = "%d" % opts.worker_instances if opts.worker_instances else "" - template_vars = { - "master_list": '\n'.join(master_addresses), - "active_master": active_master, - "slave_list": '\n'.join(slave_addresses), - "cluster_url": cluster_url, - "hdfs_data_dirs": hdfs_data_dirs, - "mapred_local_dirs": mapred_local_dirs, - "spark_local_dirs": spark_local_dirs, - "swap": str(opts.swap), - "modules": '\n'.join(modules), - "spark_version": spark_v, - "tachyon_version": tachyon_v, - "hadoop_major_version": opts.hadoop_major_version, - "spark_worker_instances": worker_instances_str, - "spark_master_opts": opts.master_opts - } - - if opts.copy_aws_credentials: - template_vars["aws_access_key_id"] = conn.aws_access_key_id - template_vars["aws_secret_access_key"] = conn.aws_secret_access_key - else: - template_vars["aws_access_key_id"] = "" - template_vars["aws_secret_access_key"] = "" - - # Create a temp directory in which we will place all the files to be - # deployed after we substitue template parameters in them - tmp_dir = tempfile.mkdtemp() - for path, dirs, files in os.walk(root_dir): - if path.find(".svn") == -1: - dest_dir = os.path.join('/', path[len(root_dir):]) - local_dir = tmp_dir + dest_dir - if not os.path.exists(local_dir): - os.makedirs(local_dir) - for filename in files: - if filename[0] not in '#.~' and filename[-1] != '~': - dest_file = os.path.join(dest_dir, filename) - local_file = tmp_dir + dest_file - with open(os.path.join(path, filename)) as src: - with open(local_file, "w") as dest: - text = src.read() - for key in template_vars: - text = text.replace("{{" + key + "}}", template_vars[key]) - dest.write(text) - dest.close() - # rsync the whole directory over to the master machine - command = [ - 'rsync', '-rv', - '-e', stringify_command(ssh_command(opts)), - "%s/" % tmp_dir, - "%s@%s:/" % (opts.user, active_master) - ] - subprocess.check_call(command) - # Remove the temp directory we created above - shutil.rmtree(tmp_dir) - - -# Deploy a given local directory to a cluster, WITHOUT parameter substitution. -# Note that unlike deploy_files, this works for binary files. -# Also, it is up to the user to add (or not) the trailing slash in root_dir. -# Files are only deployed to the first master instance in the cluster. -# -# root_dir should be an absolute path. -def deploy_user_files(root_dir, opts, master_nodes): - active_master = get_dns_name(master_nodes[0], opts.private_ips) - command = [ - 'rsync', '-rv', - '-e', stringify_command(ssh_command(opts)), - "%s" % root_dir, - "%s@%s:/" % (opts.user, active_master) - ] - subprocess.check_call(command) - - -def stringify_command(parts): - if isinstance(parts, str): - return parts - else: - return ' '.join(map(pipes.quote, parts)) - - -def ssh_args(opts): - parts = ['-o', 'StrictHostKeyChecking=no'] - parts += ['-o', 'UserKnownHostsFile=/dev/null'] - if opts.identity_file is not None: - parts += ['-i', opts.identity_file] - return parts - - -def ssh_command(opts): - return ['ssh'] + ssh_args(opts) - - -# Run a command on a host through ssh, retrying up to five times -# and then throwing an exception if ssh continues to fail. -def ssh(host, opts, command): - tries = 0 - while True: - try: - return subprocess.check_call( - ssh_command(opts) + ['-t', '-t', '%s@%s' % (opts.user, host), - stringify_command(command)]) - except subprocess.CalledProcessError as e: - if tries > 5: - # If this was an ssh failure, provide the user with hints. - if e.returncode == 255: - raise UsageError( - "Failed to SSH to remote host {0}.\n" - "Please check that you have provided the correct --identity-file and " - "--key-pair parameters and try again.".format(host)) - else: - raise e - print("Error executing remote command, retrying after 30 seconds: {0}".format(e), - file=stderr) - time.sleep(30) - tries = tries + 1 - - -# Backported from Python 2.7 for compatiblity with 2.6 (See SPARK-1990) -def _check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output - - -def ssh_read(host, opts, command): - return _check_output( - ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)]) - - -def ssh_write(host, opts, command, arguments): - tries = 0 - while True: - proc = subprocess.Popen( - ssh_command(opts) + ['%s@%s' % (opts.user, host), stringify_command(command)], - stdin=subprocess.PIPE) - proc.stdin.write(arguments) - proc.stdin.close() - status = proc.wait() - if status == 0: - break - elif tries > 5: - raise RuntimeError("ssh_write failed with error %s" % proc.returncode) - else: - print("Error {0} while executing remote command, retrying after 30 seconds". - format(status), file=stderr) - time.sleep(30) - tries = tries + 1 - - -# Gets a list of zones to launch instances in -def get_zones(conn, opts): - if opts.zone == 'all': - zones = [z.name for z in conn.get_all_zones()] - else: - zones = [opts.zone] - return zones - - -# Gets the number of items in a partition -def get_partition(total, num_partitions, current_partitions): - num_slaves_this_zone = total // num_partitions - if (total % num_partitions) - current_partitions > 0: - num_slaves_this_zone += 1 - return num_slaves_this_zone - - -# Gets the IP address, taking into account the --private-ips flag -def get_ip_address(instance, private_ips=False): - ip = instance.ip_address if not private_ips else \ - instance.private_ip_address - return ip - - -# Gets the DNS name, taking into account the --private-ips flag -def get_dns_name(instance, private_ips=False): - dns = instance.public_dns_name if not private_ips else \ - instance.private_ip_address - if not dns: - raise UsageError("Failed to determine hostname of {0}.\n" - "Please check that you provided --private-ips if " - "necessary".format(instance)) - return dns - - -def real_main(): - (opts, action, cluster_name) = parse_args() - - # Input parameter validation - get_validate_spark_version(opts.spark_version, opts.spark_git_repo) - - if opts.wait is not None: - # NOTE: DeprecationWarnings are silent in 2.7+ by default. - # To show them, run Python with the -Wdefault switch. - # See: https://docs.python.org/3.5/whatsnew/2.7.html - warnings.warn( - "This option is deprecated and has no effect. " - "spark-ec2 automatically waits as long as necessary for clusters to start up.", - DeprecationWarning - ) - - if opts.identity_file is not None: - if not os.path.exists(opts.identity_file): - print("ERROR: The identity file '{f}' doesn't exist.".format(f=opts.identity_file), - file=stderr) - sys.exit(1) - - file_mode = os.stat(opts.identity_file).st_mode - if not (file_mode & S_IRUSR) or not oct(file_mode)[-2:] == '00': - print("ERROR: The identity file must be accessible only by you.", file=stderr) - print('You can fix this with: chmod 400 "{f}"'.format(f=opts.identity_file), - file=stderr) - sys.exit(1) - - if opts.instance_type not in EC2_INSTANCE_TYPES: - print("Warning: Unrecognized EC2 instance type for instance-type: {t}".format( - t=opts.instance_type), file=stderr) - - if opts.master_instance_type != "": - if opts.master_instance_type not in EC2_INSTANCE_TYPES: - print("Warning: Unrecognized EC2 instance type for master-instance-type: {t}".format( - t=opts.master_instance_type), file=stderr) - # Since we try instance types even if we can't resolve them, we check if they resolve first - # and, if they do, see if they resolve to the same virtualization type. - if opts.instance_type in EC2_INSTANCE_TYPES and \ - opts.master_instance_type in EC2_INSTANCE_TYPES: - if EC2_INSTANCE_TYPES[opts.instance_type] != \ - EC2_INSTANCE_TYPES[opts.master_instance_type]: - print("Error: spark-ec2 currently does not support having a master and slaves " - "with different AMI virtualization types.", file=stderr) - print("master instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.master_instance_type]), file=stderr) - print("slave instance virtualization type: {t}".format( - t=EC2_INSTANCE_TYPES[opts.instance_type]), file=stderr) - sys.exit(1) - - if opts.ebs_vol_num > 8: - print("ebs-vol-num cannot be greater than 8", file=stderr) - sys.exit(1) - - # Prevent breaking ami_prefix (/, .git and startswith checks) - # Prevent forks with non spark-ec2 names for now. - if opts.spark_ec2_git_repo.endswith("/") or \ - opts.spark_ec2_git_repo.endswith(".git") or \ - not opts.spark_ec2_git_repo.startswith("https://github.com") or \ - not opts.spark_ec2_git_repo.endswith("spark-ec2"): - print("spark-ec2-git-repo must be a github repo and it must not have a trailing / or .git. " - "Furthermore, we currently only support forks named spark-ec2.", file=stderr) - sys.exit(1) - - if not (opts.deploy_root_dir is None or - (os.path.isabs(opts.deploy_root_dir) and - os.path.isdir(opts.deploy_root_dir) and - os.path.exists(opts.deploy_root_dir))): - print("--deploy-root-dir must be an absolute path to a directory that exists " - "on the local file system", file=stderr) - sys.exit(1) - - try: - if opts.profile is None: - conn = ec2.connect_to_region(opts.region) - else: - conn = ec2.connect_to_region(opts.region, profile_name=opts.profile) - except Exception as e: - print((e), file=stderr) - sys.exit(1) - - # Select an AZ at random if it was not specified. - if opts.zone == "": - opts.zone = random.choice(conn.get_all_zones()).name - - if action == "launch": - if opts.slaves <= 0: - print("ERROR: You have to start at least 1 slave", file=sys.stderr) - sys.exit(1) - if opts.resume: - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - else: - (master_nodes, slave_nodes) = launch_cluster(conn, opts, cluster_name) - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready' - ) - setup_cluster(conn, master_nodes, slave_nodes, opts, True) - - elif action == "destroy": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - - if any(master_nodes + slave_nodes): - print("The following instances will be terminated:") - for inst in master_nodes + slave_nodes: - print("> %s" % get_dns_name(inst, opts.private_ips)) - print("ALL DATA ON ALL NODES WILL BE LOST!!") - - msg = "Are you sure you want to destroy the cluster {c}? (y/N) ".format(c=cluster_name) - response = raw_input(msg) - if response == "y": - print("Terminating master...") - for inst in master_nodes: - inst.terminate() - print("Terminating slaves...") - for inst in slave_nodes: - inst.terminate() - - # Delete security groups as well - if opts.delete_groups: - group_names = [cluster_name + "-master", cluster_name + "-slaves"] - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='terminated' - ) - print("Deleting security groups (this will take some time)...") - attempt = 1 - while attempt <= 3: - print("Attempt %d" % attempt) - groups = [g for g in conn.get_all_security_groups() if g.name in group_names] - success = True - # Delete individual rules in all groups before deleting groups to - # remove dependencies between them - for group in groups: - print("Deleting rules in security group " + group.name) - for rule in group.rules: - for grant in rule.grants: - success &= group.revoke(ip_protocol=rule.ip_protocol, - from_port=rule.from_port, - to_port=rule.to_port, - src_group=grant) - - # Sleep for AWS eventual-consistency to catch up, and for instances - # to terminate - time.sleep(30) # Yes, it does have to be this long :-( - for group in groups: - try: - # It is needed to use group_id to make it work with VPC - conn.delete_security_group(group_id=group.id) - print("Deleted security group %s" % group.name) - except boto.exception.EC2ResponseError: - success = False - print("Failed to delete security group %s" % group.name) - - # Unfortunately, group.revoke() returns True even if a rule was not - # deleted, so this needs to be rerun if something fails - if success: - break - - attempt += 1 - - if not success: - print("Failed to delete all security groups after 3 tries.") - print("Try re-running in a few minutes.") - - elif action == "login": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - if not master_nodes[0].public_dns_name and not opts.private_ips: - print("Master has no public DNS name. Maybe you meant to specify --private-ips?") - else: - master = get_dns_name(master_nodes[0], opts.private_ips) - print("Logging into master " + master + "...") - proxy_opt = [] - if opts.proxy_port is not None: - proxy_opt = ['-D', opts.proxy_port] - subprocess.check_call( - ssh_command(opts) + proxy_opt + ['-t', '-t', "%s@%s" % (opts.user, master)]) - - elif action == "reboot-slaves": - response = raw_input( - "Are you sure you want to reboot the cluster " + - cluster_name + " slaves?\n" + - "Reboot cluster slaves " + cluster_name + " (y/N): ") - if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - print("Rebooting slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - print("Rebooting " + inst.id) - inst.reboot() - - elif action == "get-master": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - if not master_nodes[0].public_dns_name and not opts.private_ips: - print("Master has no public DNS name. Maybe you meant to specify --private-ips?") - else: - print(get_dns_name(master_nodes[0], opts.private_ips)) - - elif action == "stop": - response = raw_input( - "Are you sure you want to stop the cluster " + - cluster_name + "?\nDATA ON EPHEMERAL DISKS WILL BE LOST, " + - "BUT THE CLUSTER WILL KEEP USING SPACE ON\n" + - "AMAZON EBS IF IT IS EBS-BACKED!!\n" + - "All data on spot-instance slaves will be lost.\n" + - "Stop cluster " + cluster_name + " (y/N): ") - if response == "y": - (master_nodes, slave_nodes) = get_existing_cluster( - conn, opts, cluster_name, die_on_error=False) - print("Stopping master...") - for inst in master_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.stop() - print("Stopping slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - if inst.spot_instance_request_id: - inst.terminate() - else: - inst.stop() - - elif action == "start": - (master_nodes, slave_nodes) = get_existing_cluster(conn, opts, cluster_name) - print("Starting slaves...") - for inst in slave_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - print("Starting master...") - for inst in master_nodes: - if inst.state not in ["shutting-down", "terminated"]: - inst.start() - wait_for_cluster_state( - conn=conn, - opts=opts, - cluster_instances=(master_nodes + slave_nodes), - cluster_state='ssh-ready' - ) - - # Determine types of running instances - existing_master_type = master_nodes[0].instance_type - existing_slave_type = slave_nodes[0].instance_type - # Setting opts.master_instance_type to the empty string indicates we - # have the same instance type for the master and the slaves - if existing_master_type == existing_slave_type: - existing_master_type = "" - opts.master_instance_type = existing_master_type - opts.instance_type = existing_slave_type - - setup_cluster(conn, master_nodes, slave_nodes, opts, False) - - else: - print("Invalid action: %s" % action, file=stderr) - sys.exit(1) - - -def main(): - try: - real_main() - except UsageError as e: - print("\nError:\n", e, file=stderr) - sys.exit(1) - - -if __name__ == "__main__": - logging.basicConfig() - main() diff --git a/examples/pom.xml b/examples/pom.xml index f5ab2a7fdc09..1a0d5e585464 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml @@ -53,12 +53,6 @@ ${project.version} provided - - org.apache.spark - spark-bagel_${scala.binary.version} - ${project.version} - provided - org.apache.spark spark-hive_${scala.binary.version} diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java index 980a9108af53..3d8babba04a5 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaBinaryClassificationMetricsExample.java @@ -56,6 +56,7 @@ public static void main(String[] args) { // Compute raw scores on the test set. JavaRDD> predictionAndLabels = test.map( new Function>() { + @Override public Tuple2 call(LabeledPoint p) { Double prediction = model.predict(p.features()); return new Tuple2(prediction, p.label()); @@ -68,26 +69,27 @@ public Tuple2 call(LabeledPoint p) { // Precision by threshold JavaRDD> precision = metrics.precisionByThreshold().toJavaRDD(); - System.out.println("Precision by threshold: " + precision.toArray()); + System.out.println("Precision by threshold: " + precision.collect()); // Recall by threshold JavaRDD> recall = metrics.recallByThreshold().toJavaRDD(); - System.out.println("Recall by threshold: " + recall.toArray()); + System.out.println("Recall by threshold: " + recall.collect()); // F Score by threshold JavaRDD> f1Score = metrics.fMeasureByThreshold().toJavaRDD(); - System.out.println("F1 Score by threshold: " + f1Score.toArray()); + System.out.println("F1 Score by threshold: " + f1Score.collect()); JavaRDD> f2Score = metrics.fMeasureByThreshold(2.0).toJavaRDD(); - System.out.println("F2 Score by threshold: " + f2Score.toArray()); + System.out.println("F2 Score by threshold: " + f2Score.collect()); // Precision-recall curve JavaRDD> prc = metrics.pr().toJavaRDD(); - System.out.println("Precision-recall curve: " + prc.toArray()); + System.out.println("Precision-recall curve: " + prc.collect()); // Thresholds JavaRDD thresholds = precision.map( new Function, Double>() { + @Override public Double call(Tuple2 t) { return new Double(t._1().toString()); } @@ -96,7 +98,7 @@ public Double call(Tuple2 t) { // ROC Curve JavaRDD> roc = metrics.roc().toJavaRDD(); - System.out.println("ROC curve: " + roc.toArray()); + System.out.println("ROC curve: " + roc.collect()); // AUPRC System.out.println("Area under precision-recall curve = " + metrics.areaUnderPR()); @@ -106,8 +108,7 @@ public Double call(Tuple2 t) { // Save and load model model.save(sc, "target/tmp/LogisticRegressionModel"); - LogisticRegressionModel sameModel = LogisticRegressionModel.load(sc, - "target/tmp/LogisticRegressionModel"); + LogisticRegressionModel.load(sc, "target/tmp/LogisticRegressionModel"); // $example off$ } } diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java index 47ab3fc35824..4ad210476333 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaRankingMetricsExample.java @@ -41,6 +41,7 @@ public static void main(String[] args) { JavaRDD data = sc.textFile(path); JavaRDD ratings = data.map( new Function() { + @Override public Rating call(String line) { String[] parts = line.split("::"); return new Rating(Integer.parseInt(parts[0]), Integer.parseInt(parts[1]), Double @@ -57,13 +58,14 @@ public Rating call(String line) { JavaRDD> userRecs = model.recommendProductsForUsers(10).toJavaRDD(); JavaRDD> userRecsScaled = userRecs.map( new Function, Tuple2>() { + @Override public Tuple2 call(Tuple2 t) { Rating[] scaledRatings = new Rating[t._2().length]; for (int i = 0; i < scaledRatings.length; i++) { double newRating = Math.max(Math.min(t._2()[i].rating(), 1.0), 0.0); scaledRatings[i] = new Rating(t._2()[i].user(), t._2()[i].product(), newRating); } - return new Tuple2(t._1(), scaledRatings); + return new Tuple2<>(t._1(), scaledRatings); } } ); @@ -72,6 +74,7 @@ public Tuple2 call(Tuple2 t) { // Map ratings to 1 or 0, 1 indicating a movie that should be recommended JavaRDD binarizedRatings = ratings.map( new Function() { + @Override public Rating call(Rating r) { double binaryRating; if (r.rating() > 0.0) { @@ -87,6 +90,7 @@ public Rating call(Rating r) { // Group ratings by common user JavaPairRDD> userMovies = binarizedRatings.groupBy( new Function() { + @Override public Object call(Rating r) { return r.user(); } @@ -96,8 +100,9 @@ public Object call(Rating r) { // Get true relevant documents from all user ratings JavaPairRDD> userMoviesList = userMovies.mapValues( new Function, List>() { + @Override public List call(Iterable docs) { - List products = new ArrayList(); + List products = new ArrayList<>(); for (Rating r : docs) { if (r.rating() > 0.0) { products.add(r.product()); @@ -111,8 +116,9 @@ public List call(Iterable docs) { // Extract the product id from each recommendation JavaPairRDD> userRecommendedList = userRecommended.mapValues( new Function>() { + @Override public List call(Rating[] docs) { - List products = new ArrayList(); + List products = new ArrayList<>(); for (Rating r : docs) { products.add(r.product()); } @@ -124,7 +130,7 @@ public List call(Rating[] docs) { userRecommendedList).values(); // Instantiate the metrics object - RankingMetrics metrics = RankingMetrics.of(relevantDocs); + RankingMetrics metrics = RankingMetrics.of(relevantDocs); // Precision and NDCG at k Integer[] kVector = {1, 3, 5}; @@ -139,6 +145,7 @@ public List call(Rating[] docs) { // Evaluate the model using numerical ratings and regression metrics JavaRDD> userProducts = ratings.map( new Function>() { + @Override public Tuple2 call(Rating r) { return new Tuple2(r.user(), r.product()); } @@ -147,18 +154,20 @@ public Tuple2 call(Rating r) { JavaPairRDD, Object> predictions = JavaPairRDD.fromJavaRDD( model.predict(JavaRDD.toRDD(userProducts)).toJavaRDD().map( new Function, Object>>() { + @Override public Tuple2, Object> call(Rating r) { return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); + new Tuple2<>(r.user(), r.product()), r.rating()); } } )); JavaRDD> ratesAndPreds = JavaPairRDD.fromJavaRDD(ratings.map( new Function, Object>>() { + @Override public Tuple2, Object> call(Rating r) { return new Tuple2, Object>( - new Tuple2(r.user(), r.product()), r.rating()); + new Tuple2<>(r.user(), r.product()), r.rating()); } } )).join(predictions).values(); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java new file mode 100644 index 000000000000..2377207779fe --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaActorWordCount.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import java.util.Arrays; + +import scala.Tuple2; + +import akka.actor.ActorSelection; +import akka.actor.Props; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.receiver.JavaActorReceiver; + +/** + * A sample actor as receiver, is also simplest. This receiver actor + * goes and subscribe to a typical publisher/feeder actor and receives + * data. + * + * @see [[org.apache.spark.examples.streaming.FeederActor]] + */ +class JavaSampleActorReceiver extends JavaActorReceiver { + + private final String urlOfPublisher; + + public JavaSampleActorReceiver(String urlOfPublisher) { + this.urlOfPublisher = urlOfPublisher; + } + + private ActorSelection remotePublisher; + + @Override + public void preStart() { + remotePublisher = getContext().actorSelection(urlOfPublisher); + remotePublisher.tell(new SubscribeReceiver(getSelf()), getSelf()); + } + + public void onReceive(Object msg) throws Exception { + store((T) msg); + } + + @Override + public void postStop() { + remotePublisher.tell(new UnsubscribeReceiver(getSelf()), getSelf()); + } +} + +/** + * A sample word count program demonstrating the use of plugging in + * Actor as Receiver + * Usage: JavaActorWordCount + * and describe the AkkaSystem that Spark Sample feeder is running on. + * + * To run this example locally, you may run Feeder Actor as + *
    + *     $ bin/run-example org.apache.spark.examples.streaming.FeederActor localhost 9999
    + * 
    + * and then run the example + *
    + *     $ bin/run-example org.apache.spark.examples.streaming.JavaActorWordCount localhost 9999
    + * 
    + */ +public class JavaActorWordCount { + + public static void main(String[] args) { + if (args.length < 2) { + System.err.println("Usage: JavaActorWordCount "); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + final String host = args[0]; + final String port = args[1]; + SparkConf sparkConf = new SparkConf().setAppName("JavaActorWordCount"); + // Create the context and set the batch size + JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); + + String feederActorURI = "akka.tcp://test@" + host + ":" + port + "/user/FeederActor"; + + /* + * Following is the use of actorStream to plug in custom actor as receiver + * + * An important point to note: + * Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e type of data received and InputDstream + * should be same. + * + * For example: Both actorStream and JavaSampleActorReceiver are parameterized + * to same type to ensure type safety. + */ + JavaDStream lines = jssc.actorStream( + Props.create(JavaSampleActorReceiver.class, feederActorURI), "SampleReceiver"); + + // compute wordcount + lines.flatMap(new FlatMapFunction() { + @Override + public Iterable call(String s) { + return Arrays.asList(s.split("\\s+")); + } + }).mapToPair(new PairFunction() { + @Override + public Tuple2 call(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + @Override + public Integer call(Integer i1, Integer i2) { + return i1 + i2; + } + }).print(); + + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java index bceda97f058e..bc963a02be60 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaRecoverableNetworkWordCount.java @@ -21,17 +21,23 @@ import java.io.IOException; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.List; import java.util.regex.Pattern; import scala.Tuple2; import com.google.common.collect.Lists; import com.google.common.io.Files; +import org.apache.spark.Accumulator; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.function.Function2; import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.api.java.function.VoidFunction2; +import org.apache.spark.broadcast.Broadcast; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.Time; import org.apache.spark.streaming.api.java.JavaDStream; @@ -41,7 +47,48 @@ import org.apache.spark.streaming.api.java.JavaStreamingContextFactory; /** - * Counts words in text encoded with UTF8 received from the network every second. + * Use this singleton to get or register a Broadcast variable. + */ +class JavaWordBlacklist { + + private static volatile Broadcast> instance = null; + + public static Broadcast> getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaWordBlacklist.class) { + if (instance == null) { + List wordBlacklist = Arrays.asList("a", "b", "c"); + instance = jsc.broadcast(wordBlacklist); + } + } + } + return instance; + } +} + +/** + * Use this singleton to get or register an Accumulator. + */ +class JavaDroppedWordsCounter { + + private static volatile Accumulator instance = null; + + public static Accumulator getInstance(JavaSparkContext jsc) { + if (instance == null) { + synchronized (JavaDroppedWordsCounter.class) { + if (instance == null) { + instance = jsc.accumulator(0, "WordsInBlacklistCounter"); + } + } + } + return instance; + } +} + +/** + * Counts words in text encoded with UTF8 received from the network every second. This example also + * shows how to use lazily instantiated singleton instances for Accumulator and Broadcast so that + * they can be registered on driver failures. * * Usage: JavaRecoverableNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive @@ -108,14 +155,30 @@ public Integer call(Integer i1, Integer i2) { } }); - wordCounts.foreachRDD(new Function2, Time, Void>() { + wordCounts.foreachRDD(new VoidFunction2, Time>() { @Override - public Void call(JavaPairRDD rdd, Time time) throws IOException { - String counts = "Counts at time " + time + " " + rdd.collect(); - System.out.println(counts); + public void call(JavaPairRDD rdd, Time time) throws IOException { + // Get or register the blacklist Broadcast + final Broadcast> blacklist = JavaWordBlacklist.getInstance(new JavaSparkContext(rdd.context())); + // Get or register the droppedWordsCounter Accumulator + final Accumulator droppedWordsCounter = JavaDroppedWordsCounter.getInstance(new JavaSparkContext(rdd.context())); + // Use blacklist to drop words and use droppedWordsCounter to count them + String counts = rdd.filter(new Function, Boolean>() { + @Override + public Boolean call(Tuple2 wordCount) { + if (blacklist.value().contains(wordCount._1())) { + droppedWordsCounter.add(wordCount._2()); + return false; + } else { + return true; + } + } + }).collect().toString(); + String output = "Counts at time " + time + " " + counts; + System.out.println(output); + System.out.println("Dropped " + droppedWordsCounter.value() + " word(s) totally"); System.out.println("Appending to " + outputFile.getAbsolutePath()); - Files.append(counts + "\n", outputFile, Charset.defaultCharset()); - return null; + Files.append(output + "\n", outputFile, Charset.defaultCharset()); } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java index 3515d7be45d3..084f68a8be43 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaSqlNetworkWordCount.java @@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.FlatMapFunction; import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.VoidFunction2; import org.apache.spark.sql.SQLContext; import org.apache.spark.sql.DataFrame; import org.apache.spark.api.java.StorageLevels; @@ -78,13 +78,14 @@ public Iterable call(String x) { }); // Convert RDDs of the words DStream to DataFrame and run SQL query - words.foreachRDD(new Function2, Time, Void>() { + words.foreachRDD(new VoidFunction2, Time>() { @Override - public Void call(JavaRDD rdd, Time time) { + public void call(JavaRDD rdd, Time time) { SQLContext sqlContext = JavaSQLContextSingleton.getInstance(rdd.context()); // Convert JavaRDD[String] to JavaRDD[bean class] to DataFrame JavaRDD rowRDD = rdd.map(new Function() { + @Override public JavaRecord call(String word) { JavaRecord record = new JavaRecord(); record.setWord(word); @@ -101,7 +102,6 @@ public JavaRecord call(String word) { sqlContext.sql("select word, count(*) as total from words group by word"); System.out.println("========= " + time + "========="); wordCountsDataFrame.show(); - return null; } }); diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java index 14997c64d505..f52cc7c20576 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaStatefulNetworkWordCount.java @@ -23,17 +23,14 @@ import scala.Tuple2; -import com.google.common.base.Optional; -import com.google.common.collect.Lists; - import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.*; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.StorageLevels; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.State; import org.apache.spark.streaming.StateSpec; -import org.apache.spark.streaming.Time; import org.apache.spark.streaming.api.java.*; /** @@ -67,8 +64,8 @@ public static void main(String[] args) { // Initial state RDD input to mapWithState @SuppressWarnings("unchecked") - List> tuples = Arrays.asList(new Tuple2("hello", 1), - new Tuple2("world", 1)); + List> tuples = + Arrays.asList(new Tuple2<>("hello", 1), new Tuple2<>("world", 1)); JavaPairRDD initialRDD = ssc.sparkContext().parallelizePairs(tuples); JavaReceiverInputDStream lines = ssc.socketTextStream( @@ -77,7 +74,7 @@ public static void main(String[] args) { JavaDStream words = lines.flatMap(new FlatMapFunction() { @Override public Iterable call(String x) { - return Lists.newArrayList(SPACE.split(x)); + return Arrays.asList(SPACE.split(x)); } }); @@ -85,18 +82,17 @@ public Iterable call(String x) { new PairFunction() { @Override public Tuple2 call(String s) { - return new Tuple2(s, 1); + return new Tuple2<>(s, 1); } }); // Update the cumulative count function - final Function3, State, Tuple2> mappingFunc = + Function3, State, Tuple2> mappingFunc = new Function3, State, Tuple2>() { - @Override public Tuple2 call(String word, Optional one, State state) { - int sum = one.or(0) + (state.exists() ? state.get() : 0); - Tuple2 output = new Tuple2(word, sum); + int sum = one.orElse(0) + (state.exists() ? state.get() : 0); + Tuple2 output = new Tuple2<>(word, sum); state.update(sum); return output; } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java new file mode 100644 index 000000000000..d869768026ae --- /dev/null +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaTwitterHashTagJoinSentiments.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.examples.streaming; + +import org.apache.spark.SparkConf; +import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.function.FlatMapFunction; +import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.Function2; +import org.apache.spark.api.java.function.PairFunction; +import org.apache.spark.api.java.function.VoidFunction; +import org.apache.spark.streaming.Duration; +import org.apache.spark.streaming.api.java.JavaDStream; +import org.apache.spark.streaming.api.java.JavaPairDStream; +import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; +import org.apache.spark.streaming.api.java.JavaStreamingContext; +import org.apache.spark.streaming.twitter.TwitterUtils; +import scala.Tuple2; +import twitter4j.Status; + +import java.util.Arrays; +import java.util.List; + +/** + * Displays the most positive hash tags by joining the streaming Twitter data with a static RDD of + * the AFINN word list (http://neuro.imm.dtu.dk/wiki/AFINN) + */ +public class JavaTwitterHashTagJoinSentiments { + + public static void main(String[] args) { + if (args.length < 4) { + System.err.println("Usage: JavaTwitterHashTagJoinSentiments " + + " []"); + System.exit(1); + } + + StreamingExamples.setStreamingLogLevels(); + + String consumerKey = args[0]; + String consumerSecret = args[1]; + String accessToken = args[2]; + String accessTokenSecret = args[3]; + String[] filters = Arrays.copyOfRange(args, 4, args.length); + + // Set the system properties so that Twitter4j library used by Twitter stream + // can use them to generate OAuth credentials + System.setProperty("twitter4j.oauth.consumerKey", consumerKey); + System.setProperty("twitter4j.oauth.consumerSecret", consumerSecret); + System.setProperty("twitter4j.oauth.accessToken", accessToken); + System.setProperty("twitter4j.oauth.accessTokenSecret", accessTokenSecret); + + SparkConf sparkConf = new SparkConf().setAppName("JavaTwitterHashTagJoinSentiments"); + JavaStreamingContext jssc = new JavaStreamingContext(sparkConf, new Duration(2000)); + JavaReceiverInputDStream stream = TwitterUtils.createStream(jssc, filters); + + JavaDStream words = stream.flatMap(new FlatMapFunction() { + @Override + public Iterable call(Status s) { + return Arrays.asList(s.getText().split(" ")); + } + }); + + JavaDStream hashTags = words.filter(new Function() { + @Override + public Boolean call(String word) { + return word.startsWith("#"); + } + }); + + // Read in the word-sentiment list and create a static RDD from it + String wordSentimentFilePath = "data/streaming/AFINN-111.txt"; + final JavaPairRDD wordSentiments = jssc.sparkContext().textFile(wordSentimentFilePath) + .mapToPair(new PairFunction(){ + @Override + public Tuple2 call(String line) { + String[] columns = line.split("\t"); + return new Tuple2<>(columns[0], Double.parseDouble(columns[1])); + } + }); + + JavaPairDStream hashTagCount = hashTags.mapToPair( + new PairFunction() { + @Override + public Tuple2 call(String s) { + // leave out the # character + return new Tuple2<>(s.substring(1), 1); + } + }); + + JavaPairDStream hashTagTotals = hashTagCount.reduceByKeyAndWindow( + new Function2() { + @Override + public Integer call(Integer a, Integer b) { + return a + b; + } + }, new Duration(10000)); + + // Determine the hash tags with the highest sentiment values by joining the streaming RDD + // with the static RDD inside the transform() method and then multiplying + // the frequency of the hash tag by its sentiment value + JavaPairDStream> joinedTuples = + hashTagTotals.transformToPair(new Function, + JavaPairRDD>>() { + @Override + public JavaPairRDD> call( + JavaPairRDD topicCount) { + return wordSentiments.join(topicCount); + } + }); + + JavaPairDStream topicHappiness = joinedTuples.mapToPair( + new PairFunction>, String, Double>() { + @Override + public Tuple2 call(Tuple2> topicAndTuplePair) { + Tuple2 happinessAndCount = topicAndTuplePair._2(); + return new Tuple2<>(topicAndTuplePair._1(), + happinessAndCount._1() * happinessAndCount._2()); + } + }); + + JavaPairDStream happinessTopicPairs = topicHappiness.mapToPair( + new PairFunction, Double, String>() { + @Override + public Tuple2 call(Tuple2 topicHappiness) { + return new Tuple2<>(topicHappiness._2(), + topicHappiness._1()); + } + }); + + JavaPairDStream happiest10 = happinessTopicPairs.transformToPair( + new Function, JavaPairRDD>() { + @Override + public JavaPairRDD call( + JavaPairRDD happinessAndTopics) { + return happinessAndTopics.sortByKey(false); + } + } + ); + + // Print hash tags with the most positive sentiment values + happiest10.foreachRDD(new VoidFunction>() { + @Override + public void call(JavaPairRDD happinessTopicPairs) { + List> topList = happinessTopicPairs.take(10); + System.out.println( + String.format("\nHappiest topics in last 10 seconds (%s total):", + happinessTopicPairs.count())); + for (Tuple2 pair : topList) { + System.out.println( + String.format("%s (%s happiness)", pair._2(), pair._1())); + } + } + }); + + jssc.start(); + jssc.awaitTermination(); + } +} diff --git a/examples/src/main/python/streaming/network_wordjoinsentiments.py b/examples/src/main/python/streaming/network_wordjoinsentiments.py new file mode 100644 index 000000000000..b85517dfdd91 --- /dev/null +++ b/examples/src/main/python/streaming/network_wordjoinsentiments.py @@ -0,0 +1,77 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" + Shows the most positive words in UTF8 encoded, '\n' delimited text directly received the network + every 5 seconds. The streaming data is joined with a static RDD of the AFINN word list + (http://neuro.imm.dtu.dk/wiki/AFINN) + + Usage: network_wordjoinsentiments.py + and describe the TCP server that Spark Streaming would connect to receive data. + + To run this on your local machine, you need to first run a Netcat server + `$ nc -lk 9999` + and then run the example + `$ bin/spark-submit examples/src/main/python/streaming/network_wordjoinsentiments.py \ + localhost 9999` +""" + +from __future__ import print_function + +import sys + +from pyspark import SparkContext +from pyspark.streaming import StreamingContext + + +def print_happiest_words(rdd): + top_list = rdd.take(5) + print("Happiest topics in the last 5 seconds (%d total):" % rdd.count()) + for tuple in top_list: + print("%s (%d happiness)" % (tuple[1], tuple[0])) + +if __name__ == "__main__": + if len(sys.argv) != 3: + print("Usage: network_wordjoinsentiments.py ", file=sys.stderr) + exit(-1) + + sc = SparkContext(appName="PythonStreamingNetworkWordJoinSentiments") + ssc = StreamingContext(sc, 5) + + # Read in the word-sentiment list and create a static RDD from it + word_sentiments_file_path = "data/streaming/AFINN-111.txt" + word_sentiments = ssc.sparkContext.textFile(word_sentiments_file_path) \ + .map(lambda line: tuple(line.split("\t"))) + + lines = ssc.socketTextStream(sys.argv[1], int(sys.argv[2])) + + word_counts = lines.flatMap(lambda line: line.split(" ")) \ + .map(lambda word: (word, 1)) \ + .reduceByKey(lambda a, b: a + b) + + # Determine the words with the highest sentiment values by joining the streaming RDD + # with the static RDD inside the transform() method and then multiplying + # the frequency of the words by its sentiment value + happiest_words = word_counts.transform(lambda rdd: word_sentiments.join(rdd)) \ + .map(lambda (word, tuple): (word, float(tuple[0]) * tuple[1])) \ + .map(lambda (word, happiness): (happiness, word)) \ + .transform(lambda rdd: rdd.sortByKey(False)) + + happiest_words.foreachRDD(print_happiest_words) + + ssc.start() + ssc.awaitTermination() diff --git a/examples/src/main/python/streaming/recoverable_network_wordcount.py b/examples/src/main/python/streaming/recoverable_network_wordcount.py index ac91f0a06b17..52b2639cdf55 100644 --- a/examples/src/main/python/streaming/recoverable_network_wordcount.py +++ b/examples/src/main/python/streaming/recoverable_network_wordcount.py @@ -44,6 +44,20 @@ from pyspark.streaming import StreamingContext +# Get or register a Broadcast variable +def getWordBlacklist(sparkContext): + if ('wordBlacklist' not in globals()): + globals()['wordBlacklist'] = sparkContext.broadcast(["a", "b", "c"]) + return globals()['wordBlacklist'] + + +# Get or register an Accumulator +def getDroppedWordsCounter(sparkContext): + if ('droppedWordsCounter' not in globals()): + globals()['droppedWordsCounter'] = sparkContext.accumulator(0) + return globals()['droppedWordsCounter'] + + def createContext(host, port, outputPath): # If you do not see this printed, that means the StreamingContext has been loaded # from the new checkpoint @@ -60,8 +74,22 @@ def createContext(host, port, outputPath): wordCounts = words.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x + y) def echo(time, rdd): - counts = "Counts at time %s %s" % (time, rdd.collect()) + # Get or register the blacklist Broadcast + blacklist = getWordBlacklist(rdd.context) + # Get or register the droppedWordsCounter Accumulator + droppedWordsCounter = getDroppedWordsCounter(rdd.context) + + # Use blacklist to drop words and use droppedWordsCounter to count them + def filterFunc(wordCount): + if wordCount[0] in blacklist.value: + droppedWordsCounter.add(wordCount[1]) + False + else: + True + + counts = "Counts at time %s %s" % (time, rdd.filter(filterFunc).collect()) print(counts) + print("Dropped %d word(s) totally" % droppedWordsCounter.value) print("Appending to " + os.path.abspath(outputPath)) with open(outputPath, 'a') as f: f.write(counts + "\n") diff --git a/examples/src/main/python/transitive_closure.py b/examples/src/main/python/transitive_closure.py index 7bf5fb6ddfe2..3d61250d8b23 100755 --- a/examples/src/main/python/transitive_closure.py +++ b/examples/src/main/python/transitive_closure.py @@ -30,8 +30,8 @@ def generateGraph(): edges = set() while len(edges) < numEdges: - src = rand.randrange(0, numEdges) - dst = rand.randrange(0, numEdges) + src = rand.randrange(0, numVertices) + dst = rand.randrange(0, numVertices) if src != dst: edges.add((src, dst)) return edges diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index d812262fd87d..3da5236745b5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -21,16 +21,14 @@ package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: BroadcastTest [slices] [numElem] [broadcastAlgo] [blockSize] + * Usage: BroadcastTest [slices] [numElem] [blockSize] */ object BroadcastTest { def main(args: Array[String]) { - val bcName = if (args.length > 2) args(2) else "Http" - val blockSize = if (args.length > 3) args(3) else "4096" + val blockSize = if (args.length > 2) args(2) else "4096" val sparkConf = new SparkConf().setAppName("Broadcast Test") - .set("spark.broadcast.factory", s"org.apache.spark.broadcast.${bcName}BroadcastFactory") .set("spark.broadcast.blockSize", blockSize) val sc = new SparkContext(sparkConf) @@ -44,7 +42,7 @@ object BroadcastTest { println("===========") val startTime = System.nanoTime val barr1 = sc.broadcast(arr1) - val observedSizes = sc.parallelize(1 to 10, slices).map(_ => barr1.value.size) + val observedSizes = sc.parallelize(1 to 10, slices).map(_ => barr1.value.length) // Collect the small RDD so we can print the observed sizes locally. observedSizes.collect().foreach(i => println(i)) println("Iteration %d took %.0f milliseconds".format(i, (System.nanoTime - startTime) / 1E6)) diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index d1b9b8d398dd..973b005f91f6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -16,22 +16,20 @@ */ // scalastyle:off println - // scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer import java.util.Collections import org.apache.cassandra.hadoop.ConfigHelper -import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat import org.apache.cassandra.hadoop.cql3.CqlConfigHelper import org.apache.cassandra.hadoop.cql3.CqlOutputFormat +import org.apache.cassandra.hadoop.cql3.CqlPagingInputFormat import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job import org.apache.spark.{SparkConf, SparkContext} - /* Need to create following keyspace and column family in cassandra before running this example Start CQL shell using ./bin/cqlsh and execute following commands @@ -80,7 +78,7 @@ object CassandraCQLTest { val InputColumnFamily = "ordercf" val OutputColumnFamily = "salecount" - val job = new Job() + val job = Job.getInstance() job.setInputFormatClass(classOf[CqlPagingInputFormat]) val configuration = job.getConfiguration ConfigHelper.setInputInitialAddress(job.getConfiguration(), cHost) @@ -137,4 +135,3 @@ object CassandraCQLTest { } } // scalastyle:on println -// scalastyle:on jobcontext diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 1e679bfb5534..6a8f73ad000f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -16,7 +16,6 @@ */ // scalastyle:off println -// scalastyle:off jobcontext package org.apache.spark.examples import java.nio.ByteBuffer @@ -24,9 +23,9 @@ import java.util.Arrays import java.util.SortedMap import org.apache.cassandra.db.IColumn +import org.apache.cassandra.hadoop.ColumnFamilyInputFormat import org.apache.cassandra.hadoop.ColumnFamilyOutputFormat import org.apache.cassandra.hadoop.ConfigHelper -import org.apache.cassandra.hadoop.ColumnFamilyInputFormat import org.apache.cassandra.thrift._ import org.apache.cassandra.utils.ByteBufferUtil import org.apache.hadoop.mapreduce.Job @@ -59,7 +58,7 @@ object CassandraTest { val sc = new SparkContext(sparkConf) // Build the job configuration with ConfigHelper provided by Cassandra - val job = new Job() + val job = Job.getInstance() job.setInputFormatClass(classOf[ColumnFamilyInputFormat]) val host: String = args(1) @@ -131,7 +130,6 @@ object CassandraTest { } } // scalastyle:on println -// scalastyle:on jobcontext /* create keyspace casDemo; diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index d651fe4d6ee7..b26db0b2462e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -22,7 +22,7 @@ import java.io.File import scala.io.Source._ -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ /** diff --git a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala index 244742327a90..65d748958606 100644 --- a/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/HBaseTest.scala @@ -18,13 +18,12 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.hadoop.hbase.client.HBaseAdmin import org.apache.hadoop.hbase.{HBaseConfiguration, HTableDescriptor, TableName} +import org.apache.hadoop.hbase.client.HBaseAdmin import org.apache.hadoop.hbase.mapreduce.TableInputFormat import org.apache.spark._ - object HBaseTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("HBaseTest") diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala index 9c8aae53cf48..a3901850f283 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalFileLR.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import java.util.Random -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} /** * Logistic regression based classification. diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala index e7b28d38bdfc..407e3e08b968 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalKMeans.scala @@ -23,7 +23,7 @@ import java.util.Random import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import breeze.linalg.{Vector, DenseVector, squaredDistance} +import breeze.linalg.{squaredDistance, DenseVector, Vector} import org.apache.spark.SparkContext._ diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala index 4f6b092a59ca..58adbabe4454 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalLR.scala @@ -20,7 +20,7 @@ package org.apache.spark.examples import java.util.Random -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} /** * Logistic regression based classification. diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 61ce9db914f9..a797111dbad1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -18,8 +18,8 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.rdd.RDD import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.rdd.RDD /** * Usage: MultiBroadcastTest [slices] [numElem] diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala index 505ea5a4c7a8..e4486b949fb3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkHdfsLR.scala @@ -22,12 +22,10 @@ import java.util.Random import scala.math.exp -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} import org.apache.hadoop.conf.Configuration import org.apache.spark._ -import org.apache.spark.scheduler.InputFormatInfo - /** * Logistic regression based classification. @@ -74,12 +72,9 @@ object SparkHdfsLR { val sparkConf = new SparkConf().setAppName("SparkHdfsLR") val inputPath = args(0) val conf = new Configuration() - val sc = new SparkContext(sparkConf, - InputFormatInfo.computePreferredLocations( - Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) - )) + val sc = new SparkContext(sparkConf) val lines = sc.textFile(inputPath) - val points = lines.map(parsePoint _).cache() + val points = lines.map(parsePoint).cache() val ITERATIONS = args(1).toInt // Initialize w to a random value diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala index c56e1124ad41..1ea9121e2749 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkKMeans.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples -import breeze.linalg.{Vector, DenseVector, squaredDistance} +import breeze.linalg.{squaredDistance, DenseVector, Vector} import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala index d265c227f4ed..132800e6e4ca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkLR.scala @@ -22,7 +22,7 @@ import java.util.Random import scala.math.exp -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} import org.apache.spark._ diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala index 0fd79660dd19..018bdf6d3103 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkPageRank.scala @@ -18,8 +18,8 @@ // scalastyle:off println package org.apache.spark.examples -import org.apache.spark.SparkContext._ import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.SparkContext._ /** * Computes the PageRank of URLs from an input file. Input file should diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 95072071ccdd..b92740f1fbcb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -18,8 +18,8 @@ // scalastyle:off println package org.apache.spark.examples -import scala.util.Random import scala.collection.mutable +import scala.util.Random import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala index cfbdae02212a..8b739c9d7c1d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTachyonHdfsLR.scala @@ -22,14 +22,12 @@ import java.util.Random import scala.math.exp -import breeze.linalg.{Vector, DenseVector} +import breeze.linalg.{DenseVector, Vector} import org.apache.hadoop.conf.Configuration import org.apache.spark._ -import org.apache.spark.scheduler.InputFormatInfo import org.apache.spark.storage.StorageLevel - /** * Logistic regression based classification. * This example uses Tachyon to persist rdds during computation. @@ -71,12 +69,9 @@ object SparkTachyonHdfsLR { val inputPath = args(0) val sparkConf = new SparkConf().setAppName("SparkTachyonHdfsLR") val conf = new Configuration() - val sc = new SparkContext(sparkConf, - InputFormatInfo.computePreferredLocations( - Seq(new InputFormatInfo(conf, classOf[org.apache.hadoop.mapred.TextInputFormat], inputPath)) - )) + val sc = new SparkContext(sparkConf) val lines = sc.textFile(inputPath) - val points = lines.map(parsePoint _).persist(StorageLevel.OFF_HEAP) + val points = lines.map(parsePoint).persist(StorageLevel.OFF_HEAP) val ITERATIONS = args(1).toInt // Initialize w to a random value diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 8dd6c9706e7d..39cb83d9eeb7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -19,11 +19,12 @@ package org.apache.spark.examples.graphx import scala.collection.mutable + import org.apache.spark._ -import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ -import org.apache.spark.graphx.lib._ import org.apache.spark.graphx.PartitionStrategy._ +import org.apache.spark.graphx.lib._ +import org.apache.spark.storage.StorageLevel /** * Driver program for running graph algorithms. diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 46e52aacd90b..41ca5cbb9f08 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -18,11 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.graphx +import java.io.{FileOutputStream, PrintWriter} + +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.graphx.{GraphXUtils, PartitionStrategy} -import org.apache.spark.{SparkContext, SparkConf} import org.apache.spark.graphx.util.GraphGenerators -import java.io.{PrintWriter, FileOutputStream} /** * The SynthBenchmark application can be used to run various GraphX algorithms on diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala index f4b3613ccb94..21f58ddf3cfb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/AFTSurvivalRegressionExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.regression.AFTSurvivalRegression import org.apache.spark.mllib.linalg.Vectors // $example off$ +import org.apache.spark.sql.SQLContext /** * An example for AFTSurvivalRegression. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala index e724aa587294..2ed8101c133c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BinarizerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Binarizer // $example off$ import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.{SparkConf, SparkContext} object BinarizerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala index 7c75e3d72b47..6f6236a2b058 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketizerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Bucketizer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object BucketizerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala index a8d2bc4907e8..2be61537e613 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSqSelectorExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.ChiSqSelector import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object ChiSqSelectorExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala index ba916f66c4c0..7d07fc7dd113 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CountVectorizerExample.scala @@ -18,12 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - object CountVectorizerExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala index 14b358d46f6a..bca301d412f4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala @@ -23,7 +23,7 @@ import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.LogisticRegression import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} +import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.sql.{Row, SQLContext} diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala index 314c2c28a2a1..dc26b55a768a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DCTExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.DCT import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object DCTExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index db024b5cad93..224d8da5f0ec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -18,15 +18,15 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.classification.DecisionTreeClassificationModel -import org.apache.spark.ml.feature.{StringIndexer, IndexToString, VectorIndexer} +import org.apache.spark.ml.classification.DecisionTreeClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator +import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // $example off$ +import org.apache.spark.sql.SQLContext object DecisionTreeClassificationExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala index c4e98dfaca6c..a37d12aa636c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala @@ -27,14 +27,13 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.examples.mllib.AbstractParams import org.apache.spark.ml.{Pipeline, PipelineStage, Transformer} import org.apache.spark.ml.classification.{DecisionTreeClassificationModel, DecisionTreeClassifier} -import org.apache.spark.ml.feature.{VectorIndexer, StringIndexer} +import org.apache.spark.ml.feature.{StringIndexer, VectorIndexer} import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor} import org.apache.spark.ml.util.MetadataUtils -import org.apache.spark.mllib.evaluation.{RegressionMetrics, MulticlassMetrics} +import org.apache.spark.mllib.evaluation.{MulticlassMetrics, RegressionMetrics} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.sql.{SQLContext, DataFrame} - +import org.apache.spark.sql.{DataFrame, SQLContext} /** * An example runner for decision trees. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index ad01f55df72b..ad32e5635a3e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -17,15 +17,17 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkContext, SparkConf} + +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.regression.DecisionTreeRegressor -import org.apache.spark.ml.regression.DecisionTreeRegressionModel -import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.evaluation.RegressionEvaluator +import org.apache.spark.ml.feature.VectorIndexer +import org.apache.spark.ml.regression.DecisionTreeRegressionModel +import org.apache.spark.ml.regression.DecisionTreeRegressor // $example off$ +import org.apache.spark.sql.SQLContext + object DecisionTreeRegressionExample { def main(args: Array[String]): Unit = { val conf = new SparkConf().setAppName("DecisionTreeRegressionExample") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala index 872de51dc75d..629d322c4357 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ElementwiseProductExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.ElementwiseProduct import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object ElementwiseProductExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 474af7db4b49..cd62a803820c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline @@ -26,6 +25,7 @@ import org.apache.spark.ml.classification.{GBTClassificationModel, GBTClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // $example off$ +import org.apache.spark.sql.SQLContext object GradientBoostedTreeClassifierExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala index da1cd9c2ce52..b8cf9629bbda 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline @@ -26,6 +25,7 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.regression.{GBTRegressionModel, GBTRegressor} // $example off$ +import org.apache.spark.sql.SQLContext object GradientBoostedTreeRegressorExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala index 52537e5bb568..4cea09ba1265 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/IndexToStringExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} // $example on$ -import org.apache.spark.ml.feature.{StringIndexer, IndexToString} +import org.apache.spark.ml.feature.{IndexToString, StringIndexer} // $example off$ +import org.apache.spark.sql.SQLContext object IndexToStringExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala index 419ce3d87a6a..f9ddac77090e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LDAExample.scala @@ -18,10 +18,10 @@ package org.apache.spark.examples.ml // scalastyle:off println -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.clustering.LDA +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.types.{StructField, StructType} // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala index 22c824cea84d..c7352b3e7ab9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LinearRegressionWithElasticNetExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.regression.LinearRegression // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object LinearRegressionWithElasticNetExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala index 4c420421b670..04c60c0c1d06 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionSummaryExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.classification.{BinaryLogisticRegressionSummary, LogisticRegression} // $example off$ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.functions.max -import org.apache.spark.{SparkConf, SparkContext} object LogisticRegressionSummaryExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala index 9ee995b52c90..f632960f26ae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/LogisticRegressionWithElasticNetExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.classification.LogisticRegression // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object LogisticRegressionWithElasticNetExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala index fb7f28c9886b..9a03f69f5af0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinMaxScalerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.MinMaxScaler // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object MinMaxScalerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala index 3ae53e57dbdb..02ed746954f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MovieLensALS.scala @@ -50,7 +50,7 @@ object MovieLensALS { def parseMovie(str: String): Movie = { val fields = str.split("::") assert(fields.size == 3) - Movie(fields(0).toInt, fields(1), fields(2).split("|")) + Movie(fields(0).toInt, fields(1), fields(2).split("\\|")) } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 9c98076bd24b..d7d1e82f6f84 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.sql.SQLContext +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.classification.MultilayerPerceptronClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator // $example off$ +import org.apache.spark.sql.SQLContext /** * An example for Multilayer Perceptron Classification. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala index 8a85f71b56f3..77b913aaa3fa 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NGramExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.NGram // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object NGramExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala index 1990b55e8c5e..6b33c16c7403 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NormalizerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Normalizer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object NormalizerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala index 66602e211850..cb9fe65a85e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneHotEncoderExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object OneHotEncoderExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index b46faea5713f..ccee3b2aef98 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -22,10 +22,10 @@ import java.util.concurrent.TimeUnit.{NANOSECONDS => NANO} import scopt.OptionParser -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.examples.mllib.AbstractParams -import org.apache.spark.ml.classification.{OneVsRest, LogisticRegression} +import org.apache.spark.ml.classification.{LogisticRegression, OneVsRest} import org.apache.spark.ml.util.MetadataUtils import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala index 4c806f71a32c..535652ec6c79 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PCAExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.PCA import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object PCAExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala index 39fb79af3576..3014008ea0ce 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/PolynomialExpansionExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.PolynomialExpansion import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object PolynomialExpansionExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index 8f29b7eaa6d2..e64e673a485e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.QuantileDiscretizer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object QuantileDiscretizerExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala index 286866edea50..bec831d51c58 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RFormulaExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.RFormula // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object RFormulaExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala index e79176ca6ca1..6c9b52cf259e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline @@ -26,6 +25,7 @@ import org.apache.spark.ml.classification.{RandomForestClassificationModel, Rand import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorIndexer} // $example off$ +import org.apache.spark.sql.SQLContext object RandomForestClassifierExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala index acec1437a1af..4d2db017f346 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -18,7 +18,6 @@ // scalastyle:off println package org.apache.spark.examples.ml -import org.apache.spark.sql.SQLContext import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.Pipeline @@ -26,6 +25,7 @@ import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.feature.VectorIndexer import org.apache.spark.ml.regression.{RandomForestRegressionModel, RandomForestRegressor} // $example off$ +import org.apache.spark.sql.SQLContext object RandomForestRegressorExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala index 014abd1fdbc6..202925acadff 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/SQLTransformerExample.scala @@ -18,12 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.SQLTransformer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} - object SQLTransformerExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala index e0a41e383a7e..e3439677e78d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StandardScalerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.StandardScaler // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object StandardScalerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala index 655ffce08d3a..8199be12c155 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StopWordsRemoverExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.StopWordsRemover // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object StopWordsRemoverExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala index 9fa494cd2473..3f0e870c8dc6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/StringIndexerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.StringIndexer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object StringIndexerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala index 40c33e4e7d44..28115f939082 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object TfIdfExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala index 01e0d1388a2f..c667728d6326 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TokenizerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.{RegexTokenizer, Tokenizer} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object TokenizerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala index cd1b0e9358be..fbba17eba6a2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala @@ -17,11 +17,11 @@ package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.evaluation.RegressionEvaluator import org.apache.spark.ml.regression.LinearRegression import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} /** * A simple example demonstrating model selection using TrainValidationSplit. diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala index d527924419f8..768a8c069047 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorAssemblerExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.mllib.linalg.Vectors // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object VectorAssemblerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala index 685891c164e7..3bef37ba360b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.VectorIndexer // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object VectorIndexerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala index 04f19829eff8..01377d80e7e5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorSlicerExample.scala @@ -18,6 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute} import org.apache.spark.ml.feature.VectorSlicer @@ -26,7 +27,6 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.types.StructType // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object VectorSlicerExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala index 631ab4c8efa0..e77aa59ba32b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/Word2VecExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.ml +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.ml.feature.Word2Vec // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object Word2VecExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index ca22ddafc3c4..11e18c9f040b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -18,13 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.fpm.AssociationRules import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset // $example off$ -import org.apache.spark.{SparkConf, SparkContext} - object AssociationRulesExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala index 1a4016f76c2a..2282bd2b7d68 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala @@ -24,8 +24,8 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, SVMWithSGD} import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics +import org.apache.spark.mllib.optimization.{L1Updater, SquaredL2Updater} import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.optimization.{SquaredL2Updater, L1Updater} /** * An example app for binary classification. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala index 13a37827ab93..ade33fc5090f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -18,13 +18,13 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils // $example off$ -import org.apache.spark.{SparkContext, SparkConf} object BinaryClassificationMetricsExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala index 3a596cccb87d..53d0b8fc208e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BisectingKMeansExample.scala @@ -18,11 +18,11 @@ package org.apache.spark.examples.mllib // scalastyle:off println +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.clustering.BisectingKMeans import org.apache.spark.mllib.linalg.{Vector, Vectors} // $example off$ -import org.apache.spark.{SparkConf, SparkContext} /** * An example demonstrating a bisecting k-means clustering in spark.mllib. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala index 026d4ecc6d10..e003f35ed399 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/Correlations.scala @@ -20,10 +20,9 @@ package org.apache.spark.examples.mllib import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.{SparkConf, SparkContext} - /** * An example app for summarizing multivariate data from a file. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala index 69988cc1b933..eda211b5a8df 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/CosineSimilarity.scala @@ -20,10 +20,10 @@ package org.apache.spark.examples.mllib import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.distributed.{MatrixEntry, RowMatrix} -import org.apache.spark.{SparkConf, SparkContext} /** * Compute the similar columns of a matrix, using cosine similarity. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala index d427bbadaa0c..c6c7c6f5e2ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // $example off$ -import org.apache.spark.{SparkConf, SparkContext} object DecisionTreeClassificationExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala index fb05e7d9c506..9c8baed3b866 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.DecisionTree import org.apache.spark.mllib.tree.model.DecisionTreeModel import org.apache.spark.mllib.util.MLUtils // $example off$ -import org.apache.spark.{SparkConf, SparkContext} object DecisionTreeRegressionExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index cc6bce3cb7c9..c263f4f595a3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -26,7 +26,7 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.{DecisionTree, RandomForest, impurity} +import org.apache.spark.mllib.tree.{impurity, DecisionTree, RandomForest} import org.apache.spark.mllib.tree.configuration.{Algo, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.util.MLUtils diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 14b930550d55..a7a3eade04a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -20,8 +20,8 @@ package org.apache.spark.examples.mllib import scopt.OptionParser -import org.apache.spark.mllib.fpm.FPGrowth import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.fpm.FPGrowth /** * Example for mining frequent itemsets using FP-growth. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala index e16a6bf03357..b0144ef53313 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostedTreesRunner.scala @@ -23,10 +23,9 @@ import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.tree.GradientBoostedTrees -import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo} +import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy} import org.apache.spark.util.Utils - /** * An example runner for Gradient Boosting using decision trees as weak learners. Run with * {{{ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala index 139e1f909bdc..0ec2e11214e8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala index 3dc86da8e4d2..b87ba0defe69 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.GradientBoostedTrees import org.apache.spark.mllib.tree.configuration.BoostingStrategy diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 52ac9ae7dd2d..3834ea807acb 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -18,10 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.regression.{IsotonicRegression, IsotonicRegressionModel} // $example off$ -import org.apache.spark.{SparkConf, SparkContext} object IsotonicRegressionExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala index 61d2e7715f53..75a0419da5ec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -18,6 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.classification.LogisticRegressionModel import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics @@ -26,8 +27,6 @@ import org.apache.spark.mllib.optimization.{LBFGS, LogisticGradient, SquaredL2Up import org.apache.spark.mllib.util.MLUtils // $example off$ -import org.apache.spark.{SparkConf, SparkContext} - object LBFGSExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala index 70010b05e434..d28323555b99 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala @@ -18,16 +18,16 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.log4j.{Level, Logger} import scopt.OptionParser -import org.apache.log4j.{Level, Logger} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.ml.Pipeline import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover} import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer} import org.apache.spark.mllib.linalg.Vector import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.{SparkConf, SparkContext} /** * An example Latent Dirichlet Allocation (LDA) app. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala index 8878061a0970..f87611f5d461 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala @@ -22,9 +22,9 @@ import org.apache.log4j.{Level, Logger} import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.mllib.optimization.{L1Updater, SimpleUpdater, SquaredL2Updater} import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.optimization.{SimpleUpdater, SquaredL2Updater, L1Updater} /** * An example app for linear regression. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala index 4503c15360ad..c0d447bf69dd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultiLabelMetricsExample.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.evaluation.MultilabelMetrics import org.apache.spark.rdd.RDD // $example off$ -import org.apache.spark.{SparkContext, SparkConf} object MultiLabelMetricsExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala index 090444924598..4f925ede24d8 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MulticlassMetricsExample.scala @@ -18,13 +18,13 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLUtils // $example off$ -import org.apache.spark.{SparkContext, SparkConf} object MulticlassMetricsExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala index 5f839c75dd58..3c598172dadf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/MultivariateSummarizer.scala @@ -20,11 +20,10 @@ package org.apache.spark.examples.mllib import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.{SparkConf, SparkContext} - /** * An example app for summarizing multivariate data from a file. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala index a7a47c2a3556..8bae1b9d1832 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/NaiveBayesExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel} import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint // $example off$ -import org.apache.spark.{SparkConf, SparkContext} object NaiveBayesExample { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala index 072322395461..9208d8e24588 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PowerIterationClusteringExample.scala @@ -21,9 +21,9 @@ package org.apache.spark.examples.mllib import org.apache.log4j.{Level, Logger} import scopt.OptionParser +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.clustering.PowerIterationClustering import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkConf, SparkContext} /** * An example Power Iteration Clustering http://www.icml2010.org/papers/387.pdf app. diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index d237232c430c..ef86eab9e4ec 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -18,12 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.fpm.PrefixSpan // $example off$ -import org.apache.spark.{SparkConf, SparkContext} - object PrefixSpanExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala index 5e55abd5121c..7805153ba7b9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.RandomForest import org.apache.spark.mllib.tree.model.RandomForestModel diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala index a54fb3ab7e37..655a277e28ae 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.tree.RandomForest import org.apache.spark.mllib.tree.model.RandomForestModel diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala index bee85ba0f996..7ccbb5a0640c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomRDDGeneration.scala @@ -18,11 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.rdd.RDD -import org.apache.spark.{SparkConf, SparkContext} - /** * An example app for randomly generated RDDs. Run with * {{{ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala index cffa03d5cc9f..fdb01b86dd78 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RankingMetricsExample.scala @@ -18,12 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ -import org.apache.spark.mllib.evaluation.{RegressionMetrics, RankingMetrics} +import org.apache.spark.mllib.evaluation.{RankingMetrics, RegressionMetrics} import org.apache.spark.mllib.recommendation.{ALS, Rating} // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkContext, SparkConf} object RankingMetricsExample { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala index 64e460246544..bc946951aebf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -18,7 +18,7 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.{SparkContext, SparkConf} +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.recommendation.ALS import org.apache.spark.mllib.recommendation.MatrixFactorizationModel diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala index 47d44532521c..ace16ff1ea22 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RegressionMetricsExample.scala @@ -18,13 +18,13 @@ package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ -import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.evaluation.RegressionMetrics +import org.apache.spark.mllib.regression.LinearRegressionWithSGD import org.apache.spark.mllib.util.MLUtils // $example off$ import org.apache.spark.sql.SQLContext -import org.apache.spark.{SparkConf, SparkContext} object RegressionMetricsExample { def main(args: Array[String]) : Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala index 6963f43e082c..c4e5e965b8f4 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SampledRDDs.scala @@ -18,11 +18,11 @@ // scalastyle:off println package org.apache.spark.examples.mllib -import org.apache.spark.mllib.util.MLUtils import scopt.OptionParser import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ +import org.apache.spark.mllib.util.MLUtils /** * An example app for randomly generated and sampled RDDs. Run with diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index b4e06afa7410..ab15ac2c54d3 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -18,13 +18,12 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.{SparkConf, SparkContext} // $example on$ import org.apache.spark.mllib.fpm.FPGrowth import org.apache.spark.rdd.RDD // $example off$ -import org.apache.spark.{SparkContext, SparkConf} - object SimpleFPGrowth { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala index b4a5dca031ab..e5592966f13f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLinearRegression.scala @@ -18,9 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.SparkConf import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LabeledPoint, StreamingLinearRegressionWithSGD} -import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} /** diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala index b42f4cb5f933..a8b144a19722 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingLogisticRegression.scala @@ -18,10 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.mllib +import org.apache.spark.SparkConf +import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.classification.StreamingLogisticRegressionWithSGD -import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} /** diff --git a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala index 0a25ee7ae56f..e252ca882e53 100644 --- a/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala +++ b/examples/src/main/scala/org/apache/spark/examples/pythonconverters/HBaseConverters.scala @@ -20,12 +20,13 @@ package org.apache.spark.examples.pythonconverters import scala.collection.JavaConverters._ import scala.util.parsing.json.JSONObject -import org.apache.spark.api.python.Converter +import org.apache.hadoop.hbase.CellUtil +import org.apache.hadoop.hbase.KeyValue.Type import org.apache.hadoop.hbase.client.{Put, Result} import org.apache.hadoop.hbase.io.ImmutableBytesWritable import org.apache.hadoop.hbase.util.Bytes -import org.apache.hadoop.hbase.KeyValue.Type -import org.apache.hadoop.hbase.CellUtil + +import org.apache.spark.api.python.Converter /** * Implementation of [[org.apache.spark.api.python.Converter]] that converts all diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala index bf40bd1ef13d..4e427f54daa5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/hive/HiveFromSpark.scala @@ -18,10 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.sql.hive -import com.google.common.io.{ByteStreams, Files} - import java.io.File +import com.google.common.io.{ByteStreams, Files} + import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.sql._ import org.apache.spark.sql.hive.HiveContext diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala index e9c990719876..88cdc6bc144e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ActorWordCount.scala @@ -22,13 +22,12 @@ import scala.collection.mutable.LinkedList import scala.reflect.ClassTag import scala.util.Random -import akka.actor.{Actor, ActorRef, Props, actorRef2Scala} +import akka.actor.{actorRef2Scala, Actor, ActorRef, Props} -import org.apache.spark.{SparkConf, SecurityManager} +import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.streaming.{Seconds, StreamingContext} -import org.apache.spark.streaming.StreamingContext.toPairDStreamFunctions +import org.apache.spark.streaming.receiver.ActorReceiver import org.apache.spark.util.AkkaUtils -import org.apache.spark.streaming.receiver.ActorHelper case class SubscribeReceiver(receiverActor: ActorRef) case class UnsubscribeReceiver(receiverActor: ActorRef) @@ -62,15 +61,13 @@ class FeederActor extends Actor { }.start() def receive: Receive = { - case SubscribeReceiver(receiverActor: ActorRef) => println("received subscribe from %s".format(receiverActor.toString)) - receivers = LinkedList(receiverActor) ++ receivers + receivers = LinkedList(receiverActor) ++ receivers case UnsubscribeReceiver(receiverActor: ActorRef) => println("received unsubscribe from %s".format(receiverActor.toString)) - receivers = receivers.dropWhile(x => x eq receiverActor) - + receivers = receivers.dropWhile(x => x eq receiverActor) } } @@ -82,7 +79,7 @@ class FeederActor extends Actor { * @see [[org.apache.spark.examples.streaming.FeederActor]] */ class SampleActorReceiver[T: ClassTag](urlOfPublisher: String) -extends Actor with ActorHelper { +extends ActorReceiver { lazy private val remotePublisher = context.actorSelection(urlOfPublisher) @@ -129,9 +126,9 @@ object FeederActor { * and describe the AkkaSystem that Spark Sample feeder is running on. * * To run this example locally, you may run Feeder Actor as - * `$ bin/run-example org.apache.spark.examples.streaming.FeederActor 127.0.1.1 9999` + * `$ bin/run-example org.apache.spark.examples.streaming.FeederActor localhost 9999` * and then run the example - * `$ bin/run-example org.apache.spark.examples.streaming.ActorWordCount 127.0.1.1 9999` + * `$ bin/run-example org.apache.spark.examples.streaming.ActorWordCount localhost 9999` */ object ActorWordCount { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 28e9bf520e56..ad13d437dd54 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -18,10 +18,10 @@ // scalastyle:off println package org.apache.spark.examples.streaming -import java.io.{InputStreamReader, BufferedReader, InputStream} +import java.io.{BufferedReader, InputStream, InputStreamReader} import java.net.Socket -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.receiver.Receiver diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala index 2bdbc37e2a28..fe3b79ed5d29 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala @@ -18,12 +18,13 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import java.net.InetSocketAddress + import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.flume._ import org.apache.spark.util.IntParam -import java.net.InetSocketAddress /** * Produces a count of events received from Flume. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala index b40d17e9c2fa..e7f9bf36e35c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/KafkaWordCount.scala @@ -20,11 +20,11 @@ package org.apache.spark.examples.streaming import java.util.HashMap -import org.apache.kafka.clients.producer.{ProducerConfig, KafkaProducer, ProducerRecord} +import org.apache.kafka.clients.producer.{KafkaProducer, ProducerConfig, ProducerRecord} +import org.apache.spark.SparkConf import org.apache.spark.streaming._ import org.apache.spark.streaming.kafka._ -import org.apache.spark.SparkConf /** * Consumes messages from one or more topics in Kafka and does wordcount. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala index 9a57fe286d1a..15b57fccb407 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/NetworkWordCount.scala @@ -19,8 +19,8 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf -import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext} /** * Counts words in UTF8 encoded, '\n' delimited text received from the network every second. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 9916882e4f94..05f8e65d65a2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -23,13 +23,55 @@ import java.nio.charset.Charset import com.google.common.io.Files -import org.apache.spark.SparkConf +import org.apache.spark.{Accumulator, SparkConf, SparkContext} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Time, Seconds, StreamingContext} +import org.apache.spark.streaming.{Seconds, StreamingContext, Time} import org.apache.spark.util.IntParam /** - * Counts words in text encoded with UTF8 received from the network every second. + * Use this singleton to get or register a Broadcast variable. + */ +object WordBlacklist { + + @volatile private var instance: Broadcast[Seq[String]] = null + + def getInstance(sc: SparkContext): Broadcast[Seq[String]] = { + if (instance == null) { + synchronized { + if (instance == null) { + val wordBlacklist = Seq("a", "b", "c") + instance = sc.broadcast(wordBlacklist) + } + } + } + instance + } +} + +/** + * Use this singleton to get or register an Accumulator. + */ +object DroppedWordsCounter { + + @volatile private var instance: Accumulator[Long] = null + + def getInstance(sc: SparkContext): Accumulator[Long] = { + if (instance == null) { + synchronized { + if (instance == null) { + instance = sc.accumulator(0L, "WordsInBlacklistCounter") + } + } + } + instance + } +} + +/** + * Counts words in text encoded with UTF8 received from the network every second. This example also + * shows how to use lazily instantiated singleton instances for Accumulator and Broadcast so that + * they can be registered on driver failures. * * Usage: RecoverableNetworkWordCount * and describe the TCP server that Spark Streaming would connect to receive @@ -75,10 +117,24 @@ object RecoverableNetworkWordCount { val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { - val counts = "Counts at time " + time + " " + rdd.collect().mkString("[", ", ", "]") - println(counts) + // Get or register the blacklist Broadcast + val blacklist = WordBlacklist.getInstance(rdd.sparkContext) + // Get or register the droppedWordsCounter Accumulator + val droppedWordsCounter = DroppedWordsCounter.getInstance(rdd.sparkContext) + // Use blacklist to drop words and use droppedWordsCounter to count them + val counts = rdd.filter { case (word, count) => + if (blacklist.value.contains(word)) { + droppedWordsCounter += count + false + } else { + true + } + }.collect().mkString("[", ", ", "]") + val output = "Counts at time " + time + " " + counts + println(output) + println("Dropped " + droppedWordsCounter.value + " word(s) totally") println("Appending to " + outputFile.getAbsolutePath) - Files.append(counts + "\n", outputFile, Charset.defaultCharset()) + Files.append(output + "\n", outputFile, Charset.defaultCharset()) }) ssc } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index ed617754cbf1..9aa0f54312d2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -21,10 +21,10 @@ package org.apache.spark.examples.streaming import org.apache.spark.SparkConf import org.apache.spark.SparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Time, Seconds, StreamingContext} -import org.apache.spark.util.IntParam import org.apache.spark.sql.SQLContext import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext, Time} +import org.apache.spark.util.IntParam /** * Use DataFrames and SQL to count words in UTF8 encoded, '\n' delimited text received from the diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala index 2dce1820d973..c85d6843dc99 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StatefulNetworkWordCount.scala @@ -18,8 +18,8 @@ // scalastyle:off println package org.apache.spark.examples.streaming -import org.apache.spark.SparkConf import org.apache.spark.HashPartitioner +import org.apache.spark.SparkConf import org.apache.spark.streaming._ /** diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala index 8396e65d0d58..22a5654405dd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/StreamingExamples.scala @@ -17,10 +17,10 @@ package org.apache.spark.examples.streaming -import org.apache.spark.Logging - import org.apache.log4j.{Level, Logger} +import org.apache.spark.Logging + /** Utility functions for Spark Streaming examples. */ object StreamingExamples extends Logging { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala index 49826ede7041..0ec6214fdef1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterAlgebirdHLL.scala @@ -18,13 +18,13 @@ // scalastyle:off println package org.apache.spark.examples.streaming -import com.twitter.algebird.HyperLogLogMonoid import com.twitter.algebird.HyperLogLog._ +import com.twitter.algebird.HyperLogLogMonoid +import org.apache.spark.SparkConf import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.twitter._ -import org.apache.spark.SparkConf // scalastyle:off /** diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala new file mode 100644 index 000000000000..edf0e0b7b2b4 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/TwitterHashTagJoinSentiments.scala @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.streaming + +import org.apache.spark.SparkConf +import org.apache.spark.streaming.{Seconds, StreamingContext} +import org.apache.spark.streaming.twitter.TwitterUtils + +/** + * Displays the most positive hash tags by joining the streaming Twitter data with a static RDD of + * the AFINN word list (http://neuro.imm.dtu.dk/wiki/AFINN) + */ +object TwitterHashTagJoinSentiments { + def main(args: Array[String]) { + if (args.length < 4) { + System.err.println("Usage: TwitterHashTagJoinSentiments " + + " []") + System.exit(1) + } + + StreamingExamples.setStreamingLogLevels() + + val Array(consumerKey, consumerSecret, accessToken, accessTokenSecret) = args.take(4) + val filters = args.takeRight(args.length - 4) + + // Set the system properties so that Twitter4j library used by Twitter stream + // can use them to generate OAuth credentials + System.setProperty("twitter4j.oauth.consumerKey", consumerKey) + System.setProperty("twitter4j.oauth.consumerSecret", consumerSecret) + System.setProperty("twitter4j.oauth.accessToken", accessToken) + System.setProperty("twitter4j.oauth.accessTokenSecret", accessTokenSecret) + + val sparkConf = new SparkConf().setAppName("TwitterHashTagJoinSentiments") + val ssc = new StreamingContext(sparkConf, Seconds(2)) + val stream = TwitterUtils.createStream(ssc, None, filters) + + val hashTags = stream.flatMap(status => status.getText.split(" ").filter(_.startsWith("#"))) + + // Read in the word-sentiment list and create a static RDD from it + val wordSentimentFilePath = "data/streaming/AFINN-111.txt" + val wordSentiments = ssc.sparkContext.textFile(wordSentimentFilePath).map { line => + val Array(word, happinessValue) = line.split("\t") + (word, happinessValue) + } cache() + + // Determine the hash tags with the highest sentiment values by joining the streaming RDD + // with the static RDD inside the transform() method and then multiplying + // the frequency of the hash tag by its sentiment value + val happiest60 = hashTags.map(hashTag => (hashTag.tail, 1)) + .reduceByKeyAndWindow(_ + _, Seconds(60)) + .transform{topicCount => wordSentiments.join(topicCount)} + .map{case (topic, tuple) => (topic, tuple._1 * tuple._2)} + .map{case (topic, happinessValue) => (happinessValue, topic)} + .transform(_.sortByKey(false)) + + val happiest10 = hashTags.map(hashTag => (hashTag.tail, 1)) + .reduceByKeyAndWindow(_ + _, Seconds(10)) + .transform{topicCount => wordSentiments.join(topicCount)} + .map{case (topic, tuple) => (topic, tuple._1 * tuple._2)} + .map{case (topic, happinessValue) => (happinessValue, topic)} + .transform(_.sortByKey(false)) + + // Print hash tags with the most positive sentiment values + happiest60.foreachRDD(rdd => { + val topList = rdd.take(10) + println("\nHappiest topics in last 60 seconds (%s total):".format(rdd.count())) + topList.foreach{case (happiness, tag) => println("%s (%s happiness)".format(tag, happiness))} + }) + + happiest10.foreachRDD(rdd => { + val topList = rdd.take(10) + println("\nHappiest topics in last 10 seconds (%s total):".format(rdd.count())) + topList.foreach{case (happiness, tag) => println("%s (%s happiness)".format(tag, happiness))} + }) + + ssc.start() + ssc.awaitTermination() + } +} +// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala index 6ac9a72c3794..96448905760f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/ZeroMQWordCount.scala @@ -18,18 +18,18 @@ // scalastyle:off println package org.apache.spark.examples.streaming +import scala.language.implicitConversions + import akka.actor.ActorSystem import akka.actor.actorRef2Scala +import akka.util.ByteString import akka.zeromq._ import akka.zeromq.Subscribe -import akka.util.ByteString +import org.apache.spark.SparkConf import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.zeromq._ -import scala.language.implicitConversions -import org.apache.spark.SparkConf - /** * A simple publisher for demonstration purposes, repeatedly publishes random Messages * every one second. diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 2fcccb22dddf..ce1a62060ef6 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -18,9 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.streaming.clickstream -import java.net.ServerSocket import java.io.PrintWriter -import util.Random +import java.net.ServerSocket +import java.util.Random /** Represents a page view on a website with associated dimension data. */ class PageView(val url : String, val status : Int, val zipCode : Int, val userID : Int) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 723616817f6a..4b43550a065b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -18,8 +18,9 @@ // scalastyle:off println package org.apache.spark.examples.streaming.clickstream -import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.examples.streaming.StreamingExamples +import org.apache.spark.streaming.{Seconds, StreamingContext} + // scalastyle:off /** Analyses a streaming dataset of web page views. This class demonstrates several types of * operators available in Spark streaming. diff --git a/external/flume-assembly/pom.xml b/external/flume-assembly/pom.xml index dceedcf23ed5..b2c377fe4cc9 100644 --- a/external/flume-assembly/pom.xml +++ b/external/flume-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/pom.xml b/external/flume-sink/pom.xml index 75113ff753e7..4b6485ee0a71 100644 --- a/external/flume-sink/pom.xml +++ b/external/flume-sink/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index 7ad43b1d7b0a..b15c2097e550 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{Callable, CountDownLatch, TimeUnit} import scala.util.control.Breaks -import org.apache.flume.{Transaction, Channel} +import org.apache.flume.{Channel, Transaction} // Flume forces transactions to be thread-local (horrible, I know!) // So the sink basically spawns a new thread to pull the events out within a transaction. diff --git a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala index 941fde45cd7b..7f6cecf9cd18 100644 --- a/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala +++ b/external/flume-sink/src/test/scala/org/apache/spark/streaming/flume/sink/SparkSinkSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.flume.sink import java.net.InetSocketAddress +import java.util.concurrent.{CountDownLatch, Executors, TimeUnit} import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{TimeUnit, CountDownLatch, Executors} import scala.collection.JavaConverters._ import scala.concurrent.{ExecutionContext, Future} diff --git a/external/flume/pom.xml b/external/flume/pom.xml index 57f83607365d..a79656c6f7d9 100644 --- a/external/flume/pom.xml +++ b/external/flume/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala index 48df27b26867..5c773d4b07cf 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/EventTransformer.scala @@ -17,12 +17,12 @@ package org.apache.spark.streaming.flume -import java.io.{ObjectOutput, ObjectInput} +import java.io.{ObjectInput, ObjectOutput} import scala.collection.JavaConverters._ -import org.apache.spark.util.Utils import org.apache.spark.Logging +import org.apache.spark.util.Utils /** * A simple object that provides the implementation of readExternal and writeExternal for both diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 2b9116eb3c79..1bfa35a8b3d1 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -17,29 +17,27 @@ package org.apache.spark.streaming.flume +import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.net.InetSocketAddress -import java.io.{ObjectInput, ObjectOutput, Externalizable} import java.nio.ByteBuffer import java.util.concurrent.Executors import scala.collection.JavaConverters._ import scala.reflect.ClassTag -import org.apache.flume.source.avro.AvroSourceProtocol -import org.apache.flume.source.avro.AvroFlumeEvent -import org.apache.flume.source.avro.Status -import org.apache.avro.ipc.specific.SpecificResponder import org.apache.avro.ipc.NettyServer +import org.apache.avro.ipc.specific.SpecificResponder +import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol, Status} +import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} +import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory +import org.jboss.netty.handler.codec.compression._ + import org.apache.spark.Logging -import org.apache.spark.util.Utils import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.receiver.Receiver - -import org.jboss.netty.channel.{ChannelPipeline, ChannelPipelineFactory, Channels} -import org.jboss.netty.channel.socket.nio.NioServerSocketChannelFactory -import org.jboss.netty.handler.codec.compression._ +import org.apache.spark.util.Utils private[streaming] class FlumeInputDStream[T: ClassTag]( diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 6737750c3d63..d9c25e86540d 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -32,8 +32,8 @@ import org.apache.spark.Logging import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.flume.sink._ +import org.apache.spark.streaming.receiver.Receiver /** * A [[ReceiverInputDStream]] that can be used to read data from several Flume agents running diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala index fe5dcc8e4b9d..3f87ce46e595 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeTestUtils.scala @@ -29,7 +29,7 @@ import org.apache.avro.ipc.NettyTransceiver import org.apache.avro.ipc.specific.SpecificRequestor import org.apache.commons.lang3.RandomUtils import org.apache.flume.source.avro -import org.apache.flume.source.avro.{AvroSourceProtocol, AvroFlumeEvent} +import org.apache.flume.source.avro.{AvroFlumeEvent, AvroSourceProtocol} import org.jboss.netty.channel.ChannelPipeline import org.jboss.netty.channel.socket.SocketChannel import org.jboss.netty.channel.socket.nio.NioClientSocketChannelFactory diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index c719b80aca7e..3e3ed712f0db 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.flume +import java.io.{ByteArrayOutputStream, DataOutputStream} import java.net.InetSocketAddress -import java.io.{DataOutputStream, ByteArrayOutputStream} import java.util.{List => JList, Map => JMap} import scala.collection.JavaConverters._ @@ -30,7 +30,6 @@ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream - object FlumeUtils { private val DEFAULT_POLLING_PARALLELISM = 5 private val DEFAULT_POLLING_BATCH_SIZE = 1000 diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala index bfe7548d4f50..9515d07c5ee5 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.flume -import java.util.concurrent._ import java.util.{Collections, List => JList, Map => JMap} +import java.util.concurrent._ import scala.collection.mutable.ArrayBuffer @@ -28,7 +28,7 @@ import org.apache.flume.Context import org.apache.flume.channel.MemoryChannel import org.apache.flume.conf.Configurables -import org.apache.spark.streaming.flume.sink.{SparkSinkConfig, SparkSink} +import org.apache.spark.streaming.flume.sink.{SparkSink, SparkSinkConfig} /** * Share codes for Scala and Python unit tests diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala index 79077e4a49e1..57374ef51543 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/TestOutputStream.scala @@ -19,13 +19,13 @@ package org.apache.spark.streaming import java.io.{IOException, ObjectInputStream} +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD import org.apache.spark.streaming.dstream.{DStream, ForEachDStream} import org.apache.spark.util.Utils -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - /** * This is a output stream just for the testsuites. All the output is collected into a * ArrayBuffer. This buffer is wiped clean on being restored from checkpoint. diff --git a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala index bb951a6ef100..60db846ffb7a 100644 --- a/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala +++ b/external/flume/src/test/scala/org/apache/spark/streaming/flume/FlumePollingStreamSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.flume import java.net.InetSocketAddress import scala.collection.JavaConverters._ -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.concurrent.duration._ import scala.language.postfixOps @@ -30,8 +30,8 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.network.util.JavaUtils import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext, TestOutputStream} import org.apache.spark.streaming.dstream.ReceiverInputDStream -import org.apache.spark.streaming.{Seconds, TestOutputStream, StreamingContext} import org.apache.spark.util.{ManualClock, Utils} class FlumePollingStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { diff --git a/external/kafka-assembly/pom.xml b/external/kafka-assembly/pom.xml index a9ed39ef8c9a..0c466b3c4ac3 100644 --- a/external/kafka-assembly/pom.xml +++ b/external/kafka-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/pom.xml b/external/kafka/pom.xml index 79258c126e04..5180ab6dbafb 100644 --- a/external/kafka/pom.xml +++ b/external/kafka/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala index 8465432c5850..c4e18d92eefa 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaCluster.scala @@ -17,14 +17,17 @@ package org.apache.spark.streaming.kafka -import scala.util.control.NonFatal -import scala.util.Random -import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConverters._ import java.util.Properties + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer +import scala.util.Random +import scala.util.control.NonFatal + import kafka.api._ import kafka.common.{ErrorMapping, OffsetAndMetadata, OffsetMetadataAndError, TopicAndPartition} import kafka.consumer.{ConsumerConfig, SimpleConsumer} + import org.apache.spark.SparkException /** diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala index 38730fecf332..67f2360896b1 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaInputDStream.scala @@ -22,7 +22,7 @@ import java.util.Properties import scala.collection.Map import scala.reflect.{classTag, ClassTag} -import kafka.consumer.{KafkaStream, Consumer, ConsumerConfig, ConsumerConnector} +import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala index ea5f842c6caf..603be2281820 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDD.scala @@ -20,11 +20,6 @@ package org.apache.spark.streaming.kafka import scala.collection.mutable.ArrayBuffer import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} -import org.apache.spark.partial.{PartialResult, BoundedDouble} -import org.apache.spark.rdd.RDD -import org.apache.spark.util.NextIterator - import kafka.api.{FetchRequestBuilder, FetchResponse} import kafka.common.{ErrorMapping, TopicAndPartition} import kafka.consumer.SimpleConsumer @@ -32,6 +27,11 @@ import kafka.message.{MessageAndMetadata, MessageAndOffset} import kafka.serializer.Decoder import kafka.utils.VerifiableProperties +import org.apache.spark.{Logging, Partition, SparkContext, SparkException, TaskContext} +import org.apache.spark.partial.{BoundedDouble, PartialResult} +import org.apache.spark.rdd.RDD +import org.apache.spark.util.NextIterator + /** * A batch-oriented interface for consuming from Kafka. * Starting and ending offsets are specified in advance, diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala index 45a6982b9afe..a76fa6671a4b 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaTestUtils.scala @@ -20,8 +20,8 @@ package org.apache.spark.streaming.kafka import java.io.File import java.lang.{Integer => JInt} import java.net.InetSocketAddress -import java.util.concurrent.TimeoutException import java.util.{Map => JMap, Properties} +import java.util.concurrent.TimeoutException import scala.annotation.tailrec import scala.collection.JavaConverters._ @@ -37,9 +37,9 @@ import kafka.utils.{ZKStringSerializer, ZkUtils} import org.I0Itec.zkclient.ZkClient import org.apache.zookeeper.server.{NIOServerCnxnFactory, ZooKeeperServer} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.streaming.Time import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf} /** * This is a helper class for Kafka test suites. This has the functionality to set up diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala index fe572220528d..0cb875c9758f 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaUtils.scala @@ -27,19 +27,19 @@ import scala.reflect.ClassTag import com.google.common.base.Charsets.UTF_8 import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata -import kafka.serializer.{DefaultDecoder, Decoder, StringDecoder} -import net.razorvine.pickle.{Opcodes, Pickler, IObjectPickler} +import kafka.serializer.{Decoder, DefaultDecoder, StringDecoder} +import net.razorvine.pickle.{IObjectPickler, Opcodes, Pickler} -import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.{SparkContext, SparkException} -import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} +import org.apache.spark.api.java.{JavaPairRDD, JavaRDD, JavaSparkContext} +import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java._ import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream} +import org.apache.spark.streaming.util.WriteAheadLogUtils object KafkaUtils { /** diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala index 764d170934aa..a872781b78ee 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/ReliableKafkaReceiver.scala @@ -18,10 +18,10 @@ package org.apache.spark.streaming.kafka import java.util.Properties -import java.util.concurrent.{ThreadPoolExecutor, ConcurrentHashMap} +import java.util.concurrent.{ConcurrentHashMap, ThreadPoolExecutor} -import scala.collection.{Map, mutable} -import scala.reflect.{ClassTag, classTag} +import scala.collection.{mutable, Map} +import scala.reflect.{classTag, ClassTag} import kafka.common.TopicAndPartition import kafka.consumer.{Consumer, ConsumerConfig, ConsumerConnector, KafkaStream} diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java index fbdfbf7e509b..4891e4f4a17b 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaDirectKafkaStreamSuite.java @@ -35,6 +35,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.Durations; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; @@ -130,17 +131,15 @@ public String call(MessageAndMetadata msgAndMd) { JavaDStream unifiedStream = stream1.union(stream2); final Set result = Collections.synchronizedSet(new HashSet()); - unifiedStream.foreachRDD( - new Function, Void>() { + unifiedStream.foreachRDD(new VoidFunction>() { @Override - public Void call(JavaRDD rdd) { + public void call(JavaRDD rdd) { result.addAll(rdd.collect()); for (OffsetRange o : offsetRanges.get()) { System.out.println( o.topic() + " " + o.partition() + " " + o.fromOffset() + " " + o.untilOffset() ); } - return null; } } ); diff --git a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java index 1e69de46cd35..617c92a008fc 100644 --- a/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java +++ b/external/kafka/src/test/java/org/apache/spark/streaming/kafka/JavaKafkaStreamSuite.java @@ -31,6 +31,7 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.Duration; import org.apache.spark.streaming.api.java.JavaDStream; @@ -103,10 +104,9 @@ public String call(Tuple2 tuple2) { } ); - words.countByValue().foreachRDD( - new Function, Void>() { + words.countByValue().foreachRDD(new VoidFunction>() { @Override - public Void call(JavaPairRDD rdd) { + public void call(JavaPairRDD rdd) { List> ret = rdd.collect(); for (Tuple2 r : ret) { if (result.containsKey(r._1())) { @@ -115,8 +115,6 @@ public Void call(JavaPairRDD rdd) { result.put(r._1(), r._2()); } } - - return null; } } ); diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala index 02225d5aa7cc..655b161734d9 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/DirectKafkaStreamSuite.scala @@ -20,9 +20,6 @@ package org.apache.spark.streaming.kafka import java.io.File import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset -import org.apache.spark.streaming.scheduler.rate.RateEstimator - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ @@ -38,7 +35,9 @@ import org.apache.spark.{Logging, SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Milliseconds, StreamingContext, Time} import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.streaming.kafka.KafkaCluster.LeaderOffset import org.apache.spark.streaming.scheduler._ +import org.apache.spark.streaming.scheduler.rate.RateEstimator import org.apache.spark.util.Utils class DirectKafkaStreamSuite diff --git a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala index f52a738afd65..5e539c1d790c 100644 --- a/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala +++ b/external/kafka/src/test/scala/org/apache/spark/streaming/kafka/KafkaRDDSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.streaming.kafka import scala.util.Random -import kafka.serializer.StringDecoder import kafka.common.TopicAndPartition import kafka.message.MessageAndMetadata +import kafka.serializer.StringDecoder import org.scalatest.BeforeAndAfterAll import org.apache.spark._ diff --git a/external/mqtt-assembly/pom.xml b/external/mqtt-assembly/pom.xml index 89713a28ca6a..c4a1ae26ea69 100644 --- a/external/mqtt-assembly/pom.xml +++ b/external/mqtt-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/mqtt/pom.xml b/external/mqtt/pom.xml index 59fba8b826b4..d3a2bf5825b0 100644 --- a/external/mqtt/pom.xml +++ b/external/mqtt/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml @@ -51,7 +51,7 @@ org.eclipse.paho org.eclipse.paho.client.mqttv3 - 1.0.1 + 1.0.2 org.scalacheck diff --git a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala index 1618e2c088b7..26c6dc45d511 100644 --- a/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala +++ b/external/mqtt/src/test/scala/org/apache/spark/streaming/mqtt/MQTTTestUtils.scala @@ -27,8 +27,8 @@ import org.apache.commons.lang3.RandomUtils import org.eclipse.paho.client.mqttv3._ import org.eclipse.paho.client.mqttv3.persist.MqttDefaultFilePersistence -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.Utils /** * Share codes for Scala and Python unit tests diff --git a/external/twitter/pom.xml b/external/twitter/pom.xml index 087270de90b3..7b628b09ea6a 100644 --- a/external/twitter/pom.xml +++ b/external/twitter/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala index 9a85a6597c27..a48eec70b9f7 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterInputDStream.scala @@ -19,13 +19,13 @@ package org.apache.spark.streaming.twitter import twitter4j._ import twitter4j.auth.Authorization -import twitter4j.conf.ConfigurationBuilder import twitter4j.auth.OAuthAuthorization +import twitter4j.conf.ConfigurationBuilder +import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream._ -import org.apache.spark.storage.StorageLevel -import org.apache.spark.Logging import org.apache.spark.streaming.receiver.Receiver /* A stream of Twitter statuses, potentially filtered by one or more keywords. diff --git a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala index c6a9a2b73714..3e843e947da6 100644 --- a/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala +++ b/external/twitter/src/main/scala/org/apache/spark/streaming/twitter/TwitterUtils.scala @@ -19,10 +19,11 @@ package org.apache.spark.streaming.twitter import twitter4j.Status import twitter4j.auth.Authorization + import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaReceiverInputDStream, JavaDStream, JavaStreamingContext} -import org.apache.spark.streaming.dstream.{ReceiverInputDStream, DStream} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaReceiverInputDStream, JavaStreamingContext} +import org.apache.spark.streaming.dstream.{DStream, ReceiverInputDStream} object TwitterUtils { /** diff --git a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala index d9acb568879f..7e5fc0cbb9b3 100644 --- a/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala +++ b/external/twitter/src/test/scala/org/apache/spark/streaming/twitter/TwitterStreamSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.streaming.twitter - import org.scalatest.BeforeAndAfter import twitter4j.Status -import twitter4j.auth.{NullAuthorization, Authorization} +import twitter4j.auth.{Authorization, NullAuthorization} import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Seconds, StreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream class TwitterStreamSuite extends SparkFunSuite with BeforeAndAfter with Logging { diff --git a/external/zeromq/pom.xml b/external/zeromq/pom.xml index 02d6b8128157..a72598844907 100644 --- a/external/zeromq/pom.xml +++ b/external/zeromq/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala index 588e6bac7b14..506ba8782d3d 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQReceiver.scala @@ -19,12 +19,11 @@ package org.apache.spark.streaming.zeromq import scala.reflect.ClassTag -import akka.actor.Actor import akka.util.ByteString import akka.zeromq._ import org.apache.spark.Logging -import org.apache.spark.streaming.receiver.ActorHelper +import org.apache.spark.streaming.receiver.ActorReceiver /** * A receiver to subscribe to ZeroMQ stream. @@ -33,7 +32,7 @@ private[streaming] class ZeroMQReceiver[T: ClassTag]( publisherUrl: String, subscribe: Subscribe, bytesToObjects: Seq[ByteString] => Iterator[T]) - extends Actor with ActorHelper with Logging { + extends ActorReceiver with Logging { override def preStart(): Unit = { ZeroMQExtension(context.system) diff --git a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala index 4ea218eaa4de..63cd8a2721f0 100644 --- a/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala +++ b/external/zeromq/src/main/scala/org/apache/spark/streaming/zeromq/ZeroMQUtils.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.zeromq -import scala.reflect.ClassTag import scala.collection.JavaConverters._ +import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} import akka.util.ByteString diff --git a/extras/java8-tests/pom.xml b/extras/java8-tests/pom.xml index 4ce90e75fd35..4dfe3b654df1 100644 --- a/extras/java8-tests/pom.xml +++ b/extras/java8-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java index 14975265ab2c..27d494ce355f 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/Java8APISuite.java @@ -24,7 +24,6 @@ import scala.Tuple2; import com.google.common.collect.Iterables; -import com.google.common.base.Optional; import com.google.common.io.Files; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.Text; @@ -38,6 +37,7 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.function.*; import org.apache.spark.util.Utils; diff --git a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java index 89e0c7fdf7ee..604d818ef194 100644 --- a/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java +++ b/extras/java8-tests/src/test/java/org/apache/spark/streaming/Java8APISuite.java @@ -22,7 +22,6 @@ import scala.Tuple2; -import com.google.common.base.Optional; import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.junit.Assert; @@ -439,9 +438,14 @@ public void testPairFlatMap() { */ public static > void assertOrderInvariantEquals( List> expected, List> actual) { - expected.forEach((List list) -> Collections.sort(list)); - actual.forEach((List list) -> Collections.sort(list)); - Assert.assertEquals(expected, actual); + expected.forEach(list -> Collections.sort(list)); + List> sortedActual = new ArrayList<>(); + actual.forEach(list -> { + List sortedList = new ArrayList<>(list); + Collections.sort(sortedList); + sortedActual.add(sortedList); + }); + Assert.assertEquals(expected, sortedActual); } @Test diff --git a/extras/kinesis-asl-assembly/pom.xml b/extras/kinesis-asl-assembly/pom.xml index 61ba4787fbf9..601080c2e6fb 100644 --- a/extras/kinesis-asl-assembly/pom.xml +++ b/extras/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/extras/kinesis-asl/pom.xml b/extras/kinesis-asl/pom.xml index 519a920279c9..20e2c5e0ffbe 100644 --- a/extras/kinesis-asl/pom.xml +++ b/extras/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml @@ -59,11 +59,6 @@ amazon-kinesis-client ${aws.kinesis.client.version} - - com.amazonaws - aws-java-sdk - ${aws.java.sdk.version} - com.amazonaws amazon-kinesis-producer diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 691c1790b207..3996f168e69e 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -70,26 +70,26 @@ class KinesisBackedBlockRDDPartition( */ private[kinesis] class KinesisBackedBlockRDD[T: ClassTag]( - @transient sc: SparkContext, + sc: SparkContext, val regionName: String, val endpointUrl: String, - @transient blockIds: Array[BlockId], + @transient private val _blockIds: Array[BlockId], @transient val arrayOfseqNumberRanges: Array[SequenceNumberRanges], - @transient isBlockIdValid: Array[Boolean] = Array.empty, + @transient private val isBlockIdValid: Array[Boolean] = Array.empty, val retryTimeoutMs: Int = 10000, val messageHandler: Record => T = KinesisUtils.defaultMessageHandler _, val awsCredentialsOption: Option[SerializableAWSCredentials] = None - ) extends BlockRDD[T](sc, blockIds) { + ) extends BlockRDD[T](sc, _blockIds) { - require(blockIds.length == arrayOfseqNumberRanges.length, + require(_blockIds.length == arrayOfseqNumberRanges.length, "Number of blockIds is not equal to the number of sequence number ranges") override def isValid(): Boolean = true override def getPartitions: Array[Partition] = { - Array.tabulate(blockIds.length) { i => + Array.tabulate(_blockIds.length) { i => val isValid = if (isBlockIdValid.length == 0) true else isBlockIdValid(i) - new KinesisBackedBlockRDDPartition(i, blockIds(i), isValid, arrayOfseqNumberRanges(i)) + new KinesisBackedBlockRDDPartition(i, _blockIds(i), isValid, arrayOfseqNumberRanges(i)) } } diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala index 72ab6357a53b..3321c7527edb 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisInputDStream.scala @@ -30,7 +30,7 @@ import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.streaming.{Duration, StreamingContext, Time} private[kinesis] class KinesisInputDStream[T: ClassTag]( - @transient _ssc: StreamingContext, + _ssc: StreamingContext, streamName: String, endpointUrl: String, regionName: String, diff --git a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 80edda59e171..abb9b6cd32f1 100644 --- a/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/extras/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -185,6 +185,7 @@ private[kinesis] class KinesisReceiver[T]( workerThread.setName(s"Kinesis Receiver ${streamId}") workerThread.setDaemon(true) workerThread.start() + logInfo(s"Started receiver with workerId $workerId") } diff --git a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 6fe24fe81165..78263f9dca65 100644 --- a/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/extras/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -137,8 +137,8 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun // Verify that the generated KinesisBackedBlockRDD has the all the right information val blockInfos = Seq(blockInfo1, blockInfo2) val nonEmptyRDD = kinesisStream.createBlockRDD(time, blockInfos) - nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] - val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] + nonEmptyRDD shouldBe a [KinesisBackedBlockRDD[_]] + val kinesisRDD = nonEmptyRDD.asInstanceOf[KinesisBackedBlockRDD[_]] assert(kinesisRDD.regionName === dummyRegionName) assert(kinesisRDD.endpointUrl === dummyEndpointUrl) assert(kinesisRDD.retryTimeoutMs === batchDuration.milliseconds) @@ -203,7 +203,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun Seconds(10), StorageLevel.MEMORY_ONLY, addFive, awsCredentials.getAWSAccessKeyId, awsCredentials.getAWSSecretKey) - stream shouldBe a [ReceiverInputDStream[Int]] + stream shouldBe a [ReceiverInputDStream[_]] val collected = new mutable.HashSet[Int] with mutable.SynchronizedSet[Int] stream.foreachRDD { rdd => @@ -272,7 +272,7 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun times.foreach { time => val (arrayOfSeqNumRanges, data) = collectedData(time) val rdd = recoveredKinesisStream.getOrCompute(time).get.asInstanceOf[RDD[Array[Byte]]] - rdd shouldBe a [KinesisBackedBlockRDD[Array[Byte]]] + rdd shouldBe a [KinesisBackedBlockRDD[_]] // Verify the recovered sequence ranges val kRdd = rdd.asInstanceOf[KinesisBackedBlockRDD[Array[Byte]]] diff --git a/extras/spark-ganglia-lgpl/pom.xml b/extras/spark-ganglia-lgpl/pom.xml index 87a4f05a0596..b046a10a04d5 100644 --- a/extras/spark-ganglia-lgpl/pom.xml +++ b/extras/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 8cd66c5b2e82..388a0ef06a2b 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala index ee7302a1edbf..45526bf062fa 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/EdgeRDD.scala @@ -24,12 +24,11 @@ import org.apache.spark.Dependency import org.apache.spark.Partition import org.apache.spark.SparkContext import org.apache.spark.TaskContext -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - import org.apache.spark.graphx.impl.EdgePartition import org.apache.spark.graphx.impl.EdgePartitionBuilder import org.apache.spark.graphx.impl.EdgeRDDImpl +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * `EdgeRDD[ED, VD]` extends `RDD[Edge[ED]]` by storing the edges in columnar format on each diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala deleted file mode 100644 index 563c948957ec..000000000000 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphKryoRegistrator.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.graphx - -import com.esotericsoftware.kryo.Kryo - -import org.apache.spark.serializer.KryoRegistrator -import org.apache.spark.util.BoundedPriorityQueue -import org.apache.spark.util.collection.BitSet - -import org.apache.spark.graphx.impl._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap -import org.apache.spark.util.collection.OpenHashSet - -/** - * Registers GraphX classes with Kryo for improved performance. - */ -@deprecated("Register GraphX classes with Kryo using GraphXUtils.registerKryoClasses", "1.2.0") -class GraphKryoRegistrator extends KryoRegistrator { - - def registerClasses(kryo: Kryo) { - kryo.register(classOf[Edge[Object]]) - kryo.register(classOf[(VertexId, Object)]) - kryo.register(classOf[EdgePartition[Object, Object]]) - kryo.register(classOf[BitSet]) - kryo.register(classOf[VertexIdToIndexMap]) - kryo.register(classOf[VertexAttributeBlock[Object]]) - kryo.register(classOf[PartitionStrategy]) - kryo.register(classOf[BoundedPriorityQueue[Object]]) - kryo.register(classOf[EdgeDirection]) - kryo.register(classOf[GraphXPrimitiveKeyOpenHashMap[VertexId, Int]]) - kryo.register(classOf[OpenHashSet[Int]]) - kryo.register(classOf[OpenHashSet[Long]]) - } -} diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala index 21187be7678a..1672f7d27c40 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphLoader.scala @@ -17,9 +17,9 @@ package org.apache.spark.graphx -import org.apache.spark.storage.StorageLevel import org.apache.spark.{Logging, SparkContext} import org.apache.spark.graphx.impl.{EdgePartitionBuilder, GraphImpl} +import org.apache.spark.storage.StorageLevel /** * Provides utilities for loading [[Graph]]s from files. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index 9827dfab8684..fc36e12dd2ae 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -22,9 +22,8 @@ import scala.util.Random import org.apache.spark.SparkException import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - import org.apache.spark.graphx.lib._ +import org.apache.spark.rdd.RDD /** * Contains additional functionality for [[Graph]]. All operations are expressed in terms of the diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala index 2cb07937eaa2..8ec33e140000 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphXUtils.scala @@ -18,12 +18,10 @@ package org.apache.spark.graphx import org.apache.spark.SparkConf - import org.apache.spark.graphx.impl._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -import org.apache.spark.util.collection.{OpenHashSet, BitSet} import org.apache.spark.util.BoundedPriorityQueue +import org.apache.spark.util.collection.{BitSet, OpenHashSet} object GraphXUtils { /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala index 2ca60d51f833..b90886031009 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala @@ -18,8 +18,8 @@ package org.apache.spark.graphx import scala.reflect.ClassTag -import org.apache.spark.Logging +import org.apache.spark.Logging /** * Implements a Pregel-like bulk-synchronous message-passing API. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala index 1ef7a78fbcd0..53a9f92b82bc 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/VertexRDD.scala @@ -21,13 +21,12 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.SparkContext._ -import org.apache.spark.rdd._ -import org.apache.spark.storage.StorageLevel - import org.apache.spark.graphx.impl.RoutingTablePartition import org.apache.spark.graphx.impl.ShippableVertexPartition import org.apache.spark.graphx.impl.VertexAttributeBlock import org.apache.spark.graphx.impl.VertexRDDImpl +import org.apache.spark.rdd._ +import org.apache.spark.storage.StorageLevel /** * Extends `RDD[(VertexId, VD)]` by ensuring that there is only one entry for each vertex and by diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala index 906d42328fcb..b122969b817f 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgePartitionBuilder.scala @@ -21,7 +21,7 @@ import scala.reflect.ClassTag import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap -import org.apache.spark.util.collection.{SortDataFormat, Sorter, PrimitiveVector} +import org.apache.spark.util.collection.{PrimitiveVector, SortDataFormat, Sorter} /** Constructs an EdgePartition from scratch. */ private[graphx] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala index c88b2f65a86c..6e153b7e803e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/EdgeRDDImpl.scala @@ -19,12 +19,11 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} -import org.apache.spark.{OneToOneDependency, HashPartitioner} +import org.apache.spark.{HashPartitioner, OneToOneDependency} +import org.apache.spark.graphx._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.graphx._ - class EdgeRDDImpl[ED: ClassTag, VD: ClassTag] private[graphx] ( @transient override val partitionsRDD: RDD[(PartitionID, EdgePartition[ED, VD])], val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala index da95314440d8..81182adbc638 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/GraphImpl.scala @@ -21,12 +21,11 @@ import scala.reflect.{classTag, ClassTag} import org.apache.spark.HashPartitioner import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{RDD, ShuffledRDD} -import org.apache.spark.storage.StorageLevel import org.apache.spark.graphx._ import org.apache.spark.graphx.impl.GraphImpl._ import org.apache.spark.graphx.util.BytecodeUtils - +import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.storage.StorageLevel /** * An implementation of [[org.apache.spark.graphx.Graph]] to support computation on graphs. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala index 1df86449fa0c..f79f9c7ec448 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ReplicatedVertexView.scala @@ -20,9 +20,8 @@ package org.apache.spark.graphx.impl import scala.reflect.{classTag, ClassTag} import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.RDD - import org.apache.spark.graphx._ +import org.apache.spark.rdd.RDD /** * Manages shipping vertex attributes to the edge partitions of an diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala index 4f1260a5a67b..3fd76902af64 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/RoutingTablePartition.scala @@ -20,15 +20,13 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag import org.apache.spark.Partitioner +import org.apache.spark.graphx._ +import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage +import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap import org.apache.spark.rdd.RDD import org.apache.spark.rdd.ShuffledRDD import org.apache.spark.util.collection.{BitSet, PrimitiveVector} -import org.apache.spark.graphx._ -import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap - -import org.apache.spark.graphx.impl.RoutingTablePartition.RoutingTableMessage - private[graphx] object RoutingTablePartition { /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala index aa320088f208..3f203c4eca48 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/ShippableVertexPartition.scala @@ -19,10 +19,9 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag -import org.apache.spark.util.collection.{BitSet, PrimitiveVector} - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.{BitSet, PrimitiveVector} /** Stores vertex attributes to ship to an edge partition. */ private[graphx] diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala index fbe53acfc32a..4512bc17399a 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartition.scala @@ -19,10 +19,9 @@ package org.apache.spark.graphx.impl import scala.reflect.ClassTag -import org.apache.spark.util.collection.BitSet - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet private[graphx] object VertexPartition { /** Construct a `VertexPartition` from the given vertices. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala index 5ad6390a56c4..8d608c99b1a1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBase.scala @@ -20,10 +20,9 @@ package org.apache.spark.graphx.impl import scala.language.higherKinds import scala.reflect.ClassTag -import org.apache.spark.util.collection.BitSet - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet private[graphx] object VertexPartitionBase { /** diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala index b90f9fa32705..f508b483a2f1 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexPartitionBaseOps.scala @@ -22,10 +22,9 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.apache.spark.Logging -import org.apache.spark.util.collection.BitSet - import org.apache.spark.graphx._ import org.apache.spark.graphx.util.collection.GraphXPrimitiveKeyOpenHashMap +import org.apache.spark.util.collection.BitSet /** * An class containing additional operations for subclasses of VertexPartitionBase that provide diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala index 7f4e7e9d79d6..d5accdfbf7e9 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/VertexRDDImpl.scala @@ -21,11 +21,10 @@ import scala.reflect.ClassTag import org.apache.spark._ import org.apache.spark.SparkContext._ +import org.apache.spark.graphx._ import org.apache.spark.rdd._ import org.apache.spark.storage.StorageLevel -import org.apache.spark.graphx._ - class VertexRDDImpl[VD] private[graphx] ( @transient val partitionsRDD: RDD[ShippableVertexPartition[VD]], val targetStorageLevel: StorageLevel = StorageLevel.MEMORY_ONLY) diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala index a3ad6bed1c99..7a53eca7eac6 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/LabelPropagation.scala @@ -18,6 +18,7 @@ package org.apache.spark.graphx.lib import scala.reflect.ClassTag + import org.apache.spark.graphx._ /** Label Propagation algorithm. */ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala index 52b237fc1509..35b26c998e1d 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/PageRank.scala @@ -17,8 +17,8 @@ package org.apache.spark.graphx.lib -import scala.reflect.ClassTag import scala.language.postfixOps +import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.graphx._ diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala index 9cb24ed080e1..16300e074079 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/SVDPlusPlus.scala @@ -21,8 +21,8 @@ import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.rdd._ import org.apache.spark.graphx._ +import org.apache.spark.rdd._ /** Implementation of SVD++ algorithm. */ object SVDPlusPlus { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala index 179f2843818e..f0c6bcb93445 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ShortestPaths.scala @@ -17,9 +17,10 @@ package org.apache.spark.graphx.lib -import org.apache.spark.graphx._ import scala.reflect.ClassTag +import org.apache.spark.graphx._ + /** * Computes shortest paths to the given set of landmark vertices, returning a graph where each * vertex attribute is a map containing the shortest-path distance to each reachable landmark. diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala index 989e22630526..280b6c5578fe 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/GraphGenerators.scala @@ -23,14 +23,11 @@ import scala.reflect.ClassTag import scala.util._ import org.apache.spark._ -import org.apache.spark.serializer._ -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.graphx._ -import org.apache.spark.graphx.Graph -import org.apache.spark.graphx.Edge import org.apache.spark.graphx.impl.GraphImpl +import org.apache.spark.rdd.RDD +import org.apache.spark.serializer._ /** A collection of graph generating functions. */ object GraphGenerators extends Logging { diff --git a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala index e2754ea699da..972237da1cb2 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/util/collection/GraphXPrimitiveKeyOpenHashMap.scala @@ -17,10 +17,10 @@ package org.apache.spark.graphx.util.collection -import org.apache.spark.util.collection.OpenHashSet - import scala.reflect._ +import org.apache.spark.util.collection.OpenHashSet + /** * A fast hash map implementation for primitive, non-null keys. This hash map supports * insertions and updates, but not deletions. This map is about an order of magnitude diff --git a/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala new file mode 100644 index 000000000000..bff9f328d490 --- /dev/null +++ b/graphx/src/test/scala/org/apache/spark/graphx/GraphLoaderSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.graphx + +import java.io.File +import java.io.FileOutputStream +import java.io.OutputStreamWriter + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.Utils + +class GraphLoaderSuite extends SparkFunSuite with LocalSparkContext { + + test("GraphLoader.edgeListFile") { + withSpark { sc => + val tmpDir = Utils.createTempDir() + val graphFile = new File(tmpDir.getAbsolutePath, "graph.txt") + val writer = new OutputStreamWriter(new FileOutputStream(graphFile)) + for (i <- (1 until 101)) writer.write(s"$i 0\n") + writer.close() + try { + val graph = GraphLoader.edgeListFile(sc, tmpDir.getAbsolutePath) + val neighborAttrSums = graph.aggregateMessages[Int]( + ctx => ctx.sendToDst(ctx.srcAttr), + _ + _) + assert(neighborAttrSums.collect.toSet === Set((0: VertexId, 100))) + } finally { + Utils.deleteRecursively(tmpDir) + } + } + } +} diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala index 7435647c6d9e..a73dfd219ea4 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgePartitionSuite.scala @@ -21,11 +21,10 @@ import scala.reflect.ClassTag import scala.util.Random import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.graphx._ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.graphx._ - class EdgePartitionSuite extends SparkFunSuite { def makeEdgePartition[A: ClassTag](xs: Iterable[(Int, Int, A)]): EdgePartition[A, Int] = { diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala index 1203f8959f50..0fb8451fdcab 100644 --- a/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala +++ b/graphx/src/test/scala/org/apache/spark/graphx/impl/VertexPartitionSuite.scala @@ -18,11 +18,10 @@ package org.apache.spark.graphx.impl import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.graphx._ import org.apache.spark.serializer.JavaSerializer import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.graphx._ - class VertexPartitionSuite extends SparkFunSuite { test("isDefined, filter") { diff --git a/launcher/pom.xml b/launcher/pom.xml index 5739bfc16958..135866cea2e7 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 55fe156cf665..68af14397ba8 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -146,7 +146,7 @@ List buildClassPath(String appClassPath) throws IOException { boolean isTesting = "1".equals(getenv("SPARK_TESTING")); if (prependClasses || isTesting) { String scala = getScalaVersion(); - List projects = Arrays.asList("core", "repl", "mllib", "bagel", "graphx", + List projects = Arrays.asList("core", "repl", "mllib", "graphx", "streaming", "tools", "sql/catalyst", "sql/core", "sql/hive", "sql/hive-thriftserver", "yarn", "launcher", "network/common", "network/shuffle", "network/yarn"); if (prependClasses) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index d099ee9aa9da..414ffc2c84e5 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -293,9 +293,7 @@ private class ServerConnection extends LauncherConnection { protected void handle(Message msg) throws IOException { try { if (msg instanceof Hello) { - synchronized (timeout) { - timeout.cancel(); - } + timeout.cancel(); timeout = null; Hello hello = (Hello) msg; ChildProcAppHandle handle = pending.remove(hello.secret); diff --git a/launcher/src/main/java/org/apache/spark/launcher/Main.java b/launcher/src/main/java/org/apache/spark/launcher/Main.java index a4e3acc674f3..e751e948e356 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/Main.java +++ b/launcher/src/main/java/org/apache/spark/launcher/Main.java @@ -151,7 +151,7 @@ private static class MainClassOptionParser extends SparkSubmitOptionParser { @Override protected boolean handle(String opt, String value) { - if (opt == CLASS) { + if (CLASS.equals(opt)) { className = value; } return false; diff --git a/make-distribution.sh b/make-distribution.sh index e64ceb802464..327659298e4d 100755 --- a/make-distribution.sh +++ b/make-distribution.sh @@ -58,7 +58,7 @@ while (( "$#" )); do --hadoop) echo "Error: '--hadoop' is no longer supported:" echo "Error: use Maven profiles and options -Dhadoop.version and -Dyarn.version instead." - echo "Error: Related profiles include hadoop-1, hadoop-2.2, hadoop-2.3 and hadoop-2.4." + echo "Error: Related profiles include hadoop-2.2, hadoop-2.3 and hadoop-2.4." exit_with_usage ;; --with-yarn) @@ -159,7 +159,7 @@ fi # Build uber fat JAR cd "$SPARK_HOME" -export MAVEN_OPTS="-Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m" +export MAVEN_OPTS="${MAVEN_OPTS:--Xmx2g -XX:MaxPermSize=512M -XX:ReservedCodeCacheSize=512m}" # Store the command as an array because $MVN variable might have spaces in it. # Normal quoting tricks don't work. @@ -212,7 +212,6 @@ cp "$SPARK_HOME/README.md" "$DISTDIR" cp -r "$SPARK_HOME/bin" "$DISTDIR" cp -r "$SPARK_HOME/python" "$DISTDIR" cp -r "$SPARK_HOME/sbin" "$DISTDIR" -cp -r "$SPARK_HOME/ec2" "$DISTDIR" # Copy SparkR if it exists if [ -d "$SPARK_HOME"/R/lib/SparkR ]; then mkdir -p "$DISTDIR"/R/lib diff --git a/mllib/pom.xml b/mllib/pom.xml index df50aca1a3f7..42af2b8b3e41 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 4b2b3f8489fd..32570a16e670 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -26,11 +26,9 @@ import org.apache.hadoop.fs.Path import org.json4s._ import org.json4s.jackson.JsonMethods._ -import org.apache.spark.{SparkContext, Logging} -import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} +import org.apache.spark.{Logging, SparkContext} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.ml.param.{Param, ParamMap, Params} -import org.apache.spark.ml.util.MLReader -import org.apache.spark.ml.util.MLWriter import org.apache.spark.ml.util._ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -167,6 +165,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M } override def transformSchema(schema: StructType): StructType = { + validateParams() val theStages = $(stages) require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") @@ -298,6 +297,7 @@ class PipelineModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index e0dcd427fae2..d1388b5e2eb5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -24,9 +24,9 @@ import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} /** * (private[ml]) Trait for parameters for prediction (regression and classification). @@ -46,6 +46,7 @@ private[ml] trait PredictorParams extends Params schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { + validateParams() // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) if (fitting) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 3c7bcf7590e6..fdce273193b7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -103,6 +103,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains($(outputCol))) { @@ -115,8 +116,8 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) - dataset.withColumn($(outputCol), - callUDF(this.createTransformFunc, outputDataType, dataset($(inputCol)))) + val transformUDF = udf(this.createTransformFunc, outputDataType) + dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) } override def copy(extra: ParamMap): T = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index b5258ff34847..d02806a6ea22 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml.ann -import breeze.linalg.{*, DenseMatrix => BDM, DenseVector => BDV, Vector => BV, axpy => Baxpy, - sum => Bsum} +import breeze.linalg.{*, axpy => Baxpy, sum => Bsum, DenseMatrix => BDM, DenseVector => BDV, + Vector => BV} import breeze.numerics.{log => Blog, sigmoid => Bsigmoid} import org.apache.spark.mllib.linalg.{Vector, Vectors} diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala index a7c10333c0d5..521d209a8f0e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.attribute import scala.annotation.varargs import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.types.{DoubleType, NumericType, Metadata, MetadataBuilder, StructField} +import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, NumericType, StructField} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala index 7ac21d7d563f..f6964054db83 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/attribute/package.scala @@ -17,8 +17,8 @@ package org.apache.spark.ml -import org.apache.spark.sql.DataFrame import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} +import org.apache.spark.sql.DataFrame /** * ==ML attributes== diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 45df557a8990..8186afc17a53 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} import org.apache.spark.ml.param.shared.HasRawPredictionCol import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.linalg.{Vector, VectorUDT} @@ -26,7 +26,6 @@ import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} - /** * (private[spark]) Params for classification. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index cda2bca58c50..74bf07c3f1ff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -33,7 +33,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.loss.{LogLoss => OldLogLoss, Loss => OldLoss} import org.apache.spark.mllib.tree.model.{GradientBoostedTreesModel => OldGBTModel} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.DoubleType diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala index a691aa005ef5..719d1076fee8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala @@ -20,12 +20,12 @@ package org.apache.spark.ml.classification import scala.collection.JavaConverters._ import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.param.shared.{HasTol, HasMaxIter, HasSeed} -import org.apache.spark.ml.{PredictorParams, PredictionModel, Predictor} -import org.apache.spark.ml.param.{IntParam, ParamValidators, IntArrayParam, ParamMap} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} +import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer} +import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol} import org.apache.spark.ml.util.Identifiable -import org.apache.spark.ml.ann.{FeedForwardTrainer, FeedForwardTopology} -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.sql.DataFrame diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index fdd1851ae550..865614aa5c8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.SchemaUtils -import org.apache.spark.mllib.linalg.{DenseVector, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors, VectorUDT} import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, StructType} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index d6d85ad2533a..f7d662df2fe5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, RandomForestParams, TreeClassifierParams, TreeEnsembleModel} +import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.model.{RandomForestModel => OldRandomForestModel} diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 71e968497500..dc6d5d928097 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -20,15 +20,15 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params} +import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel} import org.apache.spark.mllib.linalg.{Vector, VectorUDT} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} -import org.apache.spark.sql.{DataFrame, Row} /** * Common params for KMeans and KMeansModel @@ -80,6 +80,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 830510b1698d..99383e77f7eb 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -18,19 +18,20 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path + import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasSeed, HasMaxIter} import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasFeaturesCol, HasMaxIter, HasSeed} import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedLDAModel, EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel, LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel, OnlineLDAOptimizer => OldOnlineLDAOptimizer} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors, Matrix, Vector} +import org.apache.spark.mllib.linalg.{Matrix, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions.{col, monotonicallyIncreasingId, udf} import org.apache.spark.sql.types.StructType @@ -262,6 +263,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index bfb70963b151..f71726f110e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -39,8 +39,7 @@ class BinaryClassificationEvaluator @Since("1.4.0") (@Since("1.4.0") override va def this() = this(Identifiable.randomUID("binEval")) /** - * param for metric name in evaluation - * Default: areaUnderROC + * param for metric name in evaluation (supports `"areaUnderROC"` (default), `"areaUnderPR"`) * @group param */ @Since("1.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index c44db0ec595e..a921153b9474 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.evaluation import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.param.{ParamMap, ParamValidators, Param} +import org.apache.spark.ml.param.{Param, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol} -import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, SchemaUtils, Identifiable} +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable, SchemaUtils} import org.apache.spark.mllib.evaluation.MulticlassMetrics -import org.apache.spark.sql.{Row, DataFrame} +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.types.DoubleType /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 63c06581482e..544cf05a30d4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.BinaryAttribute import org.apache.spark.ml.param._ @@ -72,6 +72,7 @@ final class Binarizer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 324353a96afb..0c75317d8270 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import java.{util => ju} import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Model import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.param._ @@ -86,6 +86,7 @@ final class Bucketizer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index dfec03828f4b..7b565ef3ed92 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -88,6 +88,7 @@ final class ChiSqSelector(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) @@ -135,6 +136,7 @@ final class ChiSqSelectorModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) val newField = prepOutputField(schema) val outputFields = schema.fields :+ newField diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index b9e2144c0ad4..10dcda2382f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vectors, VectorUDT} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.functions._ @@ -70,6 +70,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala index 6bed72164a1d..a6f878151de7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/DCT.scala @@ -19,11 +19,11 @@ package org.apache.spark.ml.feature import edu.emory.mathcs.jtransforms.dct._ -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.BooleanParam import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.types.DataType /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index a359cb8f37ec..07a12df32035 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -19,7 +19,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.Experimental import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{Vector, VectorUDT} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 9e15835429a3..8af00581f7e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} @@ -69,6 +69,7 @@ class HashingTF(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], s"The input column must be ArrayType, but got $inputType.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index f7b0f29a27c2..9e7eee4f2998 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -52,6 +52,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala index 2181119f04a5..7d2a1da990fc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.ml.Transformer -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index c2866f5eceff..ad0458d0d0e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature - import org.apache.hadoop.fs.Path import org.apache.spark.annotation.{Experimental, Since} @@ -25,7 +24,7 @@ import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{DoubleParam, ParamMap, Params} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.stat.Statistics import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -60,6 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala index 65414ecbefbb..f8bc7e3f0c03 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala index c2d514fd9629..a603b3f83320 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Normalizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param.{DoubleParam, ParamValidators} import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index d70164eaf022..342540418fa8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param._ @@ -66,6 +66,7 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val outputColName = $(outputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index 53d33ea2b8f7..7020397f3b06 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -77,6 +77,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -130,6 +131,7 @@ class PCAModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -167,14 +169,37 @@ object PCAModel extends MLReadable[PCAModel] { private val className = classOf[PCAModel].getName + /** + * Loads a [[PCAModel]] from data located at the input path. Note that the model includes an + * `explainedVariance` member that is not recorded by Spark 1.6 and earlier. A model + * can be loaded from such older data but will have an empty vector for + * `explainedVariance`. + * + * @param path path to serialized model data + * @return a [[PCAModel]] + */ override def load(path: String): PCAModel = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + + // explainedVariance field is not present in Spark <= 1.6 + val versionRegex = "([0-9]+)\\.([0-9]+).*".r + val hasExplainedVariance = metadata.sparkVersion match { + case versionRegex(major, minor) => + (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)) + case _ => false + } + val dataPath = new Path(path, "data").toString - val Row(pc: DenseMatrix, explainedVariance: DenseVector) = - sqlContext.read.parquet(dataPath) - .select("pc", "explainedVariance") - .head() - val model = new PCAModel(metadata.uid, pc, explainedVariance) + val model = if (hasExplainedVariance) { + val Row(pc: DenseMatrix, explainedVariance: DenseVector) = + sqlContext.read.parquet(dataPath) + .select("pc", "explainedVariance") + .head() + new PCAModel(metadata.uid, pc, explainedVariance) + } else { + val Row(pc: DenseMatrix) = sqlContext.read.parquet(dataPath).select("pc").head() + new PCAModel(metadata.uid, pc, Vectors.dense(Array.empty[Double]).asInstanceOf[DenseVector]) + } DefaultParamsReader.getAndSetParams(model, metadata) model } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala index 08610593fadd..42b26c8ee836 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PolynomialExpansion.scala @@ -19,9 +19,9 @@ package org.apache.spark.ml.feature import scala.collection.mutable -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer -import org.apache.spark.ml.param.{ParamMap, IntParam, ParamValidators} +import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.types.DataType diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 7bf67c6325a3..8fd0ce2f2e26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -20,14 +20,14 @@ package org.apache.spark.ml.feature import scala.collection.mutable import org.apache.spark.Logging -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntParam, _} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ -import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.util.random.XORShiftRandom /** @@ -74,6 +74,7 @@ final class QuantileDiscretizer(override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields require(inputFields.forall(_.name != $(outputCol)), diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 5c43a41bee3b..f9952434d298 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -21,8 +21,8 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.{Estimator, Model, Pipeline, PipelineModel, PipelineStage, Transformer} +import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasLabelCol} import org.apache.spark.ml.util.Identifiable @@ -146,6 +146,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R // optimistic schema; does not contain any ML attributes override def transformSchema(schema: StructType): StructType = { + validateParams() if (hasLabelCol(schema)) { StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) } else { @@ -178,6 +179,7 @@ class RFormulaModel private[feature]( } override def transformSchema(schema: StructType): StructType = { + validateParams() checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(withFeatures)) { @@ -240,6 +242,7 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { } override def transformSchema(schema: StructType): StructType = { + validateParams() StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name))) } @@ -288,6 +291,7 @@ private class VectorAttributeRewriter( } override def transformSchema(schema: StructType): StructType = { + validateParams() StructType( schema.fields.filter(_.name != vectorCol) ++ schema.fields.filter(_.name == vectorCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index c09f4d076c96..af6494b234ce 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -18,11 +18,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Since, Experimental} -import org.apache.spark.ml.param.{ParamMap, Param} +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.param.{Param, ParamMap} import org.apache.spark.ml.Transformer import org.apache.spark.ml.util._ -import org.apache.spark.sql.{SQLContext, DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.types.StructType /** @@ -74,6 +74,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { + validateParams() val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index d76a9c6275e6..6a0b6c240ec6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -94,6 +94,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -143,6 +144,7 @@ class StandardScalerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 318808596dc6..b93c9ed382bd 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.param.{BooleanParam, ParamMap, StringArrayParam} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} @@ -145,6 +145,7 @@ class StopWordsRemover(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 5c40c35eeaa4..912bd95a2ec7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -39,6 +39,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], @@ -272,6 +273,7 @@ class IndexToString private[ml] (override val uid: String) final def getLabels: Array[String] = $(labels) override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType.isInstanceOf[NumericType], diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala index 8ad7bbedaab5..8456a0e91580 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.UnaryTransformer import org.apache.spark.ml.param._ import org.apache.spark.ml.util._ diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index 801096fed27b..0b215659b367 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder import org.apache.spark.SparkException -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute, UnresolvedAttribute} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -106,6 +106,7 @@ class VectorAssembler(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColNames = $(inputCols) val outputColName = $(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index a637a6f2881d..2a5268406ddf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -126,6 +126,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod } override def transformSchema(schema: StructType): StructType = { + validateParams() // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). val dataType = new VectorUDT @@ -354,6 +355,7 @@ class VectorIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val dataType = new VectorUDT require(isDefined(inputCol), s"VectorIndexerModel requires input column parameter: $inputCol") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 5410a50bc2e4..300d63bd3a0d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.{Since, Experimental} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.{Attribute, AttributeGroup} -import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.param.{IntArrayParam, ParamMap, StringArrayParam} +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg._ import org.apache.spark.sql.DataFrame @@ -139,6 +139,7 @@ final class VectorSlicer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) if (schema.fieldNames.contains($(outputCol))) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index f105a983a34f..2b6b3c3a0fc5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -26,7 +26,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -92,6 +92,7 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ee7e89edd879..c0546695e487 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -859,8 +859,12 @@ final class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) * Filters this param map for the given parent. */ def filter(parent: Params): ParamMap = { - val filtered = map.filterKeys(_.parent == parent) - new ParamMap(filtered.asInstanceOf[mutable.Map[Param[Any], Any]]) + // Don't use filterKeys because mutable.Map#filterKeys + // returns the instance of collections.Map, not mutable.Map. + // Otherwise, we get ClassCastException. + // Not using filterKeys also avoid SI-6654 + val filtered = map.filter { case (k, _) => k.parent == parent.uid } + new ParamMap(filtered) } /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index c7bca1243092..4aff749ff75a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -44,6 +44,7 @@ private[shared] object SharedParamsCodeGen { " probabilities. Note: Not all models output well-calibrated probability estimates!" + " These probabilities should be treated as confidences, not precise probabilities", Some("\"probability\"")), + ParamDesc[String]("varianceCol", "Column name for the biased sample variance of prediction"), ParamDesc[Double]("threshold", "threshold in binary classification prediction, in range [0, 1]", Some("0.5"), isValid = "ParamValidators.inRange(0, 1)", finalMethods = false), diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index cb2a060a34dd..c088c16d1b05 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -138,6 +138,21 @@ private[ml] trait HasProbabilityCol extends Params { final def getProbabilityCol: String = $(probabilityCol) } +/** + * Trait for shared param varianceCol. + */ +private[ml] trait HasVarianceCol extends Params { + + /** + * Param for Column name for the biased sample variance of prediction. + * @group param + */ + final val varianceCol: Param[String] = new Param[String](this, "varianceCol", "Column name for the biased sample variance of prediction") + + /** @group getParam */ + final def getVarianceCol: String = $(varianceCol) +} + /** * Trait for shared param threshold (default: 0.5). */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 4d82b90bfdf2..551e75dc0a02 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.api.r +import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.ml.attribute._ -import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} +import org.apache.spark.ml.feature.RFormula import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} -import org.apache.spark.ml.{Pipeline, PipelineModel} import org.apache.spark.sql.DataFrame private[r] object SparkRWrappers { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index b798aa1fab76..472c1854d3d1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -31,7 +31,7 @@ import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.apache.spark.{Logging, Partitioner} -import org.apache.spark.annotation.{Since, DeveloperApi, Experimental} +import org.apache.spark.annotation.{DeveloperApi, Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -162,6 +162,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType @@ -213,6 +214,7 @@ class ALSModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index aedfb48058dc..e8a1ff2278a9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -23,18 +23,18 @@ import breeze.linalg.{DenseVector => BDV} import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS} import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel -import org.apache.spark.{Logging, SparkException} /** * Params for accelerated failure time (AFT) regression. @@ -99,6 +99,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params protected def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 477030d9ea3e..18c94f36387b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeRegressorParams} +import org.apache.spark.ml.tree._ import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} import org.apache.spark.mllib.linalg.Vector @@ -29,6 +29,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions._ /** * :: Experimental :: @@ -40,7 +41,7 @@ import org.apache.spark.sql.DataFrame @Experimental final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: String) extends Predictor[Vector, DecisionTreeRegressor, DecisionTreeRegressionModel] - with DecisionTreeParams with TreeRegressorParams { + with DecisionTreeRegressorParams { @Since("1.4.0") def this() = this(Identifiable.randomUID("dtr")) @@ -73,6 +74,9 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val override def setSeed(value: Long): this.type = super.setSeed(value) + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -113,7 +117,10 @@ final class DecisionTreeRegressionModel private[ml] ( override val rootNode: Node, override val numFeatures: Int) extends PredictionModel[Vector, DecisionTreeRegressionModel] - with DecisionTreeModel with Serializable { + with DecisionTreeModel with DecisionTreeRegressorParams with Serializable { + + /** @group setParam */ + def setVarianceCol(value: String): this.type = set(varianceCol, value) require(rootNode != null, "DecisionTreeClassificationModel given null rootNode, but it requires a non-null rootNode.") @@ -129,6 +136,29 @@ final class DecisionTreeRegressionModel private[ml] ( rootNode.predictImpl(features).prediction } + /** We need to update this function if we ever add other impurity measures. */ + protected def predictVariance(features: Vector): Double = { + rootNode.predictImpl(features).impurityStats.calculate() + } + + override def transform(dataset: DataFrame): DataFrame = { + transformSchema(dataset.schema, logging = true) + transformImpl(dataset) + } + + override protected def transformImpl(dataset: DataFrame): DataFrame = { + val predictUDF = udf { (features: Vector) => predict(features) } + val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) } + var output = dataset + if ($(predictionCol).nonEmpty) { + output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol)))) + } + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { + output = output.withColumn($(varianceCol), predictVarianceUDF(col($(featuresCol)))) + } + output + } + @Since("1.4.0") override def copy(extra: ParamMap): DecisionTreeRegressionModel = { copyValues(new DecisionTreeRegressionModel(uid, rootNode, numFeatures), extra).setParent(parent) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index bbb1c7ac0a51..1573bb4c1b74 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -21,18 +21,18 @@ import org.apache.hadoop.fs.Path import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.regression.IsotonicRegressionModel.IsotonicRegressionModelWriter import org.apache.spark.ml.util._ -import org.apache.spark.ml.{Estimator, Model} -import org.apache.spark.mllib.linalg.{Vector, VectorUDT, Vectors} +import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT} import org.apache.spark.mllib.regression.{IsotonicRegression => MLlibIsotonicRegression} import org.apache.spark.mllib.regression.{IsotonicRegressionModel => MLlibIsotonicRegressionModel} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.functions.{col, lit, udf} import org.apache.spark.sql.types.{DoubleType, StructType} -import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.storage.StorageLevel /** @@ -105,6 +105,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures protected[ml] def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { + validateParams() if (fitting) { SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) if (hasWeightCol) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 5e5850963edc..c54e08b2ad9a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -25,9 +25,9 @@ import breeze.stats.distributions.StudentsT import org.apache.hadoop.fs.Path import org.apache.spark.{Logging, SparkException} +import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.optim.WeightedLeastSquares -import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.param.shared._ @@ -534,7 +534,8 @@ class LinearRegressionSummary private[regression] ( @transient private val metrics = new RegressionMetrics( predictions .select(predictionCol, labelCol) - .map { case Row(pred: Double, label: Double) => (pred, label) } ) + .map { case Row(pred: Double, label: Double) => (pred, label) }, + !model.getFitIntercept) /** * Returns the explained variance regression score. diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala index c72ef2968032..cf189e8e96f9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala @@ -18,7 +18,7 @@ package org.apache.spark.ml.regression import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.ml.{PredictionModel, PredictorParams, Predictor} +import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams} /** diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index 11b9815ecc83..1bed542c4031 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, VectorUDT} import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrameReader, DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, DataFrameReader, Row, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index d89682611e3f..6507a8ad7cf3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -20,8 +20,8 @@ package org.apache.spark.ml.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict, ImpurityStats} +import org.apache.spark.mllib.tree.model.{ImpurityStats, + InformationGainStats => OldInformationGainStats, Node => OldNode, Predict => OldPredict} /** * :: DeveloperApi :: @@ -386,9 +386,9 @@ private[tree] object LearningNode { var levelsToGo = indexToLevel(nodeIndex) while (levelsToGo > 0) { if ((nodeIndex & (1 << levelsToGo - 1)) == 0) { - tmpNode = tmpNode.leftChild.asInstanceOf[LearningNode] + tmpNode = tmpNode.leftChild.get } else { - tmpNode = tmpNode.rightChild.asInstanceOf[LearningNode] + tmpNode = tmpNode.rightChild.get } levelsToGo -= 1 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 1ee01131d633..172ea5282056 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -21,7 +21,7 @@ import java.io.IOException import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 4a3b12d1440b..6e87302c7779 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -26,10 +26,10 @@ import org.apache.spark.Logging import org.apache.spark.ml.classification.DecisionTreeClassificationModel import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, +import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator, TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index b77191156f68..40ed95773e14 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -17,7 +17,7 @@ package org.apache.spark.ml.tree -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * Abstraction for Decision Tree models. diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 1da97db9277d..7443097492d8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -20,9 +20,11 @@ package org.apache.spark.ml.tree import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util.SchemaUtils import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy} import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} +import org.apache.spark.sql.types.{DoubleType, DataType, StructType} /** * Parameters for Decision Tree-based algorithms. @@ -256,6 +258,22 @@ private[ml] object TreeRegressorParams { final val supportedImpurities: Array[String] = Array("variance").map(_.toLowerCase) } +private[ml] trait DecisionTreeRegressorParams extends DecisionTreeParams + with TreeRegressorParams with HasVarianceCol { + + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { + val newSchema = super.validateAndTransformSchema(schema, fitting, featuresDataType) + if (isDefined(varianceCol) && $(varianceCol).nonEmpty) { + SchemaUtils.appendColumn(newSchema, $(varianceCol), DoubleType) + } else { + newSchema + } + } +} + /** * Parameters for Decision Tree-based ensemble algorithms. * diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 40f8857fc586..3eac616aeaf8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -19,8 +19,8 @@ package org.apache.spark.ml.tuning import com.github.fommil.netlib.F2jBLAS import org.apache.hadoop.fs.Path -import org.json4s.jackson.JsonMethods._ import org.json4s.{DefaultFormats, JObject} +import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{Experimental, Since} @@ -131,6 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { + validateParams() $(estimator).transformSchema(schema) } @@ -345,6 +346,7 @@ class CrossValidatorModel private[ml] ( @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { + validateParams() bestModel.transformSchema(schema) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index adf06302047a..4f67e8c21994 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -19,8 +19,8 @@ package org.apache.spark.ml.tuning import org.apache.spark.Logging import org.apache.spark.annotation.{Experimental, Since} -import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.{Estimator, Model} +import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} import org.apache.spark.ml.util.Identifiable import org.apache.spark.sql.DataFrame @@ -118,6 +118,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { + validateParams() $(estimator).transformSchema(schema) } @@ -172,6 +173,7 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { + validateParams() bestModel.transformSchema(schema) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 8897ab0825ac..553f25417241 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.tuning import org.apache.spark.annotation.DeveloperApi import org.apache.spark.ml.Estimator import org.apache.spark.ml.evaluation.Evaluator -import org.apache.spark.ml.param.{ParamMap, Param, Params} +import org.apache.spark.ml.param.{Param, ParamMap, Params} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala index bc6041b22173..6530870b83a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PowerIterationClusteringModelWrapper.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.api.python -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.clustering.PowerIterationClusteringModel +import org.apache.spark.rdd.RDD /** * A Wrapper of PowerIterationClusteringModel to provide helper method for Python diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 29160a10e16b..061db56c7493 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -42,18 +42,17 @@ import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.random.{RandomRDDs => RG} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ +import org.apache.spark.mllib.stat.{ + KernelDensity, MultivariateStatisticalSummary, Statistics} import org.apache.spark.mllib.stat.correlation.CorrelationNames import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.stat.test.{ChiSqTestResult, KolmogorovSmirnovTestResult} -import org.apache.spark.mllib.stat.{ - KernelDensity, MultivariateStatisticalSummary, Statistics} +import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} import org.apache.spark.mllib.tree.configuration.{Algo, BoostingStrategy, Strategy} import org.apache.spark.mllib.tree.impurity._ import org.apache.spark.mllib.tree.loss.Losses import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel, RandomForestModel} -import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest} -import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.mllib.util.LinearDataGenerator +import org.apache.spark.mllib.util.{LinearDataGenerator, MLUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SQLContext} import org.apache.spark.storage.StorageLevel @@ -1438,9 +1437,19 @@ private[spark] object SerDe extends Serializable { if (args.length != 3) { throw new PickleException("should be 3") } - new Rating(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], + new Rating(ratingsIdCheckLong(args(0)), ratingsIdCheckLong(args(1)), args(2).asInstanceOf[Double]) } + + private def ratingsIdCheckLong(obj: Object): Int = { + try { + obj.asInstanceOf[Int] + } catch { + case ex: ClassCastException => + throw new PickleException(s"Ratings id ${obj.toString} exceeds " + + s"max integer value of ${Int.MaxValue}", ex) + } + } } var initialized = false diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala index 0f55980481dc..55dfd973eb25 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.api.python import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} + import scala.collection.JavaConverters._ import org.apache.spark.SparkContext diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index 2d52abc122bf..2a7697b5a79c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.classification import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.mllib.classification.impl.GLMClassificationModel -import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.linalg.{DenseVector, Vector} +import org.apache.spark.mllib.linalg.BLAS.dot import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{DataValidators, Saveable, Loader} +import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 74d13e4f7794..5c9bc62cb09b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.clustering import breeze.linalg.{DenseVector => BreezeVector} - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -26,11 +25,11 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD -import org.apache.spark.mllib.linalg.{Vector, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.mllib.util.{MLUtils, Loader, Saveable} +import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} /** * Multivariate Gaussian Mixture Model (GMM) consisting of k Gaussians, where points diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 2895db7c9061..ca11ede4ccd4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer import org.apache.spark.Logging -import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} import org.apache.spark.mllib.util.MLUtils @@ -107,7 +107,7 @@ class KMeans private ( * Number of runs of the algorithm to execute in parallel. */ @Since("1.4.0") - @deprecated("Support for runs is deprecated. This param will have no effect in 1.7.0.", "1.6.0") + @deprecated("Support for runs is deprecated. This param will have no effect in 2.0.0.", "1.6.0") def getRuns: Int = runs /** @@ -117,7 +117,7 @@ class KMeans private ( * return the best clustering found over any run. Default: 1. */ @Since("0.8.0") - @deprecated("Support for runs is deprecated. This param will have no effect in 1.7.0.", "1.6.0") + @deprecated("Support for runs is deprecated. This param will have no effect in 2.0.0.", "1.6.0") def setRuns(runs: Int): this.type = { if (runs <= 0) { throw new IllegalArgumentException("Number of runs must be positive") @@ -301,6 +301,8 @@ class KMeans private ( contribs.iterator }.reduceByKey(mergeContribs).collectAsMap() + bcActiveCenters.unpersist(blocking = false) + // Update the cluster centers and costs for each active run for ((run, i) <- activeRuns.zipWithIndex) { var changed = false @@ -419,14 +421,17 @@ class KMeans private ( s0 } ) + + bcNewCenters.unpersist(blocking = false) preCosts.unpersist(blocking = false) + val chosen = data.zip(costs).mapPartitionsWithIndex { (index, pointsWithCosts) => val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) pointsWithCosts.flatMap { case (p, c) => val rs = (0 until runs).filter { r => rand.nextDouble() < 2.0 * c(r) * k / sumCosts(r) } - if (rs.length > 0) Some(p, rs) else None + if (rs.length > 0) Some((p, rs)) else None } }.collect() mergeNewCenters() @@ -448,6 +453,9 @@ class KMeans private ( ((r, KMeans.findClosest(bcCenters.value(r), p)._1), 1.0) } }.reduceByKey(_ + _).collectAsMap() + + bcCenters.unpersist(blocking = false) + val finalCenters = (0 until runs).par.map { r => val myCenters = centers(r).toArray val myWeights = (0 until myCenters.length).map(i => weightMap.getOrElse((r, i), 0.0)).toArray diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala index 91fa9b0d3590..26c6235fe590 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeansModel.scala @@ -23,15 +23,14 @@ import org.json4s._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.Row +import org.apache.spark.sql.{Row, SQLContext} /** * A clustering model for K-means. Each point belongs to the cluster with the closest center. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 7384d065a2ea..2fce3ff64110 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax, argtopk, normalize, sum} +import breeze.linalg.{argmax, argtopk, normalize, sum, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics.{exp, lgamma} import org.apache.hadoop.fs.Path import org.json4s.DefaultFormats diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index 17c0609800e9..c19595e6cd21 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.clustering import java.util.Random -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, all, normalize, sum} -import breeze.numerics.{trigamma, abs, exp} +import breeze.linalg.{all, normalize, sum, DenseMatrix => BDM, DenseVector => BDV} +import breeze.numerics.{abs, exp, trigamma} import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala index a9ba7b60bad0..647d37bd822c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAUtils.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.mllib.clustering -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, max, sum} +import breeze.linalg.{max, sum, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics._ /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala index b2f140e1b135..c9a96c68667a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LocalKMeans.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.clustering import scala.util.Random import org.apache.spark.Logging -import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.linalg.BLAS.{axpy, scal} +import org.apache.spark.mllib.linalg.Vectors /** * An utility object to run K-means locally. This is private to the ML package because it's used diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala index bb1804505948..2ab0920b0636 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala @@ -17,10 +17,11 @@ package org.apache.spark.mllib.clustering -import org.json4s.JsonDSL._ import org.json4s._ +import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.graphx._ @@ -30,7 +31,6 @@ import org.apache.spark.mllib.util.{Loader, MLUtils, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.{Logging, SparkContext, SparkException} /** * Model produced by [[PowerIterationClustering]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala index 80843719f50b..79d217e183c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaSparkContext._ import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala index 078fbfbe4f0e..f0779491e637 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/AreaUnderCurve.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.evaluation -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.rdd.RDD /** * Computes the area under the curve (AUC) using the trapezoidal rule. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala index cc01936dd34b..f8de4e2220c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.spark.Logging import org.apache.spark.annotation.Since -import org.apache.spark.api.java.{JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala index 1d8f4fe340fb..18c90b204a26 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RegressionMetrics.scala @@ -21,17 +21,24 @@ import org.apache.spark.annotation.Since import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} +import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.sql.DataFrame /** * Evaluator for regression. * - * @param predictionAndObservations an RDD of (prediction, observation) pairs. + * @param predictionAndObservations an RDD of (prediction, observation) pairs + * @param throughOrigin True if the regression is through the origin. For example, in linear + * regression, it will be true without fitting intercept. */ @Since("1.2.0") -class RegressionMetrics @Since("1.2.0") ( - predictionAndObservations: RDD[(Double, Double)]) extends Logging { +class RegressionMetrics @Since("2.0.0") ( + predictionAndObservations: RDD[(Double, Double)], throughOrigin: Boolean) + extends Logging { + + @Since("1.2.0") + def this(predictionAndObservations: RDD[(Double, Double)]) = + this(predictionAndObservations, false) /** * An auxiliary constructor taking a DataFrame. @@ -53,6 +60,8 @@ class RegressionMetrics @Since("1.2.0") ( ) summary } + + private lazy val SSy = math.pow(summary.normL2(0), 2) private lazy val SSerr = math.pow(summary.normL2(1), 2) private lazy val SStot = summary.variance(0) * (summary.count - 1) private lazy val SSreg = { @@ -102,9 +111,16 @@ class RegressionMetrics @Since("1.2.0") ( /** * Returns R^2^, the unadjusted coefficient of determination. * @see [[http://en.wikipedia.org/wiki/Coefficient_of_determination]] + * In case of regression through the origin, the definition of R^2^ is to be modified. + * @see J. G. Eisenhauer, Regression through the Origin. Teaching Statistics 25, 76-80 (2003) + * [[https://online.stat.psu.edu/~ajw13/stat501/SpecialTopics/Reg_thru_origin.pdf]] */ @Since("1.2.0") def r2: Double = { - 1 - SSerr / SStot + if (throughOrigin) { + 1 - SSerr / SSy + } else { + 1 - SSerr / SStot + } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala index eaa99cfe82e2..33728bf5d77e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala @@ -30,7 +30,7 @@ import org.apache.spark.mllib.stat.Statistics import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext -import org.apache.spark.sql.{SQLContext, Row} +import org.apache.spark.sql.{Row, SQLContext} /** * Chi Squared selector model. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 1f400e1430eb..a7e1b76df6a7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -24,7 +24,6 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuilder import com.github.fommil.netlib.BLAS.{getInstance => blas} - import org.json4s.DefaultFormats import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -36,9 +35,9 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd._ +import org.apache.spark.sql.SQLContext import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandom -import org.apache.spark.sql.SQLContext /** * Entry in vocabulary diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 70ef1ed30c71..5273ed4d7665 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -17,19 +17,29 @@ package org.apache.spark.mllib.fpm -import java.{util => ju} import java.lang.{Iterable => JavaIterable} +import java.{util => ju} -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ + +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.mllib.fpm.FPGrowth._ +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel /** @@ -39,7 +49,8 @@ import org.apache.spark.storage.StorageLevel */ @Since("1.3.0") class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( - @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { + @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) + extends Saveable with Serializable { /** * Generates association rules for the [[Item]]s in [[freqItemsets]]. * @param confidence minimal confidence of the rules produced @@ -49,6 +60,89 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") ( val associationRules = new AssociationRules(confidence) associationRules.run(freqItemsets) } + + /** + * Save this model to the given path. + * It only works for Item datatypes supported by DataFrames. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[FPGrowthModel.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + FPGrowthModel.SaveLoadV1_0.save(this, path) + } + + override protected val formatVersion: String = "1.0" +} + +@Since("2.0.0") +object FPGrowthModel extends Loader[FPGrowthModel[_]] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): FPGrowthModel[_] = { + FPGrowthModel.SaveLoadV1_0.load(sc, path) + } + + private[fpm] object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private val thisClassName = "org.apache.spark.mllib.fpm.FPGrowthModel" + + def save(model: FPGrowthModel[_], path: String): Unit = { + val sc = model.freqItemsets.sparkContext + val sqlContext = SQLContext.getOrCreate(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Get the type of item class + val sample = model.freqItemsets.first().items(0) + val className = sample.getClass.getCanonicalName + val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className) + val tpe = classSymbol.selfType + + val itemType = ScalaReflection.schemaFor(tpe).dataType + val fields = Array(StructField("items", ArrayType(itemType)), + StructField("freq", LongType)) + val schema = StructType(fields) + val rowDataRDD = model.freqItemsets.map { x => + Row(x.items, x.freq) + } + sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): FPGrowthModel[_] = { + implicit val formats = DefaultFormats + val sqlContext = SQLContext.getOrCreate(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path)) + val sample = freqItemsets.select("items").head().get(0) + loadImpl(freqItemsets, sample) + } + + def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = { + val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x => + val items = x.getAs[Seq[Item]](0).toArray + val freq = x.getLong(1) + new FreqItemset(items, freq) + } + new FPGrowthModel(freqItemsetsRDD) + } + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 97916daa2e9a..ed49c9492fdc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.fpm import java.{lang => jl, util => ju} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag import org.apache.spark.Logging diff --git a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala index 72d3aabc9b1f..57ca4d3464f1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/impl/PeriodicCheckpointer.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.impl import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala index 863abe86d38d..bb94745f078e 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/EigenValueDecomposition.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV} import com.github.fommil.netlib.ARPACK -import org.netlib.util.{intW, doubleW} +import org.netlib.util.{doubleW, intW} /** * Compute eigen-decomposition. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 8879dcf75c9b..d7a74db0b1fd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import java.util.{Arrays, Random} -import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHashSet, ArrayBuffer} +import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, HashSet => MHashSet} import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 4dcf351df43f..cecfd067bd87 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -17,8 +17,8 @@ package org.apache.spark.mllib.linalg -import java.util import java.lang.{Double => JavaDouble, Integer => JavaInteger, Iterable => JavaIterable} +import java.util import scala.annotation.varargs import scala.collection.JavaConverters._ @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV} import org.json4s.DefaultFormats import org.json4s.JsonDSL._ -import org.json4s.jackson.JsonMethods.{compact, render, parse => parseJson} +import org.json4s.jackson.JsonMethods.{compact, parse => parseJson, render} import org.apache.spark.SparkException import org.apache.spark.annotation.{AlphaComponent, Since} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala index 8a70f34e70f6..97b03b340f20 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrix.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Since -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg.{Matrix, SparseMatrix, Vectors} +import org.apache.spark.rdd.RDD /** * Represents an entry in an distributed matrix. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala index 976299124ced..e8de515211a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrix.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.annotation.Since -import org.apache.spark.rdd.RDD import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.linalg.SingularValueDecomposition +import org.apache.spark.rdd.RDD /** * Represents a row of [[org.apache.spark.mllib.linalg.distributed.IndexedRowMatrix]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala index 2018a678688e..0a36da410133 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala @@ -21,8 +21,8 @@ import java.util.Arrays import scala.collection.mutable.ListBuffer -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, SparseVector => BSV, axpy => brzAxpy, - svd => brzSvd, MatrixSingularException, inv} +import breeze.linalg.{axpy => brzAxpy, inv, svd => brzSvd, DenseMatrix => BDM, DenseVector => BDV, + MatrixSingularException, SparseVector => BSV} import breeze.numerics.{sqrt => brzSqrt} import org.apache.spark.Logging @@ -30,8 +30,8 @@ import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.rdd.RDD -import org.apache.spark.util.random.XORShiftRandom import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.random.XORShiftRandom /** * Represents a row-oriented distributed Matrix with no meaningful row indices. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala index 37bb6f6097f6..5873669b37e3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/GradientDescent.scala @@ -19,12 +19,12 @@ package org.apache.spark.mllib.optimization import scala.collection.mutable.ArrayBuffer -import breeze.linalg.{DenseVector => BDV, norm} +import breeze.linalg.{norm, DenseVector => BDV} -import org.apache.spark.annotation.{Experimental, DeveloperApi} import org.apache.spark.Logging +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Vectors, Vector} /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala index 7f6d94571b5e..d8e56720967d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Optimizer.scala @@ -17,10 +17,9 @@ package org.apache.spark.mllib.optimization -import org.apache.spark.rdd.RDD - import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala index 9f463e0cafb6..03c01e0553d7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Updater.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.optimization import scala.math._ -import breeze.linalg.{norm => brzNorm, axpy => brzAxpy, Vector => BV} +import breeze.linalg.{axpy => brzAxpy, norm => brzNorm, Vector => BV} import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala index 9eab7efc160d..fa04f8eb5e79 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/random/RandomDataGenerator.scala @@ -19,8 +19,8 @@ package org.apache.spark.mllib.random import org.apache.commons.math3.distribution._ -import org.apache.spark.annotation.{Since, DeveloperApi} -import org.apache.spark.util.random.{XORShiftRandom, Pseudorandom} +import org.apache.spark.annotation.{DeveloperApi, Since} +import org.apache.spark.util.random.{Pseudorandom, XORShiftRandom} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala index f8cea7ecea6b..92bc66949ae8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/RandomRDD.scala @@ -17,15 +17,15 @@ package org.apache.spark.mllib.rdd +import scala.reflect.ClassTag +import scala.util.Random + import org.apache.spark.{Partition, SparkContext, TaskContext} import org.apache.spark.mllib.linalg.{DenseVector, Vector} import org.apache.spark.mllib.random.RandomDataGenerator import org.apache.spark.rdd.RDD import org.apache.spark.util.Utils -import scala.reflect.ClassTag -import scala.util.Random - private[mllib] class RandomRDDPartition[T](override val index: Int, val size: Int, val generator: RandomDataGenerator[T], diff --git a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala index ead8db634499..adb5e51947f6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/rdd/SlidingRDD.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.rdd import scala.collection.mutable import scala.reflect.ClassTag -import org.apache.spark.{TaskContext, Partition} +import org.apache.spark.{Partition, TaskContext} import org.apache.spark.rdd.RDD private[mllib] diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 8f657bfb9c73..e60edc675c83 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.regression +import org.apache.spark.{Logging, SparkException} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.feature.StandardScaler -import org.apache.spark.{Logging, SparkException} -import org.apache.spark.rdd.RDD +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.optimization._ -import org.apache.spark.mllib.linalg.{Vectors, Vector} import org.apache.spark.mllib.util.MLUtils._ +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala index c284ad232537..45540f0c5c4c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LabeledPoint.scala @@ -20,7 +20,7 @@ package org.apache.spark.mllib.regression import scala.beans.BeanInfo import org.apache.spark.annotation.Since -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.NumericParser import org.apache.spark.SparkException diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala index a9aba173fa0e..d55e5dfdaaf5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala index 4996ace5df85..7da82c862a2b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/LinearRegression.scala @@ -23,7 +23,7 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.optimization._ import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.impl.GLMRegressionModel -import org.apache.spark.mllib.util.{Saveable, Loader} +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala index 201333c3690d..98404be2603c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/MultivariateOnlineSummarizer.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.stat import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.{Vectors, Vector} +import org.apache.spark.mllib.linalg.{Vector, Vectors} /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala index bcb33a7a0467..f3159f7e724c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.stat import scala.annotation.varargs import org.apache.spark.annotation.Since -import org.apache.spark.api.java.{JavaRDD, JavaDoubleRDD} -import org.apache.spark.mllib.linalg.distributed.RowMatrix +import org.apache.spark.api.java.{JavaDoubleRDD, JavaRDD} import org.apache.spark.mllib.linalg.{Matrix, Vector} +import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.stat.correlation.Correlations import org.apache.spark.mllib.stat.test.{ChiSqTest, ChiSqTestResult, KolmogorovSmirnovTest, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 0724af93088c..052b5b1d65b0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.stat.distribution -import breeze.linalg.{DenseVector => DBV, DenseMatrix => DBM, diag, max, eigSym, Vector => BV} +import breeze.linalg.{diag, eigSym, max, DenseMatrix => DBM, DenseVector => DBV, Vector => BV} import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.linalg.{Vectors, Vector, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.util.MLUtils /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala index 23c8d7c7c807..f22f2df320f0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/ChiSqTest.scala @@ -17,16 +17,16 @@ package org.apache.spark.mllib.stat.test +import scala.collection.mutable + import breeze.linalg.{DenseMatrix => BDM} import org.apache.commons.math3.distribution.ChiSquaredDistribution -import org.apache.spark.{SparkException, Logging} +import org.apache.spark.{Logging, SparkException} import org.apache.spark.mllib.linalg.{Matrices, Matrix, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import scala.collection.mutable - /** * Conduct the chi-squared test for the input RDDs using the specified method. * Goodness-of-fit test is conducted on two `Vectors`, whereas test of independence is conducted diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala index d2513a9d5c5b..0b118a76733f 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala @@ -21,7 +21,7 @@ import scala.beans.BeanProperty import org.apache.spark.annotation.Since import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} +import org.apache.spark.mllib.tree.loss.{LogLoss, Loss, SquaredError} /** * Configuration options for [[org.apache.spark.mllib.tree.GradientBoostedTrees]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 372d6617a401..6c04403f1ad7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -21,9 +21,9 @@ import scala.beans.BeanProperty import scala.collection.JavaConverters._ import org.apache.spark.annotation.Since -import org.apache.spark.mllib.tree.impurity.{Variance, Entropy, Gini, Impurity} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance} /** * Stores all the configuration options for tree construction diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala index 1c611976a930..fbbec1197404 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -19,13 +19,13 @@ package org.apache.spark.mllib.tree.impl import scala.collection.mutable -import org.apache.hadoop.fs.{Path, FileSystem} +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.rdd.RDD import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.storage.StorageLevel import org.apache.spark.mllib.tree.model.{Bin, Node, Split} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala index ea6e5aa5d94e..66f0908c1250 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala @@ -17,10 +17,10 @@ package org.apache.spark.mllib.tree.model -import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.Logging -import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.configuration.FeatureType._ /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala index b85a66c05a81..783a4acb55ce 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala @@ -18,7 +18,6 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.{DeveloperApi, Since} -import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala index 33477ee20ebb..68835bc79677 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LogisticRegressionDataGenerator.scala @@ -19,11 +19,11 @@ package org.apache.spark.mllib.util import scala.util.Random -import org.apache.spark.annotation.{Since, DeveloperApi} import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala index 906bd30563bd..8af6750da4ff 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MFDataGenerator.scala @@ -23,7 +23,7 @@ import scala.language.postfixOps import scala.util.Random import org.apache.spark.SparkContext -import org.apache.spark.annotation.{Since, DeveloperApi} +import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala index 4c9151f0cb4f..74e9271e4032 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.util import scala.reflect.ClassTag -import org.apache.spark.annotation.Since import org.apache.spark.SparkContext -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PartitionwiseSampledRDD -import org.apache.spark.util.random.BernoulliCellSampler -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.{SparseVector, DenseVector, Vector, Vectors} +import org.apache.spark.annotation.Since +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.linalg.BLAS.dot +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.random.BernoulliCellSampler /** * Helper methods to load, save and pre-process data used in ML Lib. @@ -87,7 +86,8 @@ object MLUtils { val indicesLength = indices.length while (i < indicesLength) { val current = indices(i) - require(current > previous, "indices should be one-based and in ascending order" ) + require(current > previous, s"indices should be one-based and in ascending order;" + + " found current=$current, previous=$previous; line=\"$line\"") previous = current i += 1 } diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java index 154f75d75e4a..eeeabfe359e6 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.fpm; +import java.io.File; import java.io.Serializable; import java.util.Arrays; import java.util.List; @@ -28,6 +29,7 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.util.Utils; public class JavaFPGrowthSuite implements Serializable { private transient JavaSparkContext sc; @@ -69,4 +71,42 @@ public void runFPGrowth() { long freq = itemset.freq(); } } + + @Test + public void runFPGrowthSaveLoad() { + + @SuppressWarnings("unchecked") + JavaRDD> rdd = sc.parallelize(Arrays.asList( + Arrays.asList("r z h k p".split(" ")), + Arrays.asList("z y x w v u t s".split(" ")), + Arrays.asList("s x o n r".split(" ")), + Arrays.asList("x z y m t s q e".split(" ")), + Arrays.asList("z".split(" ")), + Arrays.asList("x z y r q t p".split(" "))), 2); + + FPGrowthModel model = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd); + + File tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaFPGrowthSuite"); + String outputPath = tempDir.getPath(); + + try { + model.save(sc.sc(), outputPath); + FPGrowthModel newModel = FPGrowthModel.load(sc.sc(), outputPath); + List> freqItemsets = newModel.freqItemsets().toJavaRDD() + .collect(); + assertEquals(18, freqItemsets.size()); + + for (FPGrowth.FreqItemset itemset: freqItemsets) { + // Test return types. + List items = itemset.javaItems(); + long freq = itemset.freq(); + } + } finally { + Utils.deleteRecursively(tempDir); + } + } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java index 271dda4662e0..a6631ed7ebd6 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/recommendation/JavaALSSuite.java @@ -56,10 +56,10 @@ void validatePrediction( double matchThreshold, boolean implicitPrefs, DoubleMatrix truePrefs) { - List> localUsersProducts = new ArrayList(users * products); + List> localUsersProducts = new ArrayList<>(users * products); for (int u=0; u < users; ++u) { for (int p=0; p < products; ++p) { - localUsersProducts.add(new Tuple2(u, p)); + localUsersProducts.add(new Tuple2<>(u, p)); } } JavaPairRDD usersProducts = sc.parallelizePairs(localUsersProducts); diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java index 32c2f4f3395b..3db9b39e740e 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaIsotonicRegressionSuite.java @@ -36,11 +36,11 @@ public class JavaIsotonicRegressionSuite implements Serializable { private transient JavaSparkContext sc; - private List> generateIsotonicInput(double[] labels) { - ArrayList> input = new ArrayList(labels.length); + private static List> generateIsotonicInput(double[] labels) { + List> input = new ArrayList<>(labels.length); for (int i = 1; i <= labels.length; i++) { - input.add(new Tuple3(labels[i-1], (double) i, 1d)); + input.add(new Tuple3<>(labels[i-1], (double) i, 1.0)); } return input; @@ -70,7 +70,7 @@ public void testIsotonicRegressionJavaRDD() { runIsotonicRegression(new double[]{1, 2, 3, 3, 1, 6, 7, 8, 11, 9, 10, 12}); Assert.assertArrayEquals( - new double[] {1, 2, 7d/3, 7d/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1e-14); + new double[] {1, 2, 7.0/3, 7.0/3, 6, 7, 8, 10, 10, 12}, model.predictions(), 1.0e-14); } @Test @@ -81,10 +81,10 @@ public void testIsotonicRegressionPredictionsJavaRDD() { JavaDoubleRDD testRDD = sc.parallelizeDoubles(Arrays.asList(0.0, 1.0, 9.5, 12.0, 13.0)); List predictions = model.predict(testRDD).collect(); - Assert.assertTrue(predictions.get(0) == 1d); - Assert.assertTrue(predictions.get(1) == 1d); - Assert.assertTrue(predictions.get(2) == 10d); - Assert.assertTrue(predictions.get(3) == 12d); - Assert.assertTrue(predictions.get(4) == 12d); + Assert.assertEquals(1.0, predictions.get(0).doubleValue(), 1.0e-14); + Assert.assertEquals(1.0, predictions.get(1).doubleValue(), 1.0e-14); + Assert.assertEquals(10.0, predictions.get(2).doubleValue(), 1.0e-14); + Assert.assertEquals(12.0, predictions.get(3).doubleValue(), 1.0e-14); + Assert.assertEquals(12.0, predictions.get(4).doubleValue(), 1.0e-14); } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 8c8676745636..f3321fb5a1ab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,9 +26,10 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite -import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -174,6 +175,26 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } } + + test("pipeline validateParams") { + val df = sqlContext.createDataFrame( + Seq( + (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + ).toDF("id", "features", "label") + + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("features_scaled") + .setMin(10) + .setMax(0) + val pipeline = new Pipeline().setStages(Array(scaler)) + pipeline.fit(df) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 1087afb0cdf7..ff0d0ff77104 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -23,7 +23,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.{Identifiable, DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.{DefaultReadWriteTest, Identifiable, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 5ea71c5317b7..d7983f92a348 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -21,9 +21,9 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.NominalAttribute import org.apache.spark.ml.feature.StringIndexer import org.apache.spark.ml.param.{ParamMap, ParamsSuite} -import org.apache.spark.ml.util.{MLTestingUtils, MetadataUtils} -import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS +import org.apache.spark.ml.util.{MetadataUtils, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite._ +import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala index 932d331b472b..0d4e00668ddb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/InteractionSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.ml.feature import scala.collection.mutable.ArrayBuilder -import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute._ import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.functions.col diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala index 9f6618b92929..f372ec58269e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PCASuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.linalg._ +import org.apache.spark.mllib.linalg.distributed.RowMatrix import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.Row diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala index 70892dc57170..dfdc5792c6db 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/PolynomialExpansionSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.ml.param.ParamsSuite import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 3a4f6d235aa6..722f1abde435 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.ml.feature +import org.apache.spark.{SparkContext, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.{SparkContext, SparkFunSuite} class QuantileDiscretizerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala index 1eae125a524e..28631cef7943 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StandardScalerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature - import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 749bfac74782..5d199ca9b51b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.feature -import org.apache.spark.sql.types.{StringType, StructType, StructField, DoubleType} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param.ParamsSuite @@ -25,6 +24,7 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{DoubleType, StringType, StructField, StructType} class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 9c1c00f41ab1..f7de7c1e93fb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -17,10 +17,10 @@ package org.apache.spark.ml.feature -import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericAttribute} import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.Row diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala index 8acc3369c489..94191e5df383 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.types.StructType class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index d561bbbb2552..a73b56512566 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.ml.feature import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.mllib.feature.{Word2VecModel => OldWord2VecModel} class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { diff --git a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala index 4e2d0e93bd41..a808177cb9bf 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/impl/TreeTests.scala @@ -25,8 +25,7 @@ import org.apache.spark.ml.attribute.{AttributeGroup, NominalAttribute, NumericA import org.apache.spark.ml.tree._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, DataFrame} - +import org.apache.spark.sql.{DataFrame, SQLContext} private[ml] object TreeTests extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index a1878be747ce..748868554fe6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.ml.param +import java.io.{ByteArrayOutputStream, NotSerializableException, ObjectOutputStream} + import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.util.MyParams import org.apache.spark.mllib.linalg.{Vector, Vectors} class ParamsSuite extends SparkFunSuite { @@ -349,6 +352,31 @@ class ParamsSuite extends SparkFunSuite { val t3 = t.copy(ParamMap(t.maxIter -> 20)) assert(t3.isSet(t3.maxIter)) } + + test("Filtering ParamMap") { + val params1 = new MyParams("my_params1") + val params2 = new MyParams("my_params2") + val paramMap = ParamMap( + params1.intParam -> 1, + params2.intParam -> 1, + params1.doubleParam -> 0.2, + params2.doubleParam -> 0.2) + val filteredParamMap = paramMap.filter(params1) + + assert(filteredParamMap.size === 2) + filteredParamMap.toSeq.foreach { + case ParamPair(p, _) => + assert(p.parent === params1.uid) + } + + // At the previous implementation of ParamMap#filter, + // mutable.Map#filterKeys was used internally but + // the return type of the method is not serializable (see SI-6654). + // Now mutable.Map#filter is used instead of filterKeys and the return type is serializable. + // So let's ensure serializability. + val objOut = new ObjectOutputStream(new ByteArrayOutputStream()) + objOut.writeObject(filteredParamMap) + } } object ParamsSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2c3fb84160dc..ff0d8f556827 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -25,7 +25,6 @@ import scala.language.existentials import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkException, SparkFunSuite} import org.apache.spark.ml.recommendation.ALS._ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} @@ -34,7 +33,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row} - +import org.apache.spark.util.Utils class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 6999a910c34a..13165f67014c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.ml.regression import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.util.MLTestingUtils +import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame - +import org.apache.spark.sql.{DataFrame, Row} class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -73,6 +73,29 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex MLTestingUtils.checkCopy(model) } + test("predictVariance") { + val dt = new DecisionTreeRegressor() + .setImpurity("variance") + .setMaxDepth(2) + .setMaxBins(100) + .setPredictionCol("") + .setVarianceCol("variance") + val categoricalFeatures = Map(0 -> 2, 1 -> 2) + + val df = TreeTests.setMetadata(categoricalDataPointsRDD, categoricalFeatures, numClasses = 0) + val model = dt.fit(df) + + val predictions = model.transform(df) + .select(model.getFeaturesCol, model.getVarianceCol) + .collect() + + predictions.foreach { case Row(features: Vector, variance: Double) => + val expectedVariance = model.rootNode.predictImpl(features).impurityStats.calculate() + assert(variance === expectedVariance, + s"Expected variance $expectedVariance but got $variance.") + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 2f3e703f4c25..273c882c2a47 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.linalg.{Vector, DenseVector, Vectors} import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.sql.{DataFrame, Row} diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index 997f574e51f6..5f4d5f11bdd6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -46,8 +46,11 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { } override def afterAll(): Unit = { - Utils.deleteRecursively(tempDir) - super.afterAll() + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterAll() + } } test("select as sparse vector") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index d281084f913c..56545de14bd3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -18,14 +18,14 @@ package org.apache.spark.ml.tuning import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.feature.HashingTF -import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} -import org.apache.spark.ml.{Pipeline, Estimator, Model} -import org.apache.spark.ml.classification.{LogisticRegressionModel, LogisticRegression} +import org.apache.spark.ml.{Estimator, Model, Pipeline} +import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel} import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator} -import org.apache.spark.ml.param.{ParamPair, ParamMap} +import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.param.{ParamMap, ParamPair} import org.apache.spark.ml.param.shared.HasInputCol import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext} diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 84d06b43d622..0aa774b66078 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -22,7 +22,7 @@ import java.io.{File, IOException} import org.scalatest.Suite import org.apache.spark.SparkFunSuite -import org.apache.spark.ml.{Model, Estimator} +import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala index c8a0bb16247b..8f11bbc8e47a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/TempDirectory.scala @@ -39,7 +39,10 @@ trait TempDirectory extends BeforeAndAfterAll { self: Suite => } override def afterAll(): Unit = { - Utils.deleteRecursively(_tempDir) - super.afterAll() + try { + Utils.deleteRecursively(_tempDir) + } finally { + super.afterAll() + } } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala index 59944416d96a..0eb839f20c00 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.mllib.api.python import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Vectors, SparseMatrix} -import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, SparseMatrix, Vectors} import org.apache.spark.mllib.recommendation.Rating +import org.apache.spark.mllib.regression.LabeledPoint class PythonMLLibAPISuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala index d7b291d5a633..bf98bf2f5fde 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/StreamingLogisticRegressionSuite.scala @@ -23,8 +23,8 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.{StreamingContext, TestSuiteBase} +import org.apache.spark.streaming.dstream.DStream class StreamingLogisticRegressionSuite extends SparkFunSuite with TestSuiteBase { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala index a72723eb00da..fb3bd3f412f8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/GaussianMixtureSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.clustering import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Vector, Vectors, Matrices} +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} import org.apache.spark.mllib.stat.distribution.MultivariateGaussian import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala index 37fb69d68f6b..faef60e084cc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/LDASuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.clustering import java.util.{ArrayList => JArrayList} -import breeze.linalg.{DenseMatrix => BDM, argtopk, max, argmax} +import breeze.linalg.{argmax, argtopk, max, DenseMatrix => BDM} import org.apache.spark.SparkFunSuite import org.apache.spark.graphx.Edge diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala index c0924a213a84..77ec49d00539 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { test("Ranking metrics: map, ndcg") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala index 4b7f1be58f99..f1d517383643 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RegressionMetricsSuite.scala @@ -22,91 +22,115 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ class RegressionMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { + val obs = List[Double](77, 85, 62, 55, 63, 88, 57, 81, 51) + val eps = 1E-5 test("regression metrics for unbiased (includes intercept term) predictor") { /* Verify results in R: - preds = c(2.25, -0.25, 1.75, 7.75) - obs = c(3.0, -0.5, 2.0, 7.0) - - SStot = sum((obs - mean(obs))^2) - SSreg = sum((preds - mean(obs))^2) - SSerr = sum((obs - preds)^2) - - explainedVariance = SSreg / length(obs) - explainedVariance - > [1] 8.796875 - meanAbsoluteError = mean(abs(preds - obs)) - meanAbsoluteError - > [1] 0.5 - meanSquaredError = mean((preds - obs)^2) - meanSquaredError - > [1] 0.3125 - rmse = sqrt(meanSquaredError) - rmse - > [1] 0.559017 - r2 = 1 - SSerr / SStot - r2 - > [1] 0.9571734 + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + x = c(16, 22, 14, 10, 13, 19, 12, 18, 11) + df <- as.data.frame(cbind(x, y)) + model <- lm(y ~ x, data=df) + preds = signif(predict(model), digits = 4) + preds + 1 2 3 4 5 6 7 8 9 + 72.08 91.88 65.48 52.28 62.18 81.98 58.88 78.68 55.58 + options(digits=8) + explainedVariance = mean((preds - mean(y))^2) + [1] 157.3 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 3.7355556 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 17.539511 + rmse = sqrt(meanSquaredError) + rmse + [1] 4.18802 + r2 = summary(model)$r.squared + r2 + [1] 0.89968225 */ - val predictionAndObservations = sc.parallelize( - Seq((2.25, 3.0), (-0.25, -0.5), (1.75, 2.0), (7.75, 7.0)), 2) + val preds = List(72.08, 91.88, 65.48, 52.28, 62.18, 81.98, 58.88, 78.68, 55.58) + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.79687 absTol 1E-5, + assert(metrics.explainedVariance ~== 157.3 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.3125 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.55901 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 3.7355556 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 17.539511 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 4.18802 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.95717 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.89968225 absTol eps, "r2 score mismatch") } test("regression metrics for biased (no intercept term) predictor") { /* Verify results in R: - preds = c(2.5, 0.0, 2.0, 8.0) - obs = c(3.0, -0.5, 2.0, 7.0) - - SStot = sum((obs - mean(obs))^2) - SSreg = sum((preds - mean(obs))^2) - SSerr = sum((obs - preds)^2) - - explainedVariance = SSreg / length(obs) - explainedVariance - > [1] 8.859375 - meanAbsoluteError = mean(abs(preds - obs)) - meanAbsoluteError - > [1] 0.5 - meanSquaredError = mean((preds - obs)^2) - meanSquaredError - > [1] 0.375 - rmse = sqrt(meanSquaredError) - rmse - > [1] 0.6123724 - r2 = 1 - SSerr / SStot - r2 - > [1] 0.9486081 + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + x = c(16, 22, 14, 10, 13, 19, 12, 18, 11) + df <- as.data.frame(cbind(x, y)) + model <- lm(y ~ 0 + x, data=df) + preds = signif(predict(model), digits = 4) + preds + 1 2 3 4 5 6 7 8 9 + 72.12 99.17 63.11 45.08 58.60 85.65 54.09 81.14 49.58 + options(digits=8) + explainedVariance = mean((preds - mean(y))^2) + explainedVariance + [1] 294.88167 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 4.5888889 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 39.958711 + rmse = sqrt(meanSquaredError) + rmse + [1] 6.3212903 + r2 = summary(model)$r.squared + r2 + [1] 0.99185395 */ - val predictionAndObservations = sc.parallelize( - Seq((2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)), 2) - val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.85937 absTol 1E-5, + val preds = List(72.12, 99.17, 63.11, 45.08, 58.6, 85.65, 54.09, 81.14, 49.58) + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) + val metrics = new RegressionMetrics(predictionAndObservations, true) + assert(metrics.explainedVariance ~== 294.88167 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.5 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.375 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.61237 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 4.5888889 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 39.958711 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 6.3212903 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 0.94860 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 0.99185395 absTol eps, "r2 score mismatch") } test("regression metrics with complete fitting") { - val predictionAndObservations = sc.parallelize( - Seq((3.0, 3.0), (0.0, 0.0), (2.0, 2.0), (8.0, 8.0)), 2) + /* Verify results in R: + y = c(77, 85, 62, 55, 63, 88, 57, 81, 51) + preds = y + explainedVariance = mean((preds - mean(y))^2) + explainedVariance + [1] 174.8395 + meanAbsoluteError = mean(abs(preds - y)) + meanAbsoluteError + [1] 0 + meanSquaredError = mean((preds - y)^2) + meanSquaredError + [1] 0 + rmse = sqrt(meanSquaredError) + rmse + [1] 0 + r2 = 1 - sum((preds - y)^2)/sum((y - mean(y))^2) + r2 + [1] 1 + */ + val preds = obs + val predictionAndObservations = sc.parallelize(preds.zip(obs), 2) val metrics = new RegressionMetrics(predictionAndObservations) - assert(metrics.explainedVariance ~== 8.6875 absTol 1E-5, + assert(metrics.explainedVariance ~== 174.83951 absTol eps, "explained variance regression score mismatch") - assert(metrics.meanAbsoluteError ~== 0.0 absTol 1E-5, "mean absolute error mismatch") - assert(metrics.meanSquaredError ~== 0.0 absTol 1E-5, "mean squared error mismatch") - assert(metrics.rootMeanSquaredError ~== 0.0 absTol 1E-5, + assert(metrics.meanAbsoluteError ~== 0.0 absTol eps, "mean absolute error mismatch") + assert(metrics.meanSquaredError ~== 0.0 absTol eps, "mean squared error mismatch") + assert(metrics.rootMeanSquaredError ~== 0.0 absTol eps, "root mean squared error mismatch") - assert(metrics.r2 ~== 1.0 absTol 1E-5, "r2 score mismatch") + assert(metrics.r2 ~== 1.0 absTol eps, "r2 score mismatch") } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala index 21163633051e..5c938a61ed99 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vectors, Vector} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala index 6ab2fa677012..b4e26b2aeb3c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/StandardScalerSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} +import org.apache.spark.mllib.stat.{MultivariateOnlineSummarizer, MultivariateStatisticalSummary} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, MultivariateOnlineSummarizer} import org.apache.spark.rdd.RDD class StandardScalerSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 37d01e287669..e74ecc16ee9f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext - import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 4a9bfdb348d9..b9e997c207bc 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -274,4 +275,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { */ assert(model1.freqItemsets.count() === 65) } + + test("model save/load with String type") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + val model3 = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model3.save(sc, path) + val newModel = FPGrowthModel.load(sc, path) + val newFreqItemsets = newModel.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + assert(freqItemsets3.toSet === newFreqItemsets.toSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model save/load with Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toArray) + val rdd = sc.parallelize(transactions, 2).cache() + + val model3 = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model3.save(sc, path) + val newModel = FPGrowthModel.load(sc, path) + val newFreqItemsets = newModel.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + assert(freqItemsets3.toSet === newFreqItemsets.toSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala index 96e5ffef7a13..80da03cc2efe 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BLASSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.linalg import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.mllib.linalg.BLAS._ +import org.apache.spark.mllib.util.TestingUtils._ class BLASSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala index dc04258e41d2..de2c3c13bd92 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/BreezeMatrixConversionSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.linalg -import breeze.linalg.{DenseMatrix => BDM, CSCMatrix => BSM} +import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM} import org.apache.spark.SparkFunSuite diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index f895e2a8e4af..832ccc0aacf8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.mllib.linalg import scala.util.Random -import breeze.linalg.{DenseMatrix => BDM, squaredDistance => breezeSquaredDistance} +import breeze.linalg.{squaredDistance => breezeSquaredDistance, DenseMatrix => BDM} import org.json4s.jackson.JsonMethods.{parse => parseJson} import org.apache.spark.{Logging, SparkException, SparkFunSuite} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala index b8eb10305801..d91ba8a6fdb7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/BlockMatrixSuite.scala @@ -22,7 +22,7 @@ import java.{util => ju} import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.{SparkException, SparkFunSuite} -import org.apache.spark.mllib.linalg.{SparseMatrix, DenseMatrix, Matrices, Matrix} +import org.apache.spark.mllib.linalg.{DenseMatrix, Matrices, Matrix, SparseMatrix} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala index f3728cd036a3..37d75103d18d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/CoordinateMatrixSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{DenseMatrix => BDM} import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.util.MLlibTestSparkContext class CoordinateMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala index 6de6cf2fa863..5b7ccb90158b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/IndexedRowMatrixSuite.scala @@ -20,9 +20,9 @@ package org.apache.spark.mllib.linalg.distributed import breeze.linalg.{diag => brzDiag, DenseMatrix => BDM, DenseVector => BDV} import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD -import org.apache.spark.mllib.linalg.{Matrices, Vectors} class IndexedRowMatrixSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala index 0ff901ddc497..2dff52c601d8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/distributed/RowMatrixSuite.scala @@ -21,11 +21,11 @@ import java.util.Arrays import scala.util.Random +import breeze.linalg.{norm => brzNorm, svd => brzSvd, DenseMatrix => BDM, DenseVector => BDV} import breeze.numerics.abs -import breeze.linalg.{DenseVector => BDV, DenseMatrix => BDM, norm => brzNorm, svd => brzSvd} import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{Matrices, Vectors, Vector} +import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors} import org.apache.spark.mllib.random.RandomRDDs import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala index 36ac7d267243..1c9b7c78e5b8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala @@ -25,7 +25,7 @@ import org.scalatest.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression._ -import org.apache.spark.mllib.util.{MLUtils, LocalClusterSparkContext, MLlibTestSparkContext} +import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext, MLUtils} import org.apache.spark.mllib.util.TestingUtils._ object GradientDescentSuite { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala index 413db2000d6d..0b4c7eb302d4 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/random/RandomRDDsSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.mllib.random import scala.collection.mutable.ArrayBuffer -import org.apache.spark.SparkFunSuite import org.apache.spark.SparkContext._ +import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.mllib.rdd.{RandomRDDPartition, RandomRDD} +import org.apache.spark.mllib.rdd.{RandomRDD, RandomRDDPartition} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.util.StatCounter diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala index 10f5a2be48f7..56231429859e 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/MLPairRDDFunctionsSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.rdd import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.MLPairRDDFunctions._ +import org.apache.spark.mllib.util.MLlibTestSparkContext class MLPairRDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { test("topByKey") { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala index ac93733bab5f..0e931fca6cf0 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/rdd/RDDFunctionsSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.rdd import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.rdd.RDDFunctions._ +import org.apache.spark.mllib.util.MLlibTestSparkContext class RDDFunctionsSuite extends SparkFunSuite with MLlibTestSparkContext { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala index 39537e7bb4c7..d96103d01e4a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala index f88a1c33c9f7..0694079b9df9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala index 7a781fee634c..8fb8886645cd 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala @@ -23,7 +23,7 @@ import org.jblas.DoubleMatrix import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors -import org.apache.spark.mllib.util.{LocalClusterSparkContext, LinearDataGenerator, +import org.apache.spark.mllib.util.{LinearDataGenerator, LocalClusterSparkContext, MLlibTestSparkContext} import org.apache.spark.util.Utils diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala index 3c657c8cfe74..1142102bb040 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/StreamingTestSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.mllib.stat import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.stat.test.{StreamingTest, StreamingTestResult, StudentTTest, - WelchTTest, BinarySample} +import org.apache.spark.mllib.stat.test.{BinarySample, StreamingTest, StreamingTestResult, + StudentTTest, WelchTTest} import org.apache.spark.streaming.TestSuiteBase import org.apache.spark.streaming.dstream.DStream import org.apache.spark.util.StatCounter diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala index 6e7a00347545..669d44223d71 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.stat.distribution import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.linalg.{ Vectors, Matrices } +import org.apache.spark.mllib.linalg.{Matrices, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index bf8fe1acac2f..a9c935bd4244 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -23,9 +23,9 @@ import scala.collection.mutable import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy} import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala index 3d3f80063f90..1cc8f342021a 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala @@ -17,13 +17,13 @@ package org.apache.spark.mllib.tree +import scala.collection.mutable + import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.model.TreeEnsembleModel import org.apache.spark.util.StatCounter -import scala.collection.mutable - object EnsembleTestHelper { /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala index 6fc9e8df621d..acb3b953b53b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala @@ -19,15 +19,14 @@ package org.apache.spark.mllib.tree import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Strategy} +import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.impurity.Variance -import org.apache.spark.mllib.tree.loss.{AbsoluteError, SquaredError, LogLoss} +import org.apache.spark.mllib.tree.loss.{AbsoluteError, LogLoss, SquaredError} import org.apache.spark.mllib.tree.model.GradientBoostedTreesModel import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.util.Utils - /** * Test suite for [[GradientBoostedTrees]]. */ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala index 525ab68c7921..9b2d023bbf73 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/LocalClusterSparkContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.mllib.util -import org.scalatest.{Suite, BeforeAndAfterAll} +import org.scalatest.{BeforeAndAfterAll, Suite} import org.apache.spark.{SparkConf, SparkContext} @@ -25,18 +25,21 @@ trait LocalClusterSparkContext extends BeforeAndAfterAll { self: Suite => @transient var sc: SparkContext = _ override def beforeAll() { + super.beforeAll() val conf = new SparkConf() .setMaster("local-cluster[2, 1, 1024]") .setAppName("test-cluster") .set("spark.akka.frameSize", "1") // set to 1MB to detect direct serialization of data sc = new SparkContext(conf) - super.beforeAll() } override def afterAll() { - if (sc != null) { - sc.stop() + try { + if (sc != null) { + sc.stop() + } + } finally { + super.afterAll() } - super.afterAll() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala index 378139593b26..ebcd591465cb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLlibTestSparkContext.scala @@ -38,12 +38,15 @@ trait MLlibTestSparkContext extends BeforeAndAfterAll { self: Suite => } override def afterAll() { - sqlContext = null - SQLContext.clearActive() - if (sc != null) { - sc.stop() + try { + sqlContext = null + SQLContext.clearActive() + if (sc != null) { + sc.stop() + } + sc = null + } finally { + super.afterAll() } - sc = null - super.afterAll() } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala index 352193a67860..6de9aaf94f1b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtils.scala @@ -17,9 +17,10 @@ package org.apache.spark.mllib.util -import org.apache.spark.mllib.linalg.{Matrix, Vector} import org.scalatest.exceptions.TestFailedException +import org.apache.spark.mllib.linalg.{Matrix, Vector} + object TestingUtils { val ABS_TOL_MSG = " using absolute tolerance" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala index 8f475f30249d..44c39704e5b9 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/util/TestingUtilsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.mllib.util +import org.scalatest.exceptions.TestFailedException + import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.TestingUtils._ -import org.scalatest.exceptions.TestFailedException class TestingUtilsSuite extends SparkFunSuite { diff --git a/network/common/pom.xml b/network/common/pom.xml index 9af6cc5e925f..eda2b7307088 100644 --- a/network/common/pom.xml +++ b/network/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml @@ -52,11 +52,6 @@ com.google.code.findbugs jsr305 - com.google.guava guava diff --git a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java index 51d34cac6e63..29e6a30dc1f6 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java +++ b/network/common/src/main/java/org/apache/spark/network/client/StreamCallback.java @@ -21,8 +21,8 @@ import java.nio.ByteBuffer; /** - * Callback for streaming data. Stream data will be offered to the {@link onData(String, ByteBuffer)} - * method as it arrives. Once all the stream data is received, {@link onComplete(String)} will be + * Callback for streaming data. Stream data will be offered to the {@link #onData(String, ByteBuffer)} + * method as it arrives. Once all the stream data is received, {@link #onComplete(String)} will be * called. *

    * The network library guarantees that a single thread will call these methods at a time, but diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java index c49ca4d5ee92..e15f096d3691 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportClient.java @@ -288,7 +288,7 @@ public void send(ByteBuffer message) { /** * Removes any state associated with the given RPC. * - * @param requestId The RPC id returned by {@link #sendRpc(byte[], RpcResponseCallback)}. + * @param requestId The RPC id returned by {@link #sendRpc(ByteBuffer, RpcResponseCallback)}. */ public void removeRpcRequest(long requestId) { handler.removeRpcRequest(requestId); diff --git a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index 23a8dba59344..f0e2004d2de2 100644 --- a/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -116,7 +116,11 @@ private void failOutstandingRequests(Throwable cause) { } @Override - public void channelUnregistered() { + public void channelActive() { + } + + @Override + public void channelInactive() { if (numOutstandingRequests() > 0) { String remoteAddress = NettyUtils.getRemoteAddress(channel); logger.error("Still have {} requests outstanding when connection from {} is closed", diff --git a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java index c215bd9d1504..c41f5b6873f6 100644 --- a/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/sasl/SaslRpcHandler.java @@ -135,9 +135,14 @@ public StreamManager getStreamManager() { } @Override - public void connectionTerminated(TransportClient client) { + public void channelActive(TransportClient client) { + delegate.channelActive(client); + } + + @Override + public void channelInactive(TransportClient client) { try { - delegate.connectionTerminated(client); + delegate.channelInactive(client); } finally { if (saslServer != null) { saslServer.dispose(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java index 3843406b2740..4a1f28e9ffb3 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/MessageHandler.java @@ -28,9 +28,12 @@ public abstract class MessageHandler { /** Handles the receipt of a single message. */ public abstract void handle(T message) throws Exception; + /** Invoked when the channel this MessageHandler is on is active. */ + public abstract void channelActive(); + /** Invoked when an exception was caught on the Channel. */ public abstract void exceptionCaught(Throwable cause); - /** Invoked when the channel this MessageHandler is on has been unregistered. */ - public abstract void channelUnregistered(); + /** Invoked when the channel this MessageHandler is on is inactive. */ + public abstract void channelInactive(); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java index ee1c68369947..a99c3015b0e0 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/RpcHandler.java @@ -57,7 +57,7 @@ public abstract void receive( /** * Receives an RPC message that does not expect a reply. The default implementation will - * call "{@link receive(TransportClient, byte[], RpcResponseCallback)}" and log a warning if + * call "{@link #receive(TransportClient, ByteBuffer, RpcResponseCallback)}" and log a warning if * any of the callback methods are called. * * @param client A channel client which enables the handler to make requests back to the sender @@ -69,10 +69,15 @@ public void receive(TransportClient client, ByteBuffer message) { } /** - * Invoked when the connection associated with the given client has been invalidated. + * Invoked when the channel associated with the given client is active. + */ + public void channelActive(TransportClient client) { } + + /** + * Invoked when the channel associated with the given client is inactive. * No further requests will come from this client. */ - public void connectionTerminated(TransportClient client) { } + public void channelInactive(TransportClient client) { } public void exceptionCaught(Throwable cause, TransportClient client) { } diff --git a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java index 3f0155957a14..07f161a29cfb 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java +++ b/network/common/src/main/java/org/apache/spark/network/server/StreamManager.java @@ -54,6 +54,7 @@ public abstract class StreamManager { * {@link #getChunk(long, int)} method. * * @param streamId id of a stream that has been previously registered with the StreamManager. + * @return A managed buffer for the stream, or null if the stream was not found. */ public ManagedBuffer openStream(String streamId) { throw new UnsupportedOperationException(); diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java index 09435bcbab35..18a9b7887ec2 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportChannelHandler.java @@ -84,14 +84,29 @@ public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws E } @Override - public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + public void channelActive(ChannelHandlerContext ctx) throws Exception { try { - requestHandler.channelUnregistered(); + requestHandler.channelActive(); + } catch (RuntimeException e) { + logger.error("Exception from request handler while registering channel", e); + } + try { + responseHandler.channelActive(); + } catch (RuntimeException e) { + logger.error("Exception from response handler while registering channel", e); + } + super.channelRegistered(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + requestHandler.channelInactive(); } catch (RuntimeException e) { logger.error("Exception from request handler while unregistering channel", e); } try { - responseHandler.channelUnregistered(); + responseHandler.channelInactive(); } catch (RuntimeException e) { logger.error("Exception from response handler while unregistering channel", e); } diff --git a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java index c864d7ce16bd..296ced3db093 100644 --- a/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java +++ b/network/common/src/main/java/org/apache/spark/network/server/TransportRequestHandler.java @@ -83,7 +83,12 @@ public void exceptionCaught(Throwable cause) { } @Override - public void channelUnregistered() { + public void channelActive() { + rpcHandler.channelActive(reverseClient); + } + + @Override + public void channelInactive() { if (streamManager != null) { try { streamManager.connectionTerminated(channel); @@ -91,7 +96,7 @@ public void channelUnregistered() { logger.error("StreamManager connectionTerminated() callback failed.", e); } } - rpcHandler.connectionTerminated(reverseClient); + rpcHandler.channelInactive(reverseClient); } @Override @@ -141,7 +146,12 @@ private void processStreamRequest(final StreamRequest req) { return; } - respond(new StreamResponse(req.streamId, buf.size(), buf)); + if (buf != null) { + respond(new StreamResponse(req.streamId, buf.size(), buf)); + } else { + respond(new StreamFailure(req.streamId, String.format( + "Stream '%s' was not found.", req.streamId))); + } } private void processRpcRequest(final RpcRequest req) { diff --git a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java index 751516b9d82a..045773317a78 100644 --- a/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java +++ b/network/common/src/test/java/org/apache/spark/network/sasl/SparkSaslSuite.java @@ -160,7 +160,7 @@ public Void answer(InvocationOnMock invocation) { long deadline = System.nanoTime() + TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); while (deadline > System.nanoTime()) { try { - verify(rpcHandler, times(2)).connectionTerminated(any(TransportClient.class)); + verify(rpcHandler, times(2)).channelInactive(any(TransportClient.class)); error = null; break; } catch (Throwable t) { @@ -362,8 +362,8 @@ public void testRpcHandlerDelegate() throws Exception { saslHandler.getStreamManager(); verify(handler).getStreamManager(); - saslHandler.connectionTerminated(null); - verify(handler).connectionTerminated(any(TransportClient.class)); + saslHandler.channelInactive(null); + verify(handler).channelInactive(any(TransportClient.class)); saslHandler.exceptionCaught(null, null); verify(handler).exceptionCaught(any(Throwable.class), any(TransportClient.class)); diff --git a/network/shuffle/pom.xml b/network/shuffle/pom.xml index 70ba5cb1995b..f9aa7e2dd1f4 100644 --- a/network/shuffle/pom.xml +++ b/network/shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/network/yarn/pom.xml b/network/yarn/pom.xml index e2360eff5cfe..a19cbb04b18c 100644 --- a/network/yarn/pom.xml +++ b/network/yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/pom.xml b/pom.xml index c560e13641c6..0eac21275432 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT pom Spark Project Parent POM http://spark.apache.org/ @@ -88,7 +88,6 @@ tags core - bagel graphx mllib tools @@ -153,7 +152,6 @@ 1.7.7 hadoop2 0.7.1 - 1.9.40 1.4.0 0.10.1 @@ -185,6 +183,7 @@ 3.5.2 1.3.9 0.9.2 + 3.5.2 ${java.home} @@ -194,7 +193,7 @@ declared in the projects that build assemblies. For other projects the scope should remain as "compile", otherwise they are not available - during compilation if the dependency is transivite (e.g. "bagel/" depending on "core/" and + during compilation if the dependency is transivite (e.g. "graphx/" depending on "core/" and needing Hadoop classes in the classpath to compile). --> compile @@ -227,93 +226,6 @@ false - - apache-repo - Apache Repository - https://repository.apache.org/content/repositories/releases - - true - - - false - - - - jboss-repo - JBoss Repository - https://repository.jboss.org/nexus/content/repositories/releases - - true - - - false - - - - mqtt-repo - MQTT Repository - https://repo.eclipse.org/content/repositories/paho-releases - - true - - - false - - - - cloudera-repo - Cloudera Repository - https://repository.cloudera.com/artifactory/cloudera-repos - - true - - - false - - - - spark-hive-staging - Staging Repo for Hive 1.2.1 (Spark Version) - https://oss.sonatype.org/content/repositories/orgspark-project-1113 - - true - - - - mapr-repo - MapR Repository - http://repository.mapr.com/maven/ - - true - - - false - - - - - spring-releases - Spring Release Repository - https://repo.spring.io/libs-release - - false - - - false - - - - - twttr-repo - Twttr Repository - http://maven.twttr.com - - true - - - false - - @@ -1845,6 +1757,11 @@ + + org.antlr + antlr-runtime + ${antlr.version} + @@ -1952,6 +1869,11 @@ + + org.antlr + antlr3-maven-plugin + 3.5.2 + org.apache.maven.plugins @@ -2114,6 +2036,23 @@ maven-deploy-plugin 2.8.2 + + org.apache.maven.plugins + maven-dependency-plugin + + + default-cli + + build-classpath + + + + runtime + + + + @@ -2225,17 +2164,6 @@ com.google.common org.spark-project.guava - - - com/google/common/base/Absent* - com/google/common/base/Function - com/google/common/base/Optional* - com/google/common/base/Present* - com/google/common/base/Supplier - @@ -2443,19 +2371,6 @@ http://hadoop.apache.org/docs/ra.b.c/hadoop-project-dist/hadoop-common/dependency-analysis.html --> - - hadoop-1 - - 1.2.1 - 2.4.1 - 0.98.7-hadoop1 - hadoop1 - 1.8.8 - org.spark-project.akka - 2.3.4-spark - - - hadoop-2.2 diff --git a/project/MimaBuild.scala b/project/MimaBuild.scala index 519052620246..9ba9f8286f10 100644 --- a/project/MimaBuild.scala +++ b/project/MimaBuild.scala @@ -91,7 +91,7 @@ object MimaBuild { def mimaSettings(sparkHome: File, projectRef: ProjectRef) = { val organization = "org.apache.spark" - val previousSparkVersion = "1.5.0" + val previousSparkVersion = "1.6.0" val fullId = "spark-" + projectRef.project + "_2.10" mimaDefaultSettings ++ Seq(previousArtifact := Some(organization % fullId % previousSparkVersion), diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index edae59d88266..0d5f938d9ef5 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -30,9 +30,112 @@ import com.typesafe.tools.mima.core.ProblemFilters._ * It is also possible to exclude Spark classes and packages. This should be used sparingly: * * MimaBuild.excludeSparkClass("graphx.util.collection.GraphXPrimitiveKeyOpenHashMap") + * + * For a new Spark version, please update MimaBuild.scala to reflect the previous version. */ object MimaExcludes { def excludes(version: String) = version match { + case v if v.startsWith("2.0") => + Seq( + excludePackage("org.apache.spark.rpc"), + excludePackage("org.spark-project.jetty"), + excludePackage("org.apache.spark.unused"), + excludePackage("org.apache.spark.util.collection.unsafe"), + excludePackage("org.apache.spark.sql.catalyst"), + excludePackage("org.apache.spark.sql.execution"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.feature.PCAModel.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.status.api.v1.StageData.this"), + // SPARK-12600 Remove SQL deprecated methods + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$QueryExecution"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.SQLContext$SparkPlanner"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.applySchema"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.parquetFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jdbc"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.jsonRDD"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.sql.SQLContext.load") + ) ++ Seq( + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.SparkContext.emptyRDD"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.broadcast.HttpBroadcastFactory") + ) ++ + Seq( + // SPARK-4819 replace Guava Optional + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getCheckpointDir"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.api.java.JavaSparkContext.getSparkHome"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.getCheckpointFile"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.partitioner") + ) ++ + Seq( + // SPARK-12481 Remove Hadoop 1.x + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("org.apache.spark.mapred.SparkHadoopMapRedUtil"), + // SPARK-12615 Remove deprecated APIs in core + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.$default$6"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.numericRDDToDoubleRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intToIntWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.intWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.writableWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToAsyncRDDActions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.boolToBoolWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longToLongWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToOrderedRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.booleanWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringToText"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.doubleToDoubleWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToSequenceFileRDDFunctions"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.bytesToBytesWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.longWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.stringWritableConverter"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.floatToFloatWritable"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.rddToPairRDDFunctions$default$4"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.addOnCompleteCallback"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.runningLocally"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.TaskContext.attemptId"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.defaultMinSplits"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.runJob"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.runJob"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.tachyonFolderName"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.initLocalProperties"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearJars"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.clearFiles"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.SparkContext.this"), + ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.SparkContext.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.toArray"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.flatMapWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.filterWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.foreachWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapWith"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.RDD.mapPartitionsWithSplit$default$2"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.rdd.SequenceFileRDDFunctions.this"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.splits"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaRDDLike.toArray"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.defaultMinSplits"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearJars"), + ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.api.java.JavaSparkContext.clearFiles") + ) ++ + // SPARK-12665 Remove deprecated and unused classes + Seq( + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.graphx.GraphKryoRegistrator"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$Multiplier"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.util.Vector$") + ) ++ Seq( + // SPARK-12591 Register OpenHashMapBasedStateMap for Kryo + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoInputDataInputBridge"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.serializer.KryoOutputDataOutputBridge") + ) ++ Seq( + // SPARK-12510 Refactor ActorReceiver to support Java + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.streaming.receiver.ActorReceiver") + ) case v if v.startsWith("1.6") => Seq( MimaBuild.excludeSparkPackage("deploy"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index b1dcaedcba75..4c34c888cfd5 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -34,10 +34,10 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, + val allProjects@Seq(catalyst, core, graphx, hive, hiveThriftServer, mllib, repl, sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka, streamingMqtt, streamingTwitter, streamingZeromq, launcher, unsafe, testTags) = - Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", + Seq("catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl", "sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter", "streaming-zeromq", "launcher", "unsafe", "test-tags").map(ProjectRef(buildLocation, _)) @@ -141,7 +141,12 @@ object SparkBuild extends PomBuild { publishMavenStyle := true, unidocGenjavadocVersion := "0.9-spark0", - resolvers += Resolver.mavenLocal, + // Override SBT's default resolvers: + resolvers := Seq( + DefaultMavenRepository, + Resolver.mavenLocal + ), + externalResolvers := resolvers.value, otherResolvers <<= SbtPomKeys.mvnLocalRepository(dotM2 => Seq(Resolver.file("dotM2", dotM2))), publishLocalConfiguration in MavenCompile <<= (packagedArtifacts, deliverLocal, ivyLoggingLevel) map { (arts, _, level) => new PublishConfiguration(None, "dotM2", arts, Seq(), level) @@ -247,6 +252,9 @@ object SparkBuild extends PomBuild { /* Enable unidoc only for the root spark project */ enable(Unidoc.settings)(spark) + /* Catalyst ANTLR generation settings */ + enable(Catalyst.settings)(catalyst) + /* Spark SQL Core console settings */ enable(SQL.settings)(sql) @@ -352,11 +360,63 @@ object OldDeps { scalaVersion := "2.10.5", libraryDependencies := Seq("spark-streaming-mqtt", "spark-streaming-zeromq", "spark-streaming-flume", "spark-streaming-kafka", "spark-streaming-twitter", - "spark-streaming", "spark-mllib", "spark-bagel", "spark-graphx", + "spark-streaming", "spark-mllib", "spark-graphx", "spark-core").map(versionArtifact(_).get intransitive()) ) } +object Catalyst { + lazy val settings = Seq( + // ANTLR code-generation step. + // + // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of + // build errors in the current plugin. + // Create Parser from ANTLR grammar files. + sourceGenerators in Compile += Def.task { + val log = streams.value.log + + val grammarFileNames = Seq( + "SparkSqlLexer.g", + "SparkSqlParser.g") + val sourceDir = (sourceDirectory in Compile).value / "antlr3" + val targetDir = (sourceManaged in Compile).value + + // Create default ANTLR Tool. + val antlr = new org.antlr.Tool + + // Setup input and output directories. + antlr.setInputDirectory(sourceDir.getPath) + antlr.setOutputDirectory(targetDir.getPath) + antlr.setForceRelativeOutput(true) + antlr.setMake(true) + + // Add grammar files. + grammarFileNames.flatMap(gFileName => (sourceDir ** gFileName).get).foreach { gFilePath => + val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath + log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath)) + antlr.addGrammarFile(relGFilePath) + // We will set library directory multiple times here. However, only the + // last one has effect. Because the grammar files are located under the same directory, + // We assume there is only one library directory. + antlr.setLibDirectory(gFilePath.getParent) + } + + // Generate the parser. + antlr.process + if (antlr.getNumErrors > 0) { + log.error("ANTLR: Caught %d build errors.".format(antlr.getNumErrors)) + } + + // Return all generated java files. + (targetDir ** "*.java").get.toSeq + }.taskValue, + // Include ANTLR tokens files. + resourceGenerators in Compile += Def.task { + ((sourceManaged in Compile).value ** "*.tokens").get.toSeq + }.taskValue + ) +} + object SQL { lazy val settings = Seq( initialCommands in console := @@ -416,7 +476,6 @@ object Hive { // new query tests. fullClasspath in Test := (fullClasspath in Test).value.filterNot { f => f.toString.contains("jcl-over") } ) - } object Assembly { @@ -556,7 +615,7 @@ object Unidoc { unidocProjectFilter in(ScalaUnidoc, unidoc) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), unidocProjectFilter in(JavaUnidoc, unidoc) := - inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, streamingFlumeSink, yarn, testTags), + inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, streamingFlumeSink, yarn, testTags), // Skip actual catalyst, but include the subproject. // Catalyst is not public API and contains quasiquotes which break scaladoc. diff --git a/project/plugins.sbt b/project/plugins.sbt index 5e23224cf8aa..822a7c4a82d5 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,9 +1,3 @@ -resolvers += Resolver.url("artifactory", url("http://scalasbt.artifactoryonline.com/scalasbt/sbt-plugin-releases"))(Resolver.ivyStylePatterns) - -resolvers += "Typesafe Repository" at "http://repo.typesafe.com/typesafe/releases/" - -resolvers += "sonatype-releases" at "https://oss.sonatype.org/content/repositories/releases/" - addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.11.2") addSbtPlugin("com.typesafe.sbteclipse" % "sbteclipse-plugin" % "2.2.0") @@ -27,3 +21,5 @@ addSbtPlugin("io.spray" % "sbt-revolver" % "0.7.2") libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" + +libraryDependencies += "org.antlr" % "antlr" % "3.5.2" diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 8475dfb1c6ad..d530723ca980 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -65,7 +65,7 @@ def deco(f): # for back compatibility -from pyspark.sql import SQLContext, HiveContext, SchemaRDD, Row +from pyspark.sql import SQLContext, HiveContext, Row __all__ = [ "SparkConf", "SparkContext", "SparkFiles", "RDD", "StorageLevel", "Broadcast", diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 5599b8f3ecd8..265c6a14f1ca 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -273,7 +273,7 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, HasRawPredictionCol, DecisionTreeParams, - TreeClassifierParams, HasCheckpointInterval): + TreeClassifierParams, HasCheckpointInterval, HasSeed): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for classification. @@ -313,12 +313,14 @@ class DecisionTreeClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini"): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", + seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ + seed=None) """ super(DecisionTreeClassifier, self).__init__() self._java_obj = self._new_java_obj( @@ -335,12 +337,13 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre probabilityCol="probability", rawPredictionCol="rawPrediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="gini"): + impurity="gini", seed=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ probabilityCol="probability", rawPredictionCol="rawPrediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \ + seed=None) Sets params for the DecisionTreeClassifier. """ kwargs = self.setParams._input_kwargs diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 7bb8ab94e17d..9189c0222022 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -36,6 +36,14 @@ def clusterCenters(self): """Get the cluster centers, represented as a list of NumPy arrays.""" return [c.toArray() for c in self._call_java("clusterCenters")] + @since("2.0.0") + def computeCost(self, dataset): + """ + Return the K-means cost (sum of squared distances of points to their nearest center) + for this model on the given data. + """ + return self._call_java("computeCost", dataset) + @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed): @@ -53,6 +61,8 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol >>> centers = model.clusterCenters() >>> len(centers) 2 + >>> model.computeCost(df) + 2.000... >>> transformed = model.transform(df).select("features", "prediction") >>> rows = transformed.collect() >>> rows[0].prediction == rows[1].prediction diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index 35c9b776a3d5..92ce96aa3c4d 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -32,12 +32,13 @@ class Param(object): .. versionadded:: 1.3.0 """ - def __init__(self, parent, name, doc): + def __init__(self, parent, name, doc, expectedType=None): if not isinstance(parent, Identifiable): raise TypeError("Parent must be an Identifiable but got type %s." % type(parent)) self.parent = parent.uid self.name = str(name) self.doc = str(doc) + self.expectedType = expectedType def __str__(self): return str(self.parent) + "__" + self.name @@ -247,7 +248,24 @@ def _set(self, **kwargs): Sets user-supplied params. """ for param, value in kwargs.items(): - self._paramMap[getattr(self, param)] = value + p = getattr(self, param) + if p.expectedType is None or type(value) == p.expectedType or value is None: + self._paramMap[getattr(self, param)] = value + else: + try: + # Try and do "safe" conversions that don't lose information + if p.expectedType == float: + self._paramMap[getattr(self, param)] = float(value) + # Python 3 unified long & int + elif p.expectedType == int and type(value).__name__ == 'long': + self._paramMap[getattr(self, param)] = value + else: + raise Exception( + "Provided type {0} incompatible with type {1} for param {2}" + .format(type(value), p.expectedType, p)) + except ValueError: + raise Exception(("Failed to convert {0} to type {1} for param {2}" + .format(type(value), p.expectedType, p))) return self def _setDefault(self, **kwargs): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index 0528dc1e3a6b..82855bc4c75b 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -38,7 +38,7 @@ # python _shared_params_code_gen.py > shared.py -def _gen_param_header(name, doc, defaultValueStr): +def _gen_param_header(name, doc, defaultValueStr, expectedType): """ Generates the header part for shared variables @@ -51,22 +51,26 @@ def _gen_param_header(name, doc, defaultValueStr): """ # a placeholder to make it appear in the generated doc - $name = Param(Params._dummy(), "$name", "$doc") + $name = Param(Params._dummy(), "$name", "$doc", $expectedType) def __init__(self): super(Has$Name, self).__init__() #: param for $doc - self.$name = Param(self, "$name", "$doc")''' + self.$name = Param(self, "$name", "$doc", $expectedType)''' if defaultValueStr is not None: template += ''' self._setDefault($name=$defaultValueStr)''' Name = name[0].upper() + name[1:] + expectedTypeName = str(expectedType) + if expectedType is not None: + expectedTypeName = expectedType.__name__ return template \ .replace("$name", name) \ .replace("$Name", Name) \ .replace("$doc", doc) \ - .replace("$defaultValueStr", str(defaultValueStr)) + .replace("$defaultValueStr", str(defaultValueStr)) \ + .replace("$expectedType", expectedTypeName) def _gen_param_code(name, doc, defaultValueStr): @@ -84,7 +88,7 @@ def set$Name(self, value): """ Sets the value of :py:attr:`$name`. """ - self._paramMap[self.$name] = value + self._set($name=value) return self def get$Name(self): @@ -105,44 +109,45 @@ def get$Name(self): print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n") print("from pyspark.ml.param import Param, Params\n\n") shared = [ - ("maxIter", "max number of iterations (>= 0).", None), - ("regParam", "regularization parameter (>= 0).", None), - ("featuresCol", "features column name.", "'features'"), - ("labelCol", "label column name.", "'label'"), - ("predictionCol", "prediction column name.", "'prediction'"), + ("maxIter", "max number of iterations (>= 0).", None, int), + ("regParam", "regularization parameter (>= 0).", None, float), + ("featuresCol", "features column name.", "'features'", str), + ("labelCol", "label column name.", "'label'", str), + ("predictionCol", "prediction column name.", "'prediction'", str), ("probabilityCol", "Column name for predicted class conditional probabilities. " + "Note: Not all models output well-calibrated probability estimates! These probabilities " + - "should be treated as confidences, not precise probabilities.", "'probability'"), - ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'"), - ("inputCol", "input column name.", None), - ("inputCols", "input column names.", None), - ("outputCol", "output column name.", "self.uid + '__output'"), - ("numFeatures", "number of features.", None), + "should be treated as confidences, not precise probabilities.", "'probability'", str), + ("rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", "'rawPrediction'", + str), + ("inputCol", "input column name.", None, str), + ("inputCols", "input column names.", None, None), + ("outputCol", "output column name.", "self.uid + '__output'", str), + ("numFeatures", "number of features.", None, int), ("checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). " + - "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None), - ("seed", "random seed.", "hash(type(self).__name__)"), - ("tol", "the convergence tolerance for iterative algorithms.", None), - ("stepSize", "Step size to be used for each iteration of optimization.", None), + "E.g. 10 means that the cache will get checkpointed every 10 iterations.", None, int), + ("seed", "random seed.", "hash(type(self).__name__)", int), + ("tol", "the convergence tolerance for iterative algorithms.", None, float), + ("stepSize", "Step size to be used for each iteration of optimization.", None, float), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + "out rows with bad values), or error (which will throw an errror). More options may be " + - "added later.", None), + "added later.", None, str), ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + - "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0"), - ("fitIntercept", "whether to fit an intercept term.", "True"), + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", float), + ("fitIntercept", "whether to fit an intercept term.", "True", bool), ("standardization", "whether to standardize the training features before fitting the " + - "model.", "True"), + "model.", "True", bool), ("thresholds", "Thresholds in multi-class classification to adjust the probability of " + "predicting each class. Array must have length equal to the number of classes, with " + "values >= 0. The class with largest value p/t is predicted, where p is the original " + - "probability of that class and t is the class' threshold.", None), + "probability of that class and t is the class' threshold.", None, None), ("weightCol", "weight column name. If this is not set or empty, we treat " + - "all instance weights as 1.0.", None), + "all instance weights as 1.0.", None, str), ("solver", "the solver algorithm for optimization. If this is not set or empty, " + - "default value is 'auto'.", "'auto'")] + "default value is 'auto'.", "'auto'", str)] code = [] - for name, doc, defaultValueStr in shared: - param_code = _gen_param_header(name, doc, defaultValueStr) + for name, doc, defaultValueStr, expectedType in shared: + param_code = _gen_param_header(name, doc, defaultValueStr, expectedType) code.append(param_code + "\n" + _gen_param_code(name, doc, defaultValueStr)) decisionTreeParams = [ diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 4d960801502c..23f94314844f 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -26,18 +26,18 @@ class HasMaxIter(Params): """ # a placeholder to make it appear in the generated doc - maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).") + maxIter = Param(Params._dummy(), "maxIter", "max number of iterations (>= 0).", int) def __init__(self): super(HasMaxIter, self).__init__() #: param for max number of iterations (>= 0). - self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0).") + self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0).", int) def setMaxIter(self, value): """ Sets the value of :py:attr:`maxIter`. """ - self._paramMap[self.maxIter] = value + self._set(maxIter=value) return self def getMaxIter(self): @@ -53,18 +53,18 @@ class HasRegParam(Params): """ # a placeholder to make it appear in the generated doc - regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).") + regParam = Param(Params._dummy(), "regParam", "regularization parameter (>= 0).", float) def __init__(self): super(HasRegParam, self).__init__() #: param for regularization parameter (>= 0). - self.regParam = Param(self, "regParam", "regularization parameter (>= 0).") + self.regParam = Param(self, "regParam", "regularization parameter (>= 0).", float) def setRegParam(self, value): """ Sets the value of :py:attr:`regParam`. """ - self._paramMap[self.regParam] = value + self._set(regParam=value) return self def getRegParam(self): @@ -80,19 +80,19 @@ class HasFeaturesCol(Params): """ # a placeholder to make it appear in the generated doc - featuresCol = Param(Params._dummy(), "featuresCol", "features column name.") + featuresCol = Param(Params._dummy(), "featuresCol", "features column name.", str) def __init__(self): super(HasFeaturesCol, self).__init__() #: param for features column name. - self.featuresCol = Param(self, "featuresCol", "features column name.") + self.featuresCol = Param(self, "featuresCol", "features column name.", str) self._setDefault(featuresCol='features') def setFeaturesCol(self, value): """ Sets the value of :py:attr:`featuresCol`. """ - self._paramMap[self.featuresCol] = value + self._set(featuresCol=value) return self def getFeaturesCol(self): @@ -108,19 +108,19 @@ class HasLabelCol(Params): """ # a placeholder to make it appear in the generated doc - labelCol = Param(Params._dummy(), "labelCol", "label column name.") + labelCol = Param(Params._dummy(), "labelCol", "label column name.", str) def __init__(self): super(HasLabelCol, self).__init__() #: param for label column name. - self.labelCol = Param(self, "labelCol", "label column name.") + self.labelCol = Param(self, "labelCol", "label column name.", str) self._setDefault(labelCol='label') def setLabelCol(self, value): """ Sets the value of :py:attr:`labelCol`. """ - self._paramMap[self.labelCol] = value + self._set(labelCol=value) return self def getLabelCol(self): @@ -136,19 +136,19 @@ class HasPredictionCol(Params): """ # a placeholder to make it appear in the generated doc - predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.") + predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name.", str) def __init__(self): super(HasPredictionCol, self).__init__() #: param for prediction column name. - self.predictionCol = Param(self, "predictionCol", "prediction column name.") + self.predictionCol = Param(self, "predictionCol", "prediction column name.", str) self._setDefault(predictionCol='prediction') def setPredictionCol(self, value): """ Sets the value of :py:attr:`predictionCol`. """ - self._paramMap[self.predictionCol] = value + self._set(predictionCol=value) return self def getPredictionCol(self): @@ -164,19 +164,19 @@ class HasProbabilityCol(Params): """ # a placeholder to make it appear in the generated doc - probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + probabilityCol = Param(Params._dummy(), "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", str) def __init__(self): super(HasProbabilityCol, self).__init__() #: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities. - self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.") + self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.", str) self._setDefault(probabilityCol='probability') def setProbabilityCol(self, value): """ Sets the value of :py:attr:`probabilityCol`. """ - self._paramMap[self.probabilityCol] = value + self._set(probabilityCol=value) return self def getProbabilityCol(self): @@ -192,19 +192,19 @@ class HasRawPredictionCol(Params): """ # a placeholder to make it appear in the generated doc - rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.") + rawPredictionCol = Param(Params._dummy(), "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", str) def __init__(self): super(HasRawPredictionCol, self).__init__() #: param for raw prediction (a.k.a. confidence) column name. - self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.") + self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name.", str) self._setDefault(rawPredictionCol='rawPrediction') def setRawPredictionCol(self, value): """ Sets the value of :py:attr:`rawPredictionCol`. """ - self._paramMap[self.rawPredictionCol] = value + self._set(rawPredictionCol=value) return self def getRawPredictionCol(self): @@ -220,18 +220,18 @@ class HasInputCol(Params): """ # a placeholder to make it appear in the generated doc - inputCol = Param(Params._dummy(), "inputCol", "input column name.") + inputCol = Param(Params._dummy(), "inputCol", "input column name.", str) def __init__(self): super(HasInputCol, self).__init__() #: param for input column name. - self.inputCol = Param(self, "inputCol", "input column name.") + self.inputCol = Param(self, "inputCol", "input column name.", str) def setInputCol(self, value): """ Sets the value of :py:attr:`inputCol`. """ - self._paramMap[self.inputCol] = value + self._set(inputCol=value) return self def getInputCol(self): @@ -247,18 +247,18 @@ class HasInputCols(Params): """ # a placeholder to make it appear in the generated doc - inputCols = Param(Params._dummy(), "inputCols", "input column names.") + inputCols = Param(Params._dummy(), "inputCols", "input column names.", None) def __init__(self): super(HasInputCols, self).__init__() #: param for input column names. - self.inputCols = Param(self, "inputCols", "input column names.") + self.inputCols = Param(self, "inputCols", "input column names.", None) def setInputCols(self, value): """ Sets the value of :py:attr:`inputCols`. """ - self._paramMap[self.inputCols] = value + self._set(inputCols=value) return self def getInputCols(self): @@ -274,19 +274,19 @@ class HasOutputCol(Params): """ # a placeholder to make it appear in the generated doc - outputCol = Param(Params._dummy(), "outputCol", "output column name.") + outputCol = Param(Params._dummy(), "outputCol", "output column name.", str) def __init__(self): super(HasOutputCol, self).__init__() #: param for output column name. - self.outputCol = Param(self, "outputCol", "output column name.") + self.outputCol = Param(self, "outputCol", "output column name.", str) self._setDefault(outputCol=self.uid + '__output') def setOutputCol(self, value): """ Sets the value of :py:attr:`outputCol`. """ - self._paramMap[self.outputCol] = value + self._set(outputCol=value) return self def getOutputCol(self): @@ -302,18 +302,18 @@ class HasNumFeatures(Params): """ # a placeholder to make it appear in the generated doc - numFeatures = Param(Params._dummy(), "numFeatures", "number of features.") + numFeatures = Param(Params._dummy(), "numFeatures", "number of features.", int) def __init__(self): super(HasNumFeatures, self).__init__() #: param for number of features. - self.numFeatures = Param(self, "numFeatures", "number of features.") + self.numFeatures = Param(self, "numFeatures", "number of features.", int) def setNumFeatures(self, value): """ Sets the value of :py:attr:`numFeatures`. """ - self._paramMap[self.numFeatures] = value + self._set(numFeatures=value) return self def getNumFeatures(self): @@ -329,18 +329,18 @@ class HasCheckpointInterval(Params): """ # a placeholder to make it appear in the generated doc - checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") + checkpointInterval = Param(Params._dummy(), "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", int) def __init__(self): super(HasCheckpointInterval, self).__init__() #: param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. - self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.") + self.checkpointInterval = Param(self, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", int) def setCheckpointInterval(self, value): """ Sets the value of :py:attr:`checkpointInterval`. """ - self._paramMap[self.checkpointInterval] = value + self._set(checkpointInterval=value) return self def getCheckpointInterval(self): @@ -356,19 +356,19 @@ class HasSeed(Params): """ # a placeholder to make it appear in the generated doc - seed = Param(Params._dummy(), "seed", "random seed.") + seed = Param(Params._dummy(), "seed", "random seed.", int) def __init__(self): super(HasSeed, self).__init__() #: param for random seed. - self.seed = Param(self, "seed", "random seed.") + self.seed = Param(self, "seed", "random seed.", int) self._setDefault(seed=hash(type(self).__name__)) def setSeed(self, value): """ Sets the value of :py:attr:`seed`. """ - self._paramMap[self.seed] = value + self._set(seed=value) return self def getSeed(self): @@ -384,18 +384,18 @@ class HasTol(Params): """ # a placeholder to make it appear in the generated doc - tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.") + tol = Param(Params._dummy(), "tol", "the convergence tolerance for iterative algorithms.", float) def __init__(self): super(HasTol, self).__init__() #: param for the convergence tolerance for iterative algorithms. - self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms.") + self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms.", float) def setTol(self, value): """ Sets the value of :py:attr:`tol`. """ - self._paramMap[self.tol] = value + self._set(tol=value) return self def getTol(self): @@ -411,18 +411,18 @@ class HasStepSize(Params): """ # a placeholder to make it appear in the generated doc - stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.") + stepSize = Param(Params._dummy(), "stepSize", "Step size to be used for each iteration of optimization.", float) def __init__(self): super(HasStepSize, self).__init__() #: param for Step size to be used for each iteration of optimization. - self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.") + self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.", float) def setStepSize(self, value): """ Sets the value of :py:attr:`stepSize`. """ - self._paramMap[self.stepSize] = value + self._set(stepSize=value) return self def getStepSize(self): @@ -438,18 +438,18 @@ class HasHandleInvalid(Params): """ # a placeholder to make it appear in the generated doc - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", str) def __init__(self): super(HasHandleInvalid, self).__init__() #: param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. - self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.") + self.handleInvalid = Param(self, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", str) def setHandleInvalid(self, value): """ Sets the value of :py:attr:`handleInvalid`. """ - self._paramMap[self.handleInvalid] = value + self._set(handleInvalid=value) return self def getHandleInvalid(self): @@ -465,19 +465,19 @@ class HasElasticNetParam(Params): """ # a placeholder to make it appear in the generated doc - elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + elasticNetParam = Param(Params._dummy(), "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", float) def __init__(self): super(HasElasticNetParam, self).__init__() #: param for the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty. - self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.") + self.elasticNetParam = Param(self, "elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", float) self._setDefault(elasticNetParam=0.0) def setElasticNetParam(self, value): """ Sets the value of :py:attr:`elasticNetParam`. """ - self._paramMap[self.elasticNetParam] = value + self._set(elasticNetParam=value) return self def getElasticNetParam(self): @@ -493,19 +493,19 @@ class HasFitIntercept(Params): """ # a placeholder to make it appear in the generated doc - fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.") + fitIntercept = Param(Params._dummy(), "fitIntercept", "whether to fit an intercept term.", bool) def __init__(self): super(HasFitIntercept, self).__init__() #: param for whether to fit an intercept term. - self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.") + self.fitIntercept = Param(self, "fitIntercept", "whether to fit an intercept term.", bool) self._setDefault(fitIntercept=True) def setFitIntercept(self, value): """ Sets the value of :py:attr:`fitIntercept`. """ - self._paramMap[self.fitIntercept] = value + self._set(fitIntercept=value) return self def getFitIntercept(self): @@ -521,19 +521,19 @@ class HasStandardization(Params): """ # a placeholder to make it appear in the generated doc - standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.") + standardization = Param(Params._dummy(), "standardization", "whether to standardize the training features before fitting the model.", bool) def __init__(self): super(HasStandardization, self).__init__() #: param for whether to standardize the training features before fitting the model. - self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.") + self.standardization = Param(self, "standardization", "whether to standardize the training features before fitting the model.", bool) self._setDefault(standardization=True) def setStandardization(self, value): """ Sets the value of :py:attr:`standardization`. """ - self._paramMap[self.standardization] = value + self._set(standardization=value) return self def getStandardization(self): @@ -549,18 +549,18 @@ class HasThresholds(Params): """ # a placeholder to make it appear in the generated doc - thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + thresholds = Param(Params._dummy(), "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", None) def __init__(self): super(HasThresholds, self).__init__() #: param for Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold. - self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.") + self.thresholds = Param(self, "thresholds", "Thresholds in multi-class classification to adjust the probability of predicting each class. Array must have length equal to the number of classes, with values >= 0. The class with largest value p/t is predicted, where p is the original probability of that class and t is the class' threshold.", None) def setThresholds(self, value): """ Sets the value of :py:attr:`thresholds`. """ - self._paramMap[self.thresholds] = value + self._set(thresholds=value) return self def getThresholds(self): @@ -576,18 +576,18 @@ class HasWeightCol(Params): """ # a placeholder to make it appear in the generated doc - weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + weightCol = Param(Params._dummy(), "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", str) def __init__(self): super(HasWeightCol, self).__init__() #: param for weight column name. If this is not set or empty, we treat all instance weights as 1.0. - self.weightCol = Param(self, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.") + self.weightCol = Param(self, "weightCol", "weight column name. If this is not set or empty, we treat all instance weights as 1.0.", str) def setWeightCol(self, value): """ Sets the value of :py:attr:`weightCol`. """ - self._paramMap[self.weightCol] = value + self._set(weightCol=value) return self def getWeightCol(self): @@ -603,19 +603,19 @@ class HasSolver(Params): """ # a placeholder to make it appear in the generated doc - solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + solver = Param(Params._dummy(), "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", str) def __init__(self): super(HasSolver, self).__init__() #: param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'. - self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.") + self.solver = Param(self, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.", str) self._setDefault(solver='auto') def setSolver(self, value): """ Sets the value of :py:attr:`solver`. """ - self._paramMap[self.solver] = value + self._set(solver=value) return self def getSolver(self): @@ -658,7 +658,7 @@ def setMaxDepth(self, value): """ Sets the value of :py:attr:`maxDepth`. """ - self._paramMap[self.maxDepth] = value + self._set(maxDepth=value) return self def getMaxDepth(self): @@ -671,7 +671,7 @@ def setMaxBins(self, value): """ Sets the value of :py:attr:`maxBins`. """ - self._paramMap[self.maxBins] = value + self._set(maxBins=value) return self def getMaxBins(self): @@ -684,7 +684,7 @@ def setMinInstancesPerNode(self, value): """ Sets the value of :py:attr:`minInstancesPerNode`. """ - self._paramMap[self.minInstancesPerNode] = value + self._set(minInstancesPerNode=value) return self def getMinInstancesPerNode(self): @@ -697,7 +697,7 @@ def setMinInfoGain(self, value): """ Sets the value of :py:attr:`minInfoGain`. """ - self._paramMap[self.minInfoGain] = value + self._set(minInfoGain=value) return self def getMinInfoGain(self): @@ -710,7 +710,7 @@ def setMaxMemoryInMB(self, value): """ Sets the value of :py:attr:`maxMemoryInMB`. """ - self._paramMap[self.maxMemoryInMB] = value + self._set(maxMemoryInMB=value) return self def getMaxMemoryInMB(self): @@ -723,7 +723,7 @@ def setCacheNodeIds(self, value): """ Sets the value of :py:attr:`cacheNodeIds`. """ - self._paramMap[self.cacheNodeIds] = value + self._set(cacheNodeIds=value) return self def getCacheNodeIds(self): diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 4475451edb78..9f5f6ac8fa4e 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -86,7 +86,7 @@ class Transformer(Params): @abstractmethod def _transform(self, dataset): """ - Transforms the input dataset with optional parameters. + Transforms the input dataset. :param dataset: input dataset, which is an instance of :py:class:`pyspark.sql.DataFrame` diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index a0bb8ceed886..401bac0223eb 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -386,7 +386,8 @@ class GBTParams(TreeEnsembleParams): @inherit_doc class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval): + DecisionTreeParams, TreeRegressorParams, HasCheckpointInterval, + HasSeed): """ `http://en.wikipedia.org/wiki/Decision_tree_learning Decision tree` learning algorithm for regression. @@ -415,11 +416,13 @@ class DecisionTreeRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance"): + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", + seed=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", seed=None) """ super(DecisionTreeRegressor, self).__init__() self._java_obj = self._new_java_obj( @@ -435,11 +438,12 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, - impurity="variance"): + impurity="variance", seed=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \ - maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") + maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \ + impurity="variance", seed=None) Sets params for the DecisionTreeRegressor. """ kwargs = self.setParams._input_kwargs diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 7a16cf52cccb..4eb17bfdcca9 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -37,6 +37,7 @@ from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand +from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed @@ -92,6 +93,27 @@ class MockModel(MockTransformer, Model, HasFake): pass +class ParamTypeConversionTests(PySparkTestCase): + """ + Test that param type conversion happens. + """ + + def test_int_to_float(self): + from pyspark.mllib.linalg import Vectors + df = self.sc.parallelize([ + Row(label=1.0, weight=2.0, features=Vectors.dense(1.0))]).toDF() + lr = LogisticRegression(elasticNetParam=0) + lr.fit(df) + lr.setElasticNetParam(0) + lr.fit(df) + + def test_invalid_to_float(self): + from pyspark.mllib.linalg import Vectors + self.assertRaises(Exception, lambda: LogisticRegression(elasticNetParam="happy")) + lr = LogisticRegression(elasticNetParam=0) + self.assertRaises(Exception, lambda: lr.setElasticNetParam("panda")) + + class PipelineTests(PySparkTestCase): def test_pipeline(self): diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 4bcb4aaec89d..dd1d4b076edd 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -15,7 +15,7 @@ # limitations under the License. # -from abc import ABCMeta +from abc import ABCMeta, abstractmethod from pyspark import SparkContext from pyspark.sql import DataFrame @@ -110,6 +110,7 @@ class JavaEstimator(Estimator, JavaWrapper): __metaclass__ = ABCMeta + @abstractmethod def _create_model(self, java_model): """ Creates a model from the input Java model reference. diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py index c9e6f1dec6bf..d22a7f4c3b16 100644 --- a/python/pyspark/mllib/clustering.py +++ b/python/pyspark/mllib/clustering.py @@ -173,7 +173,7 @@ def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||" """Train a k-means clustering model.""" if runs != 1: warnings.warn( - "Support for runs is deprecated in 1.6.0. This param will have no effect in 1.7.0.") + "Support for runs is deprecated in 1.6.0. This param will have no effect in 2.0.0.") clusterInitialModel = [] if initialModel is not None: if not isinstance(initialModel, KMeansModel): @@ -346,7 +346,7 @@ def train(cls, rdd, k, convergenceTol=1e-3, maxIterations=100, seed=None, initia if initialModel.k != k: raise Exception("Mismatched cluster count, initialModel.k = %s, however k = %s" % (initialModel.k, k)) - initialModelWeights = initialModel.weights + initialModelWeights = list(initialModel.weights) initialModelMu = [initialModel.gaussians[i].mu for i in range(initialModel.k)] initialModelSigma = [initialModel.gaussians[i].sigma for i in range(initialModel.k)] java_model = callMLlibFunc("trainGaussianMixtureModel", rdd.map(_convert_to_vector), diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index acd7ec57d69d..612935352575 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -172,6 +172,38 @@ def setWithStd(self, withStd): self.call("setWithStd", withStd) return self + @property + @since('2.0.0') + def withStd(self): + """ + Returns if the model scales the data to unit standard deviation. + """ + return self.call("withStd") + + @property + @since('2.0.0') + def withMean(self): + """ + Returns if the model centers the data before scaling. + """ + return self.call("withMean") + + @property + @since('2.0.0') + def std(self): + """ + Return the column standard deviation values. + """ + return self.call("std") + + @property + @since('2.0.0') + def mean(self): + """ + Return the column mean values. + """ + return self.call("mean") + class StandardScaler(object): """ @@ -196,6 +228,14 @@ class StandardScaler(object): >>> for r in result.collect(): r DenseVector([-0.7071, 0.7071, -0.7071]) DenseVector([0.7071, -0.7071, 0.7071]) + >>> int(model.std[0]) + 4 + >>> int(model.mean[0]*10) + 9 + >>> model.withStd + True + >>> model.withMean + True .. versionadded:: 1.2.0 """ diff --git a/python/pyspark/mllib/linalg/__init__.py b/python/pyspark/mllib/linalg/__init__.py index ae9ce5845090..131b855bf99c 100644 --- a/python/pyspark/mllib/linalg/__init__.py +++ b/python/pyspark/mllib/linalg/__init__.py @@ -528,7 +528,9 @@ def __init__(self, size, *args): assert len(self.indices) == len(self.values), "index and value arrays not same length" for i in xrange(len(self.indices) - 1): if self.indices[i] >= self.indices[i + 1]: - raise TypeError("indices array must be sorted") + raise TypeError( + "Indices %s and %s are not strictly increasing" + % (self.indices[i], self.indices[i + 1])) def numNonzeros(self): """ diff --git a/python/pyspark/mllib/linalg/distributed.py b/python/pyspark/mllib/linalg/distributed.py index 0e7605078863..e1f022187d50 100644 --- a/python/pyspark/mllib/linalg/distributed.py +++ b/python/pyspark/mllib/linalg/distributed.py @@ -297,6 +297,20 @@ def numCols(self): """ return self._java_matrix_wrapper.call("numCols") + def columnSimilarities(self): + """ + Compute all cosine similarities between columns. + + >>> rows = sc.parallelize([IndexedRow(0, [1, 2, 3]), + ... IndexedRow(6, [4, 5, 6])]) + >>> mat = IndexedRowMatrix(rows) + >>> cs = mat.columnSimilarities() + >>> print(cs.numCols()) + 3 + """ + java_coordinate_matrix = self._java_matrix_wrapper.call("columnSimilarities") + return CoordinateMatrix(java_coordinate_matrix) + def toRowMatrix(self): """ Convert this matrix to a RowMatrix. diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index f8e8e0e0adbe..3436a28b2974 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -54,6 +54,7 @@ from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT from pyspark.mllib.classification import StreamingLogisticRegressionWithSGD +from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics @@ -474,6 +475,18 @@ def test_gmm_deterministic(self): for c1, c2 in zip(clusters1.weights, clusters2.weights): self.assertEqual(round(c1, 7), round(c2, 7)) + def test_gmm_with_initial_model(self): + from pyspark.mllib.clustering import GaussianMixture + data = self.sc.parallelize([ + (-10, -5), (-9, -4), (10, 5), (9, 4) + ]) + + gmm1 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63) + gmm2 = GaussianMixture.train(data, 2, convergenceTol=0.001, + maxIterations=10, seed=63, initialModel=gmm1) + self.assertAlmostEqual((gmm1.weights - gmm2.weights).sum(), 0.0) + def test_classification(self): from pyspark.mllib.classification import LogisticRegressionWithSGD, SVMWithSGD, NaiveBayes from pyspark.mllib.tree import DecisionTree, DecisionTreeModel, RandomForest,\ @@ -1539,6 +1552,22 @@ def test_load_vectors(self): shutil.rmtree(load_vectors_path) +class ALSTests(MLlibTestCase): + + def test_als_ratings_serialize(self): + r = Rating(7, 1123, 3.14) + jr = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(r))) + nr = ser.loads(bytes(self.sc._jvm.SerDe.dumps(jr))) + self.assertEqual(r.user, nr.user) + self.assertEqual(r.product, nr.product) + self.assertAlmostEqual(r.rating, nr.rating, 2) + + def test_als_ratings_id_long_error(self): + r = Rating(1205640308657491975, 50233468418, 1.0) + # rating user id exceeds max int value, should fail when pickled + self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r))) + + if __name__ == "__main__": if not _have_scipy: print("NOTE: Skipping SciPy tests as it does not seem to be installed") diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 00bb9a62e904..a019c0586254 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -220,18 +220,18 @@ def context(self): def cache(self): """ - Persist this RDD with the default storage level (C{MEMORY_ONLY_SER}). + Persist this RDD with the default storage level (C{MEMORY_ONLY}). """ self.is_cached = True - self.persist(StorageLevel.MEMORY_ONLY_SER) + self.persist(StorageLevel.MEMORY_ONLY) return self - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): """ Set this RDD's storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + If no storage level is specified defaults to (C{MEMORY_ONLY}). >>> rdd = sc.parallelize(["b", "a", "c"]) >>> rdd.persist().is_cached diff --git a/python/pyspark/sql/__init__.py b/python/pyspark/sql/__init__.py index 98eaf52866d2..0b06c8339f50 100644 --- a/python/pyspark/sql/__init__.py +++ b/python/pyspark/sql/__init__.py @@ -47,7 +47,7 @@ from pyspark.sql.types import Row from pyspark.sql.context import SQLContext, HiveContext from pyspark.sql.column import Column -from pyspark.sql.dataframe import DataFrame, SchemaRDD, DataFrameNaFunctions, DataFrameStatFunctions +from pyspark.sql.dataframe import DataFrame, DataFrameNaFunctions, DataFrameStatFunctions from pyspark.sql.group import GroupedData from pyspark.sql.readwriter import DataFrameReader, DataFrameWriter from pyspark.sql.window import Window, WindowSpec diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index 81fd4e782628..900def59d23a 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -27,8 +27,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.types import * -__all__ = ["DataFrame", "Column", "SchemaRDD", "DataFrameNaFunctions", - "DataFrameStatFunctions"] +__all__ = ["DataFrame", "Column", "DataFrameNaFunctions", "DataFrameStatFunctions"] def _create_column_from_literal(literal): @@ -272,23 +271,6 @@ def substr(self, startPos, length): __getslice__ = substr - @ignore_unicode_prefix - @since(1.3) - def inSet(self, *cols): - """ - A boolean expression that is evaluated to true if the value of this - expression is contained by the evaluated values of the arguments. - - >>> df[df.name.inSet("Bob", "Mike")].collect() - [Row(age=5, name=u'Bob')] - >>> df[df.age.inSet([1, 2, 3])].collect() - [Row(age=2, name=u'Alice')] - - .. note:: Deprecated in 1.5, use :func:`Column.isin` instead. - """ - warnings.warn("inSet is deprecated. Use isin() instead.") - return self.isin(*cols) - @ignore_unicode_prefix @since(1.5) def isin(self, *cols): diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b05aa2f5c4cd..91e27cf16e43 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -18,6 +18,7 @@ import sys import warnings import json +from functools import reduce if sys.version >= '3': basestring = unicode = str @@ -236,14 +237,9 @@ def _inferSchemaFromList(self, data): if type(first) is dict: warnings.warn("inferring schema from dict is deprecated," "please use pyspark.sql.Row instead") - schema = _infer_schema(first) + schema = reduce(_merge_type, map(_infer_schema, data)) if _has_nulltype(schema): - for r in data: - schema = _merge_type(schema, _infer_schema(r)) - if not _has_nulltype(schema): - break - else: - raise ValueError("Some of types cannot be determined after inferring") + raise ValueError("Some of types cannot be determined after inferring") return schema def _inferSchema(self, rdd, samplingRatio=None): @@ -278,33 +274,6 @@ def _inferSchema(self, rdd, samplingRatio=None): schema = rdd.map(_infer_schema).reduce(_merge_type) return schema - @ignore_unicode_prefix - def inferSchema(self, rdd, samplingRatio=None): - """ - .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. - """ - warnings.warn("inferSchema is deprecated, please use createDataFrame instead.") - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - return self.createDataFrame(rdd, None, samplingRatio) - - @ignore_unicode_prefix - def applySchema(self, rdd, schema): - """ - .. note:: Deprecated in 1.3, use :func:`createDataFrame` instead. - """ - warnings.warn("applySchema is deprecated, please use createDataFrame instead") - - if isinstance(rdd, DataFrame): - raise TypeError("Cannot apply schema to DataFrame") - - if not isinstance(schema, StructType): - raise TypeError("schema should be StructType, but got %s" % type(schema)) - - return self.createDataFrame(rdd, schema) - def _createFromRDD(self, rdd, schema, samplingRatio): """ Create an RDD for DataFrame from an existing RDD, returns the RDD and schema. @@ -454,90 +423,6 @@ def dropTempTable(self, tableName): """ self._ssql_ctx.dropTempTable(tableName) - def parquetFile(self, *paths): - """Loads a Parquet file, returning the result as a :class:`DataFrame`. - - .. note:: Deprecated in 1.4, use :func:`DataFrameReader.parquet` instead. - - >>> sqlContext.parquetFile('python/test_support/sql/parquet_partitioned').dtypes - [('name', 'string'), ('year', 'int'), ('month', 'int'), ('day', 'int')] - """ - warnings.warn("parquetFile is deprecated. Use read.parquet() instead.") - gateway = self._sc._gateway - jpaths = gateway.new_array(gateway.jvm.java.lang.String, len(paths)) - for i in range(0, len(paths)): - jpaths[i] = paths[i] - jdf = self._ssql_ctx.parquetFile(jpaths) - return DataFrame(jdf, self) - - def jsonFile(self, path, schema=None, samplingRatio=1.0): - """Loads a text file storing one JSON object per line as a :class:`DataFrame`. - - .. note:: Deprecated in 1.4, use :func:`DataFrameReader.json` instead. - - >>> sqlContext.jsonFile('python/test_support/sql/people.json').dtypes - [('age', 'bigint'), ('name', 'string')] - """ - warnings.warn("jsonFile is deprecated. Use read.json() instead.") - if schema is None: - df = self._ssql_ctx.jsonFile(path, samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonFile(path, scala_datatype) - return DataFrame(df, self) - - @ignore_unicode_prefix - @since(1.0) - def jsonRDD(self, rdd, schema=None, samplingRatio=1.0): - """Loads an RDD storing one JSON object per string as a :class:`DataFrame`. - - If the schema is provided, applies the given schema to this JSON dataset. - Otherwise, it samples the dataset with ratio ``samplingRatio`` to determine the schema. - - >>> df1 = sqlContext.jsonRDD(json) - >>> df1.first() - Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) - - >>> df2 = sqlContext.jsonRDD(json, df1.schema) - >>> df2.first() - Row(field1=1, field2=u'row1', field3=Row(field4=11, field5=None), field6=None) - - >>> from pyspark.sql.types import * - >>> schema = StructType([ - ... StructField("field2", StringType()), - ... StructField("field3", - ... StructType([StructField("field5", ArrayType(IntegerType()))])) - ... ]) - >>> df3 = sqlContext.jsonRDD(json, schema) - >>> df3.first() - Row(field2=u'row1', field3=Row(field5=None)) - """ - - def func(iterator): - for x in iterator: - if not isinstance(x, basestring): - x = unicode(x) - if isinstance(x, unicode): - x = x.encode("utf-8") - yield x - keyed = rdd.mapPartitions(func) - keyed._bypass_serializer = True - jrdd = keyed._jrdd.map(self._jvm.BytesToString()) - if schema is None: - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio) - else: - scala_datatype = self._ssql_ctx.parseDataType(schema.json()) - df = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype) - return DataFrame(df, self) - - def load(self, path=None, source=None, schema=None, **options): - """Returns the dataset in a data source as a :class:`DataFrame`. - - .. note:: Deprecated in 1.4, use :func:`DataFrameReader.load` instead. - """ - warnings.warn("load is deprecated. Use read.load() instead.") - return self.read.load(path, source, schema, **options) - @since(1.3) def createExternalTable(self, tableName, path=None, source=None, schema=None, **options): """Creates an external table based on the dataset in a data source. diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index 78ab475eb466..a7bc288e3886 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -36,7 +36,7 @@ from pyspark.sql.readwriter import DataFrameWriter from pyspark.sql.types import * -__all__ = ["DataFrame", "SchemaRDD", "DataFrameNaFunctions", "DataFrameStatFunctions"] +__all__ = ["DataFrame", "DataFrameNaFunctions", "DataFrameStatFunctions"] class DataFrame(object): @@ -113,14 +113,6 @@ def toJSON(self, use_unicode=True): rdd = self._jdf.toJSON() return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode)) - def saveAsParquetFile(self, path): - """Saves the contents as a Parquet file, preserving the schema. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.parquet` instead. - """ - warnings.warn("saveAsParquetFile is deprecated. Use write.parquet() instead.") - self._jdf.saveAsParquetFile(path) - @since(1.3) def registerTempTable(self, name): """Registers this RDD as a temporary table using the given name. @@ -135,38 +127,6 @@ def registerTempTable(self, name): """ self._jdf.registerTempTable(name) - def registerAsTable(self, name): - """ - .. note:: Deprecated in 1.4, use :func:`registerTempTable` instead. - """ - warnings.warn("Use registerTempTable instead of registerAsTable.") - self.registerTempTable(name) - - def insertInto(self, tableName, overwrite=False): - """Inserts the contents of this :class:`DataFrame` into the specified table. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.insertInto` instead. - """ - warnings.warn("insertInto is deprecated. Use write.insertInto() instead.") - self.write.insertInto(tableName, overwrite) - - def saveAsTable(self, tableName, source=None, mode="error", **options): - """Saves the contents of this :class:`DataFrame` to a data source as a table. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.saveAsTable` instead. - """ - warnings.warn("insertInto is deprecated. Use write.saveAsTable() instead.") - self.write.saveAsTable(tableName, source, mode, **options) - - @since(1.3) - def save(self, path=None, source=None, mode="error", **options): - """Saves the contents of the :class:`DataFrame` to a data source. - - .. note:: Deprecated in 1.4, use :func:`DataFrameWriter.save` instead. - """ - warnings.warn("insertInto is deprecated. Use write.save() instead.") - return self.write.save(path, source, mode, **options) - @property @since(1.4) def write(self): @@ -371,18 +331,18 @@ def foreachPartition(self, f): @since(1.3) def cache(self): - """ Persists with the default storage level (C{MEMORY_ONLY_SER}). + """ Persists with the default storage level (C{MEMORY_ONLY}). """ self.is_cached = True self._jdf.cache() return self @since(1.3) - def persist(self, storageLevel=StorageLevel.MEMORY_ONLY_SER): + def persist(self, storageLevel=StorageLevel.MEMORY_ONLY): """Sets the storage level to persist its values across operations after the first time it is computed. This can only be used to assign a new storage level if the RDD does not have a storage level set yet. - If no storage level is specified defaults to (C{MEMORY_ONLY_SER}). + If no storage level is specified defaults to (C{MEMORY_ONLY}). """ self.is_cached = True javaStorageLevel = self._sc._getJavaStorageLevel(storageLevel) @@ -608,13 +568,16 @@ def join(self, other, on=None, how=None): :param on: a string for join column name, a list of column names, , a join expression (Column) or a list of Columns. If `on` is a string or a list of string indicating the name of the join column(s), - the column(s) must exist on both sides, and this performs an inner equi-join. + the column(s) must exist on both sides, and this performs an equi-join. :param how: str, default 'inner'. - One of `inner`, `outer`, `left_outer`, `right_outer`, `semijoin`. + One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`. >>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect() [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + >>> df.join(df2, 'name', 'outer').select('name', 'height').collect() + [Row(name=u'Tom', height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)] + >>> cond = [df.name == df3.name, df.age == df3.age] >>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect() [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)] @@ -1385,12 +1348,6 @@ def toPandas(self): drop_duplicates = dropDuplicates -# Having SchemaRDD for backward compatibility (for docs) -class SchemaRDD(DataFrame): - """SchemaRDD is deprecated, please use :class:`DataFrame`. - """ - - def _to_scala_map(sc, jm): """ Convert a dict into a JVM Map. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 90625949f747..b0390cb9942e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -149,12 +149,8 @@ def _(): } _window_functions = { - 'rowNumber': - """.. note:: Deprecated in 1.6, use row_number instead.""", 'row_number': """returns a sequential number starting at 1 within a window partition.""", - 'denseRank': - """.. note:: Deprecated in 1.6, use dense_rank instead.""", 'dense_rank': """returns the rank of rows within a window partition, without any gaps. @@ -171,13 +167,9 @@ def _(): place and that the next person came in third. This is equivalent to the RANK function in SQL.""", - 'cumeDist': - """.. note:: Deprecated in 1.6, use cume_dist instead.""", 'cume_dist': """returns the cumulative distribution of values within a window partition, i.e. the fraction of rows that are below the current row.""", - 'percentRank': - """.. note:: Deprecated in 1.6, use percent_rank instead.""", 'percent_rank': """returns the relative rank (i.e. percentile) of rows within a window partition.""", } @@ -318,14 +310,6 @@ def isnull(col): return Column(sc._jvm.functions.isnull(_to_java_column(col))) -@since(1.4) -def monotonicallyIncreasingId(): - """ - .. note:: Deprecated in 1.6, use monotonically_increasing_id instead. - """ - return monotonically_increasing_id() - - @since(1.6) def monotonically_increasing_id(): """A column that generates monotonically increasing 64-bit integers. @@ -434,14 +418,6 @@ def shiftRightUnsigned(col, numBits): return Column(jc) -@since(1.4) -def sparkPartitionId(): - """ - .. note:: Deprecated in 1.6, use spark_partition_id instead. - """ - return spark_partition_id() - - @since(1.6) def spark_partition_id(): """A column for partition ID of the Spark task. @@ -1042,6 +1018,18 @@ def sha2(col, numBits): return Column(jc) +@since(2.0) +def hash(*cols): + """Calculates the hash code of given columns, and returns the result as a int column. + + >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect() + [Row(hash=1358996357)] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.hash(_to_seq(sc, cols, _to_java_column)) + return Column(jc) + + # ---------------------- String/Binary functions ------------------------------ _string_functions = { @@ -1053,7 +1041,7 @@ def sha2(col, numBits): 'lower': 'Converts a string column to lower case.', 'upper': 'Converts a string column to upper case.', 'reverse': 'Reverses the string column and returns it as a new string column.', - 'ltrim': 'Trim the spaces from right end for the specified string value.', + 'ltrim': 'Trim the spaces from left end for the specified string value.', 'rtrim': 'Trim the spaces from right end for the specified string value.', 'trim': 'Trim the spaces from both ends for the specified string column.', } diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index 2e75f0c8a182..0b20022b14b8 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -130,11 +130,9 @@ def load(self, path=None, format=None, schema=None, **options): self.schema(schema) self.options(**options) if path is not None: - if type(path) == list: - return self._df( - self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) - else: - return self._df(self._jreader.load(path)) + if type(path) != list: + path = [path] + return self._df(self._jreader.load(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) else: return self._df(self._jreader.load()) @@ -160,6 +158,8 @@ def json(self, path, schema=None): quotes * ``allowNumericLeadingZeros`` (default ``false``): allows leading zeros in numbers \ (e.g. 00012) + * ``allowBackslashEscapingAnyCharacter`` (default ``false``): allows accepting quoting \ + of all character using backslash quoting mechanism >>> df1 = sqlContext.read.json('python/test_support/sql/people.json') >>> df1.dtypes @@ -177,7 +177,17 @@ def json(self, path, schema=None): elif type(path) == list: return self._df(self._jreader.json(self._sqlContext._sc._jvm.PythonUtils.toSeq(path))) elif isinstance(path, RDD): - return self._df(self._jreader.json(path._jrdd)) + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = path.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._sqlContext._jvm.BytesToString()) + return self._df(self._jreader.json(jrdd)) else: raise TypeError("path can be only string or RDD") @@ -207,7 +217,7 @@ def parquet(self, *paths): @ignore_unicode_prefix @since(1.6) def text(self, paths): - """Loads a text file and returns a [[DataFrame]] with a single string column named "text". + """Loads a text file and returns a [[DataFrame]] with a single string column named "value". Each line in the text file is a new row in the resulting DataFrame. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 9f5f7cfdf7a6..e396cf41f2f7 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -326,7 +326,7 @@ def test_broadcast_in_udf(self): def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) + df = self.sqlCtx.read.json(rdd) df.count() df.collect() df.schema @@ -345,7 +345,7 @@ def test_basic_functions(self): df.collect() def test_apply_schema_to_row(self): - df = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + df = self.sqlCtx.read.json(self.sc.parallelize(["""{"a":2}"""])) df2 = self.sqlCtx.createDataFrame(df.map(lambda x: x), df.schema) self.assertEqual(df.collect(), df2.collect()) @@ -353,6 +353,17 @@ def test_apply_schema_to_row(self): df3 = self.sqlCtx.createDataFrame(rdd, df.schema) self.assertEqual(10, df3.count()) + def test_infer_schema_to_local(self): + input = [{"a": 1}, {"b": "coffee"}] + rdd = self.sc.parallelize(input) + df = self.sqlCtx.createDataFrame(input) + df2 = self.sqlCtx.createDataFrame(rdd, samplingRatio=1.0) + self.assertEqual(df.schema, df2.schema) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x, b=None)) + df3 = self.sqlCtx.createDataFrame(rdd, df.schema) + self.assertEqual(10, df3.count()) + def test_serialize_nested_array_and_map(self): d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})] rdd = self.sc.parallelize(d) @@ -397,12 +408,12 @@ def test_infer_nested_schema(self): NestedRow = Row("f1", "f2") nestedRdd1 = self.sc.parallelize([NestedRow([1, 2], {"row1": 1.0}), NestedRow([2, 3], {"row2": 2.0})]) - df = self.sqlCtx.inferSchema(nestedRdd1) + df = self.sqlCtx.createDataFrame(nestedRdd1) self.assertEqual(Row(f1=[1, 2], f2={u'row1': 1.0}), df.collect()[0]) nestedRdd2 = self.sc.parallelize([NestedRow([[1, 2], [2, 3]], [1, 2]), NestedRow([[2, 3], [3, 4]], [2, 3])]) - df = self.sqlCtx.inferSchema(nestedRdd2) + df = self.sqlCtx.createDataFrame(nestedRdd2) self.assertEqual(Row(f1=[[1, 2], [2, 3]], f2=[1, 2]), df.collect()[0]) from collections import namedtuple @@ -410,7 +421,7 @@ def test_infer_nested_schema(self): rdd = self.sc.parallelize([CustomRow(field1=1, field2="row1"), CustomRow(field1=2, field2="row2"), CustomRow(field1=3, field2="row3")]) - df = self.sqlCtx.inferSchema(rdd) + df = self.sqlCtx.createDataFrame(rdd) self.assertEqual(Row(field1=1, field2=u'row1'), df.first()) def test_create_dataframe_from_objects(self): @@ -570,14 +581,14 @@ def test_parquet_with_udt(self): df0 = self.sqlCtx.createDataFrame([row]) output_dir = os.path.join(self.tempdir.name, "labeled_point") df0.write.parquet(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) + df1 = self.sqlCtx.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, ExamplePoint(1.0, 2.0)) row = Row(label=1.0, point=PythonOnlyPoint(1.0, 2.0)) df0 = self.sqlCtx.createDataFrame([row]) df0.write.parquet(output_dir, mode='overwrite') - df1 = self.sqlCtx.parquetFile(output_dir) + df1 = self.sqlCtx.read.parquet(output_dir) point = df1.head().point self.assertEqual(point, PythonOnlyPoint(1.0, 2.0)) @@ -752,7 +763,7 @@ def test_save_and_load(self): defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.load(path=tmpPath) + actual = self.sqlCtx.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) @@ -785,7 +796,7 @@ def test_save_and_load_builder(self): defaultDataSourceName = self.sqlCtx.getConf("spark.sql.sources.default", "org.apache.spark.sql.parquet") self.sqlCtx.sql("SET spark.sql.sources.default=org.apache.spark.sql.json") - actual = self.sqlCtx.load(path=tmpPath) + actual = self.sqlCtx.read.load(path=tmpPath) self.assertEqual(sorted(df.collect()), sorted(actual.collect())) self.sqlCtx.sql("SET spark.sql.sources.default=" + defaultDataSourceName) @@ -794,7 +805,7 @@ def test_save_and_load_builder(self): def test_help_command(self): # Regression test for SPARK-5464 rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) - df = self.sqlCtx.jsonRDD(rdd) + df = self.sqlCtx.read.json(rdd) # render_doc() reproduces the help() exception without printing output pydoc.render_doc(df) pydoc.render_doc(df.foo) @@ -842,8 +853,8 @@ def test_infer_long_type(self): # this saving as Parquet caused issues as well. output_dir = os.path.join(self.tempdir.name, "infer_long_type") - df.saveAsParquetFile(output_dir) - df1 = self.sqlCtx.parquetFile(output_dir) + df.write.parquet(output_dir) + df1 = self.sqlCtx.read.parquet(output_dir) self.assertEqual('a', df1.first().f1) self.assertEqual(100000000000000, df1.first().f2) @@ -1194,9 +1205,9 @@ def test_window_functions(self): F.max("key").over(w.rowsBetween(0, 1)), F.min("key").over(w.rowsBetween(0, 1)), F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), - F.rowNumber().over(w), + F.row_number().over(w), F.rank().over(w), - F.denseRank().over(w), + F.dense_rank().over(w), F.ntile(2).over(w)) rs = sorted(sel.collect()) expected = [ @@ -1216,9 +1227,9 @@ def test_window_functions_without_partitionBy(self): F.max("key").over(w.rowsBetween(0, 1)), F.min("key").over(w.rowsBetween(0, 1)), F.count("key").over(w.rowsBetween(float('-inf'), float('inf'))), - F.rowNumber().over(w), + F.row_number().over(w), F.rank().over(w), - F.denseRank().over(w), + F.dense_rank().over(w), F.ntile(2).over(w)) rs = sorted(sel.collect()) expected = [ diff --git a/python/pyspark/storagelevel.py b/python/pyspark/storagelevel.py index 676aa0f7144a..d4f184a85d76 100644 --- a/python/pyspark/storagelevel.py +++ b/python/pyspark/storagelevel.py @@ -23,8 +23,10 @@ class StorageLevel(object): """ Flags for controlling the storage of an RDD. Each StorageLevel records whether to use memory, whether to drop the RDD to disk if it falls out of memory, whether to keep the data in memory - in a serialized format, and whether to replicate the RDD partitions on multiple nodes. - Also contains static constants for some commonly used storage levels, such as MEMORY_ONLY. + in a JAVA-specific serialized format, and whether to replicate the RDD partitions on multiple + nodes. Also contains static constants for some commonly used storage levels, MEMORY_ONLY. + Since the data is always serialized on the Python side, all the constants use the serialized + formats. """ def __init__(self, useDisk, useMemory, useOffHeap, deserialized, replication=1): @@ -49,12 +51,21 @@ def __str__(self): StorageLevel.DISK_ONLY = StorageLevel(True, False, False, False) StorageLevel.DISK_ONLY_2 = StorageLevel(True, False, False, False, 2) -StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, True) -StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, True, 2) -StorageLevel.MEMORY_ONLY_SER = StorageLevel(False, True, False, False) -StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel(False, True, False, False, 2) -StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, True) -StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, True, 2) -StorageLevel.MEMORY_AND_DISK_SER = StorageLevel(True, True, False, False) -StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel(True, True, False, False, 2) +StorageLevel.MEMORY_ONLY = StorageLevel(False, True, False, False) +StorageLevel.MEMORY_ONLY_2 = StorageLevel(False, True, False, False, 2) +StorageLevel.MEMORY_AND_DISK = StorageLevel(True, True, False, False) +StorageLevel.MEMORY_AND_DISK_2 = StorageLevel(True, True, False, False, 2) StorageLevel.OFF_HEAP = StorageLevel(False, False, True, False, 1) + +""" +.. note:: The following four storage level constants are deprecated in 2.0, since the records \ +will always be serialized in Python. +""" +StorageLevel.MEMORY_ONLY_SER = StorageLevel.MEMORY_ONLY +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_ONLY`` instead.""" +StorageLevel.MEMORY_ONLY_SER_2 = StorageLevel.MEMORY_ONLY_2 +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_ONLY_2`` instead.""" +StorageLevel.MEMORY_AND_DISK_SER = StorageLevel.MEMORY_AND_DISK +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_AND_DISK`` instead.""" +StorageLevel.MEMORY_AND_DISK_SER_2 = StorageLevel.MEMORY_AND_DISK_2 +""".. note:: Deprecated in 2.0, use ``StorageLevel.MEMORY_AND_DISK_2`` instead.""" diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 1388b6d044e0..0f1f005ce3ed 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -19,6 +19,7 @@ import os import sys +from threading import RLock, Timer from py4j.java_gateway import java_import, JavaObject @@ -32,6 +33,63 @@ __all__ = ["StreamingContext"] +class Py4jCallbackConnectionCleaner(object): + + """ + A cleaner to clean up callback connections that are not closed by Py4j. See SPARK-12617. + It will scan all callback connections every 30 seconds and close the dead connections. + """ + + def __init__(self, gateway): + self._gateway = gateway + self._stopped = False + self._timer = None + self._lock = RLock() + + def start(self): + if self._stopped: + return + + def clean_closed_connections(): + from py4j.java_gateway import quiet_close, quiet_shutdown + + callback_server = self._gateway._callback_server + if callback_server: + with callback_server.lock: + try: + closed_connections = [] + for connection in callback_server.connections: + if not connection.isAlive(): + quiet_close(connection.input) + quiet_shutdown(connection.socket) + quiet_close(connection.socket) + closed_connections.append(connection) + + for closed_connection in closed_connections: + callback_server.connections.remove(closed_connection) + except Exception: + import traceback + traceback.print_exc() + + self._start_timer(clean_closed_connections) + + self._start_timer(clean_closed_connections) + + def _start_timer(self, f): + with self._lock: + if not self._stopped: + self._timer = Timer(30.0, f) + self._timer.daemon = True + self._timer.start() + + def stop(self): + with self._lock: + self._stopped = True + if self._timer: + self._timer.cancel() + self._timer = None + + class StreamingContext(object): """ Main entry point for Spark Streaming functionality. A StreamingContext @@ -47,6 +105,9 @@ class StreamingContext(object): # Reference to a currently active StreamingContext _activeContext = None + # A cleaner to clean leak sockets of callback server every 30 seconds + _py4j_cleaner = None + def __init__(self, sparkContext, batchDuration=None, jssc=None): """ Create a new StreamingContext. @@ -95,11 +156,33 @@ def _ensure_initialized(cls): jgws = JavaObject("GATEWAY_SERVER", gw._gateway_client) # update the port of CallbackClient with real port gw.jvm.PythonDStream.updatePythonGatewayPort(jgws, gw._python_proxy_port) + _py4j_cleaner = Py4jCallbackConnectionCleaner(gw) + _py4j_cleaner.start() # register serializer for TransformFunction # it happens before creating SparkContext when loading from checkpointing - cls._transformerSerializer = TransformFunctionSerializer( - SparkContext._active_spark_context, CloudPickleSerializer(), gw) + if cls._transformerSerializer is None: + transformer_serializer = TransformFunctionSerializer() + transformer_serializer.init( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) + # SPARK-12511 streaming driver with checkpointing unable to finalize leading to OOM + # There is an issue that Py4J's PythonProxyHandler.finalize blocks forever. + # (https://github.com/bartdag/py4j/pull/184) + # + # Py4j will create a PythonProxyHandler in Java for "transformer_serializer" when + # calling "registerSerializer". If we call "registerSerializer" twice, the second + # PythonProxyHandler will override the first one, then the first one will be GCed and + # trigger "PythonProxyHandler.finalize". To avoid that, we should not call + # "registerSerializer" more than once, so that "PythonProxyHandler" in Java side won't + # be GCed. + # + # TODO Once Py4J fixes this issue, we should upgrade Py4j to the latest version. + transformer_serializer.gateway.jvm.PythonDStream.registerSerializer( + transformer_serializer) + cls._transformerSerializer = transformer_serializer + else: + cls._transformerSerializer.init( + SparkContext._active_spark_context, CloudPickleSerializer(), gw) @classmethod def getOrCreate(cls, checkpointPath, setupFunc): @@ -116,16 +199,13 @@ def getOrCreate(cls, checkpointPath, setupFunc): gw = SparkContext._gateway # Check whether valid checkpoint information exists in the given path - if gw.jvm.CheckpointReader.read(checkpointPath).isEmpty(): + ssc_option = gw.jvm.StreamingContextPythonHelper().tryRecoverFromCheckpoint(checkpointPath) + if ssc_option.isEmpty(): ssc = setupFunc() ssc.checkpoint(checkpointPath) return ssc - try: - jssc = gw.jvm.JavaStreamingContext(checkpointPath) - except Exception: - print("failed to load StreamingContext from checkpoint", file=sys.stderr) - raise + jssc = gw.jvm.JavaStreamingContext(ssc_option.get()) # If there is already an active instance of Python SparkContext use it, or create a new one if not SparkContext._active_spark_context: @@ -258,7 +338,7 @@ def checkpoint(self, directory): """ self._jssc.checkpoint(directory) - def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + def socketTextStream(self, hostname, port, storageLevel=StorageLevel.MEMORY_AND_DISK_2): """ Create an input from TCP source hostname:port. Data is received using a TCP socket and receive byte is interpreted as UTF8 encoded ``\\n`` delimited diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py index b994a53bf2b8..86447f5e58ec 100644 --- a/python/pyspark/streaming/dstream.py +++ b/python/pyspark/streaming/dstream.py @@ -208,10 +208,10 @@ def func(iterator): def cache(self): """ Persist the RDDs of this DStream with the default storage level - (C{MEMORY_ONLY_SER}). + (C{MEMORY_ONLY}). """ self.is_cached = True - self.persist(StorageLevel.MEMORY_ONLY_SER) + self.persist(StorageLevel.MEMORY_ONLY) return self def persist(self, storageLevel): @@ -247,7 +247,7 @@ def countByValue(self): Return a new DStream in which each RDD contains the counts of each distinct value in each RDD of this DStream. """ - return self.map(lambda x: (x, None)).reduceByKey(lambda x, y: None).count() + return self.map(lambda x: (x, 1)).reduceByKey(lambda x, y: x+y) def saveAsTextFiles(self, prefix, suffix=None): """ @@ -493,7 +493,7 @@ def countByValueAndWindow(self, windowDuration, slideDuration, numPartitions=Non keyed = self.map(lambda x: (x, 1)) counted = keyed.reduceByKeyAndWindow(operator.add, operator.sub, windowDuration, slideDuration, numPartitions) - return counted.filter(lambda kv: kv[1] > 0).count() + return counted.filter(lambda kv: kv[1] > 0) def groupByKeyAndWindow(self, windowDuration, slideDuration, numPartitions=None): """ diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index b3d190536592..b1fff0a5c7d6 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -40,7 +40,7 @@ class FlumeUtils(object): @staticmethod def createStream(ssc, hostname, port, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, enableDecompression=False, bodyDecoder=utf8_decoder): """ @@ -70,7 +70,7 @@ def createStream(ssc, hostname, port, @staticmethod def createPollingStream(ssc, addresses, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, maxBatchSize=1000, parallelism=5, bodyDecoder=utf8_decoder): diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index cdf97ec73aaf..13f8f9578e62 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -40,7 +40,7 @@ class KafkaUtils(object): @staticmethod def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2, + storageLevel=StorageLevel.MEMORY_AND_DISK_2, keyDecoder=utf8_decoder, valueDecoder=utf8_decoder): """ Create an input stream that pulls messages from a Kafka Broker. diff --git a/python/pyspark/streaming/mqtt.py b/python/pyspark/streaming/mqtt.py index 1ce4093196e6..3a515ea4996f 100644 --- a/python/pyspark/streaming/mqtt.py +++ b/python/pyspark/streaming/mqtt.py @@ -28,7 +28,7 @@ class MQTTUtils(object): @staticmethod def createStream(ssc, brokerUrl, topic, - storageLevel=StorageLevel.MEMORY_AND_DISK_SER_2): + storageLevel=StorageLevel.MEMORY_AND_DISK_2): """ Create an input stream that pulls messages from a Mqtt Broker. diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 4949cd68e321..86b05d9fd242 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -279,8 +279,10 @@ def test_countByValue(self): def func(dstream): return dstream.countByValue() - expected = [[4], [4], [3]] - self._test_func(input, func, expected) + expected = [[(1, 2), (2, 2), (3, 2), (4, 2)], + [(5, 2), (6, 2), (7, 1), (8, 1)], + [("a", 2), ("b", 1), ("", 1)]] + self._test_func(input, func, expected, sort=True) def test_groupByKey(self): """Basic operation test for DStream.groupByKey.""" @@ -651,7 +653,16 @@ def test_count_by_value_and_window(self): def func(dstream): return dstream.countByValueAndWindow(2.5, .5) - expected = [[1], [2], [3], [4], [5], [6], [6], [6], [6], [6]] + expected = [[(0, 1)], + [(0, 2), (1, 1)], + [(0, 3), (1, 2), (2, 1)], + [(0, 4), (1, 3), (2, 2), (3, 1)], + [(0, 5), (1, 4), (2, 3), (3, 2), (4, 1)], + [(0, 5), (1, 5), (2, 4), (3, 3), (4, 2), (5, 1)], + [(0, 4), (1, 4), (2, 4), (3, 3), (4, 2), (5, 1)], + [(0, 3), (1, 3), (2, 3), (3, 3), (4, 2), (5, 1)], + [(0, 2), (1, 2), (2, 2), (3, 2), (4, 2), (5, 1)], + [(0, 1), (1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]] self._test_func(input, func, expected) def test_group_by_key_and_window(self): diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index abbbf6eb9394..e617fc9ce9ee 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -89,11 +89,10 @@ class TransformFunctionSerializer(object): it uses this class to invoke Python, which returns the serialized function as a byte array. """ - def __init__(self, ctx, serializer, gateway=None): + def init(self, ctx, serializer, gateway=None): self.ctx = ctx self.serializer = serializer self.gateway = gateway or self.ctx._gateway - self.gateway.jvm.PythonDStream.registerSerializer(self) self.failure = None def dumps(self, id): diff --git a/repl/pom.xml b/repl/pom.xml index 154c99d23c7f..efc3dd452e32 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml @@ -50,12 +50,6 @@ test-jar test - - org.apache.spark - spark-bagel_${scala.binary.version} - ${project.version} - runtime - org.apache.spark spark-mllib_${scala.binary.version} diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index da8f0aa1e336..2bf1be1a582b 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -17,7 +17,7 @@ package org.apache.spark.repl -import java.io.{IOException, ByteArrayOutputStream, InputStream} +import java.io.{ByteArrayOutputStream, FilterInputStream, InputStream, IOException} import java.net.{HttpURLConnection, URI, URL, URLEncoder} import java.nio.channels.Channels @@ -27,10 +27,9 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.xbean.asm5._ import org.apache.xbean.asm5.Opcodes._ -import org.apache.spark.{SparkConf, SparkEnv, Logging} +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils -import org.apache.spark.util.ParentClassLoader +import org.apache.spark.util.{ParentClassLoader, Utils} /** * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI, @@ -96,7 +95,24 @@ class ExecutorClassLoader( private def getClassFileInputStreamFromSparkRPC(path: String): InputStream = { val channel = env.rpcEnv.openChannel(s"$classUri/$path") - Channels.newInputStream(channel) + new FilterInputStream(Channels.newInputStream(channel)) { + + override def read(): Int = toClassNotFound(super.read()) + + override def read(b: Array[Byte]): Int = toClassNotFound(super.read(b)) + + override def read(b: Array[Byte], offset: Int, len: Int) = + toClassNotFound(super.read(b, offset, len)) + + private def toClassNotFound(fn: => Int): Int = { + try { + fn + } catch { + case e: Exception => + throw new ClassNotFoundException(path, e) + } + } + } } private def getClassFileInputStreamFromHttpServer(pathInDirectory: String): InputStream = { diff --git a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala index 1360f09e7fa1..ce3f51bd72dd 100644 --- a/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ExecutorClassLoaderSuite.scala @@ -30,14 +30,14 @@ import scala.language.implicitConversions import scala.language.postfixOps import com.google.common.io.Files +import org.mockito.Matchers.anyString +import org.mockito.Mockito._ +import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.BeforeAndAfterAll import org.scalatest.concurrent.Interruptor import org.scalatest.concurrent.Timeouts._ import org.scalatest.mock.MockitoSugar -import org.mockito.invocation.InvocationOnMock -import org.mockito.stubbing.Answer -import org.mockito.Matchers.anyString -import org.mockito.Mockito._ import org.apache.spark._ import org.apache.spark.rpc.RpcEnv @@ -72,13 +72,16 @@ class ExecutorClassLoaderSuite } override def afterAll() { - super.afterAll() - if (classServer != null) { - classServer.stop() + try { + if (classServer != null) { + classServer.stop() + } + Utils.deleteRecursively(tempDir1) + Utils.deleteRecursively(tempDir2) + SparkEnv.set(null) + } finally { + super.afterAll() } - Utils.deleteRecursively(tempDir1) - Utils.deleteRecursively(tempDir2) - SparkEnv.set(null) } test("child first") { diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 6925e18737b7..9714c46fe99a 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -187,14 +187,6 @@ This file is divided into 3 sections: scala.collection.JavaConverters._ and use .asScala / .asJava methods - - - ^getConfiguration$|^getTaskAttemptID$ - Instead of calling .getConfiguration() or .getTaskAttemptID() directly, - use SparkHadoopUtil's getConfigurationFromJobContext() and getTaskAttemptIDFromTaskAttemptContext() methods. - - - @@ -219,13 +211,19 @@ This file is divided into 3 sections: java,scala,3rdParty,spark - javax?\..+ - scala\..+ + javax?\..* + scala\..* (?!org\.apache\.spark\.).* org\.apache\.spark\..* + + + + COLON, COMMA + + diff --git a/sql/README.md b/sql/README.md index 63d4dac9829e..a13bdab6d457 100644 --- a/sql/README.md +++ b/sql/README.md @@ -20,7 +20,7 @@ If you are working with Hive 0.12.0, you will need to set several environmental ``` export HIVE_HOME="/hive/build/dist" export HIVE_DEV_HOME="/hive/" -export HADOOP_HOME="/hadoop-1.0.4" +export HADOOP_HOME="/hadoop" ``` If you are working with Hive 0.13.1, the following steps are needed: diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 61d6fc63554b..76ca3f3bb1bf 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml @@ -71,6 +71,10 @@ org.codehaus.janino janino + + org.antlr + antlr-runtime + target/scala-${scala.binary.version}/classes @@ -103,6 +107,24 @@ + + org.antlr + antlr3-maven-plugin + + + + antlr + + + + + ../catalyst/src/main/antlr3 + + **/SparkSqlLexer.g + **/SparkSqlParser.g + + + diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g new file mode 100644 index 000000000000..aabb5d49582c --- /dev/null +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g @@ -0,0 +1,570 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. +*/ + +parser grammar ExpressionParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.displayRecognitionError(tokenNames, e); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +// fun(par1, par2, par3) +function +@init { gParent.pushMsg("function specification", state); } +@after { gParent.popMsg(state); } + : + functionName + LPAREN + ( + (STAR) => (star=STAR) + | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)? + ) + RPAREN (KW_OVER ws=window_specification)? + -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?) + -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?) + -> ^(TOK_FUNCTIONDI functionName (selectExpression+)? $ws?) + ; + +functionName +@init { gParent.pushMsg("function name", state); } +@after { gParent.popMsg(state); } + : // Keyword IF is also a function name + (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) + | + (functionIdentifier) => functionIdentifier + | + {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text] + ; + +castExpression +@init { gParent.pushMsg("cast expression", state); } +@after { gParent.popMsg(state); } + : + KW_CAST + LPAREN + expression + KW_AS + primitiveType + RPAREN -> ^(TOK_FUNCTION primitiveType expression) + ; + +caseExpression +@init { gParent.pushMsg("case expression", state); } +@after { gParent.popMsg(state); } + : + KW_CASE expression + (KW_WHEN expression KW_THEN expression)+ + (KW_ELSE expression)? + KW_END -> ^(TOK_FUNCTION KW_CASE expression*) + ; + +whenExpression +@init { gParent.pushMsg("case expression", state); } +@after { gParent.popMsg(state); } + : + KW_CASE + ( KW_WHEN expression KW_THEN expression)+ + (KW_ELSE expression)? + KW_END -> ^(TOK_FUNCTION KW_WHEN expression*) + ; + +constant +@init { gParent.pushMsg("constant", state); } +@after { gParent.popMsg(state); } + : + Number + | dateLiteral + | timestampLiteral + | intervalLiteral + | StringLiteral + | stringLiteralSequence + | BigintLiteral + | SmallintLiteral + | TinyintLiteral + | DecimalLiteral + | charSetStringLiteral + | booleanValue + ; + +stringLiteralSequence + : + StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) + ; + +charSetStringLiteral +@init { gParent.pushMsg("character string literal", state); } +@after { gParent.popMsg(state); } + : + csName=CharSetName csLiteral=CharSetLiteral -> ^(TOK_CHARSETLITERAL $csName $csLiteral) + ; + +dateLiteral + : + KW_DATE StringLiteral -> + { + // Create DateLiteral token, but with the text of the string value + // This makes the dateLiteral more consistent with the other type literals. + adaptor.create(TOK_DATELITERAL, $StringLiteral.text) + } + | + KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE) + ; + +timestampLiteral + : + KW_TIMESTAMP StringLiteral -> + { + adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text) + } + | + KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP) + ; + +intervalLiteral + : + KW_INTERVAL StringLiteral qualifiers=intervalQualifiers -> + { + adaptor.create($qualifiers.tree.token.getType(), $StringLiteral.text) + } + ; + +intervalQualifiers + : + KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH_LITERAL + | KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME_LITERAL + | KW_YEAR -> TOK_INTERVAL_YEAR_LITERAL + | KW_MONTH -> TOK_INTERVAL_MONTH_LITERAL + | KW_DAY -> TOK_INTERVAL_DAY_LITERAL + | KW_HOUR -> TOK_INTERVAL_HOUR_LITERAL + | KW_MINUTE -> TOK_INTERVAL_MINUTE_LITERAL + | KW_SECOND -> TOK_INTERVAL_SECOND_LITERAL + ; + +expression +@init { gParent.pushMsg("expression specification", state); } +@after { gParent.popMsg(state); } + : + precedenceOrExpression + ; + +atomExpression + : + (KW_NULL) => KW_NULL -> TOK_NULL + | (constant) => constant + | castExpression + | caseExpression + | whenExpression + | (functionName LPAREN) => function + | tableOrColumn + | LPAREN! expression RPAREN! + ; + + +precedenceFieldExpression + : + atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))* + ; + +precedenceUnaryOperator + : + PLUS | MINUS | TILDE + ; + +nullCondition + : + KW_NULL -> ^(TOK_ISNULL) + | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL) + ; + +precedenceUnaryPrefixExpression + : + (precedenceUnaryOperator^)* precedenceFieldExpression + ; + +precedenceUnarySuffixExpression + : + ( + (LPAREN precedenceUnaryPrefixExpression RPAREN) => LPAREN precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? RPAREN + | + precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? + ) + -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) + -> precedenceUnaryPrefixExpression + ; + + +precedenceBitwiseXorOperator + : + BITWISEXOR + ; + +precedenceBitwiseXorExpression + : + precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)* + ; + + +precedenceStarOperator + : + STAR | DIVIDE | MOD | DIV + ; + +precedenceStarExpression + : + precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)* + ; + + +precedencePlusOperator + : + PLUS | MINUS + ; + +precedencePlusExpression + : + precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)* + ; + + +precedenceAmpersandOperator + : + AMPERSAND + ; + +precedenceAmpersandExpression + : + precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)* + ; + + +precedenceBitwiseOrOperator + : + BITWISEOR + ; + +precedenceBitwiseOrExpression + : + precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)* + ; + + +// Equal operators supporting NOT prefix +precedenceEqualNegatableOperator + : + KW_LIKE | KW_RLIKE | KW_REGEXP + ; + +precedenceEqualOperator + : + precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +subQueryExpression + : + LPAREN! selectStatement[true] RPAREN! + ; + +precedenceEqualExpression + : + (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple + | + precedenceEqualExpressionSingle + ; + +precedenceEqualExpressionSingle + : + (left=precedenceBitwiseOrExpression -> $left) + ( + (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression) + -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr)) + | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression) + -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr) + | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression) + -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)) + | (KW_NOT KW_IN expressions) + -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)) + | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression) + -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle) + | (KW_IN expressions) + -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions) + | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) + -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max) + | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) + -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max) + )* + | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression) + ; + +expressions + : + LPAREN expression (COMMA expression)* RPAREN -> expression+ + ; + +//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11)) +precedenceEqualExpressionMutiple + : + (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+)) + ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) + -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+) + | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) + -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+))) + ; + +expressionsToStruct + : + LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+) + ; + +precedenceNotOperator + : + KW_NOT + ; + +precedenceNotExpression + : + (precedenceNotOperator^)* precedenceEqualExpression + ; + + +precedenceAndOperator + : + KW_AND + ; + +precedenceAndExpression + : + precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)* + ; + + +precedenceOrOperator + : + KW_OR + ; + +precedenceOrExpression + : + precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* + ; + + +booleanValue + : + KW_TRUE^ | KW_FALSE^ + ; + +booleanValueTok + : + KW_TRUE -> TOK_TRUE + | KW_FALSE -> TOK_FALSE + ; + +tableOrPartition + : + tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) + ; + +partitionSpec + : + KW_PARTITION + LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) + ; + +partitionVal + : + identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) + ; + +dropPartitionSpec + : + KW_PARTITION + LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) + ; + +dropPartitionVal + : + identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) + ; + +dropPartitionOperator + : + EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN + ; + +sysFuncNames + : + KW_AND + | KW_OR + | KW_NOT + | KW_LIKE + | KW_IF + | KW_CASE + | KW_WHEN + | KW_TINYINT + | KW_SMALLINT + | KW_INT + | KW_BIGINT + | KW_FLOAT + | KW_DOUBLE + | KW_BOOLEAN + | KW_STRING + | KW_BINARY + | KW_ARRAY + | KW_MAP + | KW_STRUCT + | KW_UNIONTYPE + | EQUAL + | EQUAL_NS + | NOTEQUAL + | LESSTHANOREQUALTO + | LESSTHAN + | GREATERTHANOREQUALTO + | GREATERTHAN + | DIVIDE + | PLUS + | MINUS + | STAR + | MOD + | DIV + | AMPERSAND + | TILDE + | BITWISEOR + | BITWISEXOR + | KW_RLIKE + | KW_REGEXP + | KW_IN + | KW_BETWEEN + ; + +descFuncNames + : + (sysFuncNames) => sysFuncNames + | StringLiteral + | functionIdentifier + ; + +identifier + : + Identifier + | nonReserved -> Identifier[$nonReserved.text] + // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, + // the sql11keywords in existing q tests will NOT be added back. + | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] + ; + +functionIdentifier +@init { gParent.pushMsg("function identifier", state); } +@after { gParent.popMsg(state); } + : db=identifier DOT fn=identifier + -> Identifier[$db.text + "." + $fn.text] + | + identifier + ; + +principalIdentifier +@init { gParent.pushMsg("identifier for principal spec", state); } +@after { gParent.popMsg(state); } + : identifier + | QuotedIdentifier + ; + +//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved +//Non reserved keywords are basically the keywords that can be used as identifiers. +//All the KW_* are automatically not only keywords, but also reserved keywords. +//That means, they can NOT be used as identifiers. +//If you would like to use them as identifiers, put them in the nonReserved list below. +//If you are not sure, please refer to the SQL2011 column in +//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html +nonReserved + : + KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS + | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS + | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY + | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY + | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE + | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT + | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE + | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR + | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG + | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE + | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY + | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER + | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE + | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED + | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED + | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED + | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET + | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR + | KW_WORK + | KW_TRANSACTION + | KW_WRITE + | KW_ISOLATION + | KW_LEVEL + | KW_SNAPSHOT + | KW_AUTOCOMMIT + | KW_ANTI +; + +//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. +sql11ReservedKeywordsUsedAsCastFunctionName + : + KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP + ; + +//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. +//We are planning to remove the following whole list after several releases. +//Thus, please do not change the following list unless you know what to do. +sql11ReservedKeywordsUsedAsIdentifier + : + KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN + | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE + | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT + | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL + | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION + | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT + | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE + | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH +//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. + | KW_REGEXP | KW_RLIKE + ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g new file mode 100644 index 000000000000..972c52e3ffce --- /dev/null +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g @@ -0,0 +1,332 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/FromClauseParser.g grammar. +*/ +parser grammar FromClauseParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.displayRecognitionError(tokenNames, e); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------------------------------------------------------------------- + +tableAllColumns + : STAR + -> ^(TOK_ALLCOLREF) + | tableName DOT STAR + -> ^(TOK_ALLCOLREF tableName) + ; + +// (table|column) +tableOrColumn +@init { gParent.pushMsg("table or column identifier", state); } +@after { gParent.popMsg(state); } + : + identifier -> ^(TOK_TABLE_OR_COL identifier) + ; + +expressionList +@init { gParent.pushMsg("expression list", state); } +@after { gParent.popMsg(state); } + : + expression (COMMA expression)* -> ^(TOK_EXPLIST expression+) + ; + +aliasList +@init { gParent.pushMsg("alias list", state); } +@after { gParent.popMsg(state); } + : + identifier (COMMA identifier)* -> ^(TOK_ALIASLIST identifier+) + ; + +//----------------------- Rules for parsing fromClause ------------------------------ +// from [col1, col2, col3] table1, [col4, col5] table2 +fromClause +@init { gParent.pushMsg("from clause", state); } +@after { gParent.popMsg(state); } + : + KW_FROM joinSource -> ^(TOK_FROM joinSource) + ; + +joinSource +@init { gParent.pushMsg("join source", state); } +@after { gParent.popMsg(state); } + : fromSource ( joinToken^ fromSource ( KW_ON! expression {$joinToken.start.getType() != COMMA}? )? )* + | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+ + ; + +uniqueJoinSource +@init { gParent.pushMsg("unique join source", state); } +@after { gParent.popMsg(state); } + : KW_PRESERVE? fromSource uniqueJoinExpr + ; + +uniqueJoinExpr +@init { gParent.pushMsg("unique join expression list", state); } +@after { gParent.popMsg(state); } + : LPAREN e1+=expression (COMMA e1+=expression)* RPAREN + -> ^(TOK_EXPLIST $e1*) + ; + +uniqueJoinToken +@init { gParent.pushMsg("unique join", state); } +@after { gParent.popMsg(state); } + : KW_UNIQUEJOIN -> TOK_UNIQUEJOIN; + +joinToken +@init { gParent.pushMsg("join type specifier", state); } +@after { gParent.popMsg(state); } + : + KW_JOIN -> TOK_JOIN + | KW_INNER KW_JOIN -> TOK_JOIN + | COMMA -> TOK_JOIN + | KW_CROSS KW_JOIN -> TOK_CROSSJOIN + | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN + | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN + | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN + | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN + | KW_ANTI KW_JOIN -> TOK_ANTIJOIN + ; + +lateralView +@init {gParent.pushMsg("lateral view", state); } +@after {gParent.popMsg(state); } + : + (KW_LATERAL KW_VIEW KW_OUTER) => KW_LATERAL KW_VIEW KW_OUTER function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? + -> ^(TOK_LATERAL_VIEW_OUTER ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) + | + KW_LATERAL KW_VIEW function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? + -> ^(TOK_LATERAL_VIEW ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) + ; + +tableAlias +@init {gParent.pushMsg("table alias", state); } +@after {gParent.popMsg(state); } + : + identifier -> ^(TOK_TABALIAS identifier) + ; + +fromSource +@init { gParent.pushMsg("from source", state); } +@after { gParent.popMsg(state); } + : + (LPAREN KW_VALUES) => fromSource0 + | fromSource0 + | (LPAREN joinSource) => LPAREN joinSource RPAREN -> joinSource + ; + + +fromSource0 +@init { gParent.pushMsg("from source 0", state); } +@after { gParent.popMsg(state); } + : + ((Identifier LPAREN)=> partitionedTableFunction | tableSource | subQuerySource | virtualTableSource) (lateralView^)* + ; + +tableBucketSample +@init { gParent.pushMsg("table bucket sample specification", state); } +@after { gParent.popMsg(state); } + : + KW_TABLESAMPLE LPAREN KW_BUCKET (numerator=Number) KW_OUT KW_OF (denominator=Number) (KW_ON expr+=expression (COMMA expr+=expression)*)? RPAREN -> ^(TOK_TABLEBUCKETSAMPLE $numerator $denominator $expr*) + ; + +splitSample +@init { gParent.pushMsg("table split sample specification", state); } +@after { gParent.popMsg(state); } + : + KW_TABLESAMPLE LPAREN (numerator=Number) (percent=KW_PERCENT|KW_ROWS) RPAREN + -> {percent != null}? ^(TOK_TABLESPLITSAMPLE TOK_PERCENT $numerator) + -> ^(TOK_TABLESPLITSAMPLE TOK_ROWCOUNT $numerator) + | + KW_TABLESAMPLE LPAREN (numerator=ByteLengthLiteral) RPAREN + -> ^(TOK_TABLESPLITSAMPLE TOK_LENGTH $numerator) + ; + +tableSample +@init { gParent.pushMsg("table sample specification", state); } +@after { gParent.popMsg(state); } + : + tableBucketSample | + splitSample + ; + +tableSource +@init { gParent.pushMsg("table source", state); } +@after { gParent.popMsg(state); } + : tabname=tableName + ((tableProperties) => props=tableProperties)? + ((tableSample) => ts=tableSample)? + ((KW_AS) => (KW_AS alias=Identifier) + | + (Identifier) => (alias=Identifier))? + -> ^(TOK_TABREF $tabname $props? $ts? $alias?) + ; + +tableName +@init { gParent.pushMsg("table name", state); } +@after { gParent.popMsg(state); } + : + db=identifier DOT tab=identifier + -> ^(TOK_TABNAME $db $tab) + | + tab=identifier + -> ^(TOK_TABNAME $tab) + ; + +viewName +@init { gParent.pushMsg("view name", state); } +@after { gParent.popMsg(state); } + : + (db=identifier DOT)? view=identifier + -> ^(TOK_TABNAME $db? $view) + ; + +subQuerySource +@init { gParent.pushMsg("subquery source", state); } +@after { gParent.popMsg(state); } + : + LPAREN queryStatementExpression[false] RPAREN KW_AS? identifier -> ^(TOK_SUBQUERY queryStatementExpression identifier) + ; + +//---------------------- Rules for parsing PTF clauses ----------------------------- +partitioningSpec +@init { gParent.pushMsg("partitioningSpec clause", state); } +@after { gParent.popMsg(state); } + : + partitionByClause orderByClause? -> ^(TOK_PARTITIONINGSPEC partitionByClause orderByClause?) | + orderByClause -> ^(TOK_PARTITIONINGSPEC orderByClause) | + distributeByClause sortByClause? -> ^(TOK_PARTITIONINGSPEC distributeByClause sortByClause?) | + sortByClause -> ^(TOK_PARTITIONINGSPEC sortByClause) | + clusterByClause -> ^(TOK_PARTITIONINGSPEC clusterByClause) + ; + +partitionTableFunctionSource +@init { gParent.pushMsg("partitionTableFunctionSource clause", state); } +@after { gParent.popMsg(state); } + : + subQuerySource | + tableSource | + partitionedTableFunction + ; + +partitionedTableFunction +@init { gParent.pushMsg("ptf clause", state); } +@after { gParent.popMsg(state); } + : + name=Identifier LPAREN KW_ON + ((partitionTableFunctionSource) => (ptfsrc=partitionTableFunctionSource spec=partitioningSpec?)) + ((Identifier LPAREN expression RPAREN ) => Identifier LPAREN expression RPAREN ( COMMA Identifier LPAREN expression RPAREN)*)? + ((RPAREN) => (RPAREN)) ((Identifier) => alias=Identifier)? + -> ^(TOK_PTBLFUNCTION $name $alias? $ptfsrc $spec? expression*) + ; + +//----------------------- Rules for parsing whereClause ----------------------------- +// where a=b and ... +whereClause +@init { gParent.pushMsg("where clause", state); } +@after { gParent.popMsg(state); } + : + KW_WHERE searchCondition -> ^(TOK_WHERE searchCondition) + ; + +searchCondition +@init { gParent.pushMsg("search condition", state); } +@after { gParent.popMsg(state); } + : + expression + ; + +//----------------------------------------------------------------------------------- + +//-------- Row Constructor ---------------------------------------------------------- +//in support of SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as FOO(a,b,c) and +// INSERT INTO (col1,col2,...) VALUES(...),(...),... +// INSERT INTO
    (col1,col2,...) SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as Foo(a,b,c) +valueRowConstructor +@init { gParent.pushMsg("value row constructor", state); } +@after { gParent.popMsg(state); } + : + LPAREN precedenceUnaryPrefixExpression (COMMA precedenceUnaryPrefixExpression)* RPAREN -> ^(TOK_VALUE_ROW precedenceUnaryPrefixExpression+) + ; + +valuesTableConstructor +@init { gParent.pushMsg("values table constructor", state); } +@after { gParent.popMsg(state); } + : + valueRowConstructor (COMMA valueRowConstructor)* -> ^(TOK_VALUES_TABLE valueRowConstructor+) + ; + +/* +VALUES(1),(2) means 2 rows, 1 column each. +VALUES(1,2),(3,4) means 2 rows, 2 columns each. +VALUES(1,2,3) means 1 row, 3 columns +*/ +valuesClause +@init { gParent.pushMsg("values clause", state); } +@after { gParent.popMsg(state); } + : + KW_VALUES valuesTableConstructor -> valuesTableConstructor + ; + +/* +This represents a clause like this: +(VALUES(1,2),(2,3)) as VirtTable(col1,col2) +*/ +virtualTableSource +@init { gParent.pushMsg("virtual table source", state); } +@after { gParent.popMsg(state); } + : + LPAREN valuesClause RPAREN tableNameColList -> ^(TOK_VIRTUAL_TABLE tableNameColList valuesClause) + ; +/* +e.g. as VirtTable(col1,col2) +Note that we only want literals as column names +*/ +tableNameColList +@init { gParent.pushMsg("from source", state); } +@after { gParent.popMsg(state); } + : + KW_AS? identifier LPAREN identifier (COMMA identifier)* RPAREN -> ^(TOK_VIRTUAL_TABREF ^(TOK_TABNAME identifier) ^(TOK_COL_NAME identifier+)) + ; + +//----------------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g new file mode 100644 index 000000000000..916eb6a7ac26 --- /dev/null +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g @@ -0,0 +1,184 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. +*/ +parser grammar IdentifiersParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.displayRecognitionError(tokenNames, e); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------------------------------------------------------------------- + +// group by a,b +groupByClause +@init { gParent.pushMsg("group by clause", state); } +@after { gParent.popMsg(state); } + : + KW_GROUP KW_BY + expression + ( COMMA expression)* + ((rollup=KW_WITH KW_ROLLUP) | (cube=KW_WITH KW_CUBE)) ? + (sets=KW_GROUPING KW_SETS + LPAREN groupingSetExpression ( COMMA groupingSetExpression)* RPAREN ) ? + -> {rollup != null}? ^(TOK_ROLLUP_GROUPBY expression+) + -> {cube != null}? ^(TOK_CUBE_GROUPBY expression+) + -> {sets != null}? ^(TOK_GROUPING_SETS expression+ groupingSetExpression+) + -> ^(TOK_GROUPBY expression+) + ; + +groupingSetExpression +@init {gParent.pushMsg("grouping set expression", state); } +@after {gParent.popMsg(state); } + : + (LPAREN) => groupingSetExpressionMultiple + | + groupingExpressionSingle + ; + +groupingSetExpressionMultiple +@init {gParent.pushMsg("grouping set part expression", state); } +@after {gParent.popMsg(state); } + : + LPAREN + expression? (COMMA expression)* + RPAREN + -> ^(TOK_GROUPING_SETS_EXPRESSION expression*) + ; + +groupingExpressionSingle +@init { gParent.pushMsg("groupingExpression expression", state); } +@after { gParent.popMsg(state); } + : + expression -> ^(TOK_GROUPING_SETS_EXPRESSION expression) + ; + +havingClause +@init { gParent.pushMsg("having clause", state); } +@after { gParent.popMsg(state); } + : + KW_HAVING havingCondition -> ^(TOK_HAVING havingCondition) + ; + +havingCondition +@init { gParent.pushMsg("having condition", state); } +@after { gParent.popMsg(state); } + : + expression + ; + +expressionsInParenthese + : + LPAREN expression (COMMA expression)* RPAREN -> expression+ + ; + +expressionsNotInParenthese + : + expression (COMMA expression)* -> expression+ + ; + +columnRefOrderInParenthese + : + LPAREN columnRefOrder (COMMA columnRefOrder)* RPAREN -> columnRefOrder+ + ; + +columnRefOrderNotInParenthese + : + columnRefOrder (COMMA columnRefOrder)* -> columnRefOrder+ + ; + +// order by a,b +orderByClause +@init { gParent.pushMsg("order by clause", state); } +@after { gParent.popMsg(state); } + : + KW_ORDER KW_BY columnRefOrder ( COMMA columnRefOrder)* -> ^(TOK_ORDERBY columnRefOrder+) + ; + +clusterByClause +@init { gParent.pushMsg("cluster by clause", state); } +@after { gParent.popMsg(state); } + : + KW_CLUSTER KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_CLUSTERBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_CLUSTERBY expressionsNotInParenthese) + ) + ; + +partitionByClause +@init { gParent.pushMsg("partition by clause", state); } +@after { gParent.popMsg(state); } + : + KW_PARTITION KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) + ) + ; + +distributeByClause +@init { gParent.pushMsg("distribute by clause", state); } +@after { gParent.popMsg(state); } + : + KW_DISTRIBUTE KW_BY + ( + (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) + | + expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) + ) + ; + +sortByClause +@init { gParent.pushMsg("sort by clause", state); } +@after { gParent.popMsg(state); } + : + KW_SORT KW_BY + ( + (LPAREN) => columnRefOrderInParenthese -> ^(TOK_SORTBY columnRefOrderInParenthese) + | + columnRefOrderNotInParenthese -> ^(TOK_SORTBY columnRefOrderNotInParenthese) + ) + ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g new file mode 100644 index 000000000000..2d2bafb1ee34 --- /dev/null +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g @@ -0,0 +1,228 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/SelectClauseParser.g grammar. +*/ +parser grammar SelectClauseParser; + +options +{ +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} + +@members { + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + @Override + public void displayRecognitionError(String[] tokenNames, + RecognitionException e) { + gParent.displayRecognitionError(tokenNames, e); + } + protected boolean useSQL11ReservedKeywordsForIdentifier() { + return gParent.useSQL11ReservedKeywordsForIdentifier(); + } +} + +@rulecatch { +catch (RecognitionException e) { + throw e; +} +} + +//----------------------- Rules for parsing selectClause ----------------------------- +// select a,b,c ... +selectClause +@init { gParent.pushMsg("select clause", state); } +@after { gParent.popMsg(state); } + : + KW_SELECT hintClause? (((KW_ALL | dist=KW_DISTINCT)? selectList) + | (transform=KW_TRANSFORM selectTrfmClause)) + -> {$transform == null && $dist == null}? ^(TOK_SELECT hintClause? selectList) + -> {$transform == null && $dist != null}? ^(TOK_SELECTDI hintClause? selectList) + -> ^(TOK_SELECT hintClause? ^(TOK_SELEXPR selectTrfmClause) ) + | + trfmClause ->^(TOK_SELECT ^(TOK_SELEXPR trfmClause)) + ; + +selectList +@init { gParent.pushMsg("select list", state); } +@after { gParent.popMsg(state); } + : + selectItem ( COMMA selectItem )* -> selectItem+ + ; + +selectTrfmClause +@init { gParent.pushMsg("transform clause", state); } +@after { gParent.popMsg(state); } + : + LPAREN selectExpressionList RPAREN + inSerde=rowFormat inRec=recordWriter + KW_USING StringLiteral + ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? + outSerde=rowFormat outRec=recordReader + -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) + ; + +hintClause +@init { gParent.pushMsg("hint clause", state); } +@after { gParent.popMsg(state); } + : + DIVIDE STAR PLUS hintList STAR DIVIDE -> ^(TOK_HINTLIST hintList) + ; + +hintList +@init { gParent.pushMsg("hint list", state); } +@after { gParent.popMsg(state); } + : + hintItem (COMMA hintItem)* -> hintItem+ + ; + +hintItem +@init { gParent.pushMsg("hint item", state); } +@after { gParent.popMsg(state); } + : + hintName (LPAREN hintArgs RPAREN)? -> ^(TOK_HINT hintName hintArgs?) + ; + +hintName +@init { gParent.pushMsg("hint name", state); } +@after { gParent.popMsg(state); } + : + KW_MAPJOIN -> TOK_MAPJOIN + | KW_STREAMTABLE -> TOK_STREAMTABLE + ; + +hintArgs +@init { gParent.pushMsg("hint arguments", state); } +@after { gParent.popMsg(state); } + : + hintArgName (COMMA hintArgName)* -> ^(TOK_HINTARGLIST hintArgName+) + ; + +hintArgName +@init { gParent.pushMsg("hint argument name", state); } +@after { gParent.popMsg(state); } + : + identifier + ; + +selectItem +@init { gParent.pushMsg("selection target", state); } +@after { gParent.popMsg(state); } + : + (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns) + | + ( expression + ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))? + ) -> ^(TOK_SELEXPR expression identifier*) + ; + +trfmClause +@init { gParent.pushMsg("transform clause", state); } +@after { gParent.popMsg(state); } + : + ( KW_MAP selectExpressionList + | KW_REDUCE selectExpressionList ) + inSerde=rowFormat inRec=recordWriter + KW_USING StringLiteral + ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? + outSerde=rowFormat outRec=recordReader + -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) + ; + +selectExpression +@init { gParent.pushMsg("select expression", state); } +@after { gParent.popMsg(state); } + : + (tableAllColumns) => tableAllColumns + | + expression + ; + +selectExpressionList +@init { gParent.pushMsg("select expression list", state); } +@after { gParent.popMsg(state); } + : + selectExpression (COMMA selectExpression)* -> ^(TOK_EXPLIST selectExpression+) + ; + +//---------------------- Rules for windowing clauses ------------------------------- +window_clause +@init { gParent.pushMsg("window_clause", state); } +@after { gParent.popMsg(state); } +: + KW_WINDOW window_defn (COMMA window_defn)* -> ^(KW_WINDOW window_defn+) +; + +window_defn +@init { gParent.pushMsg("window_defn", state); } +@after { gParent.popMsg(state); } +: + Identifier KW_AS window_specification -> ^(TOK_WINDOWDEF Identifier window_specification) +; + +window_specification +@init { gParent.pushMsg("window_specification", state); } +@after { gParent.popMsg(state); } +: + (Identifier | ( LPAREN Identifier? partitioningSpec? window_frame? RPAREN)) -> ^(TOK_WINDOWSPEC Identifier? partitioningSpec? window_frame?) +; + +window_frame : + window_range_expression | + window_value_expression +; + +window_range_expression +@init { gParent.pushMsg("window_range_expression", state); } +@after { gParent.popMsg(state); } +: + KW_ROWS sb=window_frame_start_boundary -> ^(TOK_WINDOWRANGE $sb) | + KW_ROWS KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWRANGE $s $end) +; + +window_value_expression +@init { gParent.pushMsg("window_value_expression", state); } +@after { gParent.popMsg(state); } +: + KW_RANGE sb=window_frame_start_boundary -> ^(TOK_WINDOWVALUES $sb) | + KW_RANGE KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWVALUES $s $end) +; + +window_frame_start_boundary +@init { gParent.pushMsg("windowframestartboundary", state); } +@after { gParent.popMsg(state); } +: + KW_UNBOUNDED KW_PRECEDING -> ^(KW_PRECEDING KW_UNBOUNDED) | + KW_CURRENT KW_ROW -> ^(KW_CURRENT) | + Number KW_PRECEDING -> ^(KW_PRECEDING Number) +; + +window_frame_boundary +@init { gParent.pushMsg("windowframeboundary", state); } +@after { gParent.popMsg(state); } +: + KW_UNBOUNDED (r=KW_PRECEDING|r=KW_FOLLOWING) -> ^($r KW_UNBOUNDED) | + KW_CURRENT KW_ROW -> ^(KW_CURRENT) | + Number (d=KW_PRECEDING | d=KW_FOLLOWING ) -> ^($d Number) +; + diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g new file mode 100644 index 000000000000..44a63fbef258 --- /dev/null +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g @@ -0,0 +1,486 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveLexer.g grammar. +*/ +lexer grammar SparkSqlLexer; + +@lexer::header { +package org.apache.spark.sql.catalyst.parser; + +} + +@lexer::members { + private ParserConf parserConf; + private ParseErrorReporter reporter; + + public void configure(ParserConf parserConf, ParseErrorReporter reporter) { + this.parserConf = parserConf; + this.reporter = reporter; + } + + protected boolean allowQuotedId() { + if (parserConf == null) { + return true; + } + return parserConf.supportQuotedId(); + } + + @Override + public void displayRecognitionError(String[] tokenNames, RecognitionException e) { + if (reporter != null) { + reporter.report(this, e, tokenNames); + } + } +} + +// Keywords + +KW_TRUE : 'TRUE'; +KW_FALSE : 'FALSE'; +KW_ALL : 'ALL'; +KW_NONE: 'NONE'; +KW_AND : 'AND'; +KW_OR : 'OR'; +KW_NOT : 'NOT' | '!'; +KW_LIKE : 'LIKE'; + +KW_IF : 'IF'; +KW_EXISTS : 'EXISTS'; + +KW_ASC : 'ASC'; +KW_DESC : 'DESC'; +KW_ORDER : 'ORDER'; +KW_GROUP : 'GROUP'; +KW_BY : 'BY'; +KW_HAVING : 'HAVING'; +KW_WHERE : 'WHERE'; +KW_FROM : 'FROM'; +KW_AS : 'AS'; +KW_SELECT : 'SELECT'; +KW_DISTINCT : 'DISTINCT'; +KW_INSERT : 'INSERT'; +KW_OVERWRITE : 'OVERWRITE'; +KW_OUTER : 'OUTER'; +KW_UNIQUEJOIN : 'UNIQUEJOIN'; +KW_PRESERVE : 'PRESERVE'; +KW_JOIN : 'JOIN'; +KW_LEFT : 'LEFT'; +KW_RIGHT : 'RIGHT'; +KW_FULL : 'FULL'; +KW_ANTI : 'ANTI'; +KW_ON : 'ON'; +KW_PARTITION : 'PARTITION'; +KW_PARTITIONS : 'PARTITIONS'; +KW_TABLE: 'TABLE'; +KW_TABLES: 'TABLES'; +KW_COLUMNS: 'COLUMNS'; +KW_INDEX: 'INDEX'; +KW_INDEXES: 'INDEXES'; +KW_REBUILD: 'REBUILD'; +KW_FUNCTIONS: 'FUNCTIONS'; +KW_SHOW: 'SHOW'; +KW_MSCK: 'MSCK'; +KW_REPAIR: 'REPAIR'; +KW_DIRECTORY: 'DIRECTORY'; +KW_LOCAL: 'LOCAL'; +KW_TRANSFORM : 'TRANSFORM'; +KW_USING: 'USING'; +KW_CLUSTER: 'CLUSTER'; +KW_DISTRIBUTE: 'DISTRIBUTE'; +KW_SORT: 'SORT'; +KW_UNION: 'UNION'; +KW_EXCEPT: 'EXCEPT'; +KW_LOAD: 'LOAD'; +KW_EXPORT: 'EXPORT'; +KW_IMPORT: 'IMPORT'; +KW_REPLICATION: 'REPLICATION'; +KW_METADATA: 'METADATA'; +KW_DATA: 'DATA'; +KW_INPATH: 'INPATH'; +KW_IS: 'IS'; +KW_NULL: 'NULL'; +KW_CREATE: 'CREATE'; +KW_EXTERNAL: 'EXTERNAL'; +KW_ALTER: 'ALTER'; +KW_CHANGE: 'CHANGE'; +KW_COLUMN: 'COLUMN'; +KW_FIRST: 'FIRST'; +KW_AFTER: 'AFTER'; +KW_DESCRIBE: 'DESCRIBE'; +KW_DROP: 'DROP'; +KW_RENAME: 'RENAME'; +KW_TO: 'TO'; +KW_COMMENT: 'COMMENT'; +KW_BOOLEAN: 'BOOLEAN'; +KW_TINYINT: 'TINYINT'; +KW_SMALLINT: 'SMALLINT'; +KW_INT: 'INT'; +KW_BIGINT: 'BIGINT'; +KW_FLOAT: 'FLOAT'; +KW_DOUBLE: 'DOUBLE'; +KW_DATE: 'DATE'; +KW_DATETIME: 'DATETIME'; +KW_TIMESTAMP: 'TIMESTAMP'; +KW_INTERVAL: 'INTERVAL'; +KW_DECIMAL: 'DECIMAL'; +KW_STRING: 'STRING'; +KW_CHAR: 'CHAR'; +KW_VARCHAR: 'VARCHAR'; +KW_ARRAY: 'ARRAY'; +KW_STRUCT: 'STRUCT'; +KW_MAP: 'MAP'; +KW_UNIONTYPE: 'UNIONTYPE'; +KW_REDUCE: 'REDUCE'; +KW_PARTITIONED: 'PARTITIONED'; +KW_CLUSTERED: 'CLUSTERED'; +KW_SORTED: 'SORTED'; +KW_INTO: 'INTO'; +KW_BUCKETS: 'BUCKETS'; +KW_ROW: 'ROW'; +KW_ROWS: 'ROWS'; +KW_FORMAT: 'FORMAT'; +KW_DELIMITED: 'DELIMITED'; +KW_FIELDS: 'FIELDS'; +KW_TERMINATED: 'TERMINATED'; +KW_ESCAPED: 'ESCAPED'; +KW_COLLECTION: 'COLLECTION'; +KW_ITEMS: 'ITEMS'; +KW_KEYS: 'KEYS'; +KW_KEY_TYPE: '$KEY$'; +KW_LINES: 'LINES'; +KW_STORED: 'STORED'; +KW_FILEFORMAT: 'FILEFORMAT'; +KW_INPUTFORMAT: 'INPUTFORMAT'; +KW_OUTPUTFORMAT: 'OUTPUTFORMAT'; +KW_INPUTDRIVER: 'INPUTDRIVER'; +KW_OUTPUTDRIVER: 'OUTPUTDRIVER'; +KW_ENABLE: 'ENABLE'; +KW_DISABLE: 'DISABLE'; +KW_LOCATION: 'LOCATION'; +KW_TABLESAMPLE: 'TABLESAMPLE'; +KW_BUCKET: 'BUCKET'; +KW_OUT: 'OUT'; +KW_OF: 'OF'; +KW_PERCENT: 'PERCENT'; +KW_CAST: 'CAST'; +KW_ADD: 'ADD'; +KW_REPLACE: 'REPLACE'; +KW_RLIKE: 'RLIKE'; +KW_REGEXP: 'REGEXP'; +KW_TEMPORARY: 'TEMPORARY'; +KW_FUNCTION: 'FUNCTION'; +KW_MACRO: 'MACRO'; +KW_FILE: 'FILE'; +KW_JAR: 'JAR'; +KW_EXPLAIN: 'EXPLAIN'; +KW_EXTENDED: 'EXTENDED'; +KW_FORMATTED: 'FORMATTED'; +KW_PRETTY: 'PRETTY'; +KW_DEPENDENCY: 'DEPENDENCY'; +KW_LOGICAL: 'LOGICAL'; +KW_SERDE: 'SERDE'; +KW_WITH: 'WITH'; +KW_DEFERRED: 'DEFERRED'; +KW_SERDEPROPERTIES: 'SERDEPROPERTIES'; +KW_DBPROPERTIES: 'DBPROPERTIES'; +KW_LIMIT: 'LIMIT'; +KW_SET: 'SET'; +KW_UNSET: 'UNSET'; +KW_TBLPROPERTIES: 'TBLPROPERTIES'; +KW_IDXPROPERTIES: 'IDXPROPERTIES'; +KW_VALUE_TYPE: '$VALUE$'; +KW_ELEM_TYPE: '$ELEM$'; +KW_DEFINED: 'DEFINED'; +KW_CASE: 'CASE'; +KW_WHEN: 'WHEN'; +KW_THEN: 'THEN'; +KW_ELSE: 'ELSE'; +KW_END: 'END'; +KW_MAPJOIN: 'MAPJOIN'; +KW_STREAMTABLE: 'STREAMTABLE'; +KW_CLUSTERSTATUS: 'CLUSTERSTATUS'; +KW_UTC: 'UTC'; +KW_UTCTIMESTAMP: 'UTC_TMESTAMP'; +KW_LONG: 'LONG'; +KW_DELETE: 'DELETE'; +KW_PLUS: 'PLUS'; +KW_MINUS: 'MINUS'; +KW_FETCH: 'FETCH'; +KW_INTERSECT: 'INTERSECT'; +KW_VIEW: 'VIEW'; +KW_IN: 'IN'; +KW_DATABASE: 'DATABASE'; +KW_DATABASES: 'DATABASES'; +KW_MATERIALIZED: 'MATERIALIZED'; +KW_SCHEMA: 'SCHEMA'; +KW_SCHEMAS: 'SCHEMAS'; +KW_GRANT: 'GRANT'; +KW_REVOKE: 'REVOKE'; +KW_SSL: 'SSL'; +KW_UNDO: 'UNDO'; +KW_LOCK: 'LOCK'; +KW_LOCKS: 'LOCKS'; +KW_UNLOCK: 'UNLOCK'; +KW_SHARED: 'SHARED'; +KW_EXCLUSIVE: 'EXCLUSIVE'; +KW_PROCEDURE: 'PROCEDURE'; +KW_UNSIGNED: 'UNSIGNED'; +KW_WHILE: 'WHILE'; +KW_READ: 'READ'; +KW_READS: 'READS'; +KW_PURGE: 'PURGE'; +KW_RANGE: 'RANGE'; +KW_ANALYZE: 'ANALYZE'; +KW_BEFORE: 'BEFORE'; +KW_BETWEEN: 'BETWEEN'; +KW_BOTH: 'BOTH'; +KW_BINARY: 'BINARY'; +KW_CROSS: 'CROSS'; +KW_CONTINUE: 'CONTINUE'; +KW_CURSOR: 'CURSOR'; +KW_TRIGGER: 'TRIGGER'; +KW_RECORDREADER: 'RECORDREADER'; +KW_RECORDWRITER: 'RECORDWRITER'; +KW_SEMI: 'SEMI'; +KW_LATERAL: 'LATERAL'; +KW_TOUCH: 'TOUCH'; +KW_ARCHIVE: 'ARCHIVE'; +KW_UNARCHIVE: 'UNARCHIVE'; +KW_COMPUTE: 'COMPUTE'; +KW_STATISTICS: 'STATISTICS'; +KW_USE: 'USE'; +KW_OPTION: 'OPTION'; +KW_CONCATENATE: 'CONCATENATE'; +KW_SHOW_DATABASE: 'SHOW_DATABASE'; +KW_UPDATE: 'UPDATE'; +KW_RESTRICT: 'RESTRICT'; +KW_CASCADE: 'CASCADE'; +KW_SKEWED: 'SKEWED'; +KW_ROLLUP: 'ROLLUP'; +KW_CUBE: 'CUBE'; +KW_DIRECTORIES: 'DIRECTORIES'; +KW_FOR: 'FOR'; +KW_WINDOW: 'WINDOW'; +KW_UNBOUNDED: 'UNBOUNDED'; +KW_PRECEDING: 'PRECEDING'; +KW_FOLLOWING: 'FOLLOWING'; +KW_CURRENT: 'CURRENT'; +KW_CURRENT_DATE: 'CURRENT_DATE'; +KW_CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +KW_LESS: 'LESS'; +KW_MORE: 'MORE'; +KW_OVER: 'OVER'; +KW_GROUPING: 'GROUPING'; +KW_SETS: 'SETS'; +KW_TRUNCATE: 'TRUNCATE'; +KW_NOSCAN: 'NOSCAN'; +KW_PARTIALSCAN: 'PARTIALSCAN'; +KW_USER: 'USER'; +KW_ROLE: 'ROLE'; +KW_ROLES: 'ROLES'; +KW_INNER: 'INNER'; +KW_EXCHANGE: 'EXCHANGE'; +KW_URI: 'URI'; +KW_SERVER : 'SERVER'; +KW_ADMIN: 'ADMIN'; +KW_OWNER: 'OWNER'; +KW_PRINCIPALS: 'PRINCIPALS'; +KW_COMPACT: 'COMPACT'; +KW_COMPACTIONS: 'COMPACTIONS'; +KW_TRANSACTIONS: 'TRANSACTIONS'; +KW_REWRITE : 'REWRITE'; +KW_AUTHORIZATION: 'AUTHORIZATION'; +KW_CONF: 'CONF'; +KW_VALUES: 'VALUES'; +KW_RELOAD: 'RELOAD'; +KW_YEAR: 'YEAR'; +KW_MONTH: 'MONTH'; +KW_DAY: 'DAY'; +KW_HOUR: 'HOUR'; +KW_MINUTE: 'MINUTE'; +KW_SECOND: 'SECOND'; +KW_START: 'START'; +KW_TRANSACTION: 'TRANSACTION'; +KW_COMMIT: 'COMMIT'; +KW_ROLLBACK: 'ROLLBACK'; +KW_WORK: 'WORK'; +KW_ONLY: 'ONLY'; +KW_WRITE: 'WRITE'; +KW_ISOLATION: 'ISOLATION'; +KW_LEVEL: 'LEVEL'; +KW_SNAPSHOT: 'SNAPSHOT'; +KW_AUTOCOMMIT: 'AUTOCOMMIT'; + +// Operators +// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. + +DOT : '.'; // generated as a part of Number rule +COLON : ':' ; +COMMA : ',' ; +SEMICOLON : ';' ; + +LPAREN : '(' ; +RPAREN : ')' ; +LSQUARE : '[' ; +RSQUARE : ']' ; +LCURLY : '{'; +RCURLY : '}'; + +EQUAL : '=' | '=='; +EQUAL_NS : '<=>'; +NOTEQUAL : '<>' | '!='; +LESSTHANOREQUALTO : '<='; +LESSTHAN : '<'; +GREATERTHANOREQUALTO : '>='; +GREATERTHAN : '>'; + +DIVIDE : '/'; +PLUS : '+'; +MINUS : '-'; +STAR : '*'; +MOD : '%'; +DIV : 'DIV'; + +AMPERSAND : '&'; +TILDE : '~'; +BITWISEOR : '|'; +BITWISEXOR : '^'; +QUESTION : '?'; +DOLLAR : '$'; + +// LITERALS +fragment +Letter + : 'a'..'z' | 'A'..'Z' + ; + +fragment +HexDigit + : 'a'..'f' | 'A'..'F' + ; + +fragment +Digit + : + '0'..'9' + ; + +fragment +Exponent + : + ('e' | 'E') ( PLUS|MINUS )? (Digit)+ + ; + +fragment +RegexComponent + : 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' + | PLUS | STAR | QUESTION | MINUS | DOT + | LPAREN | RPAREN | LSQUARE | RSQUARE | LCURLY | RCURLY + | BITWISEXOR | BITWISEOR | DOLLAR | '!' + ; + +StringLiteral + : + ( '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + )+ + ; + +CharSetLiteral + : + StringLiteral + | '0' 'X' (HexDigit|Digit)+ + ; + +BigintLiteral + : + (Digit)+ 'L' + ; + +SmallintLiteral + : + (Digit)+ 'S' + ; + +TinyintLiteral + : + (Digit)+ 'Y' + ; + +DecimalLiteral + : + Number 'B' 'D' + ; + +ByteLengthLiteral + : + (Digit)+ ('b' | 'B' | 'k' | 'K' | 'm' | 'M' | 'g' | 'G') + ; + +Number + : + (Digit)+ ( DOT (Digit)* (Exponent)? | Exponent)? + ; + +/* +An Identifier can be: +- tableName +- columnName +- select expr alias +- lateral view aliases +- database name +- view name +- subquery alias +- function name +- ptf argument identifier +- index name +- property name for: db,tbl,partition... +- fileFormat +- role name +- privilege name +- principal name +- macro name +- hint name +- window name +*/ +Identifier + : + (Letter | Digit) (Letter | Digit | '_')* + | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers; + at the API level only columns are allowed to be of this form */ + | '`' RegexComponent+ '`' + ; + +fragment +QuotedIdentifier + : + '`' ( '``' | ~('`') )* '`' { setText(getText().substring(1, getText().length() -1 ).replaceAll("``", "`")); } + ; + +CharSetName + : + '_' (Letter | Digit | '_' | '-' | '.' | ':' )+ + ; + +WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} + ; + +COMMENT + : '--' (~('\n'|'\r'))* + { $channel=HIDDEN; } + ; + diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g new file mode 100644 index 000000000000..2c13d3056f46 --- /dev/null +++ b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g @@ -0,0 +1,2485 @@ +/** + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveParser.g grammar. +*/ +parser grammar SparkSqlParser; + +options +{ +tokenVocab=SparkSqlLexer; +output=AST; +ASTLabelType=CommonTree; +backtrack=false; +k=3; +} +import SelectClauseParser, FromClauseParser, IdentifiersParser, ExpressionParser; + +tokens { +TOK_INSERT; +TOK_QUERY; +TOK_SELECT; +TOK_SELECTDI; +TOK_SELEXPR; +TOK_FROM; +TOK_TAB; +TOK_PARTSPEC; +TOK_PARTVAL; +TOK_DIR; +TOK_TABREF; +TOK_SUBQUERY; +TOK_INSERT_INTO; +TOK_DESTINATION; +TOK_ALLCOLREF; +TOK_TABLE_OR_COL; +TOK_FUNCTION; +TOK_FUNCTIONDI; +TOK_FUNCTIONSTAR; +TOK_WHERE; +TOK_OP_EQ; +TOK_OP_NE; +TOK_OP_LE; +TOK_OP_LT; +TOK_OP_GE; +TOK_OP_GT; +TOK_OP_DIV; +TOK_OP_ADD; +TOK_OP_SUB; +TOK_OP_MUL; +TOK_OP_MOD; +TOK_OP_BITAND; +TOK_OP_BITNOT; +TOK_OP_BITOR; +TOK_OP_BITXOR; +TOK_OP_AND; +TOK_OP_OR; +TOK_OP_NOT; +TOK_OP_LIKE; +TOK_TRUE; +TOK_FALSE; +TOK_TRANSFORM; +TOK_SERDE; +TOK_SERDENAME; +TOK_SERDEPROPS; +TOK_EXPLIST; +TOK_ALIASLIST; +TOK_GROUPBY; +TOK_ROLLUP_GROUPBY; +TOK_CUBE_GROUPBY; +TOK_GROUPING_SETS; +TOK_GROUPING_SETS_EXPRESSION; +TOK_HAVING; +TOK_ORDERBY; +TOK_CLUSTERBY; +TOK_DISTRIBUTEBY; +TOK_SORTBY; +TOK_UNIONALL; +TOK_UNIONDISTINCT; +TOK_EXCEPT; +TOK_INTERSECT; +TOK_JOIN; +TOK_LEFTOUTERJOIN; +TOK_RIGHTOUTERJOIN; +TOK_FULLOUTERJOIN; +TOK_UNIQUEJOIN; +TOK_CROSSJOIN; +TOK_LOAD; +TOK_EXPORT; +TOK_IMPORT; +TOK_REPLICATION; +TOK_METADATA; +TOK_NULL; +TOK_ISNULL; +TOK_ISNOTNULL; +TOK_TINYINT; +TOK_SMALLINT; +TOK_INT; +TOK_BIGINT; +TOK_BOOLEAN; +TOK_FLOAT; +TOK_DOUBLE; +TOK_DATE; +TOK_DATELITERAL; +TOK_DATETIME; +TOK_TIMESTAMP; +TOK_TIMESTAMPLITERAL; +TOK_INTERVAL_YEAR_MONTH; +TOK_INTERVAL_YEAR_MONTH_LITERAL; +TOK_INTERVAL_DAY_TIME; +TOK_INTERVAL_DAY_TIME_LITERAL; +TOK_INTERVAL_YEAR_LITERAL; +TOK_INTERVAL_MONTH_LITERAL; +TOK_INTERVAL_DAY_LITERAL; +TOK_INTERVAL_HOUR_LITERAL; +TOK_INTERVAL_MINUTE_LITERAL; +TOK_INTERVAL_SECOND_LITERAL; +TOK_STRING; +TOK_CHAR; +TOK_VARCHAR; +TOK_BINARY; +TOK_DECIMAL; +TOK_LIST; +TOK_STRUCT; +TOK_MAP; +TOK_UNIONTYPE; +TOK_COLTYPELIST; +TOK_CREATEDATABASE; +TOK_CREATETABLE; +TOK_TRUNCATETABLE; +TOK_CREATEINDEX; +TOK_CREATEINDEX_INDEXTBLNAME; +TOK_DEFERRED_REBUILDINDEX; +TOK_DROPINDEX; +TOK_LIKETABLE; +TOK_DESCTABLE; +TOK_DESCFUNCTION; +TOK_ALTERTABLE; +TOK_ALTERTABLE_RENAME; +TOK_ALTERTABLE_ADDCOLS; +TOK_ALTERTABLE_RENAMECOL; +TOK_ALTERTABLE_RENAMEPART; +TOK_ALTERTABLE_REPLACECOLS; +TOK_ALTERTABLE_ADDPARTS; +TOK_ALTERTABLE_DROPPARTS; +TOK_ALTERTABLE_PARTCOLTYPE; +TOK_ALTERTABLE_MERGEFILES; +TOK_ALTERTABLE_TOUCH; +TOK_ALTERTABLE_ARCHIVE; +TOK_ALTERTABLE_UNARCHIVE; +TOK_ALTERTABLE_SERDEPROPERTIES; +TOK_ALTERTABLE_SERIALIZER; +TOK_ALTERTABLE_UPDATECOLSTATS; +TOK_TABLE_PARTITION; +TOK_ALTERTABLE_FILEFORMAT; +TOK_ALTERTABLE_LOCATION; +TOK_ALTERTABLE_PROPERTIES; +TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION; +TOK_ALTERTABLE_DROPPROPERTIES; +TOK_ALTERTABLE_SKEWED; +TOK_ALTERTABLE_EXCHANGEPARTITION; +TOK_ALTERTABLE_SKEWED_LOCATION; +TOK_ALTERTABLE_BUCKETS; +TOK_ALTERTABLE_CLUSTER_SORT; +TOK_ALTERTABLE_COMPACT; +TOK_ALTERINDEX_REBUILD; +TOK_ALTERINDEX_PROPERTIES; +TOK_MSCK; +TOK_SHOWDATABASES; +TOK_SHOWTABLES; +TOK_SHOWCOLUMNS; +TOK_SHOWFUNCTIONS; +TOK_SHOWPARTITIONS; +TOK_SHOW_CREATEDATABASE; +TOK_SHOW_CREATETABLE; +TOK_SHOW_TABLESTATUS; +TOK_SHOW_TBLPROPERTIES; +TOK_SHOWLOCKS; +TOK_SHOWCONF; +TOK_LOCKTABLE; +TOK_UNLOCKTABLE; +TOK_LOCKDB; +TOK_UNLOCKDB; +TOK_SWITCHDATABASE; +TOK_DROPDATABASE; +TOK_DROPTABLE; +TOK_DATABASECOMMENT; +TOK_TABCOLLIST; +TOK_TABCOL; +TOK_TABLECOMMENT; +TOK_TABLEPARTCOLS; +TOK_TABLEROWFORMAT; +TOK_TABLEROWFORMATFIELD; +TOK_TABLEROWFORMATCOLLITEMS; +TOK_TABLEROWFORMATMAPKEYS; +TOK_TABLEROWFORMATLINES; +TOK_TABLEROWFORMATNULL; +TOK_TABLEFILEFORMAT; +TOK_FILEFORMAT_GENERIC; +TOK_OFFLINE; +TOK_ENABLE; +TOK_DISABLE; +TOK_READONLY; +TOK_NO_DROP; +TOK_STORAGEHANDLER; +TOK_NOT_CLUSTERED; +TOK_NOT_SORTED; +TOK_TABCOLNAME; +TOK_TABLELOCATION; +TOK_PARTITIONLOCATION; +TOK_TABLEBUCKETSAMPLE; +TOK_TABLESPLITSAMPLE; +TOK_PERCENT; +TOK_LENGTH; +TOK_ROWCOUNT; +TOK_TMP_FILE; +TOK_TABSORTCOLNAMEASC; +TOK_TABSORTCOLNAMEDESC; +TOK_STRINGLITERALSEQUENCE; +TOK_CHARSETLITERAL; +TOK_CREATEFUNCTION; +TOK_DROPFUNCTION; +TOK_RELOADFUNCTION; +TOK_CREATEMACRO; +TOK_DROPMACRO; +TOK_TEMPORARY; +TOK_CREATEVIEW; +TOK_DROPVIEW; +TOK_ALTERVIEW; +TOK_ALTERVIEW_PROPERTIES; +TOK_ALTERVIEW_DROPPROPERTIES; +TOK_ALTERVIEW_ADDPARTS; +TOK_ALTERVIEW_DROPPARTS; +TOK_ALTERVIEW_RENAME; +TOK_VIEWPARTCOLS; +TOK_EXPLAIN; +TOK_EXPLAIN_SQ_REWRITE; +TOK_TABLESERIALIZER; +TOK_TABLEPROPERTIES; +TOK_TABLEPROPLIST; +TOK_INDEXPROPERTIES; +TOK_INDEXPROPLIST; +TOK_TABTYPE; +TOK_LIMIT; +TOK_TABLEPROPERTY; +TOK_IFEXISTS; +TOK_IFNOTEXISTS; +TOK_ORREPLACE; +TOK_HINTLIST; +TOK_HINT; +TOK_MAPJOIN; +TOK_STREAMTABLE; +TOK_HINTARGLIST; +TOK_USERSCRIPTCOLNAMES; +TOK_USERSCRIPTCOLSCHEMA; +TOK_RECORDREADER; +TOK_RECORDWRITER; +TOK_LEFTSEMIJOIN; +TOK_ANTIJOIN; +TOK_LATERAL_VIEW; +TOK_LATERAL_VIEW_OUTER; +TOK_TABALIAS; +TOK_ANALYZE; +TOK_CREATEROLE; +TOK_DROPROLE; +TOK_GRANT; +TOK_REVOKE; +TOK_SHOW_GRANT; +TOK_PRIVILEGE_LIST; +TOK_PRIVILEGE; +TOK_PRINCIPAL_NAME; +TOK_USER; +TOK_GROUP; +TOK_ROLE; +TOK_RESOURCE_ALL; +TOK_GRANT_WITH_OPTION; +TOK_GRANT_WITH_ADMIN_OPTION; +TOK_ADMIN_OPTION_FOR; +TOK_GRANT_OPTION_FOR; +TOK_PRIV_ALL; +TOK_PRIV_ALTER_METADATA; +TOK_PRIV_ALTER_DATA; +TOK_PRIV_DELETE; +TOK_PRIV_DROP; +TOK_PRIV_INDEX; +TOK_PRIV_INSERT; +TOK_PRIV_LOCK; +TOK_PRIV_SELECT; +TOK_PRIV_SHOW_DATABASE; +TOK_PRIV_CREATE; +TOK_PRIV_OBJECT; +TOK_PRIV_OBJECT_COL; +TOK_GRANT_ROLE; +TOK_REVOKE_ROLE; +TOK_SHOW_ROLE_GRANT; +TOK_SHOW_ROLES; +TOK_SHOW_SET_ROLE; +TOK_SHOW_ROLE_PRINCIPALS; +TOK_SHOWINDEXES; +TOK_SHOWDBLOCKS; +TOK_INDEXCOMMENT; +TOK_DESCDATABASE; +TOK_DATABASEPROPERTIES; +TOK_DATABASELOCATION; +TOK_DBPROPLIST; +TOK_ALTERDATABASE_PROPERTIES; +TOK_ALTERDATABASE_OWNER; +TOK_TABNAME; +TOK_TABSRC; +TOK_RESTRICT; +TOK_CASCADE; +TOK_TABLESKEWED; +TOK_TABCOLVALUE; +TOK_TABCOLVALUE_PAIR; +TOK_TABCOLVALUES; +TOK_SKEWED_LOCATIONS; +TOK_SKEWED_LOCATION_LIST; +TOK_SKEWED_LOCATION_MAP; +TOK_STOREDASDIRS; +TOK_PARTITIONINGSPEC; +TOK_PTBLFUNCTION; +TOK_WINDOWDEF; +TOK_WINDOWSPEC; +TOK_WINDOWVALUES; +TOK_WINDOWRANGE; +TOK_SUBQUERY_EXPR; +TOK_SUBQUERY_OP; +TOK_SUBQUERY_OP_NOTIN; +TOK_SUBQUERY_OP_NOTEXISTS; +TOK_DB_TYPE; +TOK_TABLE_TYPE; +TOK_CTE; +TOK_ARCHIVE; +TOK_FILE; +TOK_JAR; +TOK_RESOURCE_URI; +TOK_RESOURCE_LIST; +TOK_SHOW_COMPACTIONS; +TOK_SHOW_TRANSACTIONS; +TOK_DELETE_FROM; +TOK_UPDATE_TABLE; +TOK_SET_COLUMNS_CLAUSE; +TOK_VALUE_ROW; +TOK_VALUES_TABLE; +TOK_VIRTUAL_TABLE; +TOK_VIRTUAL_TABREF; +TOK_ANONYMOUS; +TOK_COL_NAME; +TOK_URI_TYPE; +TOK_SERVER_TYPE; +TOK_START_TRANSACTION; +TOK_ISOLATION_LEVEL; +TOK_ISOLATION_SNAPSHOT; +TOK_TXN_ACCESS_MODE; +TOK_TXN_READ_ONLY; +TOK_TXN_READ_WRITE; +TOK_COMMIT; +TOK_ROLLBACK; +TOK_SET_AUTOCOMMIT; +} + + +// Package headers +@header { +package org.apache.spark.sql.catalyst.parser; + +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +} + + +@members { + Stack msgs = new Stack(); + + private static HashMap xlateMap; + static { + //this is used to support auto completion in CLI + xlateMap = new HashMap(); + + // Keywords + xlateMap.put("KW_TRUE", "TRUE"); + xlateMap.put("KW_FALSE", "FALSE"); + xlateMap.put("KW_ALL", "ALL"); + xlateMap.put("KW_NONE", "NONE"); + xlateMap.put("KW_AND", "AND"); + xlateMap.put("KW_OR", "OR"); + xlateMap.put("KW_NOT", "NOT"); + xlateMap.put("KW_LIKE", "LIKE"); + + xlateMap.put("KW_ASC", "ASC"); + xlateMap.put("KW_DESC", "DESC"); + xlateMap.put("KW_ORDER", "ORDER"); + xlateMap.put("KW_BY", "BY"); + xlateMap.put("KW_GROUP", "GROUP"); + xlateMap.put("KW_WHERE", "WHERE"); + xlateMap.put("KW_FROM", "FROM"); + xlateMap.put("KW_AS", "AS"); + xlateMap.put("KW_SELECT", "SELECT"); + xlateMap.put("KW_DISTINCT", "DISTINCT"); + xlateMap.put("KW_INSERT", "INSERT"); + xlateMap.put("KW_OVERWRITE", "OVERWRITE"); + xlateMap.put("KW_OUTER", "OUTER"); + xlateMap.put("KW_JOIN", "JOIN"); + xlateMap.put("KW_LEFT", "LEFT"); + xlateMap.put("KW_RIGHT", "RIGHT"); + xlateMap.put("KW_FULL", "FULL"); + xlateMap.put("KW_ON", "ON"); + xlateMap.put("KW_PARTITION", "PARTITION"); + xlateMap.put("KW_PARTITIONS", "PARTITIONS"); + xlateMap.put("KW_TABLE", "TABLE"); + xlateMap.put("KW_TABLES", "TABLES"); + xlateMap.put("KW_TBLPROPERTIES", "TBLPROPERTIES"); + xlateMap.put("KW_SHOW", "SHOW"); + xlateMap.put("KW_MSCK", "MSCK"); + xlateMap.put("KW_DIRECTORY", "DIRECTORY"); + xlateMap.put("KW_LOCAL", "LOCAL"); + xlateMap.put("KW_TRANSFORM", "TRANSFORM"); + xlateMap.put("KW_USING", "USING"); + xlateMap.put("KW_CLUSTER", "CLUSTER"); + xlateMap.put("KW_DISTRIBUTE", "DISTRIBUTE"); + xlateMap.put("KW_SORT", "SORT"); + xlateMap.put("KW_UNION", "UNION"); + xlateMap.put("KW_LOAD", "LOAD"); + xlateMap.put("KW_DATA", "DATA"); + xlateMap.put("KW_INPATH", "INPATH"); + xlateMap.put("KW_IS", "IS"); + xlateMap.put("KW_NULL", "NULL"); + xlateMap.put("KW_CREATE", "CREATE"); + xlateMap.put("KW_EXTERNAL", "EXTERNAL"); + xlateMap.put("KW_ALTER", "ALTER"); + xlateMap.put("KW_DESCRIBE", "DESCRIBE"); + xlateMap.put("KW_DROP", "DROP"); + xlateMap.put("KW_RENAME", "RENAME"); + xlateMap.put("KW_TO", "TO"); + xlateMap.put("KW_COMMENT", "COMMENT"); + xlateMap.put("KW_BOOLEAN", "BOOLEAN"); + xlateMap.put("KW_TINYINT", "TINYINT"); + xlateMap.put("KW_SMALLINT", "SMALLINT"); + xlateMap.put("KW_INT", "INT"); + xlateMap.put("KW_BIGINT", "BIGINT"); + xlateMap.put("KW_FLOAT", "FLOAT"); + xlateMap.put("KW_DOUBLE", "DOUBLE"); + xlateMap.put("KW_DATE", "DATE"); + xlateMap.put("KW_DATETIME", "DATETIME"); + xlateMap.put("KW_TIMESTAMP", "TIMESTAMP"); + xlateMap.put("KW_STRING", "STRING"); + xlateMap.put("KW_BINARY", "BINARY"); + xlateMap.put("KW_ARRAY", "ARRAY"); + xlateMap.put("KW_MAP", "MAP"); + xlateMap.put("KW_REDUCE", "REDUCE"); + xlateMap.put("KW_PARTITIONED", "PARTITIONED"); + xlateMap.put("KW_CLUSTERED", "CLUSTERED"); + xlateMap.put("KW_SORTED", "SORTED"); + xlateMap.put("KW_INTO", "INTO"); + xlateMap.put("KW_BUCKETS", "BUCKETS"); + xlateMap.put("KW_ROW", "ROW"); + xlateMap.put("KW_FORMAT", "FORMAT"); + xlateMap.put("KW_DELIMITED", "DELIMITED"); + xlateMap.put("KW_FIELDS", "FIELDS"); + xlateMap.put("KW_TERMINATED", "TERMINATED"); + xlateMap.put("KW_COLLECTION", "COLLECTION"); + xlateMap.put("KW_ITEMS", "ITEMS"); + xlateMap.put("KW_KEYS", "KEYS"); + xlateMap.put("KW_KEY_TYPE", "\$KEY\$"); + xlateMap.put("KW_LINES", "LINES"); + xlateMap.put("KW_STORED", "STORED"); + xlateMap.put("KW_SEQUENCEFILE", "SEQUENCEFILE"); + xlateMap.put("KW_TEXTFILE", "TEXTFILE"); + xlateMap.put("KW_INPUTFORMAT", "INPUTFORMAT"); + xlateMap.put("KW_OUTPUTFORMAT", "OUTPUTFORMAT"); + xlateMap.put("KW_LOCATION", "LOCATION"); + xlateMap.put("KW_TABLESAMPLE", "TABLESAMPLE"); + xlateMap.put("KW_BUCKET", "BUCKET"); + xlateMap.put("KW_OUT", "OUT"); + xlateMap.put("KW_OF", "OF"); + xlateMap.put("KW_CAST", "CAST"); + xlateMap.put("KW_ADD", "ADD"); + xlateMap.put("KW_REPLACE", "REPLACE"); + xlateMap.put("KW_COLUMNS", "COLUMNS"); + xlateMap.put("KW_RLIKE", "RLIKE"); + xlateMap.put("KW_REGEXP", "REGEXP"); + xlateMap.put("KW_TEMPORARY", "TEMPORARY"); + xlateMap.put("KW_FUNCTION", "FUNCTION"); + xlateMap.put("KW_EXPLAIN", "EXPLAIN"); + xlateMap.put("KW_EXTENDED", "EXTENDED"); + xlateMap.put("KW_SERDE", "SERDE"); + xlateMap.put("KW_WITH", "WITH"); + xlateMap.put("KW_SERDEPROPERTIES", "SERDEPROPERTIES"); + xlateMap.put("KW_LIMIT", "LIMIT"); + xlateMap.put("KW_SET", "SET"); + xlateMap.put("KW_PROPERTIES", "TBLPROPERTIES"); + xlateMap.put("KW_VALUE_TYPE", "\$VALUE\$"); + xlateMap.put("KW_ELEM_TYPE", "\$ELEM\$"); + xlateMap.put("KW_DEFINED", "DEFINED"); + xlateMap.put("KW_SUBQUERY", "SUBQUERY"); + xlateMap.put("KW_REWRITE", "REWRITE"); + xlateMap.put("KW_UPDATE", "UPDATE"); + xlateMap.put("KW_VALUES", "VALUES"); + xlateMap.put("KW_PURGE", "PURGE"); + + + // Operators + xlateMap.put("DOT", "."); + xlateMap.put("COLON", ":"); + xlateMap.put("COMMA", ","); + xlateMap.put("SEMICOLON", ");"); + + xlateMap.put("LPAREN", "("); + xlateMap.put("RPAREN", ")"); + xlateMap.put("LSQUARE", "["); + xlateMap.put("RSQUARE", "]"); + + xlateMap.put("EQUAL", "="); + xlateMap.put("NOTEQUAL", "<>"); + xlateMap.put("EQUAL_NS", "<=>"); + xlateMap.put("LESSTHANOREQUALTO", "<="); + xlateMap.put("LESSTHAN", "<"); + xlateMap.put("GREATERTHANOREQUALTO", ">="); + xlateMap.put("GREATERTHAN", ">"); + + xlateMap.put("DIVIDE", "/"); + xlateMap.put("PLUS", "+"); + xlateMap.put("MINUS", "-"); + xlateMap.put("STAR", "*"); + xlateMap.put("MOD", "\%"); + + xlateMap.put("AMPERSAND", "&"); + xlateMap.put("TILDE", "~"); + xlateMap.put("BITWISEOR", "|"); + xlateMap.put("BITWISEXOR", "^"); + xlateMap.put("CharSetLiteral", "\\'"); + } + + public static Collection getKeywords() { + return xlateMap.values(); + } + + private static String xlate(String name) { + + String ret = xlateMap.get(name); + if (ret == null) { + ret = name; + } + + return ret; + } + + @Override + public Object recoverFromMismatchedSet(IntStream input, + RecognitionException re, BitSet follow) throws RecognitionException { + throw re; + } + + @Override + public void displayRecognitionError(String[] tokenNames, RecognitionException e) { + if (reporter != null) { + reporter.report(this, e, tokenNames); + } + } + + @Override + public String getErrorHeader(RecognitionException e) { + String header = null; + if (e.charPositionInLine < 0 && input.LT(-1) != null) { + Token t = input.LT(-1); + header = "line " + t.getLine() + ":" + t.getCharPositionInLine(); + } else { + header = super.getErrorHeader(e); + } + + return header; + } + + @Override + public String getErrorMessage(RecognitionException e, String[] tokenNames) { + String msg = null; + + // Translate the token names to something that the user can understand + String[] xlateNames = new String[tokenNames.length]; + for (int i = 0; i < tokenNames.length; ++i) { + xlateNames[i] = SparkSqlParser.xlate(tokenNames[i]); + } + + if (e instanceof NoViableAltException) { + @SuppressWarnings("unused") + NoViableAltException nvae = (NoViableAltException) e; + // for development, can add + // "decision=<<"+nvae.grammarDecisionDescription+">>" + // and "(decision="+nvae.decisionNumber+") and + // "state "+nvae.stateNumber + msg = "cannot recognize input near" + + (input.LT(1) != null ? " " + getTokenErrorDisplay(input.LT(1)) : "") + + (input.LT(2) != null ? " " + getTokenErrorDisplay(input.LT(2)) : "") + + (input.LT(3) != null ? " " + getTokenErrorDisplay(input.LT(3)) : ""); + } else if (e instanceof MismatchedTokenException) { + MismatchedTokenException mte = (MismatchedTokenException) e; + msg = super.getErrorMessage(e, xlateNames) + (input.LT(-1) == null ? "":" near '" + input.LT(-1).getText()) + "'"; + } else if (e instanceof FailedPredicateException) { + FailedPredicateException fpe = (FailedPredicateException) e; + msg = "Failed to recognize predicate '" + fpe.token.getText() + "'. Failed rule: '" + fpe.ruleName + "'"; + } else { + msg = super.getErrorMessage(e, xlateNames); + } + + if (msgs.size() > 0) { + msg = msg + " in " + msgs.peek(); + } + return msg; + } + + public void pushMsg(String msg, RecognizerSharedState state) { + // ANTLR generated code does not wrap the @init code wit this backtracking check, + // even if the matching @after has it. If we have parser rules with that are doing + // some lookahead with syntactic predicates this can cause the push() and pop() calls + // to become unbalanced, so make sure both push/pop check the backtracking state. + if (state.backtracking == 0) { + msgs.push(msg); + } + } + + public void popMsg(RecognizerSharedState state) { + if (state.backtracking == 0) { + Object o = msgs.pop(); + } + } + + // counter to generate unique union aliases + private int aliasCounter; + private String generateUnionAlias() { + return "u_" + (++aliasCounter); + } + private char [] excludedCharForColumnName = {'.', ':'}; + private boolean containExcludedCharForCreateTableColumnName(String input) { + for(char c : excludedCharForColumnName) { + if(input.indexOf(c)>-1) { + return true; + } + } + return false; + } + private CommonTree throwSetOpException() throws RecognitionException { + throw new FailedPredicateException(input, "orderByClause clusterByClause distributeByClause sortByClause limitClause can only be applied to the whole union.", ""); + } + private CommonTree throwColumnNameException() throws RecognitionException { + throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", ""); + } + + private ParserConf parserConf; + private ParseErrorReporter reporter; + + public void configure(ParserConf parserConf, ParseErrorReporter reporter) { + this.parserConf = parserConf; + this.reporter = reporter; + } + + protected boolean useSQL11ReservedKeywordsForIdentifier() { + if (parserConf == null) { + return true; + } + return !parserConf.supportSQL11ReservedKeywords(); + } +} + +@rulecatch { +catch (RecognitionException e) { + reportError(e); + throw e; +} +} + +// starting rule +statement + : explainStatement EOF + | execStatement EOF + ; + +explainStatement +@init { pushMsg("explain statement", state); } +@after { popMsg(state); } + : KW_EXPLAIN ( + explainOption* execStatement -> ^(TOK_EXPLAIN execStatement explainOption*) + | + KW_REWRITE queryStatementExpression[true] -> ^(TOK_EXPLAIN_SQ_REWRITE queryStatementExpression)) + ; + +explainOption +@init { msgs.push("explain option"); } +@after { msgs.pop(); } + : KW_EXTENDED|KW_FORMATTED|KW_DEPENDENCY|KW_LOGICAL|KW_AUTHORIZATION + ; + +execStatement +@init { pushMsg("statement", state); } +@after { popMsg(state); } + : queryStatementExpression[true] + | loadStatement + | exportStatement + | importStatement + | ddlStatement + | deleteStatement + | updateStatement + | sqlTransactionStatement + ; + +loadStatement +@init { pushMsg("load statement", state); } +@after { popMsg(state); } + : KW_LOAD KW_DATA (islocal=KW_LOCAL)? KW_INPATH (path=StringLiteral) (isoverwrite=KW_OVERWRITE)? KW_INTO KW_TABLE (tab=tableOrPartition) + -> ^(TOK_LOAD $path $tab $islocal? $isoverwrite?) + ; + +replicationClause +@init { pushMsg("replication clause", state); } +@after { popMsg(state); } + : KW_FOR (isMetadataOnly=KW_METADATA)? KW_REPLICATION LPAREN (replId=StringLiteral) RPAREN + -> ^(TOK_REPLICATION $replId $isMetadataOnly?) + ; + +exportStatement +@init { pushMsg("export statement", state); } +@after { popMsg(state); } + : KW_EXPORT + KW_TABLE (tab=tableOrPartition) + KW_TO (path=StringLiteral) + replicationClause? + -> ^(TOK_EXPORT $tab $path replicationClause?) + ; + +importStatement +@init { pushMsg("import statement", state); } +@after { popMsg(state); } + : KW_IMPORT + ((ext=KW_EXTERNAL)? KW_TABLE (tab=tableOrPartition))? + KW_FROM (path=StringLiteral) + tableLocation? + -> ^(TOK_IMPORT $path $tab? $ext? tableLocation?) + ; + +ddlStatement +@init { pushMsg("ddl statement", state); } +@after { popMsg(state); } + : createDatabaseStatement + | switchDatabaseStatement + | dropDatabaseStatement + | createTableStatement + | dropTableStatement + | truncateTableStatement + | alterStatement + | descStatement + | showStatement + | metastoreCheck + | createViewStatement + | dropViewStatement + | createFunctionStatement + | createMacroStatement + | createIndexStatement + | dropIndexStatement + | dropFunctionStatement + | reloadFunctionStatement + | dropMacroStatement + | analyzeStatement + | lockStatement + | unlockStatement + | lockDatabase + | unlockDatabase + | createRoleStatement + | dropRoleStatement + | (grantPrivileges) => grantPrivileges + | (revokePrivileges) => revokePrivileges + | showGrants + | showRoleGrants + | showRolePrincipals + | showRoles + | grantRole + | revokeRole + | setRole + | showCurrentRole + ; + +ifExists +@init { pushMsg("if exists clause", state); } +@after { popMsg(state); } + : KW_IF KW_EXISTS + -> ^(TOK_IFEXISTS) + ; + +restrictOrCascade +@init { pushMsg("restrict or cascade clause", state); } +@after { popMsg(state); } + : KW_RESTRICT + -> ^(TOK_RESTRICT) + | KW_CASCADE + -> ^(TOK_CASCADE) + ; + +ifNotExists +@init { pushMsg("if not exists clause", state); } +@after { popMsg(state); } + : KW_IF KW_NOT KW_EXISTS + -> ^(TOK_IFNOTEXISTS) + ; + +storedAsDirs +@init { pushMsg("stored as directories", state); } +@after { popMsg(state); } + : KW_STORED KW_AS KW_DIRECTORIES + -> ^(TOK_STOREDASDIRS) + ; + +orReplace +@init { pushMsg("or replace clause", state); } +@after { popMsg(state); } + : KW_OR KW_REPLACE + -> ^(TOK_ORREPLACE) + ; + +createDatabaseStatement +@init { pushMsg("create database statement", state); } +@after { popMsg(state); } + : KW_CREATE (KW_DATABASE|KW_SCHEMA) + ifNotExists? + name=identifier + databaseComment? + dbLocation? + (KW_WITH KW_DBPROPERTIES dbprops=dbProperties)? + -> ^(TOK_CREATEDATABASE $name ifNotExists? dbLocation? databaseComment? $dbprops?) + ; + +dbLocation +@init { pushMsg("database location specification", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_DATABASELOCATION $locn) + ; + +dbProperties +@init { pushMsg("dbproperties", state); } +@after { popMsg(state); } + : + LPAREN dbPropertiesList RPAREN -> ^(TOK_DATABASEPROPERTIES dbPropertiesList) + ; + +dbPropertiesList +@init { pushMsg("database properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_DBPROPLIST keyValueProperty+) + ; + + +switchDatabaseStatement +@init { pushMsg("switch database statement", state); } +@after { popMsg(state); } + : KW_USE identifier + -> ^(TOK_SWITCHDATABASE identifier) + ; + +dropDatabaseStatement +@init { pushMsg("drop database statement", state); } +@after { popMsg(state); } + : KW_DROP (KW_DATABASE|KW_SCHEMA) ifExists? identifier restrictOrCascade? + -> ^(TOK_DROPDATABASE identifier ifExists? restrictOrCascade?) + ; + +databaseComment +@init { pushMsg("database's comment", state); } +@after { popMsg(state); } + : KW_COMMENT comment=StringLiteral + -> ^(TOK_DATABASECOMMENT $comment) + ; + +createTableStatement +@init { pushMsg("create table statement", state); } +@after { popMsg(state); } + : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName + ( like=KW_LIKE likeName=tableName + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + | (LPAREN columnNameTypeList RPAREN)? + tableComment? + tablePartition? + tableBuckets? + tableSkewed? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + (KW_AS selectStatementWithCTE)? + ) + -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? + ^(TOK_LIKETABLE $likeName?) + columnNameTypeList? + tableComment? + tablePartition? + tableBuckets? + tableSkewed? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + selectStatementWithCTE? + ) + ; + +truncateTableStatement +@init { pushMsg("truncate table statement", state); } +@after { popMsg(state); } + : KW_TRUNCATE KW_TABLE tablePartitionPrefix (KW_COLUMNS LPAREN columnNameList RPAREN)? -> ^(TOK_TRUNCATETABLE tablePartitionPrefix columnNameList?); + +createIndexStatement +@init { pushMsg("create index statement", state);} +@after {popMsg(state);} + : KW_CREATE KW_INDEX indexName=identifier + KW_ON KW_TABLE tab=tableName LPAREN indexedCols=columnNameList RPAREN + KW_AS typeName=StringLiteral + autoRebuild? + indexPropertiesPrefixed? + indexTblName? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + indexComment? + ->^(TOK_CREATEINDEX $indexName $typeName $tab $indexedCols + autoRebuild? + indexPropertiesPrefixed? + indexTblName? + tableRowFormat? + tableFileFormat? + tableLocation? + tablePropertiesPrefixed? + indexComment?) + ; + +indexComment +@init { pushMsg("comment on an index", state);} +@after {popMsg(state);} + : + KW_COMMENT comment=StringLiteral -> ^(TOK_INDEXCOMMENT $comment) + ; + +autoRebuild +@init { pushMsg("auto rebuild index", state);} +@after {popMsg(state);} + : KW_WITH KW_DEFERRED KW_REBUILD + ->^(TOK_DEFERRED_REBUILDINDEX) + ; + +indexTblName +@init { pushMsg("index table name", state);} +@after {popMsg(state);} + : KW_IN KW_TABLE indexTbl=tableName + ->^(TOK_CREATEINDEX_INDEXTBLNAME $indexTbl) + ; + +indexPropertiesPrefixed +@init { pushMsg("table properties with prefix", state); } +@after { popMsg(state); } + : + KW_IDXPROPERTIES! indexProperties + ; + +indexProperties +@init { pushMsg("index properties", state); } +@after { popMsg(state); } + : + LPAREN indexPropertiesList RPAREN -> ^(TOK_INDEXPROPERTIES indexPropertiesList) + ; + +indexPropertiesList +@init { pushMsg("index properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_INDEXPROPLIST keyValueProperty+) + ; + +dropIndexStatement +@init { pushMsg("drop index statement", state);} +@after {popMsg(state);} + : KW_DROP KW_INDEX ifExists? indexName=identifier KW_ON tab=tableName + ->^(TOK_DROPINDEX $indexName $tab ifExists?) + ; + +dropTableStatement +@init { pushMsg("drop statement", state); } +@after { popMsg(state); } + : KW_DROP KW_TABLE ifExists? tableName KW_PURGE? replicationClause? + -> ^(TOK_DROPTABLE tableName ifExists? KW_PURGE? replicationClause?) + ; + +alterStatement +@init { pushMsg("alter statement", state); } +@after { popMsg(state); } + : KW_ALTER KW_TABLE tableName alterTableStatementSuffix -> ^(TOK_ALTERTABLE tableName alterTableStatementSuffix) + | KW_ALTER KW_VIEW tableName KW_AS? alterViewStatementSuffix -> ^(TOK_ALTERVIEW tableName alterViewStatementSuffix) + | KW_ALTER KW_INDEX alterIndexStatementSuffix -> alterIndexStatementSuffix + | KW_ALTER (KW_DATABASE|KW_SCHEMA) alterDatabaseStatementSuffix -> alterDatabaseStatementSuffix + ; + +alterTableStatementSuffix +@init { pushMsg("alter table statement", state); } +@after { popMsg(state); } + : (alterStatementSuffixRename[true]) => alterStatementSuffixRename[true] + | alterStatementSuffixDropPartitions[true] + | alterStatementSuffixAddPartitions[true] + | alterStatementSuffixTouch + | alterStatementSuffixArchive + | alterStatementSuffixUnArchive + | alterStatementSuffixProperties + | alterStatementSuffixSkewedby + | alterStatementSuffixExchangePartition + | alterStatementPartitionKeyType + | partitionSpec? alterTblPartitionStatementSuffix -> alterTblPartitionStatementSuffix partitionSpec? + ; + +alterTblPartitionStatementSuffix +@init {pushMsg("alter table partition statement suffix", state);} +@after {popMsg(state);} + : alterStatementSuffixFileFormat + | alterStatementSuffixLocation + | alterStatementSuffixMergeFiles + | alterStatementSuffixSerdeProperties + | alterStatementSuffixRenamePart + | alterStatementSuffixBucketNum + | alterTblPartitionStatementSuffixSkewedLocation + | alterStatementSuffixClusterbySortby + | alterStatementSuffixCompact + | alterStatementSuffixUpdateStatsCol + | alterStatementSuffixRenameCol + | alterStatementSuffixAddCol + ; + +alterStatementPartitionKeyType +@init {msgs.push("alter partition key type"); } +@after {msgs.pop();} + : KW_PARTITION KW_COLUMN LPAREN columnNameType RPAREN + -> ^(TOK_ALTERTABLE_PARTCOLTYPE columnNameType) + ; + +alterViewStatementSuffix +@init { pushMsg("alter view statement", state); } +@after { popMsg(state); } + : alterViewSuffixProperties + | alterStatementSuffixRename[false] + | alterStatementSuffixAddPartitions[false] + | alterStatementSuffixDropPartitions[false] + | selectStatementWithCTE + ; + +alterIndexStatementSuffix +@init { pushMsg("alter index statement", state); } +@after { popMsg(state); } + : indexName=identifier KW_ON tableName partitionSpec? + ( + KW_REBUILD + ->^(TOK_ALTERINDEX_REBUILD tableName $indexName partitionSpec?) + | + KW_SET KW_IDXPROPERTIES + indexProperties + ->^(TOK_ALTERINDEX_PROPERTIES tableName $indexName indexProperties) + ) + ; + +alterDatabaseStatementSuffix +@init { pushMsg("alter database statement", state); } +@after { popMsg(state); } + : alterDatabaseSuffixProperties + | alterDatabaseSuffixSetOwner + ; + +alterDatabaseSuffixProperties +@init { pushMsg("alter database properties statement", state); } +@after { popMsg(state); } + : name=identifier KW_SET KW_DBPROPERTIES dbProperties + -> ^(TOK_ALTERDATABASE_PROPERTIES $name dbProperties) + ; + +alterDatabaseSuffixSetOwner +@init { pushMsg("alter database set owner", state); } +@after { popMsg(state); } + : dbName=identifier KW_SET KW_OWNER principalName + -> ^(TOK_ALTERDATABASE_OWNER $dbName principalName) + ; + +alterStatementSuffixRename[boolean table] +@init { pushMsg("rename statement", state); } +@after { popMsg(state); } + : KW_RENAME KW_TO tableName + -> { table }? ^(TOK_ALTERTABLE_RENAME tableName) + -> ^(TOK_ALTERVIEW_RENAME tableName) + ; + +alterStatementSuffixAddCol +@init { pushMsg("add column statement", state); } +@after { popMsg(state); } + : (add=KW_ADD | replace=KW_REPLACE) KW_COLUMNS LPAREN columnNameTypeList RPAREN restrictOrCascade? + -> {$add != null}? ^(TOK_ALTERTABLE_ADDCOLS columnNameTypeList restrictOrCascade?) + -> ^(TOK_ALTERTABLE_REPLACECOLS columnNameTypeList restrictOrCascade?) + ; + +alterStatementSuffixRenameCol +@init { pushMsg("rename column name", state); } +@after { popMsg(state); } + : KW_CHANGE KW_COLUMN? oldName=identifier newName=identifier colType (KW_COMMENT comment=StringLiteral)? alterStatementChangeColPosition? restrictOrCascade? + ->^(TOK_ALTERTABLE_RENAMECOL $oldName $newName colType $comment? alterStatementChangeColPosition? restrictOrCascade?) + ; + +alterStatementSuffixUpdateStatsCol +@init { pushMsg("update column statistics", state); } +@after { popMsg(state); } + : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? + ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) + ; + +alterStatementChangeColPosition + : first=KW_FIRST|KW_AFTER afterCol=identifier + ->{$first != null}? ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION ) + -> ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION $afterCol) + ; + +alterStatementSuffixAddPartitions[boolean table] +@init { pushMsg("add partition statement", state); } +@after { popMsg(state); } + : KW_ADD ifNotExists? alterStatementSuffixAddPartitionsElement+ + -> { table }? ^(TOK_ALTERTABLE_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) + -> ^(TOK_ALTERVIEW_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) + ; + +alterStatementSuffixAddPartitionsElement + : partitionSpec partitionLocation? + ; + +alterStatementSuffixTouch +@init { pushMsg("touch statement", state); } +@after { popMsg(state); } + : KW_TOUCH (partitionSpec)* + -> ^(TOK_ALTERTABLE_TOUCH (partitionSpec)*) + ; + +alterStatementSuffixArchive +@init { pushMsg("archive statement", state); } +@after { popMsg(state); } + : KW_ARCHIVE (partitionSpec)* + -> ^(TOK_ALTERTABLE_ARCHIVE (partitionSpec)*) + ; + +alterStatementSuffixUnArchive +@init { pushMsg("unarchive statement", state); } +@after { popMsg(state); } + : KW_UNARCHIVE (partitionSpec)* + -> ^(TOK_ALTERTABLE_UNARCHIVE (partitionSpec)*) + ; + +partitionLocation +@init { pushMsg("partition location", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_PARTITIONLOCATION $locn) + ; + +alterStatementSuffixDropPartitions[boolean table] +@init { pushMsg("drop partition statement", state); } +@after { popMsg(state); } + : KW_DROP ifExists? dropPartitionSpec (COMMA dropPartitionSpec)* KW_PURGE? replicationClause? + -> { table }? ^(TOK_ALTERTABLE_DROPPARTS dropPartitionSpec+ ifExists? KW_PURGE? replicationClause?) + -> ^(TOK_ALTERVIEW_DROPPARTS dropPartitionSpec+ ifExists? replicationClause?) + ; + +alterStatementSuffixProperties +@init { pushMsg("alter properties statement", state); } +@after { popMsg(state); } + : KW_SET KW_TBLPROPERTIES tableProperties + -> ^(TOK_ALTERTABLE_PROPERTIES tableProperties) + | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties + -> ^(TOK_ALTERTABLE_DROPPROPERTIES tableProperties ifExists?) + ; + +alterViewSuffixProperties +@init { pushMsg("alter view properties statement", state); } +@after { popMsg(state); } + : KW_SET KW_TBLPROPERTIES tableProperties + -> ^(TOK_ALTERVIEW_PROPERTIES tableProperties) + | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties + -> ^(TOK_ALTERVIEW_DROPPROPERTIES tableProperties ifExists?) + ; + +alterStatementSuffixSerdeProperties +@init { pushMsg("alter serdes statement", state); } +@after { popMsg(state); } + : KW_SET KW_SERDE serdeName=StringLiteral (KW_WITH KW_SERDEPROPERTIES tableProperties)? + -> ^(TOK_ALTERTABLE_SERIALIZER $serdeName tableProperties?) + | KW_SET KW_SERDEPROPERTIES tableProperties + -> ^(TOK_ALTERTABLE_SERDEPROPERTIES tableProperties) + ; + +tablePartitionPrefix +@init {pushMsg("table partition prefix", state);} +@after {popMsg(state);} + : tableName partitionSpec? + ->^(TOK_TABLE_PARTITION tableName partitionSpec?) + ; + +alterStatementSuffixFileFormat +@init {pushMsg("alter fileformat statement", state); } +@after {popMsg(state);} + : KW_SET KW_FILEFORMAT fileFormat + -> ^(TOK_ALTERTABLE_FILEFORMAT fileFormat) + ; + +alterStatementSuffixClusterbySortby +@init {pushMsg("alter partition cluster by sort by statement", state);} +@after {popMsg(state);} + : KW_NOT KW_CLUSTERED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_CLUSTERED) + | KW_NOT KW_SORTED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_SORTED) + | tableBuckets -> ^(TOK_ALTERTABLE_CLUSTER_SORT tableBuckets) + ; + +alterTblPartitionStatementSuffixSkewedLocation +@init {pushMsg("alter partition skewed location", state);} +@after {popMsg(state);} + : KW_SET KW_SKEWED KW_LOCATION skewedLocations + -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) + ; + +skewedLocations +@init { pushMsg("skewed locations", state); } +@after { popMsg(state); } + : + LPAREN skewedLocationsList RPAREN -> ^(TOK_SKEWED_LOCATIONS skewedLocationsList) + ; + +skewedLocationsList +@init { pushMsg("skewed locations list", state); } +@after { popMsg(state); } + : + skewedLocationMap (COMMA skewedLocationMap)* -> ^(TOK_SKEWED_LOCATION_LIST skewedLocationMap+) + ; + +skewedLocationMap +@init { pushMsg("specifying skewed location map", state); } +@after { popMsg(state); } + : + key=skewedValueLocationElement EQUAL value=StringLiteral -> ^(TOK_SKEWED_LOCATION_MAP $key $value) + ; + +alterStatementSuffixLocation +@init {pushMsg("alter location", state);} +@after {popMsg(state);} + : KW_SET KW_LOCATION newLoc=StringLiteral + -> ^(TOK_ALTERTABLE_LOCATION $newLoc) + ; + + +alterStatementSuffixSkewedby +@init {pushMsg("alter skewed by statement", state);} +@after{popMsg(state);} + : tableSkewed + ->^(TOK_ALTERTABLE_SKEWED tableSkewed) + | + KW_NOT KW_SKEWED + ->^(TOK_ALTERTABLE_SKEWED) + | + KW_NOT storedAsDirs + ->^(TOK_ALTERTABLE_SKEWED storedAsDirs) + ; + +alterStatementSuffixExchangePartition +@init {pushMsg("alter exchange partition", state);} +@after{popMsg(state);} + : KW_EXCHANGE partitionSpec KW_WITH KW_TABLE exchangename=tableName + -> ^(TOK_ALTERTABLE_EXCHANGEPARTITION partitionSpec $exchangename) + ; + +alterStatementSuffixRenamePart +@init { pushMsg("alter table rename partition statement", state); } +@after { popMsg(state); } + : KW_RENAME KW_TO partitionSpec + ->^(TOK_ALTERTABLE_RENAMEPART partitionSpec) + ; + +alterStatementSuffixStatsPart +@init { pushMsg("alter table stats partition statement", state); } +@after { popMsg(state); } + : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? + ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) + ; + +alterStatementSuffixMergeFiles +@init { pushMsg("", state); } +@after { popMsg(state); } + : KW_CONCATENATE + -> ^(TOK_ALTERTABLE_MERGEFILES) + ; + +alterStatementSuffixBucketNum +@init { pushMsg("", state); } +@after { popMsg(state); } + : KW_INTO num=Number KW_BUCKETS + -> ^(TOK_ALTERTABLE_BUCKETS $num) + ; + +alterStatementSuffixCompact +@init { msgs.push("compaction request"); } +@after { msgs.pop(); } + : KW_COMPACT compactType=StringLiteral + -> ^(TOK_ALTERTABLE_COMPACT $compactType) + ; + + +fileFormat +@init { pushMsg("file format specification", state); } +@after { popMsg(state); } + : KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral KW_SERDE serdeCls=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? + -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $serdeCls $inDriver? $outDriver?) + | genericSpec=identifier -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) + ; + +tabTypeExpr +@init { pushMsg("specifying table types", state); } +@after { popMsg(state); } + : identifier (DOT^ identifier)? + (identifier (DOT^ + ( + (KW_ELEM_TYPE) => KW_ELEM_TYPE + | + (KW_KEY_TYPE) => KW_KEY_TYPE + | + (KW_VALUE_TYPE) => KW_VALUE_TYPE + | identifier + ))* + )? + ; + +partTypeExpr +@init { pushMsg("specifying table partitions", state); } +@after { popMsg(state); } + : tabTypeExpr partitionSpec? -> ^(TOK_TABTYPE tabTypeExpr partitionSpec?) + ; + +tabPartColTypeExpr +@init { pushMsg("specifying table partitions columnName", state); } +@after { popMsg(state); } + : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?) + ; + +descStatement +@init { pushMsg("describe statement", state); } +@after { popMsg(state); } + : + (KW_DESCRIBE|KW_DESC) + ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) KW_EXTENDED? (dbName=identifier) -> ^(TOK_DESCDATABASE $dbName KW_EXTENDED?) + | + (KW_FUNCTION) => KW_FUNCTION KW_EXTENDED? (name=descFuncNames) -> ^(TOK_DESCFUNCTION $name KW_EXTENDED?) + | + (KW_FORMATTED|KW_EXTENDED|KW_PRETTY) => ((descOptions=KW_FORMATTED|descOptions=KW_EXTENDED|descOptions=KW_PRETTY) parttype=tabPartColTypeExpr) -> ^(TOK_DESCTABLE $parttype $descOptions) + | + parttype=tabPartColTypeExpr -> ^(TOK_DESCTABLE $parttype) + ) + ; + +analyzeStatement +@init { pushMsg("analyze statement", state); } +@after { popMsg(state); } + : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) + | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? + -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) + ; + +showStatement +@init { pushMsg("show statement", state); } +@after { popMsg(state); } + : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?) + | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES (TOK_FROM $db_name)? showStmtIdentifier?) + | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? + -> ^(TOK_SHOWCOLUMNS tableName $db_name?) + | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) + | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) + | KW_SHOW KW_CREATE ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) + | + KW_TABLE tabName=tableName -> ^(TOK_SHOW_CREATETABLE $tabName) + ) + | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? + -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) + | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) + | KW_SHOW KW_LOCKS + ( + (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) + | + (parttype=partTypeExpr)? (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWLOCKS $parttype? $isExtended?) + ) + | KW_SHOW (showOptions=KW_FORMATTED)? (KW_INDEX|KW_INDEXES) KW_ON showStmtIdentifier ((KW_FROM|KW_IN) db_name=identifier)? + -> ^(TOK_SHOWINDEXES showStmtIdentifier $showOptions? $db_name?) + | KW_SHOW KW_COMPACTIONS -> ^(TOK_SHOW_COMPACTIONS) + | KW_SHOW KW_TRANSACTIONS -> ^(TOK_SHOW_TRANSACTIONS) + | KW_SHOW KW_CONF StringLiteral -> ^(TOK_SHOWCONF StringLiteral) + ; + +lockStatement +@init { pushMsg("lock statement", state); } +@after { popMsg(state); } + : KW_LOCK KW_TABLE tableName partitionSpec? lockMode -> ^(TOK_LOCKTABLE tableName lockMode partitionSpec?) + ; + +lockDatabase +@init { pushMsg("lock database statement", state); } +@after { popMsg(state); } + : KW_LOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) lockMode -> ^(TOK_LOCKDB $dbName lockMode) + ; + +lockMode +@init { pushMsg("lock mode", state); } +@after { popMsg(state); } + : KW_SHARED | KW_EXCLUSIVE + ; + +unlockStatement +@init { pushMsg("unlock statement", state); } +@after { popMsg(state); } + : KW_UNLOCK KW_TABLE tableName partitionSpec? -> ^(TOK_UNLOCKTABLE tableName partitionSpec?) + ; + +unlockDatabase +@init { pushMsg("unlock database statement", state); } +@after { popMsg(state); } + : KW_UNLOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) -> ^(TOK_UNLOCKDB $dbName) + ; + +createRoleStatement +@init { pushMsg("create role", state); } +@after { popMsg(state); } + : KW_CREATE KW_ROLE roleName=identifier + -> ^(TOK_CREATEROLE $roleName) + ; + +dropRoleStatement +@init {pushMsg("drop role", state);} +@after {popMsg(state);} + : KW_DROP KW_ROLE roleName=identifier + -> ^(TOK_DROPROLE $roleName) + ; + +grantPrivileges +@init {pushMsg("grant privileges", state);} +@after {popMsg(state);} + : KW_GRANT privList=privilegeList + privilegeObject? + KW_TO principalSpecification + withGrantOption? + -> ^(TOK_GRANT $privList principalSpecification privilegeObject? withGrantOption?) + ; + +revokePrivileges +@init {pushMsg("revoke privileges", state);} +@afer {popMsg(state);} + : KW_REVOKE grantOptionFor? privilegeList privilegeObject? KW_FROM principalSpecification + -> ^(TOK_REVOKE privilegeList principalSpecification privilegeObject? grantOptionFor?) + ; + +grantRole +@init {pushMsg("grant role", state);} +@after {popMsg(state);} + : KW_GRANT KW_ROLE? identifier (COMMA identifier)* KW_TO principalSpecification withAdminOption? + -> ^(TOK_GRANT_ROLE principalSpecification withAdminOption? identifier+) + ; + +revokeRole +@init {pushMsg("revoke role", state);} +@after {popMsg(state);} + : KW_REVOKE adminOptionFor? KW_ROLE? identifier (COMMA identifier)* KW_FROM principalSpecification + -> ^(TOK_REVOKE_ROLE principalSpecification adminOptionFor? identifier+) + ; + +showRoleGrants +@init {pushMsg("show role grants", state);} +@after {popMsg(state);} + : KW_SHOW KW_ROLE KW_GRANT principalName + -> ^(TOK_SHOW_ROLE_GRANT principalName) + ; + + +showRoles +@init {pushMsg("show roles", state);} +@after {popMsg(state);} + : KW_SHOW KW_ROLES + -> ^(TOK_SHOW_ROLES) + ; + +showCurrentRole +@init {pushMsg("show current role", state);} +@after {popMsg(state);} + : KW_SHOW KW_CURRENT KW_ROLES + -> ^(TOK_SHOW_SET_ROLE) + ; + +setRole +@init {pushMsg("set role", state);} +@after {popMsg(state);} + : KW_SET KW_ROLE + ( + (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) + | + (KW_NONE) => (none=KW_NONE) -> ^(TOK_SHOW_SET_ROLE Identifier[$none.text]) + | + identifier -> ^(TOK_SHOW_SET_ROLE identifier) + ) + ; + +showGrants +@init {pushMsg("show grants", state);} +@after {popMsg(state);} + : KW_SHOW KW_GRANT principalName? (KW_ON privilegeIncludeColObject)? + -> ^(TOK_SHOW_GRANT principalName? privilegeIncludeColObject?) + ; + +showRolePrincipals +@init {pushMsg("show role principals", state);} +@after {popMsg(state);} + : KW_SHOW KW_PRINCIPALS roleName=identifier + -> ^(TOK_SHOW_ROLE_PRINCIPALS $roleName) + ; + + +privilegeIncludeColObject +@init {pushMsg("privilege object including columns", state);} +@after {popMsg(state);} + : (KW_ALL) => KW_ALL -> ^(TOK_RESOURCE_ALL) + | privObjectCols -> ^(TOK_PRIV_OBJECT_COL privObjectCols) + ; + +privilegeObject +@init {pushMsg("privilege object", state);} +@after {popMsg(state);} + : KW_ON privObject -> ^(TOK_PRIV_OBJECT privObject) + ; + +// database or table type. Type is optional, default type is table +privObject + : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) + | KW_TABLE? tableName partitionSpec? -> ^(TOK_TABLE_TYPE tableName partitionSpec?) + | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) + | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) + ; + +privObjectCols + : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) + | KW_TABLE? tableName (LPAREN cols=columnNameList RPAREN)? partitionSpec? -> ^(TOK_TABLE_TYPE tableName $cols? partitionSpec?) + | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) + | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) + ; + +privilegeList +@init {pushMsg("grant privilege list", state);} +@after {popMsg(state);} + : privlegeDef (COMMA privlegeDef)* + -> ^(TOK_PRIVILEGE_LIST privlegeDef+) + ; + +privlegeDef +@init {pushMsg("grant privilege", state);} +@after {popMsg(state);} + : privilegeType (LPAREN cols=columnNameList RPAREN)? + -> ^(TOK_PRIVILEGE privilegeType $cols?) + ; + +privilegeType +@init {pushMsg("privilege type", state);} +@after {popMsg(state);} + : KW_ALL -> ^(TOK_PRIV_ALL) + | KW_ALTER -> ^(TOK_PRIV_ALTER_METADATA) + | KW_UPDATE -> ^(TOK_PRIV_ALTER_DATA) + | KW_CREATE -> ^(TOK_PRIV_CREATE) + | KW_DROP -> ^(TOK_PRIV_DROP) + | KW_INDEX -> ^(TOK_PRIV_INDEX) + | KW_LOCK -> ^(TOK_PRIV_LOCK) + | KW_SELECT -> ^(TOK_PRIV_SELECT) + | KW_SHOW_DATABASE -> ^(TOK_PRIV_SHOW_DATABASE) + | KW_INSERT -> ^(TOK_PRIV_INSERT) + | KW_DELETE -> ^(TOK_PRIV_DELETE) + ; + +principalSpecification +@init { pushMsg("user/group/role name list", state); } +@after { popMsg(state); } + : principalName (COMMA principalName)* -> ^(TOK_PRINCIPAL_NAME principalName+) + ; + +principalName +@init {pushMsg("user|group|role name", state);} +@after {popMsg(state);} + : KW_USER principalIdentifier -> ^(TOK_USER principalIdentifier) + | KW_GROUP principalIdentifier -> ^(TOK_GROUP principalIdentifier) + | KW_ROLE identifier -> ^(TOK_ROLE identifier) + ; + +withGrantOption +@init {pushMsg("with grant option", state);} +@after {popMsg(state);} + : KW_WITH KW_GRANT KW_OPTION + -> ^(TOK_GRANT_WITH_OPTION) + ; + +grantOptionFor +@init {pushMsg("grant option for", state);} +@after {popMsg(state);} + : KW_GRANT KW_OPTION KW_FOR + -> ^(TOK_GRANT_OPTION_FOR) +; + +adminOptionFor +@init {pushMsg("admin option for", state);} +@after {popMsg(state);} + : KW_ADMIN KW_OPTION KW_FOR + -> ^(TOK_ADMIN_OPTION_FOR) +; + +withAdminOption +@init {pushMsg("with admin option", state);} +@after {popMsg(state);} + : KW_WITH KW_ADMIN KW_OPTION + -> ^(TOK_GRANT_WITH_ADMIN_OPTION) + ; + +metastoreCheck +@init { pushMsg("metastore check statement", state); } +@after { popMsg(state); } + : KW_MSCK (repair=KW_REPAIR)? (KW_TABLE tableName partitionSpec? (COMMA partitionSpec)*)? + -> ^(TOK_MSCK $repair? (tableName partitionSpec*)?) + ; + +resourceList +@init { pushMsg("resource list", state); } +@after { popMsg(state); } + : + resource (COMMA resource)* -> ^(TOK_RESOURCE_LIST resource+) + ; + +resource +@init { pushMsg("resource", state); } +@after { popMsg(state); } + : + resType=resourceType resPath=StringLiteral -> ^(TOK_RESOURCE_URI $resType $resPath) + ; + +resourceType +@init { pushMsg("resource type", state); } +@after { popMsg(state); } + : + KW_JAR -> ^(TOK_JAR) + | + KW_FILE -> ^(TOK_FILE) + | + KW_ARCHIVE -> ^(TOK_ARCHIVE) + ; + +createFunctionStatement +@init { pushMsg("create function statement", state); } +@after { popMsg(state); } + : KW_CREATE (temp=KW_TEMPORARY)? KW_FUNCTION functionIdentifier KW_AS StringLiteral + (KW_USING rList=resourceList)? + -> {$temp != null}? ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList? TOK_TEMPORARY) + -> ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList?) + ; + +dropFunctionStatement +@init { pushMsg("drop function statement", state); } +@after { popMsg(state); } + : KW_DROP (temp=KW_TEMPORARY)? KW_FUNCTION ifExists? functionIdentifier + -> {$temp != null}? ^(TOK_DROPFUNCTION functionIdentifier ifExists? TOK_TEMPORARY) + -> ^(TOK_DROPFUNCTION functionIdentifier ifExists?) + ; + +reloadFunctionStatement +@init { pushMsg("reload function statement", state); } +@after { popMsg(state); } + : KW_RELOAD KW_FUNCTION -> ^(TOK_RELOADFUNCTION); + +createMacroStatement +@init { pushMsg("create macro statement", state); } +@after { popMsg(state); } + : KW_CREATE KW_TEMPORARY KW_MACRO Identifier + LPAREN columnNameTypeList? RPAREN expression + -> ^(TOK_CREATEMACRO Identifier columnNameTypeList? expression) + ; + +dropMacroStatement +@init { pushMsg("drop macro statement", state); } +@after { popMsg(state); } + : KW_DROP KW_TEMPORARY KW_MACRO ifExists? Identifier + -> ^(TOK_DROPMACRO Identifier ifExists?) + ; + +createViewStatement +@init { + pushMsg("create view statement", state); +} +@after { popMsg(state); } + : KW_CREATE (orReplace)? KW_VIEW (ifNotExists)? name=tableName + (LPAREN columnNameCommentList RPAREN)? tableComment? viewPartition? + tablePropertiesPrefixed? + KW_AS + selectStatementWithCTE + -> ^(TOK_CREATEVIEW $name orReplace? + ifNotExists? + columnNameCommentList? + tableComment? + viewPartition? + tablePropertiesPrefixed? + selectStatementWithCTE + ) + ; + +viewPartition +@init { pushMsg("view partition specification", state); } +@after { popMsg(state); } + : KW_PARTITIONED KW_ON LPAREN columnNameList RPAREN + -> ^(TOK_VIEWPARTCOLS columnNameList) + ; + +dropViewStatement +@init { pushMsg("drop view statement", state); } +@after { popMsg(state); } + : KW_DROP KW_VIEW ifExists? viewName -> ^(TOK_DROPVIEW viewName ifExists?) + ; + +showFunctionIdentifier +@init { pushMsg("identifier for show function statement", state); } +@after { popMsg(state); } + : functionIdentifier + | StringLiteral + ; + +showStmtIdentifier +@init { pushMsg("identifier for show statement", state); } +@after { popMsg(state); } + : identifier + | StringLiteral + ; + +tableComment +@init { pushMsg("table's comment", state); } +@after { popMsg(state); } + : + KW_COMMENT comment=StringLiteral -> ^(TOK_TABLECOMMENT $comment) + ; + +tablePartition +@init { pushMsg("table partition specification", state); } +@after { popMsg(state); } + : KW_PARTITIONED KW_BY LPAREN columnNameTypeList RPAREN + -> ^(TOK_TABLEPARTCOLS columnNameTypeList) + ; + +tableBuckets +@init { pushMsg("table buckets specification", state); } +@after { popMsg(state); } + : + KW_CLUSTERED KW_BY LPAREN bucketCols=columnNameList RPAREN (KW_SORTED KW_BY LPAREN sortCols=columnNameOrderList RPAREN)? KW_INTO num=Number KW_BUCKETS + -> ^(TOK_ALTERTABLE_BUCKETS $bucketCols $sortCols? $num) + ; + +tableSkewed +@init { pushMsg("table skewed specification", state); } +@after { popMsg(state); } + : + KW_SKEWED KW_BY LPAREN skewedCols=columnNameList RPAREN KW_ON LPAREN (skewedValues=skewedValueElement) RPAREN ((storedAsDirs) => storedAsDirs)? + -> ^(TOK_TABLESKEWED $skewedCols $skewedValues storedAsDirs?) + ; + +rowFormat +@init { pushMsg("serde specification", state); } +@after { popMsg(state); } + : rowFormatSerde -> ^(TOK_SERDE rowFormatSerde) + | rowFormatDelimited -> ^(TOK_SERDE rowFormatDelimited) + | -> ^(TOK_SERDE) + ; + +recordReader +@init { pushMsg("record reader specification", state); } +@after { popMsg(state); } + : KW_RECORDREADER StringLiteral -> ^(TOK_RECORDREADER StringLiteral) + | -> ^(TOK_RECORDREADER) + ; + +recordWriter +@init { pushMsg("record writer specification", state); } +@after { popMsg(state); } + : KW_RECORDWRITER StringLiteral -> ^(TOK_RECORDWRITER StringLiteral) + | -> ^(TOK_RECORDWRITER) + ; + +rowFormatSerde +@init { pushMsg("serde format specification", state); } +@after { popMsg(state); } + : KW_ROW KW_FORMAT KW_SERDE name=StringLiteral (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? + -> ^(TOK_SERDENAME $name $serdeprops?) + ; + +rowFormatDelimited +@init { pushMsg("serde properties specification", state); } +@after { popMsg(state); } + : + KW_ROW KW_FORMAT KW_DELIMITED tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat? + -> ^(TOK_SERDEPROPS tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?) + ; + +tableRowFormat +@init { pushMsg("table row format specification", state); } +@after { popMsg(state); } + : + rowFormatDelimited + -> ^(TOK_TABLEROWFORMAT rowFormatDelimited) + | rowFormatSerde + -> ^(TOK_TABLESERIALIZER rowFormatSerde) + ; + +tablePropertiesPrefixed +@init { pushMsg("table properties with prefix", state); } +@after { popMsg(state); } + : + KW_TBLPROPERTIES! tableProperties + ; + +tableProperties +@init { pushMsg("table properties", state); } +@after { popMsg(state); } + : + LPAREN tablePropertiesList RPAREN -> ^(TOK_TABLEPROPERTIES tablePropertiesList) + ; + +tablePropertiesList +@init { pushMsg("table properties list", state); } +@after { popMsg(state); } + : + keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_TABLEPROPLIST keyValueProperty+) + | + keyProperty (COMMA keyProperty)* -> ^(TOK_TABLEPROPLIST keyProperty+) + ; + +keyValueProperty +@init { pushMsg("specifying key/value property", state); } +@after { popMsg(state); } + : + key=StringLiteral EQUAL value=StringLiteral -> ^(TOK_TABLEPROPERTY $key $value) + ; + +keyProperty +@init { pushMsg("specifying key property", state); } +@after { popMsg(state); } + : + key=StringLiteral -> ^(TOK_TABLEPROPERTY $key TOK_NULL) + ; + +tableRowFormatFieldIdentifier +@init { pushMsg("table row format's field separator", state); } +@after { popMsg(state); } + : + KW_FIELDS KW_TERMINATED KW_BY fldIdnt=StringLiteral (KW_ESCAPED KW_BY fldEscape=StringLiteral)? + -> ^(TOK_TABLEROWFORMATFIELD $fldIdnt $fldEscape?) + ; + +tableRowFormatCollItemsIdentifier +@init { pushMsg("table row format's column separator", state); } +@after { popMsg(state); } + : + KW_COLLECTION KW_ITEMS KW_TERMINATED KW_BY collIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATCOLLITEMS $collIdnt) + ; + +tableRowFormatMapKeysIdentifier +@init { pushMsg("table row format's map key separator", state); } +@after { popMsg(state); } + : + KW_MAP KW_KEYS KW_TERMINATED KW_BY mapKeysIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATMAPKEYS $mapKeysIdnt) + ; + +tableRowFormatLinesIdentifier +@init { pushMsg("table row format's line separator", state); } +@after { popMsg(state); } + : + KW_LINES KW_TERMINATED KW_BY linesIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATLINES $linesIdnt) + ; + +tableRowNullFormat +@init { pushMsg("table row format's null specifier", state); } +@after { popMsg(state); } + : + KW_NULL KW_DEFINED KW_AS nullIdnt=StringLiteral + -> ^(TOK_TABLEROWFORMATNULL $nullIdnt) + ; +tableFileFormat +@init { pushMsg("table file format specification", state); } +@after { popMsg(state); } + : + (KW_STORED KW_AS KW_INPUTFORMAT) => KW_STORED KW_AS KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? + -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $inDriver? $outDriver?) + | KW_STORED KW_BY storageHandler=StringLiteral + (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? + -> ^(TOK_STORAGEHANDLER $storageHandler $serdeprops?) + | KW_STORED KW_AS genericSpec=identifier + -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) + ; + +tableLocation +@init { pushMsg("table location specification", state); } +@after { popMsg(state); } + : + KW_LOCATION locn=StringLiteral -> ^(TOK_TABLELOCATION $locn) + ; + +columnNameTypeList +@init { pushMsg("column name type list", state); } +@after { popMsg(state); } + : columnNameType (COMMA columnNameType)* -> ^(TOK_TABCOLLIST columnNameType+) + ; + +columnNameColonTypeList +@init { pushMsg("column name type list", state); } +@after { popMsg(state); } + : columnNameColonType (COMMA columnNameColonType)* -> ^(TOK_TABCOLLIST columnNameColonType+) + ; + +columnNameList +@init { pushMsg("column name list", state); } +@after { popMsg(state); } + : columnName (COMMA columnName)* -> ^(TOK_TABCOLNAME columnName+) + ; + +columnName +@init { pushMsg("column name", state); } +@after { popMsg(state); } + : + identifier + ; + +extColumnName +@init { pushMsg("column name for complex types", state); } +@after { popMsg(state); } + : + identifier (DOT^ ((KW_ELEM_TYPE) => KW_ELEM_TYPE | (KW_KEY_TYPE) => KW_KEY_TYPE | (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier))* + ; + +columnNameOrderList +@init { pushMsg("column name order list", state); } +@after { popMsg(state); } + : columnNameOrder (COMMA columnNameOrder)* -> ^(TOK_TABCOLNAME columnNameOrder+) + ; + +skewedValueElement +@init { pushMsg("skewed value element", state); } +@after { popMsg(state); } + : + skewedColumnValues + | skewedColumnValuePairList + ; + +skewedColumnValuePairList +@init { pushMsg("column value pair list", state); } +@after { popMsg(state); } + : skewedColumnValuePair (COMMA skewedColumnValuePair)* -> ^(TOK_TABCOLVALUE_PAIR skewedColumnValuePair+) + ; + +skewedColumnValuePair +@init { pushMsg("column value pair", state); } +@after { popMsg(state); } + : + LPAREN colValues=skewedColumnValues RPAREN + -> ^(TOK_TABCOLVALUES $colValues) + ; + +skewedColumnValues +@init { pushMsg("column values", state); } +@after { popMsg(state); } + : skewedColumnValue (COMMA skewedColumnValue)* -> ^(TOK_TABCOLVALUE skewedColumnValue+) + ; + +skewedColumnValue +@init { pushMsg("column value", state); } +@after { popMsg(state); } + : + constant + ; + +skewedValueLocationElement +@init { pushMsg("skewed value location element", state); } +@after { popMsg(state); } + : + skewedColumnValue + | skewedColumnValuePair + ; + +columnNameOrder +@init { pushMsg("column name order", state); } +@after { popMsg(state); } + : identifier (asc=KW_ASC | desc=KW_DESC)? + -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC identifier) + -> ^(TOK_TABSORTCOLNAMEDESC identifier) + ; + +columnNameCommentList +@init { pushMsg("column name comment list", state); } +@after { popMsg(state); } + : columnNameComment (COMMA columnNameComment)* -> ^(TOK_TABCOLNAME columnNameComment+) + ; + +columnNameComment +@init { pushMsg("column name comment", state); } +@after { popMsg(state); } + : colName=identifier (KW_COMMENT comment=StringLiteral)? + -> ^(TOK_TABCOL $colName TOK_NULL $comment?) + ; + +columnRefOrder +@init { pushMsg("column order", state); } +@after { popMsg(state); } + : expression (asc=KW_ASC | desc=KW_DESC)? + -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC expression) + -> ^(TOK_TABSORTCOLNAMEDESC expression) + ; + +columnNameType +@init { pushMsg("column specification", state); } +@after { popMsg(state); } + : colName=identifier colType (KW_COMMENT comment=StringLiteral)? + -> {containExcludedCharForCreateTableColumnName($colName.text)}? {throwColumnNameException()} + -> {$comment == null}? ^(TOK_TABCOL $colName colType) + -> ^(TOK_TABCOL $colName colType $comment) + ; + +columnNameColonType +@init { pushMsg("column specification", state); } +@after { popMsg(state); } + : colName=identifier COLON colType (KW_COMMENT comment=StringLiteral)? + -> {$comment == null}? ^(TOK_TABCOL $colName colType) + -> ^(TOK_TABCOL $colName colType $comment) + ; + +colType +@init { pushMsg("column type", state); } +@after { popMsg(state); } + : type + ; + +colTypeList +@init { pushMsg("column type list", state); } +@after { popMsg(state); } + : colType (COMMA colType)* -> ^(TOK_COLTYPELIST colType+) + ; + +type + : primitiveType + | listType + | structType + | mapType + | unionType; + +primitiveType +@init { pushMsg("primitive type specification", state); } +@after { popMsg(state); } + : KW_TINYINT -> TOK_TINYINT + | KW_SMALLINT -> TOK_SMALLINT + | KW_INT -> TOK_INT + | KW_BIGINT -> TOK_BIGINT + | KW_BOOLEAN -> TOK_BOOLEAN + | KW_FLOAT -> TOK_FLOAT + | KW_DOUBLE -> TOK_DOUBLE + | KW_DATE -> TOK_DATE + | KW_DATETIME -> TOK_DATETIME + | KW_TIMESTAMP -> TOK_TIMESTAMP + // Uncomment to allow intervals as table column types + //| KW_INTERVAL KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH + //| KW_INTERVAL KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME + | KW_STRING -> TOK_STRING + | KW_BINARY -> TOK_BINARY + | KW_DECIMAL (LPAREN prec=Number (COMMA scale=Number)? RPAREN)? -> ^(TOK_DECIMAL $prec? $scale?) + | KW_VARCHAR LPAREN length=Number RPAREN -> ^(TOK_VARCHAR $length) + | KW_CHAR LPAREN length=Number RPAREN -> ^(TOK_CHAR $length) + ; + +listType +@init { pushMsg("list type", state); } +@after { popMsg(state); } + : KW_ARRAY LESSTHAN type GREATERTHAN -> ^(TOK_LIST type) + ; + +structType +@init { pushMsg("struct type", state); } +@after { popMsg(state); } + : KW_STRUCT LESSTHAN columnNameColonTypeList GREATERTHAN -> ^(TOK_STRUCT columnNameColonTypeList) + ; + +mapType +@init { pushMsg("map type", state); } +@after { popMsg(state); } + : KW_MAP LESSTHAN left=primitiveType COMMA right=type GREATERTHAN + -> ^(TOK_MAP $left $right) + ; + +unionType +@init { pushMsg("uniontype type", state); } +@after { popMsg(state); } + : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) + ; + +setOperator +@init { pushMsg("set operator", state); } +@after { popMsg(state); } + : KW_UNION KW_ALL -> ^(TOK_UNIONALL) + | KW_UNION KW_DISTINCT? -> ^(TOK_UNIONDISTINCT) + | KW_EXCEPT -> ^(TOK_EXCEPT) + | KW_INTERSECT -> ^(TOK_INTERSECT) + ; + +queryStatementExpression[boolean topLevel] + : + /* Would be nice to do this as a gated semantic perdicate + But the predicate gets pushed as a lookahead decision. + Calling rule doesnot know about topLevel + */ + (w=withClause {topLevel}?)? + queryStatementExpressionBody[topLevel] { + if ($w.tree != null) { + $queryStatementExpressionBody.tree.insertChild(0, $w.tree); + } + } + -> queryStatementExpressionBody + ; + +queryStatementExpressionBody[boolean topLevel] + : + fromStatement[topLevel] + | regularBody[topLevel] + ; + +withClause + : + KW_WITH cteStatement (COMMA cteStatement)* -> ^(TOK_CTE cteStatement+) +; + +cteStatement + : + identifier KW_AS LPAREN queryStatementExpression[false] RPAREN + -> ^(TOK_SUBQUERY queryStatementExpression identifier) +; + +fromStatement[boolean topLevel] +: (singleFromStatement -> singleFromStatement) + (u=setOperator r=singleFromStatement + -> ^($u {$fromStatement.tree} $r) + )* + -> {u != null && topLevel}? ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + {$fromStatement.tree} + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> {$fromStatement.tree} + ; + + +singleFromStatement + : + fromClause + ( b+=body )+ -> ^(TOK_QUERY fromClause body+) + ; + +/* +The valuesClause rule below ensures that the parse tree for +"insert into table FOO values (1,2),(3,4)" looks the same as +"insert into table FOO select a,b from (values(1,2),(3,4)) as BAR(a,b)" which itself is made to look +very similar to the tree for "insert into table FOO select a,b from BAR". Since virtual table name +is implicit, it's represented as TOK_ANONYMOUS. +*/ +regularBody[boolean topLevel] + : + i=insertClause + ( + s=selectStatement[topLevel] + {$s.tree.getFirstChildWithType(TOK_INSERT).replaceChildren(0, 0, $i.tree);} -> {$s.tree} + | + valuesClause + -> ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_VIRTUAL_TABLE ^(TOK_VIRTUAL_TABREF ^(TOK_ANONYMOUS)) valuesClause) + ) + ^(TOK_INSERT {$i.tree} ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))) + ) + ) + | + selectStatement[topLevel] + ; + +selectStatement[boolean topLevel] + : + ( + ( + LPAREN + s=selectClause + f=fromClause? + w=whereClause? + g=groupByClause? + h=havingClause? + o=orderByClause? + c=clusterByClause? + d=distributeByClause? + sort=sortByClause? + win=window_clause? + l=limitClause? + RPAREN + | + s=selectClause + f=fromClause? + w=whereClause? + g=groupByClause? + h=havingClause? + o=orderByClause? + c=clusterByClause? + d=distributeByClause? + sort=sortByClause? + win=window_clause? + l=limitClause? + ) + -> ^(TOK_QUERY $f? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + $s $w? $g? $h? $o? $c? + $d? $sort? $win? $l?)) + ) + (set=setOpSelectStatement[$selectStatement.tree, topLevel])? + -> {set == null}? + {$selectStatement.tree} + -> {o==null && c==null && d==null && sort==null && l==null}? + {$set.tree} + -> {throwSetOpException()} + ; + +setOpSelectStatement[CommonTree t, boolean topLevel] + : + (( + u=setOperator LPAREN b=simpleSelectStatement RPAREN + | + u=setOperator b=simpleSelectStatement) + -> {$setOpSelectStatement.tree != null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + ^($u {$setOpSelectStatement.tree} $b) + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> {$setOpSelectStatement.tree != null && $u.tree.getType()!=SparkSqlParser.TOK_UNIONDISTINCT}? + ^($u {$setOpSelectStatement.tree} $b) + -> {$setOpSelectStatement.tree == null && $u.tree.getType()==SparkSqlParser.TOK_UNIONDISTINCT}? + ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + ^($u {$t} $b) + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECTDI ^(TOK_SELEXPR TOK_ALLCOLREF)) + ) + ) + -> ^($u {$t} $b) + )+ + o=orderByClause? + c=clusterByClause? + d=distributeByClause? + sort=sortByClause? + win=window_clause? + l=limitClause? + -> {o==null && c==null && d==null && sort==null && win==null && l==null && !topLevel}? + {$setOpSelectStatement.tree} + -> ^(TOK_QUERY + ^(TOK_FROM + ^(TOK_SUBQUERY + {$setOpSelectStatement.tree} + {adaptor.create(Identifier, generateUnionAlias())} + ) + ) + ^(TOK_INSERT + ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) + $o? $c? $d? $sort? $win? $l? + ) + ) + ; + +simpleSelectStatement + : + selectClause + fromClause? + whereClause? + groupByClause? + havingClause? + ((window_clause) => window_clause)? + -> ^(TOK_QUERY fromClause? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + selectClause whereClause? groupByClause? havingClause? window_clause?)) + ; + +selectStatementWithCTE + : + (w=withClause)? + selectStatement[true] { + if ($w.tree != null) { + $selectStatement.tree.insertChild(0, $w.tree); + } + } + -> selectStatement + ; + +body + : + insertClause + selectClause + lateralView? + whereClause? + groupByClause? + havingClause? + orderByClause? + clusterByClause? + distributeByClause? + sortByClause? + window_clause? + limitClause? -> ^(TOK_INSERT insertClause + selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? + distributeByClause? sortByClause? window_clause? limitClause?) + | + selectClause + lateralView? + whereClause? + groupByClause? + havingClause? + orderByClause? + clusterByClause? + distributeByClause? + sortByClause? + window_clause? + limitClause? -> ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) + selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? + distributeByClause? sortByClause? window_clause? limitClause?) + ; + +insertClause +@init { pushMsg("insert clause", state); } +@after { popMsg(state); } + : + KW_INSERT KW_OVERWRITE destination ifNotExists? -> ^(TOK_DESTINATION destination ifNotExists?) + | KW_INSERT KW_INTO KW_TABLE? tableOrPartition (LPAREN targetCols=columnNameList RPAREN)? + -> ^(TOK_INSERT_INTO tableOrPartition $targetCols?) + ; + +destination +@init { pushMsg("destination specification", state); } +@after { popMsg(state); } + : + (local = KW_LOCAL)? KW_DIRECTORY StringLiteral tableRowFormat? tableFileFormat? + -> ^(TOK_DIR StringLiteral $local? tableRowFormat? tableFileFormat?) + | KW_TABLE tableOrPartition -> tableOrPartition + ; + +limitClause +@init { pushMsg("limit clause", state); } +@after { popMsg(state); } + : + KW_LIMIT num=Number -> ^(TOK_LIMIT $num) + ; + +//DELETE FROM WHERE ...; +deleteStatement +@init { pushMsg("delete statement", state); } +@after { popMsg(state); } + : + KW_DELETE KW_FROM tableName (whereClause)? -> ^(TOK_DELETE_FROM tableName whereClause?) + ; + +/*SET = (3 + col2)*/ +columnAssignmentClause + : + tableOrColumn EQUAL^ precedencePlusExpression + ; + +/*SET col1 = 5, col2 = (4 + col4), ...*/ +setColumnsClause + : + KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) + ; + +/* + UPDATE
    + SET col1 = val1, col2 = val2... WHERE ... +*/ +updateStatement +@init { pushMsg("update statement", state); } +@after { popMsg(state); } + : + KW_UPDATE tableName setColumnsClause whereClause? -> ^(TOK_UPDATE_TABLE tableName setColumnsClause whereClause?) + ; + +/* +BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly except for addition of +"setAutoCommitStatement" which is not in the standard doc but is supported by most SQL engines. +*/ +sqlTransactionStatement +@init { pushMsg("transaction statement", state); } +@after { popMsg(state); } + : + startTransactionStatement + | commitStatement + | rollbackStatement + | setAutoCommitStatement + ; + +startTransactionStatement + : + KW_START KW_TRANSACTION ( transactionMode ( COMMA transactionMode )* )? -> ^(TOK_START_TRANSACTION transactionMode*) + ; + +transactionMode + : + isolationLevel + | transactionAccessMode -> ^(TOK_TXN_ACCESS_MODE transactionAccessMode) + ; + +transactionAccessMode + : + KW_READ KW_ONLY -> TOK_TXN_READ_ONLY + | KW_READ KW_WRITE -> TOK_TXN_READ_WRITE + ; + +isolationLevel + : + KW_ISOLATION KW_LEVEL levelOfIsolation -> ^(TOK_ISOLATION_LEVEL levelOfIsolation) + ; + +/*READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE may be supported later*/ +levelOfIsolation + : + KW_SNAPSHOT -> TOK_ISOLATION_SNAPSHOT + ; + +commitStatement + : + KW_COMMIT ( KW_WORK )? -> TOK_COMMIT + ; + +rollbackStatement + : + KW_ROLLBACK ( KW_WORK )? -> TOK_ROLLBACK + ; +setAutoCommitStatement + : + KW_SET KW_AUTOCOMMIT booleanValueTok -> ^(TOK_SET_AUTOCOMMIT booleanValueTok) + ; +/* +END user defined transaction boundaries +*/ diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java index 3513960b4181..3d80df227151 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeArrayData.java @@ -270,8 +270,8 @@ public UnsafeRow getStruct(int ordinal, int numFields) { final int offset = getElementOffset(ordinal); if (offset < 0) return null; final int size = getElementSize(offset, ordinal); - final UnsafeRow row = new UnsafeRow(); - row.pointTo(baseObject, baseOffset + offset, numFields, size); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(baseObject, baseOffset + offset, size); return row; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index b6979d0c8297..b8d3c4910047 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -17,11 +17,7 @@ package org.apache.spark.sql.catalyst.expressions; -import java.io.Externalizable; -import java.io.IOException; -import java.io.ObjectInput; -import java.io.ObjectOutput; -import java.io.OutputStream; +import java.io.*; import java.math.BigDecimal; import java.math.BigInteger; import java.nio.ByteBuffer; @@ -30,26 +26,12 @@ import java.util.HashSet; import java.util.Set; -import org.apache.spark.sql.types.ArrayType; -import org.apache.spark.sql.types.BinaryType; -import org.apache.spark.sql.types.BooleanType; -import org.apache.spark.sql.types.ByteType; -import org.apache.spark.sql.types.CalendarIntervalType; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.DecimalType; -import org.apache.spark.sql.types.DoubleType; -import org.apache.spark.sql.types.FloatType; -import org.apache.spark.sql.types.IntegerType; -import org.apache.spark.sql.types.LongType; -import org.apache.spark.sql.types.MapType; -import org.apache.spark.sql.types.NullType; -import org.apache.spark.sql.types.ShortType; -import org.apache.spark.sql.types.StringType; -import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.types.TimestampType; -import org.apache.spark.sql.types.UserDefinedType; +import com.esotericsoftware.kryo.Kryo; +import com.esotericsoftware.kryo.KryoSerializable; +import com.esotericsoftware.kryo.io.Input; +import com.esotericsoftware.kryo.io.Output; + +import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.bitset.BitSetMethods; @@ -57,23 +39,9 @@ import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; -import static org.apache.spark.sql.types.DataTypes.BooleanType; -import static org.apache.spark.sql.types.DataTypes.ByteType; -import static org.apache.spark.sql.types.DataTypes.DateType; -import static org.apache.spark.sql.types.DataTypes.DoubleType; -import static org.apache.spark.sql.types.DataTypes.FloatType; -import static org.apache.spark.sql.types.DataTypes.IntegerType; -import static org.apache.spark.sql.types.DataTypes.LongType; -import static org.apache.spark.sql.types.DataTypes.NullType; -import static org.apache.spark.sql.types.DataTypes.ShortType; -import static org.apache.spark.sql.types.DataTypes.TimestampType; +import static org.apache.spark.sql.types.DataTypes.*; import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET; -import com.esotericsoftware.kryo.Kryo; -import com.esotericsoftware.kryo.KryoSerializable; -import com.esotericsoftware.kryo.io.Input; -import com.esotericsoftware.kryo.io.Output; - /** * An Unsafe implementation of Row which is backed by raw memory instead of Java objects. * @@ -167,8 +135,16 @@ private void assertIndexIsValid(int index) { /** * Construct a new UnsafeRow. The resulting row won't be usable until `pointTo()` has been called, * since the value returned by this constructor is equivalent to a null pointer. + * + * @param numFields the number of fields in this row */ - public UnsafeRow() { } + public UnsafeRow(int numFields) { + this.numFields = numFields; + this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); + } + + // for serializer + public UnsafeRow() {} public Object getBaseObject() { return baseObject; } public long getBaseOffset() { return baseOffset; } @@ -182,15 +158,12 @@ public UnsafeRow() { } * * @param baseObject the base object * @param baseOffset the offset within the base object - * @param numFields the number of fields in this row * @param sizeInBytes the size of this row's backing data, in bytes */ - public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeInBytes) { + public void pointTo(Object baseObject, long baseOffset, int sizeInBytes) { assert numFields >= 0 : "numFields (" + numFields + ") should >= 0"; - this.bitSetWidthInBytes = calculateBitSetWidthInBytes(numFields); this.baseObject = baseObject; this.baseOffset = baseOffset; - this.numFields = numFields; this.sizeInBytes = sizeInBytes; } @@ -198,22 +171,15 @@ public void pointTo(Object baseObject, long baseOffset, int numFields, int sizeI * Update this UnsafeRow to point to the underlying byte array. * * @param buf byte array to point to - * @param numFields the number of fields in this row - * @param sizeInBytes the number of bytes valid in the byte array - */ - public void pointTo(byte[] buf, int numFields, int sizeInBytes) { - pointTo(buf, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); - } - - /** - * Updates this UnsafeRow preserving the number of fields. - * @param buf byte array to point to * @param sizeInBytes the number of bytes valid in the byte array */ public void pointTo(byte[] buf, int sizeInBytes) { - pointTo(buf, numFields, sizeInBytes); + pointTo(buf, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); } + public void setTotalSize(int sizeInBytes) { + this.sizeInBytes = sizeInBytes; + } public void setNotNullAt(int i) { assertIndexIsValid(i); @@ -489,8 +455,8 @@ public UnsafeRow getStruct(int ordinal, int numFields) { final long offsetAndSize = getLong(ordinal); final int offset = (int) (offsetAndSize >> 32); final int size = (int) offsetAndSize; - final UnsafeRow row = new UnsafeRow(); - row.pointTo(baseObject, baseOffset + offset, numFields, size); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(baseObject, baseOffset + offset, size); return row; } } @@ -529,7 +495,7 @@ public UnsafeMapData getMap(int ordinal) { */ @Override public UnsafeRow copy() { - UnsafeRow rowCopy = new UnsafeRow(); + UnsafeRow rowCopy = new UnsafeRow(numFields); final byte[] rowDataCopy = new byte[sizeInBytes]; Platform.copyMemory( baseObject, @@ -538,7 +504,7 @@ public UnsafeRow copy() { Platform.BYTE_ARRAY_OFFSET, sizeInBytes ); - rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); + rowCopy.pointTo(rowDataCopy, Platform.BYTE_ARRAY_OFFSET, sizeInBytes); return rowCopy; } @@ -547,8 +513,8 @@ public UnsafeRow copy() { * The returned row is invalid until we call copyFrom on it. */ public static UnsafeRow createFromByteArray(int numBytes, int numFields) { - final UnsafeRow row = new UnsafeRow(); - row.pointTo(new byte[numBytes], numFields, numBytes); + final UnsafeRow row = new UnsafeRow(numFields); + row.pointTo(new byte[numBytes], numBytes); return row; } @@ -600,6 +566,10 @@ public int hashCode() { return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); } + public int hashCode(int seed) { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, seed); + } + @Override public boolean equals(Object other) { if (other instanceof UnsafeRow) { diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java new file mode 100644 index 000000000000..5bc87b680f9a --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java @@ -0,0 +1,162 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser; + +import java.io.UnsupportedEncodingException; + +/** + * A couple of utility methods that help with parsing ASTs. + * + * Both methods in this class were take from the SemanticAnalyzer in Hive: + * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java + */ +public final class ParseUtils { + private ParseUtils() { + super(); + } + + public static String charSetString(String charSetName, String charSetString) + throws UnsupportedEncodingException { + // The character set name starts with a _, so strip that + charSetName = charSetName.substring(1); + if (charSetString.charAt(0) == '\'') { + return new String(unescapeSQLString(charSetString).getBytes(), charSetName); + } else // hex input is also supported + { + assert charSetString.charAt(0) == '0'; + assert charSetString.charAt(1) == 'x'; + charSetString = charSetString.substring(2); + + byte[] bArray = new byte[charSetString.length() / 2]; + int j = 0; + for (int i = 0; i < charSetString.length(); i += 2) { + int val = Character.digit(charSetString.charAt(i), 16) * 16 + + Character.digit(charSetString.charAt(i + 1), 16); + if (val > 127) { + val = val - 256; + } + bArray[j++] = (byte)val; + } + + return new String(bArray, charSetName); + } + } + + private static final int[] multiplier = new int[] {1000, 100, 10, 1}; + + @SuppressWarnings("nls") + public static String unescapeSQLString(String b) { + Character enclosure = null; + + // Some of the strings can be passed in as unicode. For example, the + // delimiter can be passed in as \002 - So, we first check if the + // string is a unicode number, else go back to the old behavior + StringBuilder sb = new StringBuilder(b.length()); + for (int i = 0; i < b.length(); i++) { + + char currentChar = b.charAt(i); + if (enclosure == null) { + if (currentChar == '\'' || b.charAt(i) == '\"') { + enclosure = currentChar; + } + // ignore all other chars outside the enclosure + continue; + } + + if (enclosure.equals(currentChar)) { + enclosure = null; + continue; + } + + if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { + int code = 0; + int base = i + 2; + for (int j = 0; j < 4; j++) { + int digit = Character.digit(b.charAt(j + base), 16); + code += digit * multiplier[j]; + } + sb.append((char)code); + i += 5; + continue; + } + + if (currentChar == '\\' && (i + 4 < b.length())) { + char i1 = b.charAt(i + 1); + char i2 = b.charAt(i + 2); + char i3 = b.charAt(i + 3); + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') + && (i3 >= '0' && i3 <= '7')) { + byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); + byte[] bValArr = new byte[1]; + bValArr[0] = bVal; + String tmp = new String(bValArr); + sb.append(tmp); + i += 3; + continue; + } + } + + if (currentChar == '\\' && (i + 2 < b.length())) { + char n = b.charAt(i + 1); + switch (n) { + case '0': + sb.append("\0"); + break; + case '\'': + sb.append("'"); + break; + case '"': + sb.append("\""); + break; + case 'b': + sb.append("\b"); + break; + case 'n': + sb.append("\n"); + break; + case 'r': + sb.append("\r"); + break; + case 't': + sb.append("\t"); + break; + case 'Z': + sb.append("\u001A"); + break; + case '\\': + sb.append("\\"); + break; + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%': + sb.append("\\%"); + break; + case '_': + sb.append("\\_"); + break; + default: + sb.append(n); + } + i++; + } else { + sb.append(currentChar); + } + } + return sb.toString(); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 352002b3499a..27ae62f1212f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -26,10 +26,9 @@ import org.apache.spark.SparkEnv; import org.apache.spark.TaskContext; -import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection; import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.util.AbstractScalaRowIterator; import org.apache.spark.sql.types.StructType; import org.apache.spark.unsafe.Platform; import org.apache.spark.util.collection.unsafe.sort.PrefixComparator; @@ -123,7 +122,7 @@ Iterator sort() throws IOException { return new AbstractScalaRowIterator() { private final int numFields = schema.length(); - private UnsafeRow row = new UnsafeRow(); + private UnsafeRow row = new UnsafeRow(numFields); @Override public boolean hasNext() { @@ -137,7 +136,6 @@ public UnsafeRow next() { row.pointTo( sortedIterator.getBaseObject(), sortedIterator.getBaseOffset(), - numFields, sortedIterator.getRecordLength()); if (!hasNext()) { UnsafeRow copy = row.copy(); // so that we don't have dangling pointers to freed page @@ -173,19 +171,21 @@ public Iterator sort(Iterator inputIterator) throws IOExce private static final class RowComparator extends RecordComparator { private final Ordering ordering; private final int numFields; - private final UnsafeRow row1 = new UnsafeRow(); - private final UnsafeRow row2 = new UnsafeRow(); + private final UnsafeRow row1; + private final UnsafeRow row2; public RowComparator(Ordering ordering, int numFields) { this.numFields = numFields; + this.row1 = new UnsafeRow(numFields); + this.row2 = new UnsafeRow(numFields); this.ordering = ordering; } @Override public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { // TODO: Why are the sizes -1? - row1.pointTo(baseObj1, baseOff1, numFields, -1); - row2.pointTo(baseObj2, baseOff2, numFields, -1); + row1.pointTo(baseObj1, baseOff1, -1); + row2.pointTo(baseObj2, baseOff2, -1); return ordering.compare(row1, row2); } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index bb0fdc4c3d83..b19538a23f19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql import java.lang.reflect.Modifier import scala.annotation.implicitNotFound -import scala.reflect.{ClassTag, classTag} +import scala.reflect.{classTag, ClassTag} import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} -import org.apache.spark.sql.catalyst.expressions.{DecodeUsingSerializer, BoundReference, EncodeUsingSerializer} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer} import org.apache.spark.sql.types._ /** @@ -157,6 +157,12 @@ object Encoders { */ def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() + /** + * An encoder for arrays of bytes. + * @since 1.6.1 + */ + def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() + /** * Creates an encoder for Java Bean of type T. * diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala new file mode 100644 index 000000000000..1eda4a9a9764 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystQl.scala @@ -0,0 +1,966 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst + +import java.sql.Date + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.CurrentOrigin +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.random.RandomSampler + +/** + * This class translates a HQL String to a Catalyst [[LogicalPlan]] or [[Expression]]. + */ +private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) { + object Token { + def unapply(node: ASTNode): Some[(String, List[ASTNode])] = { + CurrentOrigin.setPosition(node.line, node.positionInLine) + node.pattern + } + } + + + /** + * Returns the AST for the given SQL string. + */ + protected def getAst(sql: String): ASTNode = ParseDriver.parse(sql, conf) + + /** Creates LogicalPlan for a given HiveQL string. */ + def createPlan(sql: String): LogicalPlan = { + try { + createPlan(sql, ParseDriver.parse(sql, conf)) + } catch { + case e: MatchError => throw e + case e: AnalysisException => throw e + case e: Exception => + throw new AnalysisException(e.getMessage) + case e: NotImplementedError => + throw new AnalysisException( + s""" + |Unsupported language features in query: $sql + |${getAst(sql).treeString} + |$e + |${e.getStackTrace.head} + """.stripMargin) + } + } + + protected def createPlan(sql: String, tree: ASTNode): LogicalPlan = nodeToPlan(tree) + + def parseDdl(ddl: String): Seq[Attribute] = { + val tree = getAst(ddl) + assert(tree.text == "TOK_CREATETABLE", "Only CREATE TABLE supported.") + val tableOps = tree.children + val colList = tableOps + .find(_.text == "TOK_TABCOLLIST") + .getOrElse(sys.error("No columnList!")) + + colList.children.map(nodeToAttribute) + } + + protected def getClauses( + clauseNames: Seq[String], + nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { + var remainingNodes = nodeList + val clauses = clauseNames.map { clauseName => + val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName) + remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) + matches.headOption + } + + if (remainingNodes.nonEmpty) { + sys.error( + s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}. + |You are likely trying to use an unsupported Hive feature."""".stripMargin) + } + clauses + } + + protected def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode = + getClauseOption(clauseName, nodeList).getOrElse(sys.error( + s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}")) + + protected def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = { + nodeList.filter { case ast: ASTNode => ast.text == clauseName } match { + case Seq(oneMatch) => Some(oneMatch) + case Seq() => None + case _ => sys.error(s"Found multiple instances of clause $clauseName") + } + } + + protected def nodeToAttribute(node: ASTNode): Attribute = node match { + case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => + AttributeReference(colName, nodeToDataType(dataType), nullable = true)() + case _ => + noParseRule("Attribute", node) + } + + protected def nodeToDataType(node: ASTNode): DataType = node match { + case Token("TOK_DECIMAL", precision :: scale :: Nil) => + DecimalType(precision.text.toInt, scale.text.toInt) + case Token("TOK_DECIMAL", precision :: Nil) => + DecimalType(precision.text.toInt, 0) + case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT + case Token("TOK_BIGINT", Nil) => LongType + case Token("TOK_INT", Nil) => IntegerType + case Token("TOK_TINYINT", Nil) => ByteType + case Token("TOK_SMALLINT", Nil) => ShortType + case Token("TOK_BOOLEAN", Nil) => BooleanType + case Token("TOK_STRING", Nil) => StringType + case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType + case Token("TOK_FLOAT", Nil) => FloatType + case Token("TOK_DOUBLE", Nil) => DoubleType + case Token("TOK_DATE", Nil) => DateType + case Token("TOK_TIMESTAMP", Nil) => TimestampType + case Token("TOK_BINARY", Nil) => BinaryType + case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) + case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) => + StructType(fields.map(nodeToStructField)) + case Token("TOK_MAP", keyType :: valueType :: Nil) => + MapType(nodeToDataType(keyType), nodeToDataType(valueType)) + case _ => + noParseRule("DataType", node) + } + + protected def nodeToStructField(node: ASTNode): StructField = node match { + case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: _ /* comment */:: Nil) => + StructField(fieldName, nodeToDataType(dataType), nullable = true) + case _ => + noParseRule("StructField", node) + } + + protected def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = { + tableNameParts.children.map { + case Token(part, Nil) => cleanIdentifier(part) + } match { + case Seq(tableOnly) => TableIdentifier(tableOnly) + case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) + case other => sys.error("Hive only supports tables names like 'tableName' " + + s"or 'databaseName.tableName', found '$other'") + } + } + + /** + * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) + * is equivalent to + * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 + * Check the following link for details. + * +https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup + * + * The bitmask denotes the grouping expressions validity for a grouping set, + * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) + * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of + * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. + */ + protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { + val (keyASTs, setASTs) = children.partition { + case Token("TOK_GROUPING_SETS_EXPRESSION", _) => false // grouping sets + case _ => true // grouping keys + } + + val keys = keyASTs.map(nodeToExpr) + val keyMap = keyASTs.zipWithIndex.toMap + + val bitmasks: Seq[Int] = setASTs.map { + case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 + case Token("TOK_GROUPING_SETS_EXPRESSION", columns) => + columns.foldLeft(0)((bitmap, col) => { + val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2) + bitmap | 1 << keyIndex.getOrElse( + throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list")) + }) + case _ => sys.error("Expect GROUPING SETS clause") + } + + (keys, bitmasks) + } + + protected def nodeToPlan(node: ASTNode): LogicalPlan = node match { + case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => + val (fromClause: Option[ASTNode], insertClauses, cteRelations) = + queryArgs match { + case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => + val cteRelations = ctes.map { node => + val relation = nodeToRelation(node).asInstanceOf[Subquery] + relation.alias -> relation + } + (Some(from.head), inserts, Some(cteRelations.toMap)) + case Token("TOK_FROM", from) :: inserts => + (Some(from.head), inserts, None) + case Token("TOK_INSERT", _) :: Nil => + (None, queryArgs, None) + } + + // Return one query for each insert clause. + val queries = insertClauses.map { + case Token("TOK_INSERT", singleInsert) => + val ( + intoClause :: + destClause :: + selectClause :: + selectDistinctClause :: + whereClause :: + groupByClause :: + rollupGroupByClause :: + cubeGroupByClause :: + groupingSetsClause :: + orderByClause :: + havingClause :: + sortByClause :: + clusterByClause :: + distributeByClause :: + limitClause :: + lateralViewClause :: + windowClause :: Nil) = { + getClauses( + Seq( + "TOK_INSERT_INTO", + "TOK_DESTINATION", + "TOK_SELECT", + "TOK_SELECTDI", + "TOK_WHERE", + "TOK_GROUPBY", + "TOK_ROLLUP_GROUPBY", + "TOK_CUBE_GROUPBY", + "TOK_GROUPING_SETS", + "TOK_ORDERBY", + "TOK_HAVING", + "TOK_SORTBY", + "TOK_CLUSTERBY", + "TOK_DISTRIBUTEBY", + "TOK_LIMIT", + "TOK_LATERAL_VIEW", + "WINDOW"), + singleInsert) + } + + val relations = fromClause match { + case Some(f) => nodeToRelation(f) + case None => OneRowRelation + } + + val withWhere = whereClause.map { whereNode => + val Seq(whereExpr) = whereNode.children + Filter(nodeToExpr(whereExpr), relations) + }.getOrElse(relations) + + val select = (selectClause orElse selectDistinctClause) + .getOrElse(sys.error("No select clause.")) + + val transformation = nodeToTransformation(select.children.head, withWhere) + + val withLateralView = lateralViewClause.map { lv => + nodeToGenerate(lv.children.head, outer = false, withWhere) + }.getOrElse(withWhere) + + // The projection of the query can either be a normal projection, an aggregation + // (if there is a group by) or a script transformation. + val withProject: LogicalPlan = transformation.getOrElse { + val selectExpressions = + select.children.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)) + Seq( + groupByClause.map(e => e match { + case Token("TOK_GROUPBY", children) => + // Not a transformation so must be either project or aggregation. + Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView) + case _ => sys.error("Expect GROUP BY") + }), + groupingSetsClause.map(e => e match { + case Token("TOK_GROUPING_SETS", children) => + val(groupByExprs, masks) = extractGroupingSet(children) + GroupingSets(masks, groupByExprs, withLateralView, selectExpressions) + case _ => sys.error("Expect GROUPING SETS") + }), + rollupGroupByClause.map(e => e match { + case Token("TOK_ROLLUP_GROUPBY", children) => + Aggregate( + Seq(Rollup(children.map(nodeToExpr))), + selectExpressions, + withLateralView) + case _ => sys.error("Expect WITH ROLLUP") + }), + cubeGroupByClause.map(e => e match { + case Token("TOK_CUBE_GROUPBY", children) => + Aggregate( + Seq(Cube(children.map(nodeToExpr))), + selectExpressions, + withLateralView) + case _ => sys.error("Expect WITH CUBE") + }), + Some(Project(selectExpressions, withLateralView))).flatten.head + } + + // Handle HAVING clause. + val withHaving = havingClause.map { h => + val havingExpr = h.children match { case Seq(hexpr) => nodeToExpr(hexpr) } + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(havingExpr, BooleanType), withProject) + }.getOrElse(withProject) + + // Handle SELECT DISTINCT + val withDistinct = + if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withSort = + (orderByClause, sortByClause, distributeByClause, clusterByClause) match { + case (Some(totalOrdering), None, None, None) => + Sort(totalOrdering.children.map(nodeToSortOrder), global = true, withDistinct) + case (None, Some(perPartitionOrdering), None, None) => + Sort( + perPartitionOrdering.children.map(nodeToSortOrder), + global = false, withDistinct) + case (None, None, Some(partitionExprs), None) => + RepartitionByExpression( + partitionExprs.children.map(nodeToExpr), withDistinct) + case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => + Sort( + perPartitionOrdering.children.map(nodeToSortOrder), global = false, + RepartitionByExpression( + partitionExprs.children.map(nodeToExpr), + withDistinct)) + case (None, None, None, Some(clusterExprs)) => + Sort( + clusterExprs.children.map(nodeToExpr).map(SortOrder(_, Ascending)), + global = false, + RepartitionByExpression( + clusterExprs.children.map(nodeToExpr), + withDistinct)) + case (None, None, None, None) => withDistinct + case _ => sys.error("Unsupported set of ordering / distribution clauses.") + } + + val withLimit = + limitClause.map(l => nodeToExpr(l.children.head)) + .map(Limit(_, withSort)) + .getOrElse(withSort) + + // Collect all window specifications defined in the WINDOW clause. + val windowDefinitions = windowClause.map(_.children.collect { + case Token("TOK_WINDOWDEF", + Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => + windowName -> nodesToWindowSpecification(spec) + }.toMap) + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val resolvedCrossReference = windowDefinitions.map { + windowDefMap => windowDefMap.map { + case (windowName, WindowSpecReference(other)) => + (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) + case o => o.asInstanceOf[(String, WindowSpecDefinition)] + } + } + + val withWindowDefinitions = + resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) + + // TOK_INSERT_INTO means to add files to the table. + // TOK_DESTINATION means to overwrite the table. + val resultDestination = + (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) + val overwrite = intoClause.isEmpty + nodeToDest( + resultDestination, + withWindowDefinitions, + overwrite) + } + + // If there are multiple INSERTS just UNION them together into on query. + val query = queries.reduceLeft(Union) + + // return With plan if there is CTE + cteRelations.map(With(query, _)).getOrElse(query) + + case Token("TOK_UNIONALL", left :: right :: Nil) => + Union(nodeToPlan(left), nodeToPlan(right)) + case Token("TOK_UNIONDISTINCT", left :: right :: Nil) => + Distinct(Union(nodeToPlan(left), nodeToPlan(right))) + case Token("TOK_EXCEPT", left :: right :: Nil) => + Except(nodeToPlan(left), nodeToPlan(right)) + case Token("TOK_INTERSECT", left :: right :: Nil) => + Intersect(nodeToPlan(left), nodeToPlan(right)) + + case _ => + noParseRule("Plan", node) + } + + val allJoinTokens = "(TOK_.*JOIN)".r + val laterViewToken = "TOK_LATERAL_VIEW(.*)".r + protected def nodeToRelation(node: ASTNode): LogicalPlan = { + node match { + case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => + Subquery(cleanIdentifier(alias), nodeToPlan(query)) + + case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => + nodeToGenerate( + selectClause, + outer = isOuter.nonEmpty, + nodeToRelation(relationClause)) + + /* All relations, possibly with aliases or sampling clauses. */ + case Token("TOK_TABREF", clauses) => + // If the last clause is not a token then it's the alias of the table. + val (nonAliasClauses, aliasClause) = + if (clauses.last.text.startsWith("TOK")) { + (clauses, None) + } else { + (clauses.dropRight(1), Some(clauses.last)) + } + + val (Some(tableNameParts) :: + splitSampleClause :: + bucketSampleClause :: Nil) = { + getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), + nonAliasClauses) + } + + val tableIdent = extractTableIdent(tableNameParts) + val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } + val relation = UnresolvedRelation(tableIdent, alias) + + // Apply sampling if requested. + (bucketSampleClause orElse splitSampleClause).map { + case Token("TOK_TABLESPLITSAMPLE", + Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) => + Limit(Literal(count.toInt), relation) + case Token("TOK_TABLESPLITSAMPLE", + Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + require( + fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) + && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), + s"Sampling fraction ($fraction) must be on interval [0, 100]") + Sample(0.0, fraction.toDouble / 100, withReplacement = false, + (math.random * 1000).toInt, + relation) + case Token("TOK_TABLEBUCKETSAMPLE", + Token(numerator, Nil) :: + Token(denominator, Nil) :: Nil) => + val fraction = numerator.toDouble / denominator.toDouble + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) + case a => + noParseRule("Sampling", a) + }.getOrElse(relation) + + case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) => + if (!(other.size <= 1)) { + sys.error(s"Unsupported join operation: $other") + } + + val joinType = joinToken match { + case "TOK_JOIN" => Inner + case "TOK_CROSSJOIN" => Inner + case "TOK_RIGHTOUTERJOIN" => RightOuter + case "TOK_LEFTOUTERJOIN" => LeftOuter + case "TOK_FULLOUTERJOIN" => FullOuter + case "TOK_LEFTSEMIJOIN" => LeftSemi + case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) + case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) + } + Join(nodeToRelation(relation1), + nodeToRelation(relation2), + joinType, + other.headOption.map(nodeToExpr)) + + case _ => + noParseRule("Relation", node) + } + } + + protected def nodeToSortOrder(node: ASTNode): SortOrder = node match { + case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => + SortOrder(nodeToExpr(sortExpr), Ascending) + case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => + SortOrder(nodeToExpr(sortExpr), Descending) + case _ => + noParseRule("SortOrder", node) + } + + val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r + protected def nodeToDest( + node: ASTNode, + query: LogicalPlan, + overwrite: Boolean): LogicalPlan = node match { + case Token(destinationToken(), + Token("TOK_DIR", + Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => + query + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val tableIdent = extractTableIdent(tableNameParts) + + val partitionKeys = partitionClause.map(_.children.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = false) + + case Token(destinationToken(), + Token("TOK_TAB", + tableArgs) :: + Token("TOK_IFNOTEXISTS", + ifNotExists) :: Nil) => + val Some(tableNameParts) :: partitionClause :: Nil = + getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) + + val tableIdent = extractTableIdent(tableNameParts) + + val partitionKeys = partitionClause.map(_.children.map { + // Parse partitions. We also make keys case insensitive. + case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) + case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => + cleanIdentifier(key.toLowerCase) -> None + }.toMap).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = true) + + case _ => + noParseRule("Destination", node) + } + + protected def selExprNodeToExpr(node: ASTNode): Option[Expression] = node match { + case Token("TOK_SELEXPR", e :: Nil) => + Some(nodeToExpr(e)) + + case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => + Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) + + case Token("TOK_SELEXPR", e :: aliasChildren) => + val aliasNames = aliasChildren.collect { + case Token(name, Nil) => cleanIdentifier(name) + } + Some(MultiAlias(nodeToExpr(e), aliasNames)) + + /* Hints are ignored */ + case Token("TOK_HINTLIST", _) => None + + case _ => + noParseRule("Select", node) + } + + protected val escapedIdentifier = "`([^`]+)`".r + protected val doubleQuotedString = "\"([^\"]+)\"".r + protected val singleQuotedString = "'([^']+)'".r + + protected def unquoteString(str: String) = str match { + case singleQuotedString(s) => s + case doubleQuotedString(s) => s + case other => other + } + + /** Strips backticks from ident if present */ + protected def cleanIdentifier(ident: String): String = ident match { + case escapedIdentifier(i) => i + case plainIdent => plainIdent + } + + val numericAstTypes = Seq( + SparkSqlParser.Number, + SparkSqlParser.TinyintLiteral, + SparkSqlParser.SmallintLiteral, + SparkSqlParser.BigintLiteral, + SparkSqlParser.DecimalLiteral) + + /* Case insensitive matches */ + val COUNT = "(?i)COUNT".r + val SUM = "(?i)SUM".r + val AND = "(?i)AND".r + val OR = "(?i)OR".r + val NOT = "(?i)NOT".r + val TRUE = "(?i)TRUE".r + val FALSE = "(?i)FALSE".r + val LIKE = "(?i)LIKE".r + val RLIKE = "(?i)RLIKE".r + val REGEXP = "(?i)REGEXP".r + val IN = "(?i)IN".r + val DIV = "(?i)DIV".r + val BETWEEN = "(?i)BETWEEN".r + val WHEN = "(?i)WHEN".r + val CASE = "(?i)CASE".r + + protected def nodeToExpr(node: ASTNode): Expression = node match { + /* Attribute References */ + case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => + UnresolvedAttribute.quoted(cleanIdentifier(name)) + case Token(".", qualifier :: Token(attr, Nil) :: Nil) => + nodeToExpr(qualifier) match { + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) + case other => UnresolvedExtractValue(other, Literal(attr)) + } + + /* Stars (*) */ + case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) + // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only + // has a single child which is tableName. + case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => + UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) + + /* Aggregate Functions */ + case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => + Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) + case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => + Count(Literal(1)).toAggregateExpression() + + /* Casts */ + case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) => + Cast(nodeToExpr(arg), StringType) + case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), IntegerType) + case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), LongType) + case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), FloatType) + case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DoubleType) + case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), ShortType) + case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), ByteType) + case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), BinaryType) + case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), BooleanType) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, scale.text.toInt)) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, 0)) + case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) + case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), TimestampType) + case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => + Cast(nodeToExpr(arg), DateType) + + /* Arithmetic */ + case Token("+", child :: Nil) => nodeToExpr(child) + case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) + case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) + case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) + case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) + case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) + case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) + case Token(DIV(), left :: right:: Nil) => + Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) + case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) + case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) + case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) + case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) + + /* Comparisons */ + case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) + case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) + case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) + case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) + case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) + case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) + case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) + case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) + case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) + case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) + case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) + case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => + IsNotNull(nodeToExpr(child)) + case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => + IsNull(nodeToExpr(child)) + case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => + In(nodeToExpr(value), list.map(nodeToExpr)) + case Token("TOK_FUNCTION", + Token(BETWEEN(), Nil) :: + kw :: + target :: + minValue :: + maxValue :: Nil) => + + val targetExpression = nodeToExpr(target) + val betweenExpr = + And( + GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), + LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) + kw match { + case Token("KW_FALSE", Nil) => betweenExpr + case Token("KW_TRUE", Nil) => Not(betweenExpr) + } + + /* Boolean Logic */ + case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) + case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) + case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) + case Token("!", child :: Nil) => Not(nodeToExpr(child)) + + /* Case statements */ + case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => + CaseWhen(branches.map(nodeToExpr)) + case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => + val keyExpr = nodeToExpr(branches.head) + CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) + + /* Complex datatype manipulation */ + case Token("[", child :: ordinal :: Nil) => + UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) + + /* Window Functions */ + case Token(text, args :+ Token("TOK_WINDOWSPEC", spec)) => + val function = nodeToExpr(node.copy(children = node.children.init)) + nodesToWindowSpecification(spec) match { + case reference: WindowSpecReference => + UnresolvedWindowExpression(function, reference) + case definition: WindowSpecDefinition => + WindowExpression(function, definition) + } + + /* UDFs - Must be last otherwise will preempt built in functions */ + case Token("TOK_FUNCTION", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) + // Aggregate function with DISTINCT keyword. + case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => + UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) + case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => + UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) + + /* Literals */ + case Token("TOK_NULL", Nil) => Literal.create(null, NullType) + case Token(TRUE(), Nil) => Literal.create(true, BooleanType) + case Token(FALSE(), Nil) => Literal.create(false, BooleanType) + case Token("TOK_STRINGLITERALSEQUENCE", strings) => + Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString) + + // This code is adapted from + // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 + case ast: ASTNode if numericAstTypes contains ast.tokenType => + var v: Literal = null + try { + if (ast.text.endsWith("L")) { + // Literal bigint. + v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) + } else if (ast.text.endsWith("S")) { + // Literal smallint. + v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType) + } else if (ast.text.endsWith("Y")) { + // Literal tinyint. + v = Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType) + } else if (ast.text.endsWith("BD") || ast.text.endsWith("D")) { + // Literal decimal + val strVal = ast.text.stripSuffix("D").stripSuffix("B") + v = Literal(Decimal(strVal)) + } else { + v = Literal.create(ast.text.toDouble, DoubleType) + v = Literal.create(ast.text.toLong, LongType) + v = Literal.create(ast.text.toInt, IntegerType) + } + } catch { + case nfe: NumberFormatException => // Do nothing + } + + if (v == null) { + sys.error(s"Failed to parse number '${ast.text}'.") + } else { + v + } + + case ast: ASTNode if ast.tokenType == SparkSqlParser.StringLiteral => + Literal(ParseUtils.unescapeSQLString(ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_DATELITERAL => + Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1))) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_CHARSETLITERAL => + Literal(ParseUtils.charSetString(ast.children.head.text, ast.children(1).text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => + Literal(CalendarInterval.fromYearMonthString(ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => + Literal(CalendarInterval.fromDayTimeString(ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("year", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MONTH_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("month", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("day", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_HOUR_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("hour", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_MINUTE_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("minute", ast.text)) + + case ast: ASTNode if ast.tokenType == SparkSqlParser.TOK_INTERVAL_SECOND_LITERAL => + Literal(CalendarInterval.fromSingleUnitString("second", ast.text)) + + case _ => + noParseRule("Expression", node) + } + + /* Case insensitive matches for Window Specification */ + val PRECEDING = "(?i)preceding".r + val FOLLOWING = "(?i)following".r + val CURRENT = "(?i)current".r + protected def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { + case Token(windowName, Nil) :: Nil => + // Refer to a window spec defined in the window clause. + WindowSpecReference(windowName) + case Nil => + // OVER() + WindowSpecDefinition( + partitionSpec = Nil, + orderSpec = Nil, + frameSpecification = UnspecifiedFrame) + case spec => + val (partitionClause :: rowFrame :: rangeFrame :: Nil) = + getClauses( + Seq( + "TOK_PARTITIONINGSPEC", + "TOK_WINDOWRANGE", + "TOK_WINDOWVALUES"), + spec) + + // Handle Partition By and Order By. + val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => + val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = + getClauses( + Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), + partitionAndOrdering.children) + + (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { + case (Some(partitionByExpr), Some(orderByExpr), None) => + (partitionByExpr.children.map(nodeToExpr), + orderByExpr.children.map(nodeToSortOrder)) + case (Some(partitionByExpr), None, None) => + (partitionByExpr.children.map(nodeToExpr), Nil) + case (None, Some(orderByExpr), None) => + (Nil, orderByExpr.children.map(nodeToSortOrder)) + case (None, None, Some(clusterByExpr)) => + val expressions = clusterByExpr.children.map(nodeToExpr) + (expressions, expressions.map(SortOrder(_, Ascending))) + case _ => + noParseRule("Partition & Ordering", partitionAndOrdering) + } + }.getOrElse { + (Nil, Nil) + } + + // Handle Window Frame + val windowFrame = + if (rowFrame.isEmpty && rangeFrame.isEmpty) { + UnspecifiedFrame + } else { + val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) + def nodeToBoundary(node: ASTNode): FrameBoundary = node match { + case Token(PRECEDING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedPreceding + } else { + ValuePreceding(count.toInt) + } + case Token(FOLLOWING(), Token(count, Nil) :: Nil) => + if (count.toLowerCase() == "unbounded") { + UnboundedFollowing + } else { + ValueFollowing(count.toInt) + } + case Token(CURRENT(), Nil) => CurrentRow + case _ => + noParseRule("Window Frame Boundary", node) + } + + rowFrame.orElse(rangeFrame).map { frame => + frame.children match { + case precedingNode :: followingNode :: Nil => + SpecifiedWindowFrame( + frameType, + nodeToBoundary(precedingNode), + nodeToBoundary(followingNode)) + case precedingNode :: Nil => + SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) + case _ => + noParseRule("Window Frame", frame) + } + }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) + } + + WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) + } + + protected def nodeToTransformation( + node: ASTNode, + child: LogicalPlan): Option[ScriptTransformation] = None + + val explode = "(?i)explode".r + val jsonTuple = "(?i)json_tuple".r + protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = { + val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node + + val alias = getClause("TOK_TABALIAS", clauses).children.head.text + + val generator = clauses.head match { + case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) => + Explode(nodeToExpr(childNode)) + case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => + JsonTuple(children.map(nodeToExpr)) + case other => + nodeToGenerator(other) + } + + val attributes = clauses.collect { + case Token(a, Nil) => UnresolvedAttribute(a.toLowerCase) + } + + Generate(generator, join = true, outer = outer, Some(alias.toLowerCase), attributes, child) + } + + protected def nodeToGenerator(node: ASTNode): Generator = noParseRule("Generator", node) + + protected def noParseRule(msg: String, node: ASTNode): Nothing = throw new NotImplementedError( + s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}") +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index c8ee87e8819f..b5de60cdb7b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -17,21 +17,20 @@ package org.apache.spark.sql.catalyst -import java.beans.{PropertyDescriptor, Introspector} +import java.beans.{Introspector, PropertyDescriptor} import java.lang.{Iterable => JIterable} -import java.util.{Iterator => JIterator, Map => JMap, List => JList} +import java.util.{Iterator => JIterator, List => JList, Map => JMap} import scala.language.existentials import com.google.common.reflect.TypeToken -import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - /** * Type-inference utilities for POJOs and Java collections. */ @@ -178,23 +177,23 @@ object JavaTypeInference { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath case c if c == classOf[java.lang.Short] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Integer] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Long] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Double] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Byte] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Float] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.lang.Boolean] => - NewInstance(c, getPath :: Nil, propagateNull = true, ObjectType(c)) + NewInstance(c, getPath :: Nil, ObjectType(c)) case c if c == classOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(c), "toJavaDate", getPath :: Nil, @@ -202,7 +201,7 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(c), "toJavaTimestamp", getPath :: Nil, @@ -276,7 +275,7 @@ object JavaTypeInference { ObjectType(classOf[Array[Any]])) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[JMap[_, _]]), "toJavaMap", keyData :: valueData :: Nil) @@ -288,10 +287,17 @@ object JavaTypeInference { val setters = properties.map { p => val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType - p.getWriteMethod.getName -> constructorFor(fieldType, Some(addToPath(fieldName))) + val (_, nullable) = inferDataType(fieldType) + val constructor = constructorFor(fieldType, Some(addToPath(fieldName))) + val setter = if (nullable) { + constructor + } else { + AssertNotNull(constructor, other.getName, fieldName, fieldType.toString) + } + p.getWriteMethod.getName -> setter }.toMap - val newInstance = NewInstance(other, Nil, propagateNull = false, ObjectType(other)) + val newInstance = NewInstance(other, Nil, ObjectType(other), propagateNull = false) val result = InitializeJavaBean(newInstance, setters) if (path.nonEmpty) { @@ -341,21 +347,21 @@ object JavaTypeInference { case c if c == classOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case c if c == classOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case c if c == classOf[java.math.BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index ecff8605706d..79f723cf9b8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -68,7 +68,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(elementType)) = tpe arrayClassFor(elementType) case other => - val clazz = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + val clazz = getClassFromType(tpe) ObjectType(clazz) } } @@ -177,6 +177,7 @@ object ScalaReflection extends ScalaReflection { case _ => UpCast(expr, expected, walkedTypePath) } + val className = getClassNameFromType(tpe) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath @@ -189,41 +190,41 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Long] => val boxedType = classOf[java.lang.Long] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Double] => val boxedType = classOf[java.lang.Double] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Float] => val boxedType = classOf[java.lang.Float] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Short] => val boxedType = classOf[java.lang.Short] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Byte] => val boxedType = classOf[java.lang.Byte] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.lang.Boolean] => val boxedType = classOf[java.lang.Boolean] val objectType = ObjectType(boxedType) - NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + NewInstance(boxedType, getPath :: Nil, objectType) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", getPath :: Nil, @@ -231,7 +232,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", getPath :: Nil, @@ -287,7 +288,7 @@ object ScalaReflection extends ScalaReflection { ObjectType(classOf[Array[Any]])) StaticInvoke( - scala.collection.mutable.WrappedArray, + scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", arrayData :: Nil) @@ -315,36 +316,18 @@ object ScalaReflection extends ScalaReflection { ObjectType(classOf[Array[Any]])) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } + val params = getConstructorParameters(t) - val cls = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) + val cls = getClassFromType(tpe) - val arguments = params.head.zipWithIndex.map { case (p, i) => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) - val dataType = schemaFor(fieldType).dataType + val arguments = params.zipWithIndex.map { case ((fieldName, fieldType), i) => + val Schema(dataType, nullable) = schemaFor(fieldType) val clsName = getClassNameFromType(fieldType) val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. @@ -354,14 +337,20 @@ object ScalaReflection extends ScalaReflection { Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - constructorFor( + val constructor = constructorFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) + + if (!nullable) { + AssertNotNull(constructor, t.toString, fieldName, fieldType.toString) + } else { + constructor + } } } - val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + val newInstance = NewInstance(cls, arguments, ObjectType(cls), propagateNull = false) if (path.nonEmpty) { expressions.If( @@ -372,6 +361,16 @@ object ScalaReflection extends ScalaReflection { } else { newInstance } + + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil) } } @@ -392,7 +391,7 @@ object ScalaReflection extends ScalaReflection { val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil extractorFor(inputObject, tpe, walkedTypePath) match { - case s: CreateNamedStruct => s + case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } @@ -406,7 +405,7 @@ object ScalaReflection extends ScalaReflection { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { val externalDataType = dataTypeFor(elementType) val Schema(catalystType, nullable) = silentSchemaFor(elementType) - if (isNativeType(catalystType)) { + if (isNativeType(externalDataType)) { NewInstance( classOf[GenericArrayData], input :: Nil, @@ -421,6 +420,7 @@ object ScalaReflection extends ScalaReflection { if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { + val className = getClassNameFromType(tpe) tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -477,33 +477,15 @@ object ScalaReflection extends ScalaReflection { } case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = - constructorSymbol.asTerm.alternatives.find(s => - s.isMethod && s.asMethod.isPrimaryConstructor) - - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } - - CreateNamedStruct(params.head.flatMap { p => - val fieldName = p.name.toString - val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val params = getConstructorParameters(t) + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) case t if t <:< localTypeOf[Array[_]] => val TypeRef(_, _, Seq(elementType)) = t @@ -548,28 +530,28 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case t if t <:< localTypeOf[java.sql.Date] => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case t if t <:< localTypeOf[BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) case t if t <:< localTypeOf[java.math.BigDecimal] => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) @@ -589,12 +571,37 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) + case t if Utils.classIsLoadable(className) && + Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => + val udt = Utils.classForName(className) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) + case other => throw new UnsupportedOperationException( s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } } } + + /** + * Returns the parameter names and types for the primary constructor of this class. + * + * Note that it only works for scala classes with primary constructor, and currently doesn't + * support inner class. + */ + def getConstructorParameters(cls: Class[_]): Seq[(String, Type)] = { + val m = runtimeMirror(cls.getClassLoader) + val classSymbol = m.staticClass(cls.getName) + val t = classSymbol.selfType + getConstructorParameters(t) + } + + def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) } /** @@ -668,26 +675,11 @@ trait ScalaReflection { Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) case t if t <:< localTypeOf[Product] => - val formalTypeArgs = t.typeSymbol.asClass.typeParams - val TypeRef(_, _, actualTypeArgs) = t - val constructorSymbol = t.member(nme.CONSTRUCTOR) - val params = if (constructorSymbol.isMethod) { - constructorSymbol.asMethod.paramss - } else { - // Find the primary constructor, and use its parameter ordering. - val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( - s => s.isMethod && s.asMethod.isPrimaryConstructor) - if (primaryConstructorSymbol.isEmpty) { - sys.error("Internal SQL error: Product object did not have a primary constructor.") - } else { - primaryConstructorSymbol.get.asMethod.paramss - } - } + val params = getConstructorParameters(t) Schema(StructType( - params.head.map { p => - val Schema(dataType, nullable) = - schemaFor(p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)) - StructField(p.name.toString, dataType, nullable) + params.map { case (fieldName, fieldType) => + val Schema(dataType, nullable) = schemaFor(fieldType) + StructField(fieldName, dataType, nullable) }), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) @@ -740,4 +732,32 @@ trait ScalaReflection { assert(methods.length == 1) methods.head.getParameterTypes } + + /** + * Returns the parameter names and types for the primary constructor of this type. + * + * Note that it only works for scala classes with primary constructor, and currently doesn't + * support inner class. + */ + def getConstructorParameters(tpe: Type): Seq[(String, Type)] = { + val formalTypeArgs = tpe.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = tpe + val constructorSymbol = tpe.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = constructorSymbol.asTerm.alternatives.find( + s => s.isMethod && s.asMethod.isPrimaryConstructor) + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + params.flatten.map { p => + p.name.toString -> p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index ca00a5e49f66..8a33af820735 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.analysis import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{CatalystConf, ScalaReflection, SimpleCatalystConf} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.sql.catalyst.{ScalaReflection, SimpleCatalystConf, CatalystConf} import org.apache.spark.sql.types._ /** @@ -77,6 +77,8 @@ class Analyzer( ResolveGenerate :: ResolveFunctions :: ResolveAliases :: + ResolveWindowOrder :: + ResolveWindowFrame :: ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: @@ -84,8 +86,7 @@ class Analyzer( HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, - PullOutNondeterministic, - ComputeCurrentTime), + PullOutNondeterministic), Batch("UDF", Once, HandleNullInputsForUDF), Batch("Cleanup", fixedPoint, @@ -127,14 +128,12 @@ class Analyzer( // Lookup WindowSpecDefinitions. This rule works with unresolved children. case WithWindowDefinition(windowDefinitions, child) => child.transform { - case plan => plan.transformExpressions { + case p => p.transformExpressions { case UnresolvedWindowExpression(c, WindowSpecReference(windowName)) => val errorMessage = s"Window specification $windowName is not defined in the WINDOW clause." val windowSpecDefinition = - windowDefinitions - .get(windowName) - .getOrElse(failAnalysis(errorMessage)) + windowDefinitions.getOrElse(windowName, failAnalysis(errorMessage)) WindowExpression(c, windowSpecDefinition) } } @@ -149,12 +148,12 @@ class Analyzer( exprs.zipWithIndex.map { case (expr, i) => expr transform { - case u @ UnresolvedAlias(child) => child match { + case u @ UnresolvedAlias(child, optionalAliasName) => child match { case ne: NamedExpression => ne case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) case c @ Cast(ne: NamedExpression, _) => Alias(c, ne.name)() - case other => Alias(other, s"_c$i")() + case other => Alias(other, optionalAliasName.getOrElse(s"_c$i"))() } } }.asInstanceOf[Seq[NamedExpression]] @@ -208,10 +207,10 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case a if !a.childrenResolved => a // be sure all of the children are resolved. - case a: Cube => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) - case a: Rollup => - GroupingSets(bitmasks(a), a.groupByExprs, a.child, a.aggregations) + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => + GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) + case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => + GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) case x: GroupingSets => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() @@ -287,7 +286,7 @@ class Analyzer( } } val newGroupByExprs = groupByExprs.map { - case UnresolvedAlias(e) => e + case UnresolvedAlias(e, _) => e case e => e } Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child) @@ -352,19 +351,19 @@ class Analyzer( Project( projectList.flatMap { case s: Star => s.expand(child, resolver) - case UnresolvedAlias(f @ UnresolvedFunction(_, args, _)) if containsStar(args) => + case UnresolvedAlias(f @ UnresolvedFunction(_, args, _), _) if containsStar(args) => val newChildren = expandStarExpressions(args, child) UnresolvedAlias(child = f.copy(children = newChildren)) :: Nil case Alias(f @ UnresolvedFunction(_, args, _), name) if containsStar(args) => val newChildren = expandStarExpressions(args, child) Alias(child = f.copy(children = newChildren), name)() :: Nil - case UnresolvedAlias(c @ CreateArray(args)) if containsStar(args) => + case UnresolvedAlias(c @ CreateArray(args), _) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil } UnresolvedAlias(c.copy(children = expandedArgs)) :: Nil - case UnresolvedAlias(c @ CreateStruct(args)) if containsStar(args) => + case UnresolvedAlias(c @ CreateStruct(args), _) if containsStar(args) => val expandedArgs = args.flatMap { case s: Star => s.expand(child, resolver) case o => o :: Nil @@ -577,6 +576,10 @@ class Analyzer( AggregateExpression(max, Complete, isDistinct = false) case min: Min if isDistinct => AggregateExpression(min, Complete, isDistinct = false) + // AggregateWindowFunctions are AggregateFunctions that can only be evaluated within + // the context of a Window clause. They do not need to be wrapped in an + // AggregateExpression. + case wf: AggregateWindowFunction => wf // We get an aggregate function, we need to wrap it in an AggregateExpression. case agg: AggregateFunction => AggregateExpression(agg, Complete, isDistinct) // This function is not an aggregate function, just return the resolved one. @@ -597,11 +600,17 @@ class Analyzer( } def containsAggregates(exprs: Seq[Expression]): Boolean = { - exprs.foreach(_.foreach { - case agg: AggregateExpression => return true - case _ => - }) - false + // Collect all Windowed Aggregate Expressions. + val windowedAggExprs = exprs.flatMap { expr => + expr.collect { + case WindowExpression(ae: AggregateExpression, _) => ae + } + }.toSet + + // Find the first Aggregate Expression that is not Windowed. + exprs.exists(_.collectFirst { + case ae: AggregateExpression if !windowedAggExprs.contains(ae) => ae + }.isDefined) } } @@ -875,26 +884,37 @@ class Analyzer( // Now, we extract regular expressions from expressionsWithWindowFunctions // by using extractExpr. + val seenWindowAggregates = new ArrayBuffer[AggregateExpression] val newExpressionsWithWindowFunctions = expressionsWithWindowFunctions.map { _.transform { // Extracts children expressions of a WindowFunction (input parameters of // a WindowFunction). case wf : WindowFunction => - val newChildren = wf.children.map(extractExpr(_)) + val newChildren = wf.children.map(extractExpr) wf.withNewChildren(newChildren) // Extracts expressions from the partition spec and order spec. case wsc @ WindowSpecDefinition(partitionSpec, orderSpec, _) => - val newPartitionSpec = partitionSpec.map(extractExpr(_)) + val newPartitionSpec = partitionSpec.map(extractExpr) val newOrderSpec = orderSpec.map { so => val newChild = extractExpr(so.child) so.copy(child = newChild) } wsc.copy(partitionSpec = newPartitionSpec, orderSpec = newOrderSpec) + // Extract Windowed AggregateExpression + case we @ WindowExpression( + AggregateExpression(function, mode, isDistinct), + spec: WindowSpecDefinition) => + val newChildren = function.children.map(extractExpr) + val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] + val newAgg = AggregateExpression(newFunction, mode, isDistinct) + seenWindowAggregates += newAgg + WindowExpression(newAgg, spec) + // Extracts AggregateExpression. For example, for SUM(x) - Sum(y) OVER (...), // we need to extract SUM(x). - case agg: AggregateExpression => + case agg: AggregateExpression if !seenWindowAggregates.contains(agg) => val withName = Alias(agg, s"_w${extractedExprBuffer.length}")() extractedExprBuffer += withName withName.toAttribute @@ -1102,6 +1122,42 @@ class Analyzer( } } } + + /** + * Check and add proper window frames for all window functions. + */ + object ResolveWindowFrame extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(wf: WindowFunction, + WindowSpecDefinition(_, _, f: SpecifiedWindowFrame)) + if wf.frame != UnspecifiedFrame && wf.frame != f => + failAnalysis(s"Window Frame $f must match the required frame ${wf.frame}") + case WindowExpression(wf: WindowFunction, + s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) + if wf.frame != UnspecifiedFrame => + WindowExpression(wf, s.copy(frameSpecification = wf.frame)) + case we @ WindowExpression(e, s @ WindowSpecDefinition(_, o, UnspecifiedFrame)) => + val frame = SpecifiedWindowFrame.defaultWindowFrame(o.nonEmpty, acceptWindowFrame = true) + we.copy(windowSpec = s.copy(frameSpecification = frame)) + } + } + } + + /** + * Check and add order to [[AggregateWindowFunction]]s. + */ + object ResolveWindowOrder extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case logical: LogicalPlan => logical transformExpressions { + case WindowExpression(wf: WindowFunction, spec) if spec.orderSpec.isEmpty => + failAnalysis(s"WindowFunction $wf requires window to be ordered") + case WindowExpression(rank: RankLike, spec) if spec.resolved => + val order = spec.orderSpec.map(_.child) + WindowExpression(rank.withOrder(order), spec) + } + } + } } /** @@ -1172,23 +1228,6 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** - * Computes the current date and time to make sure we return the same result in a single query. - */ -object ComputeCurrentTime extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = { - val dateExpr = CurrentDate() - val timeExpr = CurrentTimestamp() - val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) - val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) - - plan transformAllExpressions { - case CurrentDate() => currentDate - case CurrentTimestamp() => currentTime - } - } -} - /** * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala index 8f4ce74a2ea3..a8f89ce6de45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Catalog.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{TableIdentifier, CatalystConf, EmptyConf} +import org.apache.spark.sql.catalyst.{CatalystConf, EmptyConf, TableIdentifier} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Subquery} /** @@ -104,13 +104,15 @@ class SimpleCatalog(val conf: CatalystConf) extends Catalog { val tableName = getTableName(tableIdent) val table = tables.get(tableName) if (table == null) { - throw new NoSuchTableException + throw new AnalysisException("Table not found: " + tableName) } val tableWithQualifiers = Subquery(tableName, table) // If an alias was specified by the lookup, wrap the plan in a subquery so that attributes are // properly qualified with this alias. - alias.map(a => Subquery(a, tableWithQualifiers)).getOrElse(tableWithQualifiers) + alias + .map(a => Subquery(a, tableWithQualifiers)) + .getOrElse(tableWithQualifiers) } override def getTables(databaseName: Option[String]): Seq[(String, Boolean)] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 7b2c93d63d67..2a2e0d27d943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, AggregateExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -57,7 +57,7 @@ trait CheckAnalysis { operator transformExpressionsUp { case a: Attribute if !a.resolved => val from = operator.inputSet.map(_.name).mkString(", ") - a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns $from") + a.failAnalysis(s"cannot resolve '${a.prettyString}' given input columns: [$from]") case e: Expression if e.checkInputDataTypes().isFailure => e.checkInputDataTypes() match { @@ -70,15 +70,32 @@ trait CheckAnalysis { failAnalysis( s"invalid cast from ${c.child.dataType.simpleString} to ${c.dataType.simpleString}") - case WindowExpression(UnresolvedWindowFunction(name, _), _) => - failAnalysis( - s"Could not resolve window function '$name'. " + - "Note that, using window functions currently requires a HiveContext") + case w @ WindowExpression(AggregateExpression(_, _, true), _) => + failAnalysis(s"Distinct window functions are not supported: $w") + + case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, + SpecifiedWindowFrame(frame, + FrameBoundary(l), + FrameBoundary(h)))) + if order.isEmpty || frame != RowFrame || l != h => + failAnalysis("An offset window function can only be evaluated in an ordered " + + s"row-based window frame with a single offset: $w") + + case w @ WindowExpression(e, s) => + // Only allow window functions with an aggregate expression or an offset window + // function. + e match { + case _: AggregateExpression | _: OffsetWindowFunction | _: AggregateWindowFunction => + case _ => + failAnalysis(s"Expression '$e' not supported within a window function.") + } + // Make sure the window specification is valid. + s.validate match { + case Some(m) => + failAnalysis(s"Window specification $s is not valid because $m") + case None => w + } - case w @ WindowExpression(windowFunction, windowSpec) if windowSpec.validate.nonEmpty => - // The window spec is not valid. - val reason = windowSpec.validate.get - failAnalysis(s"Window specification $windowSpec is not valid because $reason") } operator match { @@ -204,10 +221,12 @@ trait CheckAnalysis { s"unresolved operator ${operator.simpleString}") case o if o.expressions.exists(!_.deterministic) && - !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] & !o.isInstanceOf[Aggregate] => + !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && + !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] => // The rule above is used to check Aggregate operator. failAnalysis( - s"""nondeterministic expressions are only allowed in Project or Filter, found: + s"""nondeterministic expressions are only allowed in + |Project, Filter, Aggregate or Window, found: | ${o.expressions.map(_.prettyString).mkString(",")} |in operator ${operator.simpleString} """.stripMargin) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f9c04d7ec0b0..5c2aa3c06b3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -49,7 +49,7 @@ trait FunctionRegistry { class SimpleFunctionRegistry extends FunctionRegistry { - private val functionBuilders = + private[sql] val functionBuilders = StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) override def registerFunction( @@ -278,12 +278,27 @@ object FunctionRegistry { // misc functions expression[Crc32]("crc32"), expression[Md5]("md5"), + expression[Murmur3Hash]("hash"), expression[Sha1]("sha"), expression[Sha1]("sha1"), expression[Sha2]("sha2"), expression[SparkPartitionID]("spark_partition_id"), expression[InputFileName]("input_file_name"), - expression[MonotonicallyIncreasingID]("monotonically_increasing_id") + expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), + + // grouping sets + expression[Cube]("cube"), + expression[Rollup]("rollup"), + + // window functions + expression[Lead]("lead"), + expression[Lag]("lag"), + expression[RowNumber]("row_number"), + expression[CumeDist]("cume_dist"), + expression[NTile]("ntile"), + expression[Rank]("rank"), + expression[DenseRank]("dense_rank"), + expression[PercentRank]("percent_rank") ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 4f89b462a6ce..fc0e87aa68ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{errors, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode} +import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode -import org.apache.spark.sql.catalyst.{TableIdentifier, errors} import org.apache.spark.sql.types.{DataType, StructType} /** @@ -284,8 +284,12 @@ case class UnresolvedExtractValue(child: Expression, extraction: Expression) /** * Holds the expression that has yet to be aliased. + * + * @param child The computation that is needs to be resolved during analysis. + * @param aliasName The name if specified to be asoosicated with the result of computing [[child]] + * */ -case class UnresolvedAlias(child: Expression) +case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) extends UnaryExpression with NamedExpression with Unevaluable { override def toAttribute: Attribute = throw new UnresolvedException(this, "toAttribute") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 8102c93c6f10..5ac1984043d8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,11 +21,11 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 7a4401cf5810..05f746e72b49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -22,15 +22,15 @@ import java.util.concurrent.ConcurrentMap import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} -import org.apache.spark.util.Utils import org.apache.spark.sql.{AnalysisException, Encoder} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedExtractValue, UnresolvedAttribute} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.{JavaTypeInference, InternalRow, ScalaReflection} -import org.apache.spark.sql.types.{StructField, ObjectType, StructType} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.types.{ObjectType, StructField, StructType} +import org.apache.spark.util.Utils /** * A factory for constructing encoders that convert objects and primitives to and from the @@ -133,7 +133,7 @@ object ExpressionEncoder { } val fromRowExpression = - NewInstance(cls, fromRowExpressions, propagateNull = false, ObjectType(cls)) + NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( schema, @@ -198,6 +198,15 @@ case class ExpressionEncoder[T]( @transient private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) + /** + * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns + * is performed). + */ + def defaultBinding: ExpressionEncoder[T] = { + val attrs = schema.toAttributes + resolve(attrs, OuterScopes.outerScopes).bind(attrs) + } + /** * Returns an encoded version of `t` as a Spark SQL row. Note that multiple calls to * toRow are allowed to return the same actual [[InternalRow]] object. Thus, the caller should diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index d34ec9408ae1..89d40b3b2c14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -22,7 +22,7 @@ import scala.reflect.ClassTag import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -35,7 +35,8 @@ object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) - val extractExpressions = extractorsFor(inputObject, schema) + // We use an If expression to wrap extractorsFor result of StructType + val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue val constructExpression = constructorFor(schema) new ExpressionEncoder[Row]( schema, @@ -55,27 +56,26 @@ object RowEncoder { val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, - false, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil) case TimestampType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, TimestampType, "fromJavaTimestamp", inputObject :: Nil) case DateType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, DateType, "fromJavaDate", inputObject :: Nil) case _: DecimalType => StaticInvoke( - Decimal, + Decimal.getClass, DecimalType.SYSTEM_DEFAULT, "apply", inputObject :: Nil) @@ -130,7 +130,9 @@ object RowEncoder { Invoke(inputObject, method, externalDataTypeFor(f.dataType), Literal(i) :: Nil), f.dataType)) } - CreateStruct(convertedFields) + If(IsNull(inputObject), + Literal.create(null, inputType), + CreateStruct(convertedFields)) } private def externalDataTypeFor(dt: DataType): DataType = dt match { @@ -166,20 +168,19 @@ object RowEncoder { val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, - false, dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) Invoke(obj, "deserialize", ObjectType(udt.userClass), input :: Nil) case TimestampType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", input :: Nil) case DateType => StaticInvoke( - DateTimeUtils, + DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", input :: Nil) @@ -197,7 +198,7 @@ object RowEncoder { "array", ObjectType(classOf[Array[_]])) StaticInvoke( - scala.collection.mutable.WrappedArray, + scala.collection.mutable.WrappedArray.getClass, ObjectType(classOf[Seq[_]]), "make", arrayData :: Nil) @@ -210,7 +211,7 @@ object RowEncoder { val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( - ArrayBasedMapData, + ArrayBasedMapData.getClass, ObjectType(classOf[Map[_, _]]), "toScalaMap", keyData :: valueData :: Nil) @@ -222,6 +223,8 @@ object RowEncoder { Literal.create(null, externalDataTypeFor(f.dataType)), constructorFor(GetStructField(input, i))) } - CreateExternalRow(convertedFields) + If(IsNull(input), + Literal.create(null, externalDataTypeFor(input.dataType)), + CreateExternalRow(convertedFields)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index ff1f28ddbbf3..7293d5d4472a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -69,10 +69,17 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) - s""" - boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); - """ + if (nullable) { + s""" + boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); + $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); + """ + } else { + ev.isNull = "false" + s""" + $javaType ${ev.value} = $value; + """ + } } } @@ -92,7 +99,7 @@ object BindReferences extends Logging { sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") } } else { - BoundReference(ordinal, a.dataType, a.nullable) + BoundReference(ordinal, a.dataType, input(ordinal).nullable) } } }.asInstanceOf[A] // Kind of a hack, but safe. TODO: Tighten return type when possible. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index cb60d5958d53..6f199cfc5d8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.math.{BigDecimal => JavaBigDecimal} +import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -81,24 +82,31 @@ object Cast { toField.nullable) } + case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt1.userClass == udt2.userClass => + true + case _ => false } private def resolvableNullability(from: Boolean, to: Boolean) = !from || to private def forceNullable(from: DataType, to: DataType) = (from, to) match { - case (StringType, _: NumericType) => true - case (StringType, TimestampType) => true - case (DoubleType, TimestampType) => true - case (FloatType, TimestampType) => true - case (StringType, DateType) => true - case (_: NumericType, DateType) => true - case (BooleanType, DateType) => true - case (DateType, _: NumericType) => true - case (DateType, BooleanType) => true - case (DoubleType, _: DecimalType) => true - case (FloatType, _: DecimalType) => true - case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null + case (NullType, _) => true + case (_, _) if from == to => false + + case (StringType, BinaryType) => false + case (StringType, _) => true + case (_, StringType) => false + + case (FloatType | DoubleType, TimestampType) => true + case (TimestampType, DateType) => false + case (_, DateType) => true + case (DateType, TimestampType) => false + case (DateType, _) => true + case (_, CalendarIntervalType) => true + + case (_, _: DecimalType) => true // overflow + case (_: FractionalType, _: IntegralType) => true // NaN, infinity case _ => false } } @@ -427,6 +435,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) case map: MapType => castMap(from.asInstanceOf[MapType], map) case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + identity[Any] + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") } private[this] lazy val cast: Any => Any = cast(child.dataType, dataType) @@ -469,6 +482,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) + case udt: UserDefinedType[_] + if udt.userClass == from.asInstanceOf[UserDefinedType[_]].userClass => + (c, evPrim, evNull) => s"$evPrim = $c;" + case _: UserDefinedType[_] => + throw new SparkException(s"Cannot cast $from to $to.") } // Since we need to cast child expressions recursively inside ComplexTypes, such as Map's @@ -913,6 +931,14 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { $evPrim = $result.copy(); """ } + + override def sql: String = dataType match { + // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this + // type of casting can only be introduced by the analyzer, and can be omitted when converting + // back to SQL query string. + case _: ArrayType | _: MapType | _: StructType => child.sql + case _ => s"CAST(${child.sql} AS ${dataType.sql})" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala index 2dcbd4eb1503..04650d85dec0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpectsInputTypes.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.types.AbstractDataType -import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion.ImplicitTypeCasts /** * An trait that gets mixin to define the expected input types of an expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 6d807c9ecf30..d6219514b752 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -18,9 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, TypeCheckResult, UnresolvedAttribute} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.trees.TreeNode +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -223,6 +224,15 @@ abstract class Expression extends TreeNode[Expression] { protected def toCommentSafeString: String = this.toString .replace("*/", "\\*\\/") .replace("\\u", "\\\\u") + + /** + * Returns SQL representation of this expression. For expressions that don't have a SQL + * representation (e.g. `ScalaUDF`), this method should throw an `UnsupportedOperationException`. + */ + @throws[UnsupportedOperationException](cause = "Expression doesn't have a SQL representation") + def sql: String = throw new UnsupportedOperationException( + s"Cannot map expression $this to its SQL representation" + ) } @@ -340,15 +350,24 @@ abstract class UnaryExpression extends Expression { ev: GeneratedExpressionCode, f: String => String): String = { val eval = child.gen(ctx) - val resultCode = f(eval.value) - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - $resultCode - } - """ + if (nullable) { + eval.code + s""" + boolean ${ev.isNull} = ${eval.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${eval.isNull}) { + ${f(eval.value)} + } + """ + } else { + ev.isNull = "false" + eval.code + s""" + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + ${f(eval.value)} + """ + } } + + override def sql: String = s"($prettyName(${child.sql}))" } @@ -424,20 +443,33 @@ abstract class BinaryExpression extends Expression { val eval1 = left.gen(ctx) val eval2 = right.gen(ctx) val resultCode = f(eval1.value, eval2.value) - s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; + if (nullable) { + s""" + ${eval1.code} + boolean ${ev.isNull} = ${eval1.isNull}; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${eval2.code} + if (!${eval2.isNull}) { + $resultCode + } else { + ${ev.isNull} = true; + } } - } - """ + """ + + } else { + ev.isNull = "false" + s""" + ${eval1.code} + ${eval2.code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ + } } + + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } @@ -474,6 +506,8 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { TypeCheckResult.TypeCheckSuccess } } + + override def sql: String = s"(${left.sql} $symbol ${right.sql})" } @@ -548,20 +582,36 @@ abstract class TernaryExpression extends Expression { f: (String, String, String) => String): String = { val evals = children.map(_.gen(ctx)) val resultCode = f(evals(0).value, evals(1).value, evals(2).value) - s""" - ${evals(0).code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${evals(0).isNull}) { - ${evals(1).code} - if (!${evals(1).isNull}) { - ${evals(2).code} - if (!${evals(2).isNull}) { - ${ev.isNull} = false; // resultCode could change nullability - $resultCode + if (nullable) { + s""" + ${evals(0).code} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + if (!${evals(0).isNull}) { + ${evals(1).code} + if (!${evals(1).isNull}) { + ${evals(2).code} + if (!${evals(2).isNull}) { + ${ev.isNull} = false; // resultCode could change nullability + $resultCode + } } } - } - """ + """ + } else { + ev.isNull = "false" + s""" + ${evals(0).code} + ${evals(1).code} + ${evals(2).code} + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + $resultCode + """ + } + } + + override def sql: String = { + val childrenSQL = children.map(_.sql).mkString(", ") + s"$prettyName($childrenSQL)" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index bf215783fc27..827dce8af100 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -19,13 +19,16 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.rdd.SqlNewHadoopRDDState import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String /** * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the name of the current file being read if available", + extended = "> SELECT _FUNC_();\n ''") case class InputFileName() extends LeafExpression with Nondeterministic { override def nullable: Boolean = true @@ -46,4 +49,5 @@ case class InputFileName() extends LeafExpression with Nondeterministic { "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } + override def sql: String = prettyName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala index 935c3aa28c99..ed894f6d6e10 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - /** * A mutable wrapper that makes two rows appear as a single concatenated row. Designed to * be instantiated once per thread and reused. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 2d7679fdfe04..94f8801dec36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.types.{LongType, DataType} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.types.{DataType, LongType} /** * Returns monotonically increasing 64-bit integers. @@ -32,6 +32,14 @@ import org.apache.spark.sql.types.{LongType, DataType} * * Since this expression is stateful, it cannot be a case object. */ +@ExpressionDescription( + usage = + """_FUNC_() - Returns monotonically increasing 64-bit integers. + The generated ID is guaranteed to be monotonically increasing and unique, but not consecutive. + The current implementation puts the partition ID in the upper 31 bits, and the lower 33 bits + represent the record number within each partition. The assumption is that the data frame has + less than 1 billion partitions, and each partition has less than 8 billion records.""", + extended = "> SELECT _FUNC_();\n 0") private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with Nondeterministic { /** @@ -70,4 +78,8 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with $countTerm++; """ } + + override def prettyName: String = "monotonically_increasing_id" + + override def sql: String = s"$prettyName()" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 85faa19bbf5e..3a6c909fffce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types.DataType @@ -30,7 +29,10 @@ import org.apache.spark.sql.types.DataType * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. * @param children The input expressions of this UDF. - * @param inputTypes The expected input types of this UDF. + * @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do + * not want to perform coercion, simply use "Nil". Note that it would've been + * better to use Option of Seq[DataType] so we can use "None" as the case for no + * type coercion. However, that would require more refactoring of the codebase. */ case class ScalaUDF( function: AnyRef, @@ -43,7 +45,7 @@ case class ScalaUDF( override def toString: String = s"UDF(${children.mkString(",")})" - // scalastyle:off + // scalastyle:off line.size.limit /** This method has been generated by this script @@ -969,7 +971,7 @@ case class ScalaUDF( } } - // scalastyle:on + // scalastyle:on line.size.limit // Generate codes used to convert the arguments to Scala type for user-defined funtions private[this] def genCodeForConverter(ctx: CodeGenContext, index: Int): String = { @@ -1010,7 +1012,7 @@ case class ScalaUDF( // This must be called before children expressions' codegen // because ctx.references is used in genCodeForConverter - val converterTerms = (0 until children.size).map(genCodeForConverter(ctx, _)) + val converterTerms = children.indices.map(genCodeForConverter(ctx, _)) // Initialize user-defined function val funcClassName = s"scala.Function${children.size}" @@ -1054,5 +1056,6 @@ case class ScalaUDF( } private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + override def eval(input: InternalRow): Any = converter(f(input)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index 290c128d65b3..1cb1b9da3049 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -19,14 +19,22 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} import org.apache.spark.sql.types._ import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.BinaryPrefixComparator import org.apache.spark.util.collection.unsafe.sort.PrefixComparators.DoublePrefixComparator -abstract sealed class SortDirection -case object Ascending extends SortDirection -case object Descending extends SortDirection +abstract sealed class SortDirection { + def sql: String +} + +case object Ascending extends SortDirection { + override def sql: String = "ASC" +} + +case object Descending extends SortDirection { + override def sql: String = "DESC" +} /** * An expression that can be used to sort a tuple. This class extends expression primarily so that diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 8bff173d64eb..aa3951480c50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -19,13 +19,15 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.TaskContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.types.{IntegerType, DataType} - +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.types.{DataType, IntegerType} /** * Expression that returns the current partition id of the Spark task. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current partition id of the Spark task", + extended = "> SELECT _FUNC_();\n 0") private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterministic { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index d07d4c338cdf..30f602227b17 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -53,7 +53,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def children: Seq[Expression] = Seq(child) - override def nullable: Boolean = false + override def nullable: Boolean = true override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index 00d7436b710d..d25f3335ffd9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ /** @@ -42,7 +41,7 @@ case class Corr( override def children: Seq[Expression] = Seq(left, right) - override def nullable: Boolean = false + override def nullable: Boolean = true override def dataType: DataType = DoubleType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 441f52ab5ca5..663c69e799fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -31,7 +31,7 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(AnyDataType) - private lazy val count = AttributeReference("count", LongType)() + private lazy val count = AttributeReference("count", LongType, nullable = false)() override lazy val aggBufferAttributes = count :: Nil @@ -39,15 +39,24 @@ case class Count(children: Seq[Expression]) extends DeclarativeAggregate { /* count = */ Literal(0L) ) - override lazy val updateExpressions = Seq( - /* count = */ If(children.map(IsNull).reduce(Or), count, count + 1L) - ) + override lazy val updateExpressions = { + val nullableChildren = children.filter(_.nullable) + if (nullableChildren.isEmpty) { + Seq( + /* count = */ count + 1L + ) + } else { + Seq( + /* count = */ If(nullableChildren.map(IsNull).reduce(Or), count, count + 1L) + ) + } + } override lazy val mergeExpressions = Seq( /* count = */ count.left + count.right ) - override lazy val evaluateExpression = Cast(count, LongType) + override lazy val evaluateExpression = count override def defaultResult: Option[Literal] = Option(Literal(0L)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index cfb042e0aa78..08a67ea3df51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -40,8 +40,6 @@ case class Sum(child: Expression) extends DeclarativeAggregate { private lazy val resultType = child.dataType match { case DecimalType.Fixed(precision, scale) => DecimalType.bounded(precision + 10, scale) - // TODO: Remove this line once we remove the NullType from inputTypes. - case NullType => IntegerType case _ => child.dataType } @@ -57,18 +55,26 @@ case class Sum(child: Expression) extends DeclarativeAggregate { /* sum = */ Literal.create(null, sumDataType) ) - override lazy val updateExpressions: Seq[Expression] = Seq( - /* sum = */ - Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) - ) + override lazy val updateExpressions: Seq[Expression] = { + if (child.nullable) { + Seq( + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) + ) + } else { + Seq( + /* sum = */ + Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)) + ) + } + } override lazy val mergeExpressions: Seq[Expression] = { - val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( /* sum = */ - Coalesce(Seq(add, sum.left)) + Coalesce(Seq(Add(Coalesce(Seq(sum.left, zero)), sum.right), sum.left)) ) } - override lazy val evaluateExpression: Expression = Cast(sum, resultType) + override lazy val evaluateExpression: Expression = sum } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 3b441de34a49..ddd99c51ab0c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ /** The mode of an [[AggregateFunction]]. */ @@ -94,11 +94,13 @@ private[sql] case class AggregateExpression( override def prettyString: String = aggregateFunction.prettyString - override def toString: String = s"(${aggregateFunction},mode=$mode,isDistinct=$isDistinct)" + override def toString: String = s"($aggregateFunction,mode=$mode,isDistinct=$isDistinct)" + + override def sql: String = aggregateFunction.sql(isDistinct) } /** - * AggregateFunction2 is the superclass of two aggregation function interfaces: + * AggregateFunction is the superclass of two aggregation function interfaces: * * - [[ImperativeAggregate]] is for aggregation functions that are specified in terms of * initialize(), update(), and merge() functions that operate on Row-based aggregation buffers. @@ -144,9 +146,6 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu */ def defaultResult: Option[Literal] = None - override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = - throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - /** * Wraps this [[AggregateFunction]] in an [[AggregateExpression]] because * [[AggregateExpression]] is the container of an [[AggregateFunction]], aggregation mode, @@ -167,6 +166,11 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu def toAggregateExpression(isDistinct: Boolean): AggregateExpression = { AggregateExpression(aggregateFunction = this, mode = Complete, isDistinct = isDistinct) } + + def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$prettyName($distinct${children.map(_.sql).mkString(", ")})" + } } /** @@ -187,7 +191,7 @@ sealed abstract class AggregateFunction extends Expression with ImplicitCastInpu * `inputAggBufferOffset`, but not on the correctness of the attribute ids in `aggBufferAttributes` * and `inputAggBufferAttributes`. */ -abstract class ImperativeAggregate extends AggregateFunction { +abstract class ImperativeAggregate extends AggregateFunction with CodegenFallback { /** * The offset of this function's first buffer value in the underlying shared mutable aggregation diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 61a17fd7db0f..7bd851c059d0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -54,6 +54,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp numeric.negate(input) } } + + override def sql: String = s"(-${child.sql})" } case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { @@ -67,6 +69,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects defineCodeGen(ctx, ev, c => c) protected override def nullSafeEval(input: Any): Any = input + + override def sql: String = s"(+${child.sql})" } /** @@ -91,6 +95,8 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) + + override def sql: String = s"$prettyName(${child.sql})" } abstract class BinaryArithmetic extends BinaryOperator { @@ -513,4 +519,6 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { val r = a % n if (r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r } + + override def sql: String = s"$prettyName(${left.sql}, ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 440c7d2fc115..6daa8ee2f42b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -27,7 +27,7 @@ import org.codehaus.janino.ClassBodyEvaluator import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 26fb143d1e45..335358014879 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.{Nondeterministic, Expression} +import org.apache.spark.sql.catalyst.expressions.{Expression, Nondeterministic} /** * A trait that can be used to provide a fallback mode for expression code generation. @@ -32,14 +32,23 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") - s""" - /* expression: ${this.toCommentSafeString} */ - java.lang.Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = $objectTerm == null; - ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; - if (!${ev.isNull}) { - ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - } - """ + if (nullable) { + s""" + /* expression: ${this.toCommentSafeString} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + boolean ${ev.isNull} = $objectTerm == null; + ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; + if (!${ev.isNull}) { + ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + } + """ + } else { + ev.isNull = "false" + s""" + /* expression: ${this.toCommentSafeString} */ + Object $objectTerm = expressions[${ctx.references.size - 1}].eval(${ctx.INPUT_ROW}); + ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 40189f087776..a6ec242589fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -44,38 +44,55 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) - val isNull = s"isNull_$i" - val value = s"value_$i" - ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") - ctx.addMutableState(ctx.javaType(e.dataType), value, - s"this.$value = ${ctx.defaultValue(e.dataType)};") - s""" - ${evaluationCode.code} - this.$isNull = ${evaluationCode.isNull}; - this.$value = ${evaluationCode.value}; - """ + if (e.nullable) { + val isNull = s"isNull_$i" + val value = s"value_$i" + ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$isNull = ${evaluationCode.isNull}; + this.$value = ${evaluationCode.value}; + """ + } else { + val value = s"value_$i" + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$value = ${evaluationCode.value}; + """ + } } val updates = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => - if (e.dataType.isInstanceOf[DecimalType]) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - s""" - if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, null)}; - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ + if (e.nullable) { + if (e.dataType.isInstanceOf[DecimalType]) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + s""" + if (this.isNull_$i) { + ${ctx.setColumn("mutableRow", e.dataType, i, null)}; + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; + } + """ + } else { + s""" + if (this.isNull_$i) { + mutableRow.setNullAt($i); + } else { + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; + } + """ + } } else { s""" - if (this.isNull_$i) { - mutableRow.setNullAt($i); - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; """ } + } val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 13634b69457a..364dbb770f5e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.NoOp -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 68005afb21d2..d0e031f27990 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -135,14 +135,21 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => s"$rowWriter.write($index, ${input.value});" } - s""" - ${input.code} - if (${input.isNull}) { - ${setNull.trim} - } else { + if (input.isNull == "false") { + s""" + ${input.code} ${writeField.trim} - } - """ + """ + } else { + s""" + ${input.code} + if (${input.isNull}) { + ${setNull.trim} + } else { + ${writeField.trim} + } + """ + } } s""" @@ -282,7 +289,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val exprTypes = expressions.map(_.dataType) val result = ctx.freshName("result") - ctx.addMutableState("UnsafeRow", result, s"this.$result = new UnsafeRow();") + ctx.addMutableState("UnsafeRow", result, s"$result = new UnsafeRow(${expressions.length});") val bufferHolder = ctx.freshName("bufferHolder") val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") @@ -296,7 +303,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro $subexprReset ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} - $result.pointTo($bufferHolder.buffer, ${expressions.length}, $bufferHolder.totalSize()); + $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); """ GeneratedExpressionCode(code, "false", result) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index da602d9b4bce..88b3c5e47f6f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.catalyst.expressions.codegen -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.Platform - abstract class UnsafeRowJoiner { def join(row1: UnsafeRow, row2: UnsafeRow): UnsafeRow } @@ -61,9 +60,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val outputBitsetWords = (schema1.size + schema2.size + 63) / 64 val bitset1Remainder = schema1.size % 64 - // The number of words we can reduce when we concat two rows together. + // The number of bytes we can reduce when we concat two rows together. // The only reduction comes from merging the bitset portion of the two rows, saving 1 word. - val sizeReduction = bitset1Words + bitset2Words - outputBitsetWords + val sizeReduction = (bitset1Words + bitset2Words - outputBitsetWords) * 8 // --------------------- copy bitset from row 1 and row 2 --------------------------- // val copyBitset = Seq.tabulate(outputBitsetWords) { i => @@ -165,13 +164,13 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | |class SpecificUnsafeRowJoiner extends ${classOf[UnsafeRowJoiner].getName} { | private byte[] buf = new byte[64]; - | private UnsafeRow out = new UnsafeRow(); + | private UnsafeRow out = new UnsafeRow(${schema1.size + schema2.size}); | | public UnsafeRow join(UnsafeRow row1, UnsafeRow row2) { | // row1: ${schema1.size} fields, $bitset1Words words in bitset | // row2: ${schema2.size}, $bitset2Words words in bitset | // output: ${schema1.size + schema2.size} fields, $outputBitsetWords words in bitset - | final int sizeInBytes = row1.getSizeInBytes() + row2.getSizeInBytes(); + | final int sizeInBytes = row1.getSizeInBytes() + row2.getSizeInBytes() - $sizeReduction; | if (sizeInBytes > buf.length) { | buf = new byte[sizeInBytes]; | } @@ -188,7 +187,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U | $copyVariableLengthRow2 | $updateOffset | - | out.pointTo(buf, ${schema1.size + schema2.size}, sizeInBytes - $sizeReduction); + | out.pointTo(buf, sizeInBytes); | | return out; | } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 741ad1f3efd8..7aac2e5e6c1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -20,7 +20,7 @@ import java.util.Comparator import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, GeneratedExpressionCode} -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 72cc89c8be91..d71bbd63c8e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 58f6a7ec8a5f..5bd97cc7467a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GeneratedExpressionCode, CodeGenContext} -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -115,15 +115,23 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { nullSafeCodeGen(ctx, ev, eval => { - s""" - if ($eval.isNullAt($ordinal)) { - ${ev.isNull} = true; - } else { + if (nullable) { + s""" + if ($eval.isNullAt($ordinal)) { + ${ev.isNull} = true; + } else { + ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; + } + """ + } else { + s""" ${ev.value} = ${ctx.getValue(eval, dataType, ordinal.toString)}; - } - """ + """ + } }) } + + override def sql: String = child.sql + s".`${childSchema(ordinal).name}`" } /** @@ -139,7 +147,6 @@ case class GetArrayStructFields( containsNull: Boolean) extends UnaryExpression { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable || containsNull || field.nullable override def toString: String = s"$child.${field.name}" protected override def nullSafeEval(input: Any): Any = { @@ -222,7 +229,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" final int index = (int) $eval2; - if (index >= $eval1.numElements() || index < 0) { + if (index >= $eval1.numElements() || index < 0 || $eval1.isNullAt(index)) { ${ev.isNull} = true; } else { ${ev.value} = ${ctx.getValue(eval1, dataType, "index")}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 40b1eec63e55..19da849d2bec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.catalyst.util.{sequenceOption, TypeUtils} import org.apache.spark.sql.types._ @@ -74,6 +74,8 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } override def toString: String = s"if ($predicate) $trueValue else $falseValue" + + override def sql: String = s"(IF(${predicate.sql}, ${trueValue.sql}, ${falseValue.sql}))" } trait CaseWhenLike extends Expression { @@ -91,7 +93,9 @@ trait CaseWhenLike extends Expression { // both then and else expressions should be considered. def valueTypes: Seq[DataType] = (thenList ++ elseValue).map(_.dataType) - def valueTypesEqual: Boolean = valueTypes.distinct.size == 1 + def valueTypesEqual: Boolean = valueTypes.size <= 1 || valueTypes.sliding(2, 1).forall { + case Seq(dt1, dt2) => dt1.sameType(dt2) + } override def checkInputDataTypes(): TypeCheckResult = { if (valueTypesEqual) { @@ -108,7 +112,7 @@ trait CaseWhenLike extends Expression { override def nullable: Boolean = { // If no value is nullable and no elseValue is provided, the whole statement defaults to null. - thenList.exists(_.nullable) || (elseValue.map(_.nullable).getOrElse(true)) + thenList.exists(_.nullable) || elseValue.map(_.nullable).getOrElse(true) } } @@ -204,6 +208,23 @@ case class CaseWhen(branches: Seq[Expression]) extends CaseWhenLike { case Seq(elseValue) => s" ELSE $elseValue" }.mkString } + + override def sql: String = { + val branchesSQL = branches.map(_.sql) + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } + + val head = s"CASE " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") + + head + body + tail + } } // scalastyle:off @@ -308,6 +329,24 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW case Seq(elseValue) => s" ELSE $elseValue" }.mkString } + + override def sql: String = { + val keySQL = key.sql + val branchesSQL = branches.map(_.sql) + val (cases, maybeElse) = if (branches.length % 2 == 0) { + (branchesSQL, None) + } else { + (branchesSQL.init, Some(branchesSQL.last)) + } + + val head = s"CASE $keySQL " + val tail = maybeElse.map(e => s" ELSE $e").getOrElse("") + " END" + val body = cases.grouped(2).map { + case Seq(whenExpr, thenExpr) => s"WHEN $whenExpr THEN $thenExpr" + }.mkString(" ") + + head + body + tail + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 03c39f8404e7..17f1df06f2fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -20,15 +20,15 @@ package org.apache.spark.sql.catalyst.expressions import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} -import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, GeneratedExpressionCode} +import scala.util.Try + import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenContext, CodegenFallback, + GeneratedExpressionCode} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} -import scala.util.Try - /** * Returns the current date at the start of query evaluation. * All calls of current_date within the same query return the same value. @@ -44,6 +44,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { DateTimeUtils.millisToDays(System.currentTimeMillis()) } + + override def prettyName: String = "current_date" } /** @@ -61,6 +63,8 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def eval(input: InternalRow): Any = { System.currentTimeMillis() * 1000L } + + override def prettyName: String = "current_timestamp" } /** @@ -85,6 +89,8 @@ case class DateAdd(startDate: Expression, days: Expression) s"""${ev.value} = $sd + $d;""" }) } + + override def prettyName: String = "date_add" } /** @@ -108,6 +114,8 @@ case class DateSub(startDate: Expression, days: Expression) s"""${ev.value} = $sd - $d;""" }) } + + override def prettyName: String = "date_sub" } case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { @@ -309,6 +317,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix def this(time: Expression) = { this(time, Literal("yyyy-MM-dd HH:mm:ss")) } + + override def prettyName: String = "to_unix_timestamp" } /** @@ -332,6 +342,8 @@ case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTi def this() = { this(CurrentTimestamp()) } + + override def prettyName: String = "unix_timestamp" } abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { @@ -340,6 +352,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { Seq(TypeCollection(StringType, DateType, TimestampType), StringType) override def dataType: DataType = LongType + override def nullable: Boolean = true private lazy val constFormat: UTF8String = right.eval().asInstanceOf[UTF8String] @@ -436,6 +449,8 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { """ } } + + override def prettyName: String = "unix_time" } /** @@ -450,11 +465,14 @@ case class FromUnixTime(sec: Expression, format: Expression) override def left: Expression = sec override def right: Expression = format + override def prettyName: String = "from_unixtime" + def this(unix: Expression) = { this(unix, Literal("yyyy-MM-dd HH:mm:ss")) } override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(LongType, StringType) @@ -561,6 +579,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) override def dataType: DataType = DateType + override def nullable: Boolean = true override def nullSafeEval(start: Any, dayOfW: Any): Any = { val dow = DateTimeUtils.getDayOfWeekFromString(dayOfW.asInstanceOf[UTF8String]) @@ -730,6 +749,8 @@ case class AddMonths(startDate: Expression, numMonths: Expression) s"""$dtu.dateAddMonths($sd, $m)""" }) } + + override def prettyName: String = "add_months" } /** @@ -755,6 +776,8 @@ case class MonthsBetween(date1: Expression, date2: Expression) s"""$dtu.monthsBetween($l, $r)""" }) } + + override def prettyName: String = "months_between" } /** @@ -820,6 +843,8 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, d => d) } + + override def prettyName: String = "to_date" } /** @@ -832,6 +857,7 @@ case class TruncDate(date: Expression, format: Expression) override def inputTypes: Seq[AbstractDataType] = Seq(DateType, StringType) override def dataType: DataType = DateType + override def nullable: Boolean = true override def prettyName: String = "trunc" private lazy val truncLevel: Int = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 78f6631e4647..5f8b544edb51 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -47,6 +47,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends UnaryExpression { override def dataType: DataType = DecimalType(precision, scale) + override def nullable: Boolean = true override def toString: String = s"MakeDecimal($child,$precision,$scale)" protected override def nullSafeEval(input: Any): Any = @@ -72,6 +73,7 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def gen(ctx: CodeGenContext): GeneratedExpressionCode = child.gen(ctx) override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = "" override def prettyName: String = "promote_precision" + override def sql: String = child.sql } /** @@ -106,4 +108,6 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } override def toString: String = s"CheckOverflow($child, $dataType)" + + override def sql: String = child.sql } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 894a0730d1c2..e7ef21aa8589 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala new file mode 100644 index 000000000000..2997ee879d47 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types._ + +/** + * A placeholder expression for cube/rollup, which will be replaced by analyzer + */ +trait GroupingSet extends Expression with CodegenFallback { + + def groupByExprs: Seq[Expression] + override def children: Seq[Expression] = groupByExprs + + // this should be replaced first + override lazy val resolved: Boolean = false + + override def dataType: DataType = throw new UnsupportedOperationException + override def foldable: Boolean = false + override def nullable: Boolean = true + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException +} + +case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {} + +case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 4991b9cb54e5..72b323587c63 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -17,18 +17,19 @@ package org.apache.spark.sql.catalyst.expressions -import java.io.{StringWriter, ByteArrayOutputStream} +import java.io.{ByteArrayOutputStream, StringWriter} + +import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{StructField, StructType, StringType, DataType} +import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -import scala.util.parsing.combinator.RegexParsers - private[this] sealed trait PathInstruction private[this] object PathInstruction { private[expressions] case object Subscript extends PathInstruction @@ -108,15 +109,17 @@ private[this] object SharedFactory { case class GetJsonObject(json: Expression, path: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { - import SharedFactory._ + import com.fasterxml.jackson.core.JsonToken._ + import PathInstruction._ + import SharedFactory._ import WriteStyle._ - import com.fasterxml.jackson.core.JsonToken._ override def left: Expression = json override def right: Expression = path override def inputTypes: Seq[DataType] = Seq(StringType, StringType) override def dataType: DataType = StringType + override def nullable: Boolean = true override def prettyName: String = "get_json_object" @transient private lazy val parsedPath = parsePath(path.eval().asInstanceOf[UTF8String]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index 68ec688c99f9..0eb915fdc169 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import java.sql.{Date, Timestamp} -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.json4s.JsonAST._ + import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types._ @@ -55,6 +57,34 @@ object Literal { */ def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def fromJSON(json: JValue): Literal = { + val dataType = DataType.parseDataType(json \ "dataType") + json \ "value" match { + case JNull => Literal.create(null, dataType) + case JString(str) => + val value = dataType match { + case BooleanType => str.toBoolean + case ByteType => str.toByte + case ShortType => str.toShort + case IntegerType => str.toInt + case LongType => str.toLong + case FloatType => str.toFloat + case DoubleType => str.toDouble + case StringType => UTF8String.fromString(str) + case DateType => java.sql.Date.valueOf(str) + case TimestampType => java.sql.Timestamp.valueOf(str) + case CalendarIntervalType => CalendarInterval.fromString(str) + case t: DecimalType => + val d = Decimal(str) + assert(d.changePrecision(t.precision, t.scale)) + d + case _ => null + } + Literal.create(value, dataType) + case other => sys.error(s"$other is not a valid Literal json value") + } + } + def create(v: Any, dataType: DataType): Literal = { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } @@ -123,6 +153,18 @@ case class Literal protected (value: Any, dataType: DataType) case _ => false } + override protected def jsonFields: List[JField] = { + // Turns all kinds of literal values to string in json field, as the type info is hard to + // retain in json format, e.g. {"a": 123} can be a int, or double, or decimal, etc. + val jsonValue = (value, dataType) match { + case (null, _) => JNull + case (i: Int, DateType) => JString(DateTimeUtils.toJavaDate(i).toString) + case (l: Long, TimestampType) => JString(DateTimeUtils.toJavaTimestamp(l).toString) + case (other, _) => JString(other.toString) + } + ("value" -> jsonValue) :: ("dataType" -> dataType.jsonValue) :: Nil + } + override def eval(input: InternalRow): Any = value override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { @@ -172,6 +214,41 @@ case class Literal protected (value: Any, dataType: DataType) } } } + + override def sql: String = (value, dataType) match { + case (_, NullType | _: ArrayType | _: MapType | _: StructType) if value == null => + "NULL" + + case _ if value == null => + s"CAST(NULL AS ${dataType.sql})" + + case (v: UTF8String, StringType) => + // Escapes all backslashes and double quotes. + "\"" + v.toString.replace("\\", "\\\\").replace("\"", "\\\"") + "\"" + + case (v: Byte, ByteType) => + s"CAST($v AS ${ByteType.simpleString.toUpperCase})" + + case (v: Short, ShortType) => + s"CAST($v AS ${ShortType.simpleString.toUpperCase})" + + case (v: Long, LongType) => + s"CAST($v AS ${LongType.simpleString.toUpperCase})" + + case (v: Float, FloatType) => + s"CAST($v AS ${FloatType.simpleString.toUpperCase})" + + case (v: Decimal, DecimalType.Fixed(precision, scale)) => + s"CAST($v AS ${DecimalType.simpleString.toUpperCase}($precision, $scale))" + + case (v: Int, DateType) => + s"DATE '${DateTimeUtils.toJavaDate(v)}'" + + case (v: Long, TimestampType) => + s"TIMESTAMP('${DateTimeUtils.toJavaTimestamp(v)}')" + + case _ => value.toString + } } // TODO: Specialize diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 28f616fbb9ca..66d8631a846a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import java.{lang => jl} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult -import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckSuccess, TypeCheckFailure} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.NumberConverter @@ -70,11 +70,15 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } + + override def sql: String = s"$name(${child.sql})" } abstract class UnaryLogExpression(f: Double => Double, name: String) extends UnaryMathExpression(f, name) { + override def nullable: Boolean = true + // values less than or equal to yAsymptote eval to null in Hive, instead of NaN or -Infinity protected val yAsymptote: Double = 0.0 @@ -194,6 +198,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) override def dataType: DataType = StringType + override def nullable: Boolean = true override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { NumberConverter.convert( @@ -621,6 +626,8 @@ case class Logarithm(left: Expression, right: Expression) this(EulerNumber(), child) } + override def nullable: Boolean = true + protected override def nullSafeEval(input1: Any, input2: Any): Any = { val dLeft = input1.asInstanceOf[Double] val dRight = input2.asInstanceOf[Double] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 0f6d02f2e00c..cc406a39f040 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -22,6 +22,8 @@ import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -30,6 +32,9 @@ import org.apache.spark.unsafe.types.UTF8String * A function that calculates an MD5 128-bit checksum and returns it as a hex string * For input of type [[BinaryType]] */ +@ExpressionDescription( + usage = "_FUNC_(input) - Returns an MD5 128-bit checksum as a hex string of the input", + extended = "> SELECT _FUNC_('Spark');\n '8cde774d6f7333752ed72cacddb05126'") case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -53,10 +58,18 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput * asking for an unsupported SHA function, the return value is NULL. If either argument is NULL or * the hash length is not one of the permitted values, the return value is NULL. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(input, bitLength) - Returns a checksum of SHA-2 family as a hex string of the input. + SHA-224, SHA-256, SHA-384, and SHA-512 are supported. Bit length of 0 is equivalent to 256.""", + extended = """> SELECT _FUNC_('Spark', 0); + '529bc3b07127ecb7e53a4dcf1991d9152c24537d919178022b2c42657f79a26b'""") +// scalastyle:on line.size.limit case class Sha2(left: Expression, right: Expression) extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[DataType] = Seq(BinaryType, IntegerType) @@ -117,6 +130,9 @@ case class Sha2(left: Expression, right: Expression) * A function that calculates a sha1 hash value and returns it as a hex string * For input of type [[BinaryType]] or [[StringType]] */ +@ExpressionDescription( + usage = "_FUNC_(input) - Returns a sha1 hash value as a hex string of the input", + extended = "> SELECT _FUNC_('Spark');\n '85f5955f4b27a9a4c2aab6ffe5d7189fc298b92c'") case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -137,6 +153,9 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu * A function that computes a cyclic redundancy check value and returns it as a bigint * For input of type [[BinaryType]] */ +@ExpressionDescription( + usage = "_FUNC_(input) - Returns a cyclic redundancy check value as a bigint of the input", + extended = "> SELECT _FUNC_('Spark');\n '1557323817'") case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = LongType @@ -160,3 +179,49 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp }) } } + +/** + * A function that calculates hash value for a group of expressions. Note that the `seed` argument + * is not exposed to users and should only be set inside spark SQL. + * + * Internally this function will write arguments into an [[UnsafeRow]], and calculate hash code of + * the unsafe row using murmur3 hasher with a seed. + * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle + * and bucketing have same data distribution. + */ +case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression { + def this(arguments: Seq[Expression]) = this(arguments, 42) + + override def dataType: DataType = IntegerType + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = false + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckFailure("function hash requires at least one argument") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + private lazy val unsafeProjection = UnsafeProjection.create(children) + + override def eval(input: InternalRow): Any = { + unsafeProjection(input).hashCode(seed) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children) + ev.isNull = "false" + s""" + ${unsafeRow.code} + final int ${ev.value} = ${unsafeRow.value}.hashCode($seed); + """ + } + + override def prettyName: String = "hash" + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 26b6aca79971..eee708cb02f9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -164,6 +164,12 @@ case class Alias(child: Expression, name: String)( explicitMetadata == a.explicitMetadata case _ => false } + + override def sql: String = { + val qualifiersString = + if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + s"${child.sql} AS $qualifiersString`$name`" + } } /** @@ -262,11 +268,21 @@ case class AttributeReference( } } + override protected final def otherCopyArgs: Seq[AnyRef] = { + exprId :: qualifiers :: Nil + } + override def toString: String = s"$name#${exprId.id}$typeSuffix" // Since the expression id is not in the first constructor it is missing from the default // tree string. override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}" + + override def sql: String = { + val qualifiersString = + if (qualifiers.isEmpty) "" else qualifiers.map("`" + _ + "`").mkString("", ".", ".") + s"$qualifiersString`$name`" + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index df4747d4e6f7..89aec2b20fd0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -83,6 +83,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { """ }.mkString("\n") } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -193,6 +195,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { ev.value = eval.isNull eval.code } + + override def sql: String = s"(${child.sql} IS NULL)" } @@ -212,6 +216,8 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { ev.value = s"(!(${eval.isNull}))" eval.code } + + override def sql: String = s"(${child.sql} IS NOT NULL)" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 10ec75eca37f..c0c3e6e89166 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -42,16 +42,14 @@ import org.apache.spark.sql.types._ * of calling the function. */ case class StaticInvoke( - staticObject: Any, + staticObject: Class[_], dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, propagateNull: Boolean = true) extends Expression { - val objectName = staticObject match { - case c: Class[_] => c.getName - case other => other.getClass.getName.stripSuffix("$") - } + val objectName = staticObject.getName.stripSuffix("$") + override def nullable: Boolean = true override def children: Seq[Expression] = arguments @@ -167,7 +165,7 @@ case class Invoke( ${obj.code} ${argGen.map(_.code).mkString("\n")} - boolean ${ev.isNull} = ${obj.value} == null; + boolean ${ev.isNull} = ${obj.isNull}; $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value; @@ -180,8 +178,8 @@ object NewInstance { def apply( cls: Class[_], arguments: Seq[Expression], - propagateNull: Boolean = false, - dataType: DataType): NewInstance = + dataType: DataType, + propagateNull: Boolean = true): NewInstance = new NewInstance(cls, arguments, propagateNull, dataType, None) } @@ -233,7 +231,7 @@ case class NewInstance( s"new $className($argString)" } - if (propagateNull) { + if (propagateNull && argGen.nonEmpty) { val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" s""" @@ -250,8 +248,8 @@ case class NewInstance( s""" $setup - $javaType ${ev.value} = $constructorCall; - final boolean ${ev.isNull} = ${ev.value} == null; + final $javaType ${ev.value} = $constructorCall; + final boolean ${ev.isNull} = false; """ } } @@ -458,10 +456,10 @@ case class MapObjects( ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck - if (${loopVar.isNull}) { + ${genFunction.code} + if (${genFunction.isNull}) { $convertedArray[$loopIndex] = null; } else { - ${genFunction.code} $convertedArray[$loopIndex] = ${genFunction.value}; } @@ -626,3 +624,43 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp """ } } + +/** + * Asserts that input values of a non-nullable child expression are not null. + * + * Note that there are cases where `child.nullable == true`, while we still needs to add this + * assertion. Consider a nullable column `s` whose data type is a struct containing a non-nullable + * `Int` field named `i`. Expression `s.i` is nullable because `s` can be null. However, for all + * non-null `s`, `s.i` can't be null. + */ +case class AssertNotNull( + child: Expression, parentType: String, fieldName: String, fieldType: String) + extends UnaryExpression { + + override def dataType: DataType = child.dataType + + override def nullable: Boolean = false + + override def eval(input: InternalRow): Any = + throw new UnsupportedOperationException("Only code-generated evaluation is supported.") + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val childGen = child.gen(ctx) + + ev.isNull = "false" + ev.value = childGen.value + + s""" + ${childGen.code} + + if (${childGen.isNull}) { + throw new RuntimeException( + "Null value appeared in non-nullable field $parentType.$fieldName of type $fieldType. " + + "If the schema is inferred from a Scala tuple/case class, or a Java bean, " + + "please try to use scala.Option[_] or other nullable types " + + "(e.g. java.lang.Integer instead of int/scala.Int)." + ); + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 304b438c84ba..bca12a8d2102 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -101,6 +101,8 @@ case class Not(child: Expression) override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { defineCodeGen(ctx, ev, c => s"!($c)") } + + override def sql: String = s"(NOT ${child.sql})" } @@ -176,6 +178,13 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } """ } + + override def sql: String = { + val childrenSQL = children.map(_.sql) + val valueSQL = childrenSQL.head + val listSQL = childrenSQL.tail.mkString(", ") + s"($valueSQL IN ($listSQL))" + } } /** @@ -226,6 +235,12 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } """ } + + override def sql: String = { + val valueSQL = child.sql + val listSQL = hset.toSeq.map(Literal(_).sql).mkString(", ") + s"($valueSQL IN ($listSQL))" + } } case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { @@ -274,6 +289,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } """ } + + override def sql: String = s"(${left.sql} AND ${right.sql})" } @@ -323,6 +340,8 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } """ } + + override def sql: String = s"(${left.sql} OR ${right.sql})" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 8bde8cb9fe87..8de47e9ddc28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -49,6 +49,9 @@ abstract class RDG extends LeafExpression with Nondeterministic { override def nullable: Boolean = false override def dataType: DataType = DoubleType + + // NOTE: Even if the user doesn't provide a seed, Spark SQL adds a default seed. + override def sql: String = s"$prettyName($seed)" } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index adef6050c356..db266639b856 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -59,6 +59,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes { matches(regex, input1.asInstanceOf[UTF8String].toString) } } + + override def sql: String = s"${left.sql} ${prettyName.toUpperCase} ${right.sql}" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index cfc68fc00bea..387d979484f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} @@ -199,9 +199,9 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { override def get(i: Int): Any = values(i) - override def toSeq: Seq[Any] = values.toSeq + override def toSeq: Seq[Any] = values.clone() - override def copy(): Row = this + override def copy(): GenericRow = this } class GenericRowWithSchema(values: Array[Any], override val schema: StructType) @@ -226,11 +226,11 @@ class GenericInternalRow(private[sql] val values: Array[Any]) extends BaseGeneri override protected def genericGet(ordinal: Int) = values(ordinal) - override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values + override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = values.clone() override def numFields: Int = values.length - override def copy(): InternalRow = new GenericInternalRow(values.clone()) + override def copy(): GenericInternalRow = this } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8770c4b76c2e..931f752b4dc1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -23,6 +23,7 @@ import java.util.{HashMap, Locale, Map => JMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.sequenceOption import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} @@ -61,6 +62,8 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas } """ } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } @@ -153,6 +156,8 @@ case class ConcatWs(children: Seq[Expression]) """ } } + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")})" } trait String2StringExpression extends ImplicitCastInputTypes { @@ -292,24 +297,24 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac val termDict = ctx.freshName("dict") val classNameDict = classOf[JMap[Character, Character]].getCanonicalName - ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;") - ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;") - ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;") + ctx.addMutableState("UTF8String", termLastMatching, s"$termLastMatching = null;") + ctx.addMutableState("UTF8String", termLastReplace, s"$termLastReplace = null;") + ctx.addMutableState(classNameDict, termDict, s"$termDict = null;") nullSafeCodeGen(ctx, ev, (src, matching, replace) => { val check = if (matchingExpr.foldable && replaceExpr.foldable) { - s"${termDict} == null" + s"$termDict == null" } else { - s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})" + s"!$matching.equals($termLastMatching) || !$replace.equals($termLastReplace)" } s"""if ($check) { // Not all of them is literal or matching or replace value changed - ${termLastMatching} = ${matching}.clone(); - ${termLastReplace} = ${replace}.clone(); - ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate - .buildDict(${termLastMatching}, ${termLastReplace}); + $termLastMatching = $matching.clone(); + $termLastReplace = $replace.clone(); + $termDict = org.apache.spark.sql.catalyst.expressions.StringTranslate + .buildDict($termLastMatching, $termLastReplace); } - ${ev.value} = ${src}.translate(${termDict}); + ${ev.value} = $src.translate($termDict); """ }) } @@ -340,6 +345,8 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi } override def dataType: DataType = IntegerType + + override def prettyName: String = "find_in_set" } /** @@ -832,7 +839,6 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn org.apache.commons.codec.binary.Base64.encodeBase64($child)); """}) } - } /** @@ -924,6 +930,7 @@ case class FormatNumber(x: Expression, d: Expression) override def left: Expression = x override def right: Expression = d override def dataType: DataType = StringType + override def nullable: Boolean = true override def inputTypes: Seq[AbstractDataType] = Seq(NumericType, IntegerType) // Associated with the pattern, for the last d value, and we will update the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index 1680aa8252ec..afe122f6a0e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedException -import org.apache.spark.sql.types.{DataType, NumericType} +import org.apache.spark.sql.catalyst.expressions.aggregate.{DeclarativeAggregate, NoOp} +import org.apache.spark.sql.types._ /** * The trait of the Window Specification (specified in the OVER clause or WINDOW clause) for @@ -117,6 +118,19 @@ sealed trait FrameBoundary { def notFollows(other: FrameBoundary): Boolean } +/** + * Extractor for making working with frame boundaries easier. + */ +object FrameBoundary { + def apply(boundary: FrameBoundary): Option[Int] = unapply(boundary) + def unapply(boundary: FrameBoundary): Option[Int] = boundary match { + case CurrentRow => Some(0) + case ValuePreceding(offset) => Some(-offset) + case ValueFollowing(offset) => Some(offset) + case _ => None + } +} + /** UNBOUNDED PRECEDING boundary. */ case object UnboundedPreceding extends FrameBoundary { def notFollows(other: FrameBoundary): Boolean = other match { @@ -243,85 +257,405 @@ object SpecifiedWindowFrame { } } +case class UnresolvedWindowExpression( + child: Expression, + windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { + + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} + +case class WindowExpression( + windowFunction: Expression, + windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { + + override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil + + override def dataType: DataType = windowFunction.dataType + override def foldable: Boolean = windowFunction.foldable + override def nullable: Boolean = windowFunction.nullable + + override def toString: String = s"$windowFunction $windowSpec" +} + /** - * Every window function needs to maintain a output buffer for its output. - * It should expect that for a n-row window frame, it will be called n times - * to retrieve value corresponding with these n rows. + * A window function is a function that can only be evaluated in the context of a window operator. */ trait WindowFunction extends Expression { - def init(): Unit + /** Frame in which the window operator must be executed. */ + def frame: WindowFrame = UnspecifiedFrame +} - def reset(): Unit +/** + * An offset window function is a window function that returns the value of the input column offset + * by a number of rows within the partition. For instance: an OffsetWindowfunction for value x with + * offset -2, will get the value of x 2 rows back in the partition. + */ +abstract class OffsetWindowFunction + extends Expression with WindowFunction with Unevaluable with ImplicitCastInputTypes { + /** + * Input expression to evaluate against a row which a number of rows below or above (depending on + * the value and sign of the offset) the current row. + */ + val input: Expression + + /** + * Default result value for the function when the input expression returns NULL. The default will + * evaluated against the current row instead of the offset row. + */ + val default: Expression + + /** + * (Foldable) expression that contains the number of rows between the current row and the row + * where the input expression is evaluated. + */ + val offset: Expression + + /** + * Direction of the number of rows between the current row and the row where the input expression + * is evaluated. + */ + val direction: SortDirection - def prepareInputParameters(input: InternalRow): AnyRef + override def children: Seq[Expression] = Seq(input, offset, default) - def update(input: AnyRef): Unit + /* + * The result of an OffsetWindowFunction is dependent on the frame in which the + * OffsetWindowFunction is executed, the input expression and the default expression. Even when + * both the input and the default expression are foldable, the result is still not foldable due to + * the frame. + */ + override def foldable: Boolean = false - def batchUpdate(inputs: Array[AnyRef]): Unit + override def nullable: Boolean = default == null || default.nullable - def evaluate(): Unit + override lazy val frame = { + // This will be triggered by the Analyzer. + val offsetValue = offset.eval() match { + case o: Int => o + case x => throw new AnalysisException( + s"Offset expression must be a foldable integer expression: $x") + } + val boundary = direction match { + case Ascending => ValueFollowing(offsetValue) + case Descending => ValuePreceding(offsetValue) + } + SpecifiedWindowFrame(RowFrame, boundary, boundary) + } + + override def dataType: DataType = input.dataType - def get(index: Int): Any + override def inputTypes: Seq[AbstractDataType] = + Seq(AnyDataType, IntegerType, TypeCollection(input.dataType, NullType)) - def newInstance(): WindowFunction + override def toString: String = s"$prettyName($input, $offset, $default)" } -case class UnresolvedWindowFunction( - name: String, - children: Seq[Expression]) - extends Expression with WindowFunction with Unevaluable { +/** + * The Lead function returns the value of 'x' at 'offset' rows after the current row in the window. + * Offsets start at 0, which is the current row. The offset must be constant integer value. The + * default offset is 1. When the value of 'x' is null at the offset, or when the offset is larger + * than the window, the default expression is evaluated. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param input expression to evaluate 'offset' rows after the current row. + * @param offset rows to jump ahead in the partition. + * @param default to use when the input value is null or when the offset is larger than the window. + */ +@ExpressionDescription(usage = + """_FUNC_(input, offset, default) - LEAD returns the value of 'x' at 'offset' rows + after the current row in the window""") +case class Lead(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) - override def init(): Unit = throw new UnresolvedException(this, "init") - override def reset(): Unit = throw new UnresolvedException(this, "reset") - override def prepareInputParameters(input: InternalRow): AnyRef = - throw new UnresolvedException(this, "prepareInputParameters") - override def update(input: AnyRef): Unit = throw new UnresolvedException(this, "update") - override def batchUpdate(inputs: Array[AnyRef]): Unit = - throw new UnresolvedException(this, "batchUpdate") - override def evaluate(): Unit = throw new UnresolvedException(this, "evaluate") - override def get(index: Int): Any = throw new UnresolvedException(this, "get") + def this(input: Expression) = this(input, Literal(1)) - override def toString: String = s"'$name(${children.mkString(",")})" + def this() = this(Literal(null)) - override def newInstance(): WindowFunction = throw new UnresolvedException(this, "newInstance") + override val direction = Ascending } -case class UnresolvedWindowExpression( - child: UnresolvedWindowFunction, - windowSpec: WindowSpecReference) extends UnaryExpression with Unevaluable { +/** + * The Lag function returns the value of 'x' at 'offset' rows before the current row in the window. + * Offsets start at 0, which is the current row. The offset must be constant integer value. The + * default offset is 1. When the value of 'x' is null at the offset, or when the offset is smaller + * than the window, the default expression is evaluated. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param input expression to evaluate 'offset' rows before the current row. + * @param offset rows to jump back in the partition. + * @param default to use when the input value is null or when the offset is smaller than the window. + */ +@ExpressionDescription(usage = + """_FUNC_(input, offset, default) - LAG returns the value of 'x' at 'offset' rows + before the current row in the window""") +case class Lag(input: Expression, offset: Expression, default: Expression) + extends OffsetWindowFunction { - override def dataType: DataType = throw new UnresolvedException(this, "dataType") - override def foldable: Boolean = throw new UnresolvedException(this, "foldable") - override def nullable: Boolean = throw new UnresolvedException(this, "nullable") - override lazy val resolved = false + def this(input: Expression, offset: Expression) = this(input, offset, Literal(null)) + + def this(input: Expression) = this(input, Literal(1)) + + def this() = this(Literal(null)) + + override val direction = Descending } -case class WindowExpression( - windowFunction: WindowFunction, - windowSpec: WindowSpecDefinition) extends Expression with Unevaluable { +abstract class AggregateWindowFunction extends DeclarativeAggregate with WindowFunction { + self: Product => + override val frame = SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow) + override def dataType: DataType = IntegerType + override def nullable: Boolean = true + override def supportsPartial: Boolean = false + override lazy val mergeExpressions = + throw new UnsupportedOperationException("Window Functions do not support merging.") +} - override def children: Seq[Expression] = windowFunction :: windowSpec :: Nil +abstract class RowNumberLike extends AggregateWindowFunction { + override def children: Seq[Expression] = Nil + override def inputTypes: Seq[AbstractDataType] = Nil + protected val zero = Literal(0) + protected val one = Literal(1) + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + override val aggBufferAttributes: Seq[AttributeReference] = rowNumber :: Nil + override val initialValues: Seq[Expression] = zero :: Nil + override val updateExpressions: Seq[Expression] = Add(rowNumber, one) :: Nil +} - override def dataType: DataType = windowFunction.dataType - override def foldable: Boolean = windowFunction.foldable - override def nullable: Boolean = windowFunction.nullable +/** + * A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation. + */ +trait SizeBasedWindowFunction extends AggregateWindowFunction { + protected def n: AttributeReference = SizeBasedWindowFunction.n +} - override def toString: String = s"$windowFunction $windowSpec" +object SizeBasedWindowFunction { + val n = AttributeReference("window__partition__size", IntegerType, nullable = false)() } /** - * Extractor for making working with frame boundaries easier. + * The RowNumber function computes a unique, sequential number to each row, starting with one, + * according to the ordering of rows within the window partition. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. */ -object FrameBoundaryExtractor { - def unapply(boundary: FrameBoundary): Option[Int] = boundary match { - case CurrentRow => Some(0) - case ValuePreceding(offset) => Some(-offset) - case ValueFollowing(offset) => Some(offset) - case _ => None +@ExpressionDescription(usage = + """_FUNC_() - The ROW_NUMBER() function assigns a unique, sequential number to + each row, starting with one, according to the ordering of rows within + the window partition.""") +case class RowNumber() extends RowNumberLike { + override val evaluateExpression = rowNumber +} + +/** + * The CumeDist function computes the position of a value relative to a all values in the partition. + * The result is the number of rows preceding or equal to the current row in the ordering of the + * partition divided by the total number of rows in the window partition. Any tie values in the + * ordering will evaluate to the same position. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + */ +@ExpressionDescription(usage = + """_FUNC_() - The CUME_DIST() function computes the position of a value relative to + a all values in the partition.""") +case class CumeDist() extends RowNumberLike with SizeBasedWindowFunction { + override def dataType: DataType = DoubleType + // The frame for CUME_DIST is Range based instead of Row based, because CUME_DIST must + // return the same value for equal values in the partition. + override val frame = SpecifiedWindowFrame(RangeFrame, UnboundedPreceding, CurrentRow) + override val evaluateExpression = Divide(Cast(rowNumber, DoubleType), Cast(n, DoubleType)) +} + +/** + * The NTile function divides the rows for each window partition into 'n' buckets ranging from 1 to + * at most 'n'. Bucket values will differ by at most 1. If the number of rows in the partition does + * not divide evenly into the number of buckets, then the remainder values are distributed one per + * bucket, starting with the first bucket. + * + * The NTile function is particularly useful for the calculation of tertiles, quartiles, deciles and + * other common summary statistics + * + * The function calculates two variables during initialization: The size of a regular bucket, and + * the number of buckets that will have one extra row added to it (when the rows do not evenly fit + * into the number of buckets); both variables are based on the size of the current partition. + * During the calculation process the function keeps track of the current row number, the current + * bucket number, and the row number at which the bucket will change (bucketThreshold). When the + * current row number reaches bucket threshold, the bucket value is increased by one and the the + * threshold is increased by the bucket size (plus one extra if the current bucket is padded). + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param buckets number of buckets to divide the rows in. Default value is 1. + */ +@ExpressionDescription(usage = + """_FUNC_(x) - The NTILE(n) function divides the rows for each window partition + into 'n' buckets ranging from 1 to at most 'n'.""") +case class NTile(buckets: Expression) extends RowNumberLike with SizeBasedWindowFunction { + def this() = this(Literal(1)) + + // Validate buckets. Note that this could be relaxed, the bucket value only needs to constant + // for each partition. + buckets.eval() match { + case b: Int if b > 0 => // Ok + case x => throw new AnalysisException( + "Buckets expression must be a foldable positive integer expression: $x") } + + private val bucket = AttributeReference("bucket", IntegerType, nullable = false)() + private val bucketThreshold = + AttributeReference("bucketThreshold", IntegerType, nullable = false)() + private val bucketSize = AttributeReference("bucketSize", IntegerType, nullable = false)() + private val bucketsWithPadding = + AttributeReference("bucketsWithPadding", IntegerType, nullable = false)() + private def bucketOverflow(e: Expression) = + If(GreaterThanOrEqual(rowNumber, bucketThreshold), e, zero) + + override val aggBufferAttributes = Seq( + rowNumber, + bucket, + bucketThreshold, + bucketSize, + bucketsWithPadding + ) + + override val initialValues = Seq( + zero, + zero, + zero, + Cast(Divide(n, buckets), IntegerType), + Cast(Remainder(n, buckets), IntegerType) + ) + + override val updateExpressions = Seq( + Add(rowNumber, one), + Add(bucket, bucketOverflow(one)), + Add(bucketThreshold, bucketOverflow( + Add(bucketSize, If(LessThan(bucket, bucketsWithPadding), one, zero)))), + NoOp, + NoOp + ) + + override val evaluateExpression = bucket +} + +/** + * A RankLike function is a WindowFunction that changes its value based on a change in the value of + * the order of the window in which is processed. For instance, when the value of 'x' changes in a + * window ordered by 'x' the rank function also changes. The size of the change of the rank function + * is (typically) not dependent on the size of the change in 'x'. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + */ +abstract class RankLike extends AggregateWindowFunction { + override def inputTypes: Seq[AbstractDataType] = children.map(_ => AnyDataType) + + /** Store the values of the window 'order' expressions. */ + protected val orderAttrs = children.map{ expr => + AttributeReference(expr.prettyString, expr.dataType)() + } + + /** Predicate that detects if the order attributes have changed. */ + protected val orderEquals = children.zip(orderAttrs) + .map(EqualNullSafe.tupled) + .reduceOption(And) + .getOrElse(Literal(true)) + + protected val orderInit = children.map(e => Literal.create(null, e.dataType)) + protected val rank = AttributeReference("rank", IntegerType, nullable = false)() + protected val rowNumber = AttributeReference("rowNumber", IntegerType, nullable = false)() + protected val zero = Literal(0) + protected val one = Literal(1) + protected val increaseRowNumber = Add(rowNumber, one) + + /** + * Different RankLike implementations use different source expressions to update their rank value. + * Rank for instance uses the number of rows seen, whereas DenseRank uses the number of changes. + */ + protected def rankSource: Expression = rowNumber + + /** Increase the rank when the current rank == 0 or when the one of order attributes changes. */ + protected val increaseRank = If(And(orderEquals, Not(EqualTo(rank, zero))), rank, rankSource) + + override val aggBufferAttributes: Seq[AttributeReference] = rank +: rowNumber +: orderAttrs + override val initialValues = zero +: one +: orderInit + override val updateExpressions = increaseRank +: increaseRowNumber +: children + override val evaluateExpression: Expression = rank + + def withOrder(order: Seq[Expression]): RankLike +} + +/** + * The Rank function computes the rank of a value in a group of values. The result is one plus the + * number of rows preceding or equal to the current row in the ordering of the partition. Tie values + * will produce gaps in the sequence. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - RANK() computes the rank of a value in a group of values. The result + is one plus the number of rows preceding or equal to the current row in the + ordering of the partition. Tie values will produce gaps in the sequence.""") +case class Rank(children: Seq[Expression]) extends RankLike { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): Rank = Rank(order) +} + +/** + * The DenseRank function computes the rank of a value in a group of values. The result is one plus + * the previously assigned rank value. Unlike Rank, DenseRank will not produce gaps in the ranking + * sequence. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - The DENSE_RANK() function computes the rank of a value in a group of + values. The result is one plus the previously assigned rank value. Unlike Rank, + DenseRank will not produce gaps in the ranking sequence.""") +case class DenseRank(children: Seq[Expression]) extends RankLike { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): DenseRank = DenseRank(order) + override protected def rankSource = Add(rank, one) + override val updateExpressions = increaseRank +: children + override val aggBufferAttributes = rank +: orderAttrs + override val initialValues = zero +: orderInit +} + +/** + * The PercentRank function computes the percentage ranking of a value in a group of values. The + * result the rank of the minus one divided by the total number of rows in the partitiion minus one: + * (r - 1) / (n - 1). If a partition only contains one row, the function will return 0. + * + * The PercentRank function is similar to the CumeDist function, but it uses rank values instead of + * row counts in the its numerator. + * + * This documentation has been based upon similar documentation for the Hive and Presto projects. + * + * @param children to base the rank on; a change in the value of one the children will trigger a + * change in rank. This is an internal parameter and will be assigned by the + * Analyser. + */ +@ExpressionDescription(usage = + """_FUNC_() - PERCENT_RANK() The PercentRank function computes the percentage + ranking of a value in a group of values.""") +case class PercentRank(children: Seq[Expression]) extends RankLike with SizeBasedWindowFunction { + def this() = this(Nil) + override def withOrder(order: Seq[Expression]): PercentRank = PercentRank(order) + override def dataType: DataType = DoubleType + override val evaluateExpression = If(GreaterThan(n, one), + Divide(Cast(Subtract(rank, one), DoubleType), Cast(Subtract(n, one), DoubleType)), + Literal(0.0d)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f6088695a927..f8121a733a8d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -28,13 +28,17 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ -abstract class Optimizer extends RuleExecutor[LogicalPlan] - -object DefaultOptimizer extends Optimizer { - val batches = +/** + * Abstract class all optimizers should inherit of, contains the standard batches (extending + * Optimizers can override this. + */ +abstract class Optimizer extends RuleExecutor[LogicalPlan] { + def batches: Seq[Batch] = { // SubQueries are only needed for analysis and can be removed before execution. Batch("Remove SubQueries", FixedPoint(100), EliminateSubQueries) :: + Batch("Compute Current Time", Once, + ComputeCurrentTime) :: Batch("Aggregate", FixedPoint(100), ReplaceDistinctWithAggregate, RemoveLiteralFromGroupExpressions) :: @@ -66,8 +70,17 @@ object DefaultOptimizer extends Optimizer { DecimalAggregates) :: Batch("LocalRelation", FixedPoint(100), ConvertToLocalRelation) :: Nil + } } +/** + * Non-abstract representation of the standard Spark optimizing strategies + * + * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while + * specific rules go to the subclasses + */ +object DefaultOptimizer extends Optimizer + /** * Pushes operations down into a Sample. */ @@ -322,6 +335,39 @@ object ProjectCollapsing extends Rule[LogicalPlan] { ) Project(cleanedProjection, child) } + + // TODO Eliminate duplicate code + // This clause is identical to the one above except that the inner operator is an `Aggregate` + // rather than a `Project`. + case p @ Project(projectList1, agg @ Aggregate(_, projectList2, child)) => + // Create a map of Aliases to their values from the child projection. + // e.g., 'SELECT ... FROM (SELECT a + b AS c, d ...)' produces Map(c -> Alias(a + b, c)). + val aliasMap = AttributeMap(projectList2.collect { + case a: Alias => (a.toAttribute, a) + }) + + // We only collapse these two Projects if their overlapped expressions are all + // deterministic. + val hasNondeterministic = projectList1.exists(_.collect { + case a: Attribute if aliasMap.contains(a) => aliasMap(a).child + }.exists(!_.deterministic)) + + if (hasNondeterministic) { + p + } else { + // Substitute any attributes that are produced by the child projection, so that we safely + // eliminate it. + // e.g., 'SELECT c + 1 FROM (SELECT a + b AS C ...' produces 'SELECT a + b + 1 ...' + // TODO: Fix TransformBase to avoid the cast below. + val substitutedProjection = projectList1.map(_.transform { + case a: Attribute => aliasMap.getOrElse(a, a) + }).asInstanceOf[Seq[NamedExpression]] + // collapse 2 projects may introduce unnecessary Aliases, trim them here. + val cleanedProjection = substitutedProjection.map(p => + CleanupAliases.trimNonTopLevelAliases(p).asInstanceOf[NamedExpression] + ) + agg.copy(aggregateExpressions = cleanedProjection) + } } } @@ -965,3 +1011,20 @@ object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] { a.copy(groupingExpressions = newGrouping) } } + +/** + * Computes the current date and time to make sure we return the same result in a single query. + */ +object ComputeCurrentTime extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + val dateExpr = CurrentDate() + val timeExpr = CurrentTimestamp() + val currentDate = Literal.create(dateExpr.eval(EmptyRow), dateExpr.dataType) + val currentTime = Literal.create(timeExpr.eval(EmptyRow), timeExpr.dataType) + + plan transformAllExpressions { + case CurrentDate() => currentDate + case CurrentTimestamp() => currentTime + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala new file mode 100644 index 000000000000..ec5e71042d4b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.antlr.runtime.{Token, TokenRewriteStream} + +import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} + +case class ASTNode( + token: Token, + startIndex: Int, + stopIndex: Int, + children: List[ASTNode], + stream: TokenRewriteStream) extends TreeNode[ASTNode] { + /** Cache the number of children. */ + val numChildren = children.size + + /** tuple used in pattern matching. */ + val pattern = Some((token.getText, children)) + + /** Line in which the ASTNode starts. */ + lazy val line: Int = { + val line = token.getLine + if (line == 0) { + if (children.nonEmpty) children.head.line + else 0 + } else { + line + } + } + + /** Position of the Character at which ASTNode starts. */ + lazy val positionInLine: Int = { + val line = token.getCharPositionInLine + if (line == -1) { + if (children.nonEmpty) children.head.positionInLine + else 0 + } else { + line + } + } + + /** Origin of the ASTNode. */ + override val origin = Origin(Some(line), Some(positionInLine)) + + /** Source text. */ + lazy val source = stream.toString(startIndex, stopIndex) + + def text: String = token.getText + + def tokenType: Int = token.getType + + /** + * Checks if this node is equal to another node. + * + * Right now this function only checks the name, type, text and children of the node + * for equality. + */ + def treeEquals(other: ASTNode): Boolean = { + def check(f: ASTNode => Any): Boolean = { + val l = f(this) + val r = f(other) + (l == null && r == null) || l.equals(r) + } + if (other == null) { + false + } else if (!check(_.token.getType) + || !check(_.token.getText) + || !check(_.numChildren)) { + false + } else { + children.zip(other.children).forall { + case (l, r) => l treeEquals r + } + } + } + + override def simpleString: String = s"$text $line, $startIndex, $stopIndex, $positionInLine " +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala new file mode 100644 index 000000000000..0e93af8b92cd --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.antlr.runtime._ +import org.antlr.runtime.tree.CommonTree + +import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException + +/** + * The ParseDriver takes a SQL command and turns this into an AST. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver + */ +object ParseDriver extends Logging { + def parse(command: String, conf: ParserConf): ASTNode = { + logInfo(s"Parsing command: $command") + + // Setup error collection. + val reporter = new ParseErrorReporter() + + // Create lexer. + val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command)) + val tokens = new TokenRewriteStream(lexer) + lexer.configure(conf, reporter) + + // Create the parser. + val parser = new SparkSqlParser(tokens) + parser.configure(conf, reporter) + + try { + val result = parser.statement() + + // Check errors. + reporter.checkForErrors() + + // Return the AST node from the result. + logInfo(s"Parse completed.") + + // Find the non null token tree in the result. + def nonNullToken(tree: CommonTree): CommonTree = { + if (tree.token != null || tree.getChildCount == 0) tree + else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree]) + } + val tree = nonNullToken(result.getTree) + + // Make sure all boundaries are set. + tree.setUnknownTokenBoundaries() + + // Construct the immutable AST. + def createASTNode(tree: CommonTree): ASTNode = { + val children = (0 until tree.getChildCount).map { i => + createASTNode(tree.getChild(i).asInstanceOf[CommonTree]) + }.toList + ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens) + } + createASTNode(tree) + } + catch { + case e: RecognitionException => + logInfo(s"Parse failed.") + reporter.throwError(e) + } + } +} + +/** + * This string stream provides the lexer with upper case characters only. This greatly simplifies + * lexing the stream, while we can maintain the original command. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream + * + * The comment below (taken from the original class) describes the rationale for doing this: + * + * This class provides and implementation for a case insensitive token checker for the lexical + * analysis part of antlr. By converting the token stream into upper case at the time when lexical + * rules are checked, this class ensures that the lexical rules need to just match the token with + * upper case letters as opposed to combination of upper case and lower case characters. This is + * purely used for matching lexical rules. The actual token text is stored in the same way as the + * user input without actually converting it into an upper case. The token values are generated by + * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead + * function and is purely used for matching lexical rules. This also means that the grammar will + * only accept capitalized tokens in case it is run from other tools like antlrworks which do not + * have the ANTLRNoCaseStringStream implementation. + */ + +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) { + override def LA(i: Int): Int = { + val la = super.LA(i) + if (la == 0 || la == CharStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * Utility used by the Parser and the Lexer for error collection and reporting. + */ +private[parser] class ParseErrorReporter { + val errors = scala.collection.mutable.Buffer.empty[ParseError] + + def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = { + errors += ParseError(br, re, tokenNames) + } + + def checkForErrors(): Unit = { + if (errors.nonEmpty) { + val first = errors.head + val e = first.re + throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail) + } + } + + def throwError(e: RecognitionException): Nothing = { + throwError(e.line, e.charPositionInLine, e.toString, errors) + } + + private def throwError( + line: Int, + startPosition: Int, + msg: String, + errors: Seq[ParseError]): Nothing = { + val b = new StringBuilder + b.append(msg).append("\n") + errors.foreach(error => error.buildMessage(b).append("\n")) + throw new AnalysisException(b.toString, Option(line), Option(startPosition)) + } +} + +/** + * Error collected during the parsing process. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError + */ +private[parser] case class ParseError( + br: BaseRecognizer, + re: RecognitionException, + tokenNames: Array[String]) { + def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = { + s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames)) + } +} diff --git a/bagel/src/main/scala/org/apache/spark/bagel/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala similarity index 74% rename from bagel/src/main/scala/org/apache/spark/bagel/package.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala index 2fb193457978..ce449b11431a 100644 --- a/bagel/src/main/scala/org/apache/spark/bagel/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala @@ -14,10 +14,13 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package org.apache.spark.sql.catalyst.parser -package org.apache.spark +trait ParserConf { + def supportQuotedId: Boolean + def supportSQL11ReservedKeywords: Boolean +} -/** - * Bagel: An implementation of Pregel in Spark. THIS IS DEPRECATED - use Spark's GraphX library. - */ -package object bagel +case class SimpleParserConf( + supportQuotedId: Boolean = true, + supportSQL11ReservedKeywords: Boolean = false) extends ParserConf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index b9db7838db08..b43b7ee71e7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -43,16 +43,17 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy def inputSet: AttributeSet = AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) + /** + * The set of all attributes that are produced by this node. + */ + def producedAttributes: AttributeSet = AttributeSet.empty + /** * Attributes that are referenced by expressions but not provided by this nodes children. * Subclasses should override this method if they produce attributes internally as it is used by * assertions designed to prevent the construction of invalid plans. - * - * Note that virtual columns should be excluded. Currently, we only support the grouping ID - * virtual column. */ - def missingInput: AttributeSet = - (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName) + def missingInput: AttributeSet = references -- inputSet -- producedAttributes /** * Runs [[transform]] with `rule` on all expressions present in this query operator. @@ -88,6 +89,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other + case null => null } val newArgs = productIterator.map(recursiveTransform).toArray @@ -120,6 +122,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy case d: DataType => d // Avoid unpacking Structs case seq: Traversable[_] => seq.map(recursiveTransform) case other: AnyRef => other + case null => null } val newArgs = productIterator.map(recursiveTransform).toArray diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 77dec7ca6e2b..a5f6764aef7c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -37,14 +37,26 @@ object JoinType { } } -sealed abstract class JoinType +sealed abstract class JoinType { + def sql: String +} -case object Inner extends JoinType +case object Inner extends JoinType { + override def sql: String = "INNER" +} -case object LeftOuter extends JoinType +case object LeftOuter extends JoinType { + override def sql: String = "LEFT OUTER" +} -case object RightOuter extends JoinType +case object RightOuter extends JoinType { + override def sql: String = "RIGHT OUTER" +} -case object FullOuter extends JoinType +case object FullOuter extends JoinType { + override def sql: String = "FULL OUTER" +} -case object LeftSemi extends JoinType +case object LeftSemi extends JoinType { + override def sql: String = "LEFT SEMI" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index e3e7a11dba97..d3b5879777a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis} +import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet} import org.apache.spark.sql.types.{StructField, StructType} object LocalRelation { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 8f8747e10593..6d859551f8c5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -295,6 +295,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ abstract class LeafNode extends LogicalPlan { override def children: Seq[LogicalPlan] = Nil + override def producedAttributes: AttributeSet = outputSet } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala index ccf5291219ad..578027da776e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ScriptTransformation.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{AttributeSet, Attribute, Expression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} /** * Transforms the input by forking and running the specified script. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 5665fd7e5f41..64957db6b401 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.types._ -import scala.collection.mutable.ArrayBuffer case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = projectList.map(_.toAttribute) @@ -210,6 +211,38 @@ case class Sort( override def output: Seq[Attribute] = child.output } +/** Factory for constructing new `Range` nodes. */ +object Range { + def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { + val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes + new Range(start, end, step, numSlices, output) + } +} + +case class Range( + start: Long, + end: Long, + step: Long, + numSlices: Int, + output: Seq[Attribute]) extends LeafNode { + require(step != 0, "step cannot be 0") + val numElements: BigInt = { + val safeStart = BigInt(start) + val safeEnd = BigInt(end) + if ((safeEnd - safeStart) % step == 0 || (safeEnd > safeStart) != (step > 0)) { + (safeEnd - safeStart) / step + } else { + // the remainder has the same sign with range, could add 1 more + (safeEnd - safeStart) / step + 1 + } + } + + override def statistics: Statistics = { + val sizeInBytes = LongType.defaultSize * numElements + Statistics( sizeInBytes = sizeInBytes ) + } +} + case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], @@ -293,7 +326,14 @@ private[sql] object Expand { Literal.create(bitmask, IntegerType) }) } - Expand(projections, child.output :+ gid, child) + val output = child.output.map { attr => + if (groupByExprs.exists(_.semanticEquals(attr))) { + attr.withNullability(true) + } else { + attr + } + } + Expand(projections, output :+ gid, child) } } @@ -358,43 +398,6 @@ case class GroupingSets( this.copy(aggregations = aggs) } -/** - * Cube is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, - * and eventually will be transformed to Aggregate(.., Expand) in Analyzer - * - * @param groupByExprs The Group By expressions candidates. - * @param child Child operator - * @param aggregations The Aggregation expressions, those non selected group by expressions - * will be considered as constant null if it appears in the expressions - */ -case class Cube( - groupByExprs: Seq[Expression], - child: LogicalPlan, - aggregations: Seq[NamedExpression]) extends GroupingAnalytics { - - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = - this.copy(aggregations = aggs) -} - -/** - * Rollup is a syntactic sugar for GROUPING SETS, and will be transformed to GroupingSets, - * and eventually will be transformed to Aggregate(.., Expand) in Analyzer - * - * @param groupByExprs The Group By expressions candidates, take effective only if the - * associated bit in the bitmask set to 1. - * @param child Child operator - * @param aggregations The Aggregation expressions, those non selected group by expressions - * will be considered as constant null if it appears in the expressions - */ -case class Rollup( - groupByExprs: Seq[Expression], - child: LogicalPlan, - aggregations: Seq[NamedExpression]) extends GroupingAnalytics { - - def withNewAggs(aggs: Seq[NamedExpression]): GroupingAnalytics = - this.copy(aggregations = aggs) -} - case class Pivot( groupByExprs: Seq[NamedExpression], pivotColumn: Expression, @@ -420,6 +423,7 @@ case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { } case class Subquery(alias: String, child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output.map(_.withQualifiers(alias :: Nil)) } @@ -487,7 +491,7 @@ case class MapPartitions[T, U]( uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override def missingInput: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = outputSet } /** Factory for constructing new `AppendColumn` nodes. */ @@ -513,7 +517,7 @@ case class AppendColumns[T, U]( newColumns: Seq[Attribute], child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output ++ newColumns - override def missingInput: AttributeSet = super.missingInput -- newColumns + override def producedAttributes: AttributeSet = AttributeSet(newColumns) } /** Factory for constructing new `MapGroups` nodes. */ @@ -548,7 +552,7 @@ case class MapGroups[K, T, U]( groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override def missingInput: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = outputSet } /** Factory for constructing new `CoGroup` nodes. */ @@ -591,5 +595,5 @@ case class CoGroup[Key, Left, Right, Result]( rightGroup: Seq[Attribute], left: LogicalPlan, right: LogicalPlan) extends BinaryNode { - override def missingInput: AttributeSet = AttributeSet.empty + override def producedAttributes: AttributeSet = outputSet } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index e6621e0f50a9..47b34d1fa2e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types.StringType /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f6fb31a2af59..1bfe0ecb1e20 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.plans.physical -import org.apache.spark.sql.catalyst.expressions.{Unevaluable, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, Unevaluable} import org.apache.spark.sql.types.{DataType, IntegerType} /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala index f80d2a93241d..9ebacb4680dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleExecutor.scala @@ -37,7 +37,7 @@ object RuleExecutor { val maxSize = map.keys.map(_.toString.length).max map.toSeq.sortBy(_._2).reverseMap { case (k, v) => s"${k.padTo(maxSize, " ").mkString} $v" - }.mkString("\n") + }.mkString("\n", "\n", "") } } @@ -59,7 +59,7 @@ abstract class RuleExecutor[TreeType <: TreeNode[_]] extends Logging { protected case class Batch(name: String, strategy: Strategy, rules: Rule[TreeType]*) /** Defines a sequence of rule batches, to be overridden by the implementation. */ - protected val batches: Seq[Batch] + protected def batches: Seq[Batch] /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index d838d845d20f..d4be545a35ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -17,10 +17,27 @@ package org.apache.spark.sql.catalyst.trees +import java.util.UUID + import scala.collection.Map +import scala.collection.mutable.Stack + +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods._ +import org.apache.spark.SparkContext +import org.apache.spark.rdd.{EmptyRDD, RDD} +import org.apache.spark.sql.catalyst.{ScalaReflectionLock, TableIdentifier} +import org.apache.spark.sql.catalyst.ScalaReflection._ import org.apache.spark.sql.catalyst.errors._ -import org.apache.spark.sql.types.{StructType, DataType} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.Statistics +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.util.Utils /** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */ private class MutableInt(var i: Int) @@ -463,4 +480,244 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { } s"$nodeName(${args.mkString(",")})" } + + def toJSON: String = compact(render(jsonValue)) + + def prettyJson: String = pretty(render(jsonValue)) + + private def jsonValue: JValue = { + val jsonValues = scala.collection.mutable.ArrayBuffer.empty[JValue] + + def collectJsonValue(tn: BaseType): Unit = { + val jsonFields = ("class" -> JString(tn.getClass.getName)) :: + ("num-children" -> JInt(tn.children.length)) :: tn.jsonFields + jsonValues += JObject(jsonFields) + tn.children.foreach(collectJsonValue) + } + + collectJsonValue(this) + jsonValues + } + + protected def jsonFields: List[JField] = { + val fieldNames = getConstructorParameters(getClass).map(_._1) + val fieldValues = productIterator.toSeq ++ otherCopyArgs + assert(fieldNames.length == fieldValues.length, s"${getClass.getSimpleName} fields: " + + fieldNames.mkString(", ") + s", values: " + fieldValues.map(_.toString).mkString(", ")) + + fieldNames.zip(fieldValues).map { + // If the field value is a child, then use an int to encode it, represents the index of + // this child in all children. + case (name, value: TreeNode[_]) if containsChild(value) => + name -> JInt(children.indexOf(value)) + case (name, value: Seq[BaseType]) if value.toSet.subsetOf(containsChild) => + name -> JArray( + value.map(v => JInt(children.indexOf(v.asInstanceOf[TreeNode[_]]))).toList + ) + case (name, value) => name -> parseToJson(value) + }.toList + } + + private def parseToJson(obj: Any): JValue = obj match { + case b: Boolean => JBool(b) + case b: Byte => JInt(b.toInt) + case s: Short => JInt(s.toInt) + case i: Int => JInt(i) + case l: Long => JInt(l) + case f: Float => JDouble(f) + case d: Double => JDouble(d) + case b: BigInt => JInt(b) + case null => JNull + case s: String => JString(s) + case u: UUID => JString(u.toString) + case dt: DataType => dt.jsonValue + case m: Metadata => m.jsonValue + case s: StorageLevel => + ("useDisk" -> s.useDisk) ~ ("useMemory" -> s.useMemory) ~ ("useOffHeap" -> s.useOffHeap) ~ + ("deserialized" -> s.deserialized) ~ ("replication" -> s.replication) + case n: TreeNode[_] => n.jsonValue + case o: Option[_] => o.map(parseToJson) + case t: Seq[_] => JArray(t.map(parseToJson).toList) + case m: Map[_, _] => + val fields = m.toList.map { case (k: String, v) => (k, parseToJson(v)) } + JObject(fields) + case r: RDD[_] => JNothing + // if it's a scala object, we can simply keep the full class path. + // TODO: currently if the class name ends with "$", we think it's a scala object, there is + // probably a better way to check it. + case obj if obj.getClass.getName.endsWith("$") => "object" -> obj.getClass.getName + // returns null if the product type doesn't have a primary constructor, e.g. HiveFunctionWrapper + case p: Product => try { + val fieldNames = getConstructorParameters(p.getClass).map(_._1) + val fieldValues = p.productIterator.toSeq + assert(fieldNames.length == fieldValues.length) + ("product-class" -> JString(p.getClass.getName)) :: fieldNames.zip(fieldValues).map { + case (name, value) => name -> parseToJson(value) + }.toList + } catch { + case _: RuntimeException => null + } + case _ => JNull + } +} + +object TreeNode { + def fromJSON[BaseType <: TreeNode[BaseType]](json: String, sc: SparkContext): BaseType = { + val jsonAST = parse(json) + assert(jsonAST.isInstanceOf[JArray]) + reconstruct(jsonAST.asInstanceOf[JArray], sc).asInstanceOf[BaseType] + } + + private def reconstruct(treeNodeJson: JArray, sc: SparkContext): TreeNode[_] = { + assert(treeNodeJson.arr.forall(_.isInstanceOf[JObject])) + val jsonNodes = Stack(treeNodeJson.arr.map(_.asInstanceOf[JObject]): _*) + + def parseNextNode(): TreeNode[_] = { + val nextNode = jsonNodes.pop() + + val cls = Utils.classForName((nextNode \ "class").asInstanceOf[JString].s) + if (cls == classOf[Literal]) { + Literal.fromJSON(nextNode) + } else if (cls.getName.endsWith("$")) { + cls.getField("MODULE$").get(cls).asInstanceOf[TreeNode[_]] + } else { + val numChildren = (nextNode \ "num-children").asInstanceOf[JInt].num.toInt + + val children: Seq[TreeNode[_]] = (1 to numChildren).map(_ => parseNextNode()) + val fields = getConstructorParameters(cls) + + val parameters: Array[AnyRef] = fields.map { + case (fieldName, fieldType) => + parseFromJson(nextNode \ fieldName, fieldType, children, sc) + }.toArray + + val maybeCtor = cls.getConstructors.find { p => + val expectedTypes = p.getParameterTypes + expectedTypes.length == fields.length && expectedTypes.zip(fields.map(_._2)).forall { + case (cls, tpe) => cls == getClassFromType(tpe) + } + } + if (maybeCtor.isEmpty) { + sys.error(s"No valid constructor for ${cls.getName}") + } else { + try { + maybeCtor.get.newInstance(parameters: _*).asInstanceOf[TreeNode[_]] + } catch { + case e: java.lang.IllegalArgumentException => + throw new RuntimeException( + s""" + |Failed to construct tree node: ${cls.getName} + |ctor: ${maybeCtor.get} + |types: ${parameters.map(_.getClass).mkString(", ")} + |args: ${parameters.mkString(", ")} + """.stripMargin, e) + } + } + } + } + + parseNextNode() + } + + import universe._ + + private def parseFromJson( + value: JValue, + expectedType: Type, + children: Seq[TreeNode[_]], + sc: SparkContext): AnyRef = ScalaReflectionLock.synchronized { + if (value == JNull) return null + + expectedType match { + case t if t <:< definitions.BooleanTpe => + value.asInstanceOf[JBool].value: java.lang.Boolean + case t if t <:< definitions.ByteTpe => + value.asInstanceOf[JInt].num.toByte: java.lang.Byte + case t if t <:< definitions.ShortTpe => + value.asInstanceOf[JInt].num.toShort: java.lang.Short + case t if t <:< definitions.IntTpe => + value.asInstanceOf[JInt].num.toInt: java.lang.Integer + case t if t <:< definitions.LongTpe => + value.asInstanceOf[JInt].num.toLong: java.lang.Long + case t if t <:< definitions.FloatTpe => + value.asInstanceOf[JDouble].num.toFloat: java.lang.Float + case t if t <:< definitions.DoubleTpe => + value.asInstanceOf[JDouble].num: java.lang.Double + + case t if t <:< localTypeOf[BigInt] => value.asInstanceOf[JInt].num + case t if t <:< localTypeOf[java.lang.String] => value.asInstanceOf[JString].s + case t if t <:< localTypeOf[UUID] => UUID.fromString(value.asInstanceOf[JString].s) + case t if t <:< localTypeOf[DataType] => DataType.parseDataType(value) + case t if t <:< localTypeOf[Metadata] => Metadata.fromJObject(value.asInstanceOf[JObject]) + case t if t <:< localTypeOf[StorageLevel] => + val JBool(useDisk) = value \ "useDisk" + val JBool(useMemory) = value \ "useMemory" + val JBool(useOffHeap) = value \ "useOffHeap" + val JBool(deserialized) = value \ "deserialized" + val JInt(replication) = value \ "replication" + StorageLevel(useDisk, useMemory, useOffHeap, deserialized, replication.toInt) + case t if t <:< localTypeOf[TreeNode[_]] => value match { + case JInt(i) => children(i.toInt) + case arr: JArray => reconstruct(arr, sc) + case _ => throw new RuntimeException(s"$value is not a valid json value for tree node.") + } + case t if t <:< localTypeOf[Option[_]] => + if (value == JNothing) { + None + } else { + val TypeRef(_, _, Seq(optType)) = t + Option(parseFromJson(value, optType, children, sc)) + } + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val JArray(elements) = value + elements.map(parseFromJson(_, elementType, children, sc)).toSeq + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + val JObject(fields) = value + fields.map { + case (name, value) => name -> parseFromJson(value, valueType, children, sc) + }.toMap + case t if t <:< localTypeOf[RDD[_]] => + new EmptyRDD[Any](sc) + case _ if isScalaObject(value) => + val JString(clsName) = value \ "object" + val cls = Utils.classForName(clsName) + cls.getField("MODULE$").get(cls) + case t if t <:< localTypeOf[Product] => + val fields = getConstructorParameters(t) + val clsName = getClassNameFromType(t) + parseToProduct(clsName, fields, value, children, sc) + // There maybe some cases that the parameter type signature is not Product but the value is, + // e.g. `SpecifiedWindowFrame` with type signature `WindowFrame`, handle it here. + case _ if isScalaProduct(value) => + val JString(clsName) = value \ "product-class" + val fields = getConstructorParameters(Utils.classForName(clsName)) + parseToProduct(clsName, fields, value, children, sc) + case _ => sys.error(s"Do not support type $expectedType with json $value.") + } + } + + private def parseToProduct( + clsName: String, + fields: Seq[(String, Type)], + value: JValue, + children: Seq[TreeNode[_]], + sc: SparkContext): AnyRef = { + val parameters: Array[AnyRef] = fields.map { + case (fieldName, fieldType) => parseFromJson(value \ fieldName, fieldType, children, sc) + }.toArray + val ctor = Utils.classForName(clsName).getConstructors.maxBy(_.getParameterTypes.size) + ctor.newInstance(parameters: _*).asInstanceOf[AnyRef] + } + + private def isScalaObject(jValue: JValue): Boolean = (jValue \ "object") match { + case JString(str) if str.endsWith("$") => true + case _ => false + } + + private def isScalaProduct(jValue: JValue): Boolean = (jValue \ "product-class") match { + case _: JString => true + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 2b9388291948..f18c052b68e3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} -import java.util.{TimeZone, Calendar} +import java.util.{Calendar, TimeZone} import javax.xml.bind.DatatypeConverter import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/LegacyTypeStringParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/LegacyTypeStringParser.scala new file mode 100644 index 000000000000..e27cf9c1989f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/LegacyTypeStringParser.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.parsing.combinator.RegexParsers + +import org.apache.spark.sql.types._ + +/** + * Parser that turns case class strings into datatypes. This is only here to maintain compatibility + * with Parquet files written by Spark 1.1 and below. + */ +object LegacyTypeStringParser extends RegexParsers { + + protected lazy val primitiveType: Parser[DataType] = + ( "StringType" ^^^ StringType + | "FloatType" ^^^ FloatType + | "IntegerType" ^^^ IntegerType + | "ByteType" ^^^ ByteType + | "ShortType" ^^^ ShortType + | "DoubleType" ^^^ DoubleType + | "LongType" ^^^ LongType + | "BinaryType" ^^^ BinaryType + | "BooleanType" ^^^ BooleanType + | "DateType" ^^^ DateType + | "DecimalType()" ^^^ DecimalType.USER_DEFAULT + | fixedDecimalType + | "TimestampType" ^^^ TimestampType + ) + + protected lazy val fixedDecimalType: Parser[DataType] = + ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { + case precision ~ scale => DecimalType(precision.toInt, scale.toInt) + } + + protected lazy val arrayType: Parser[DataType] = + "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { + case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) + } + + protected lazy val mapType: Parser[DataType] = + "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { + case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) + } + + protected lazy val structField: Parser[StructField] = + ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { + case name ~ tpe ~ nullable => + StructField(name, tpe, nullable = nullable) + } + + protected lazy val boolVal: Parser[Boolean] = + ( "true" ^^^ true + | "false" ^^^ false + ) + + protected lazy val structType: Parser[DataType] = + "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { + case fields => StructType(fields) + } + + protected lazy val dataType: Parser[DataType] = + ( arrayType + | mapType + | structType + | primitiveType + ) + + /** + * Parses a string representation of a DataType. + */ + def parse(asString: String): DataType = parseAll(dataType, asString) match { + case Success(result, _) => result + case failure: NoSuccess => + throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 71293475ca0f..7a0d0de6328a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -130,6 +130,20 @@ package object util { ret } + /** + * Converts a `Seq` of `Option[T]` to an `Option` of `Seq[T]`. + */ + def sequenceOption[T](seq: Seq[Option[T]]): Option[Seq[T]] = seq match { + case xs if xs.isEmpty => + Option(Seq.empty[T]) + + case xs => + for { + head <- xs.head + tail <- sequenceOption(xs.tail) + } yield head +: tail + } + /* FIX ME implicit class debugLogging(a: Any) { def debugLogging() { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala index a5ae8bb0e5eb..90af10f7a6b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/AbstractDataType.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.types import scala.reflect.ClassTag -import scala.reflect.runtime.universe.{TypeTag, runtimeMirror} +import scala.reflect.runtime.universe.{runtimeMirror, TypeTag} import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala index a001eadcc61d..520e34436162 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayType.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.types -import org.apache.spark.sql.catalyst.util.ArrayData +import scala.math.Ordering + import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi - -import scala.math.Ordering - +import org.apache.spark.sql.catalyst.util.ArrayData object ArrayType extends AbstractDataType { /** Construct a [[ArrayType]] object with the given element type. The `containsNull` is true. */ @@ -78,6 +77,8 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT override def simpleString: String = s"array<${elementType.simpleString}>" + override def sql: String = s"ARRAY<${elementType.sql}>" + override private[spark] def asNullable: ArrayType = ArrayType(elementType.asNullable, containsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala index 2ca427975a1c..d37130e27ba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ByteType.scala @@ -17,13 +17,12 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.ScalaReflectionLock - /** * :: DeveloperApi :: * The data type representing `Byte` values. Please use the singleton [[DataTypes.ByteType]]. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 4b54c31dcc27..92cf8d4c46bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -17,19 +17,15 @@ package org.apache.spark.sql.types -import scala.util.Try -import scala.util.parsing.combinator.RegexParsers - +import org.json4s._ import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ -import org.json4s._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.util.Utils - /** * :: DeveloperApi :: * The base type of all Spark SQL data types. @@ -66,6 +62,11 @@ abstract class DataType extends AbstractDataType { /** Readable string representation for the type. */ def simpleString: String = typeName + /** Readable string representation for the type with truncation */ + private[sql] def simpleString(maxNumberFields: Int): String = simpleString + + def sql: String = simpleString.toUpperCase + /** * Check if `this` and `other` are the same data type when ignoring nullability * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). @@ -91,21 +92,12 @@ abstract class DataType extends AbstractDataType { object DataType { - private[sql] def fromString(raw: String): DataType = { - Try(DataType.fromJson(raw)).getOrElse(DataType.fromCaseClassString(raw)) - } def fromJson(json: String): DataType = parseDataType(parse(json)) - /** - * @deprecated As of 1.2.0, replaced by `DataType.fromJson()` - */ - @deprecated("Use DataType.fromJson instead", "1.2.0") - def fromCaseClassString(string: String): DataType = CaseClassStringParser(string) - private val nonDecimalNameToType = { - Seq(NullType, DateType, TimestampType, BinaryType, - IntegerType, BooleanType, LongType, DoubleType, FloatType, ShortType, ByteType, StringType) + Seq(NullType, DateType, TimestampType, BinaryType, IntegerType, BooleanType, LongType, + DoubleType, FloatType, ShortType, ByteType, StringType, CalendarIntervalType) .map(t => t.typeName -> t).toMap } @@ -127,7 +119,7 @@ object DataType { } // NOTE: Map fields must be sorted in alphabetical order to keep consistent with the Python side. - private def parseDataType(json: JValue): DataType = json match { + private[sql] def parseDataType(json: JValue): DataType = json match { case JString(name) => nameToType(name) @@ -181,73 +173,6 @@ object DataType { StructField(name, parseDataType(dataType), nullable) } - private object CaseClassStringParser extends RegexParsers { - protected lazy val primitiveType: Parser[DataType] = - ( "StringType" ^^^ StringType - | "FloatType" ^^^ FloatType - | "IntegerType" ^^^ IntegerType - | "ByteType" ^^^ ByteType - | "ShortType" ^^^ ShortType - | "DoubleType" ^^^ DoubleType - | "LongType" ^^^ LongType - | "BinaryType" ^^^ BinaryType - | "BooleanType" ^^^ BooleanType - | "DateType" ^^^ DateType - | "DecimalType()" ^^^ DecimalType.USER_DEFAULT - | fixedDecimalType - | "TimestampType" ^^^ TimestampType - ) - - protected lazy val fixedDecimalType: Parser[DataType] = - ("DecimalType(" ~> "[0-9]+".r) ~ ("," ~> "[0-9]+".r <~ ")") ^^ { - case precision ~ scale => DecimalType(precision.toInt, scale.toInt) - } - - protected lazy val arrayType: Parser[DataType] = - "ArrayType" ~> "(" ~> dataType ~ "," ~ boolVal <~ ")" ^^ { - case tpe ~ _ ~ containsNull => ArrayType(tpe, containsNull) - } - - protected lazy val mapType: Parser[DataType] = - "MapType" ~> "(" ~> dataType ~ "," ~ dataType ~ "," ~ boolVal <~ ")" ^^ { - case t1 ~ _ ~ t2 ~ _ ~ valueContainsNull => MapType(t1, t2, valueContainsNull) - } - - protected lazy val structField: Parser[StructField] = - ("StructField(" ~> "[a-zA-Z0-9_]*".r) ~ ("," ~> dataType) ~ ("," ~> boolVal <~ ")") ^^ { - case name ~ tpe ~ nullable => - StructField(name, tpe, nullable = nullable) - } - - protected lazy val boolVal: Parser[Boolean] = - ( "true" ^^^ true - | "false" ^^^ false - ) - - protected lazy val structType: Parser[DataType] = - "StructType\\([A-zA-z]*\\(".r ~> repsep(structField, ",") <~ "))" ^^ { - case fields => StructType(fields) - } - - protected lazy val dataType: Parser[DataType] = - ( arrayType - | mapType - | structType - | primitiveType - ) - - /** - * Parses a string representation of a DataType. - * - * TODO: Generate parser as pickler... - */ - def apply(asString: String): DataType = parseAll(dataType, asString) match { - case Success(result, _) => result - case failure: NoSuccess => - throw new IllegalArgumentException(s"Unsupported dataType: $asString, $failure") - } - } - protected[types] def buildFormattedString( dataType: DataType, prefix: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index c7a1a2e7469e..38ce1604b1ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import java.math.{RoundingMode, MathContext} +import java.math.{MathContext, RoundingMode} import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala index ce45245b9f6d..cf5322125bd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DecimalType.scala @@ -25,20 +25,6 @@ import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.Expression -/** Precision parameters for a Decimal */ -@deprecated("Use DecimalType(precision, scale) directly", "1.5") -case class PrecisionInfo(precision: Int, scale: Int) { - if (scale > precision) { - throw new AnalysisException( - s"Decimal scale ($scale) cannot be greater than precision ($precision).") - } - if (precision > DecimalType.MAX_PRECISION) { - throw new AnalysisException( - s"DecimalType can only support precision up to 38" - ) - } -} - /** * :: DeveloperApi :: * The data type representing `java.math.BigDecimal` values. @@ -54,18 +40,18 @@ case class PrecisionInfo(precision: Int, scale: Int) { @DeveloperApi case class DecimalType(precision: Int, scale: Int) extends FractionalType { - // default constructor for Java - def this(precision: Int) = this(precision, 0) - def this() = this(10) + if (scale > precision) { + throw new AnalysisException( + s"Decimal scale ($scale) cannot be greater than precision ($precision).") + } - @deprecated("Use DecimalType(precision, scale) instead", "1.5") - def this(precisionInfo: Option[PrecisionInfo]) { - this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, - precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) + if (precision > DecimalType.MAX_PRECISION) { + throw new AnalysisException(s"DecimalType can only support precision up to 38") } - @deprecated("Use DecimalType.precision and DecimalType.scale instead", "1.5") - val precisionInfo = Some(PrecisionInfo(precision, scale)) + // default constructor for Java + def this(precision: Int) = this(precision, 0) + def this() = this(10) private[sql] type InternalType = Decimal @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType] } @@ -122,9 +108,6 @@ object DecimalType extends AbstractDataType { val SYSTEM_DEFAULT: DecimalType = DecimalType(MAX_PRECISION, 18) val USER_DEFAULT: DecimalType = DecimalType(10, 0) - @deprecated("Does not support unlimited precision, please specify the precision and scale", "1.5") - val Unlimited: DecimalType = SYSTEM_DEFAULT - // The decimal types compatible with other numeric types private[sql] val ByteDecimal = DecimalType(3, 0) private[sql] val ShortDecimal = DecimalType(5, 0) @@ -142,15 +125,6 @@ object DecimalType extends AbstractDataType { case DoubleType => DoubleDecimal } - @deprecated("please specify precision and scale", "1.5") - def apply(): DecimalType = USER_DEFAULT - - @deprecated("Use DecimalType(precision, scale) instead", "1.5") - def apply(precisionInfo: Option[PrecisionInfo]) { - this(precisionInfo.getOrElse(PrecisionInfo(10, 0)).precision, - precisionInfo.getOrElse(PrecisionInfo(10, 0)).scale) - } - private[sql] def bounded(precision: Int, scale: Int): DecimalType = { DecimalType(min(precision, MAX_PRECISION), min(scale, MAX_SCALE)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala index 2a1bf0938e5a..e553f65f3c99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Fractional, Numeric} +import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.DoubleAsIfIntegral import scala.reflect.runtime.universe.typeTag diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala index 08e22252aef8..ae9aa9eefaf2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.types +import scala.math.{Fractional, Numeric, Ordering} import scala.math.Numeric.FloatAsIfIntegral -import scala.math.{Ordering, Fractional, Numeric} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala index a2c6e19b05b3..38a7b8ee5265 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/IntegerType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala index 2b3adf6ade83..88aff0c87755 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/LongType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 00461e529ca0..5474954af70e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -62,6 +62,8 @@ case class MapType( override def simpleString: String = s"map<${keyType.simpleString},${valueType.simpleString}>" + override def sql: String = s"MAP<${keyType.sql}, ${valueType.sql}>" + override private[spark] def asNullable: MapType = MapType(keyType.asNullable, valueType.asNullable, valueContainsNull = true) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala index a13119e65906..486cf585284d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ShortType.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.types -import scala.math.{Ordering, Integral, Numeric} +import scala.math.{Integral, Numeric, Ordering} import scala.reflect.runtime.universe.typeTag import org.apache.spark.annotation.DeveloperApi diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index 9778df271ddd..9b5c86a8984b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.types import scala.collection.mutable.ArrayBuffer +import scala.util.Try import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, InterpretedOrdering} -import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.catalyst.util.{LegacyTypeStringParser, DataTypeParser} /** @@ -278,6 +279,28 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru s"struct<${fieldTypes.mkString(",")}>" } + override def sql: String = { + val fieldTypes = fields.map(f => s"`${f.name}`: ${f.dataType.sql}") + s"STRUCT<${fieldTypes.mkString(", ")}>" + } + + private[sql] override def simpleString(maxNumberFields: Int): String = { + val builder = new StringBuilder + val fieldTypes = fields.take(maxNumberFields).map { + case f => s"${f.name}: ${f.dataType.simpleString(maxNumberFields)}" + } + builder.append("struct<") + builder.append(fieldTypes.mkString(", ")) + if (fields.length > 2) { + if (fields.length - fieldTypes.size == 1) { + builder.append(" ... 1 more field") + } else { + builder.append(" ... " + (fields.length - 2) + " more fields") + } + } + builder.append(">").toString() + } + /** * Merges with another schema (`StructType`). For a struct field A from `this` and a struct field * B from `that`, @@ -320,9 +343,11 @@ object StructType extends AbstractDataType { override private[sql] def simpleString: String = "struct" - private[sql] def fromString(raw: String): StructType = DataType.fromString(raw) match { - case t: StructType => t - case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + private[sql] def fromString(raw: String): StructType = { + Try(DataType.fromJson(raw)).getOrElse(LegacyTypeStringParser.parse(raw)) match { + case t: StructType => t + case _ => throw new RuntimeException(s"Failed parsing StructType: $raw") + } } def apply(fields: Seq[StructField]): StructType = StructType(fields.toArray) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index 4305903616bd..d7a2c23be8a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -84,6 +84,8 @@ abstract class UserDefinedType[UserType] extends DataType with Serializable { override private[sql] def acceptsType(dataType: DataType) = this.getClass == dataType.getClass + + override def sql: String = sqlType.sql } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index 5c22a7219254..1e7118144f2e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql +import org.scalatest.{FunSpec, Matchers} + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericRow, GenericRowWithSchema} import org.apache.spark.sql.types._ -import org.scalatest.{Matchers, FunSpec} class RowTest extends FunSpec with Matchers { @@ -104,4 +105,34 @@ class RowTest extends FunSpec with Matchers { internalRow shouldEqual internalRow2 } } + + describe("row immutability") { + val values = Seq(1, 2, "3", "IV", 6L) + val externalRow = Row.fromSeq(values) + val internalRow = InternalRow.fromSeq(values) + + def modifyValues(values: Seq[Any]): Seq[Any] = { + val array = values.toArray + array(2) = "42" + array + } + + it("copy should return same ref for external rows") { + externalRow should be theSameInstanceAs externalRow.copy() + } + + it("copy should return same ref for interal rows") { + internalRow should be theSameInstanceAs internalRow.copy() + } + + it("toSeq should not expose internal state for external rows") { + val modifiedValues = modifyValues(externalRow.toSeq) + externalRow.toSeq should not equal modifiedValues + } + + it("toSeq should not expose internal state for internal rows") { + val modifiedValues = modifyValues(internalRow.toSeq(Seq.empty)) + internalRow.toSeq(Seq.empty) should not equal modifiedValues + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala new file mode 100644 index 000000000000..d7204c348831 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/CatalystQlSuite.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.sql.catalyst.plans.PlanTest + +class CatalystQlSuite extends PlanTest { + val parser = new CatalystQl() + + test("parse union/except/intersect") { + parser.createPlan("select * from t1 union all select * from t2") + parser.createPlan("select * from t1 union distinct select * from t2") + parser.createPlan("select * from t1 union select * from t2") + parser.createPlan("select * from t1 except select * from t2") + parser.createPlan("select * from t1 intersect select * from t2") + parser.createPlan("(select * from t1) union all (select * from t2)") + parser.createPlan("(select * from t1) union distinct (select * from t2)") + parser.createPlan("(select * from t1) union (select * from t2)") + parser.createPlan("select * from ((select * from t1) union (select * from t2)) t") + } + + test("window function: better support of parentheses") { + parser.createPlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " + + "order by 2) from windowData") + parser.createPlan("select sum(product + 1) over (partition by (1 + (product / 2)) " + + "order by 2) from windowData") + parser.createPlan("select sum(product + 1) over (partition by ((product / 2) + 1) " + + "order by 2) from windowData") + + parser.createPlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " + + "from windowData") + parser.createPlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " + + "from windowData") + parser.createPlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + + "from windowData") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index 827f7ce69271..b47b8adfe5d5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -18,10 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.plans.physical._ - /* Implicit conversions */ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.plans.physical._ class DistributionSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala index 9ff893b84775..b0884f528742 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/SqlParserSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias -import org.apache.spark.sql.catalyst.expressions.{Literal, GreaterThan, Not, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GreaterThan, Literal, Not} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project, LogicalPlan, Command} +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan, OneRowRelation, Project} import org.apache.spark.unsafe.types.CalendarInterval private[sql] case class TestCommand(cmd: String) extends LogicalPlan with Command { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ee435578743f..fc35959f2054 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -17,17 +17,18 @@ package org.apache.spark.sql.catalyst.analysis +import scala.beans.{BeanInfo, BeanProperty} + import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Sum} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.util.{MapData, ArrayBasedMapData, GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ -import scala.beans.{BeanProperty, BeanInfo} - @BeanInfo private[sql] case class GroupableData(@BeanProperty data: Int) @@ -133,17 +134,37 @@ class AnalysisErrorSuite extends AnalysisTest { "requires int type" :: "'null' is of date type" :: Nil) errorTest( - "unresolved window function", + "invalid window function", testRelation2.select( WindowExpression( - UnresolvedWindowFunction( - "lead", - UnresolvedAttribute("c") :: Nil), + Literal(0), WindowSpecDefinition( UnresolvedAttribute("a") :: Nil, SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, UnspecifiedFrame)).as('window)), - "lead" :: "window functions currently requires a HiveContext" :: Nil) + "not supported within a window function" :: Nil) + + errorTest( + "distinct window function", + testRelation2.select( + WindowExpression( + AggregateExpression(Count(UnresolvedAttribute("b")), Complete, isDistinct = true), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + UnspecifiedFrame)).as('window)), + "Distinct window functions are not supported" :: Nil) + + errorTest( + "offset window function", + testRelation2.select( + WindowExpression( + new Lead(UnresolvedAttribute("b")), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("b"), Ascending) :: Nil, + SpecifiedWindowFrame(RangeFrame, ValueFollowing(1), ValueFollowing(2)))).as('window)), + "window frame" :: "must match the required frame" :: Nil) errorTest( "too many generators", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index aeeca802d8bb..cf84855885a3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ class AnalysisSuite extends AnalysisTest { @@ -238,40 +237,9 @@ class AnalysisSuite extends AnalysisTest { checkAnalysis(plan, expected) } - test("analyzer should replace current_timestamp with literals") { - val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), - LocalRelation()) - - val min = System.currentTimeMillis() * 1000 - val plan = in.analyze.asInstanceOf[Project] - val max = (System.currentTimeMillis() + 1) * 1000 - - val lits = new scala.collection.mutable.ArrayBuffer[Long] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Long] - e - } - assert(lits.size == 2) - assert(lits(0) >= min && lits(0) <= max) - assert(lits(1) >= min && lits(1) <= max) - assert(lits(0) == lits(1)) - } - - test("analyzer should replace current_date with literals") { - val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) - - val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) - val plan = in.analyze.asInstanceOf[Project] - val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) - - val lits = new scala.collection.mutable.ArrayBuffer[Int] - plan.transformAllExpressions { case e: Literal => - lits += e.value.asInstanceOf[Int] - e - } - assert(lits.size == 2) - assert(lits(0) >= min && lits(0) <= max) - assert(lits(1) >= min && lits(1) <= max) - assert(lits(0) == lits(1)) + test("SPARK-12102: Ignore nullablity when comparing two sides of case") { + val relation = LocalRelation('a.struct('x.int), 'b.struct('x.int.withNullability(false))) + val plan = relation.select(CaseWhen(Seq(Literal(true), 'a, 'b)).as("val")) + assertAnalysisSuccess(plan) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 23861ed15da6..af214b7af062 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} trait AnalysisTest extends PlanTest { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index fed591fd90a9..39c8f56c1bca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Union, Project, LocalRelation} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project, Union} import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.{TableIdentifier, SimpleCatalystConf} class DecimalPrecisionSuite extends SparkFunSuite with BeforeAndAfter { val conf = new SimpleCatalystConf(true) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 915c585ec91f..0521ed848c79 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.types.{LongType, TypeCollection, StringType} +import org.apache.spark.sql.types.{LongType, StringType, TypeCollection} class ExpressionTypeCheckingSuite extends SparkFunSuite { @@ -163,6 +163,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Coalesce(Seq('intField, 'booleanField)), "input to function coalesce should all be the same type") assertError(Coalesce(Nil), "input to function coalesce cannot be empty") + assertError(new Murmur3Hash(Nil), "function hash requires at least one argument") assertError(Explode('intField), "input to function explode should be array or map type") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 142915056f45..58d808c55860 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import java.sql.Timestamp -import org.apache.spark.sql.catalyst.plans.PlanTest - import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 815a03f7c1a8..bc36a55ae0ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -36,14 +36,18 @@ class EncoderResolutionSuite extends PlanTest { val encoder = ExpressionEncoder[StringLongClass] val cls = classOf[StringLongClass] + { val attrs = Seq('a.string, 'b.int) val fromRowExpr: Expression = encoder.resolve(attrs, null).fromRowExpression val expected: Expression = NewInstance( cls, - toExternalString('a.string) :: 'b.int.cast(LongType) :: Nil, - false, - ObjectType(cls)) + Seq( + toExternalString('a.string), + AssertNotNull('b.int.cast(LongType), cls.getName, "b", "Long") + ), + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } @@ -52,9 +56,12 @@ class EncoderResolutionSuite extends PlanTest { val fromRowExpr = encoder.resolve(attrs, null).fromRowExpression val expected = NewInstance( cls, - toExternalString('a.int.cast(StringType)) :: 'b.long :: Nil, - false, - ObjectType(cls)) + Seq( + toExternalString('a.int.cast(StringType)), + AssertNotNull('b.long, cls.getName, "b", "Long") + ), + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } } @@ -69,7 +76,7 @@ class EncoderResolutionSuite extends PlanTest { val expected: Expression = NewInstance( cls, Seq( - 'a.int.cast(LongType), + AssertNotNull('a.int.cast(LongType), cls.getName, "a", "Long"), If( 'b.struct('a.int, 'b.long).isNull, Literal.create(null, ObjectType(innerCls)), @@ -78,12 +85,14 @@ class EncoderResolutionSuite extends PlanTest { Seq( toExternalString( GetStructField('b.struct('a.int, 'b.long), 0, Some("a")).cast(StringType)), - GetStructField('b.struct('a.int, 'b.long), 1, Some("b"))), - false, - ObjectType(innerCls)) + AssertNotNull( + GetStructField('b.struct('a.int, 'b.long), 1, Some("b")), + innerCls.getName, "b", "Long")), + ObjectType(innerCls), + propagateNull = false) )), - false, - ObjectType(cls)) + ObjectType(cls), + propagateNull = false) compareExpressions(fromRowExpr, expected) } @@ -102,12 +111,14 @@ class EncoderResolutionSuite extends PlanTest { cls, Seq( toExternalString(GetStructField('a.struct('a.string, 'b.byte), 0, Some("a"))), - GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType)), - false, - ObjectType(cls)), + AssertNotNull( + GetStructField('a.struct('a.string, 'b.byte), 1, Some("b")).cast(LongType), + cls.getName, "b", "Long")), + ObjectType(cls), + propagateNull = false), 'b.int.cast(LongType)), - false, - ObjectType(classOf[Tuple2[_, _]])) + ObjectType(classOf[Tuple2[_, _]]), + propagateNull = false) compareExpressions(fromRowExpr, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 7233e0f1b5ba..88c558d80a79 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.catalyst.encoders -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import java.util.Arrays import java.util.concurrent.ConcurrentMap + import scala.collection.mutable.ArrayBuffer import scala.reflect.runtime.universe.TypeTag @@ -27,10 +28,10 @@ import com.google.common.collect.MapMaker import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Encoders +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} -import org.apache.spark.sql.types.{StructType, ArrayType} +import org.apache.spark.sql.types.{ArrayType, StructType} case class RepeatedStruct(s: Seq[PrimitiveData]) @@ -77,6 +78,8 @@ class JavaSerializable(val value: Int) extends Serializable { } class ExpressionEncoderSuite extends SparkFunSuite { + OuterScopes.outerScopes.put(getClass.getName, this) + implicit def encoder[T : TypeTag]: ExpressionEncoder[T] = ExpressionEncoder() // test flat encoders @@ -128,6 +131,9 @@ class ExpressionEncoderSuite extends SparkFunSuite { encodeDecodeTest(Map(1 -> "a", 2 -> null), "map with null") encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), "map of map") + encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple") + encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple") + // Kryo encoders encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String])) encodeDecodeTest(new KryoSerializable(15), "kryo object")( @@ -155,6 +161,9 @@ class ExpressionEncoderSuite extends SparkFunSuite { productTest(OptionalData(None, None, None, None, None, None, None, None)) + encodeDecodeTest(Seq(Some(1), None), "Option in array") + encodeDecodeTest(Map(1 -> Some(10L), 2 -> Some(20L), 3 -> None), "Option in map") + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) productTest(BoxedData(null, null, null, null, null, null, null)) @@ -239,6 +248,8 @@ class ExpressionEncoderSuite extends SparkFunSuite { ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc)) } + productTest(("UDT", new ExamplePoint(0.1, 0.2))) + test("nullable of encoder schema") { def checkNullable[T: ExpressionEncoder](nullable: Boolean*): Unit = { assert(implicitly[ExpressionEncoder[T]].schema.map(_.nullable) === nullable.toSeq) @@ -275,8 +286,6 @@ class ExpressionEncoderSuite extends SparkFunSuite { } } - private val outers: ConcurrentMap[String, AnyRef] = new MapMaker().weakValues().makeMap() - outers.put(getClass.getName, this) private def encodeDecodeTest[T : ExpressionEncoder]( input: T, testName: String): Unit = { @@ -284,7 +293,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.resolve(schema, outers).bind(schema) + val boundEncoder = encoder.defaultBinding val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 0ea51ece4bc5..932511134c63 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -99,7 +99,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("binary", BinaryType) .add("date", DateType) .add("timestamp", TimestampType) - .add("udt", new ExamplePointUDT, false)) + .add("udt", new ExamplePointUDT)) encodeDecodeTest( new StructType() @@ -108,7 +108,8 @@ class RowEncoderSuite extends SparkFunSuite { .add("arrayOfArrayOfString", ArrayType(arrayOfString)) .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) .add("arrayOfMap", ArrayType(mapOfString)) - .add("arrayOfStruct", ArrayType(structOfString))) + .add("arrayOfStruct", ArrayType(structOfString)) + .add("arrayOfUDT", arrayOfUDT)) encodeDecodeTest( new StructType() @@ -130,18 +131,6 @@ class RowEncoderSuite extends SparkFunSuite { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) - test(s"encode/decode: arrayOfUDT") { - val schema = new StructType() - .add("arrayOfUDT", arrayOfUDT) - - val encoder = RowEncoder(schema) - - val input: Row = Row(Seq(new ExamplePoint(0.1, 0.2), new ExamplePoint(0.3, 0.4))) - val row = encoder.toRow(input) - val convertedBack = encoder.fromRow(row) - assert(input.getSeq[ExamplePoint](0) == convertedBack.getSeq[ExamplePoint](0)) - } - test(s"encode/decode: Product") { val schema = new StructType() .add("structAsProduct", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index a98e16c25321..43af3592070f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} -import java.util.{TimeZone, Calendar} +import java.sql.{Date, Timestamp} +import java.util.{Calendar, TimeZone} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row @@ -297,7 +297,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { test("cast from string") { assert(cast("abcdef", StringType).nullable === false) assert(cast("abcdef", BinaryType).nullable === false) - assert(cast("abcdef", BooleanType).nullable === false) + assert(cast("abcdef", BooleanType).nullable === true) assert(cast("abcdef", TimestampType).nullable === true) assert(cast("abcdef", LongType).nullable === true) assert(cast("abcdef", IntegerType).nullable === true) @@ -547,7 +547,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } { val ret = cast(array_notNull, ArrayType(BooleanType, containsNull = false)) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Seq(null, true, false)) } @@ -606,7 +606,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { } { val ret = cast(map_notNull, MapType(StringType, BooleanType, valueContainsNull = false)) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Map("a" -> null, "b" -> true, "c" -> false)) } { @@ -713,7 +713,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("a", BooleanType, nullable = true), StructField("b", BooleanType, nullable = true), StructField("c", BooleanType, nullable = false)))) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, InternalRow(null, true, false)) } @@ -754,7 +754,7 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StructType(Seq( StructField("l", LongType, nullable = true))))))) - assert(ret.resolved === true) + assert(ret.resolved === false) checkEvaluation(ret, Row( Seq(123, null, null), Map("a" -> null, "b" -> true, "c" -> false), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 9f1b19253e7c..9c1688b261aa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedExtractValue import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 0df673bb9fa0..4029da592558 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.catalyst.expressions -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ - class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("if") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala index 511f0307901d..a8f758d625a0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DecimalExpressionSuite.scala @@ -18,8 +18,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{LongType, DecimalType, Decimal} - +import org.apache.spark.sql.types.{Decimal, DecimalType, LongType} class DecimalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index f869a96edb1c..e028d22a54ba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -57,8 +57,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) - case (result: Double, expected: Spread[Double]) => - expected.isWithin(result) + case (result: Double, expected: Spread[Double @unchecked]) => + expected.asInstanceOf[Spread[Double]].isWithin(result) case _ => result == expected } } @@ -275,8 +275,8 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { (result, expected) match { case (result: Array[Byte], expected: Array[Byte]) => java.util.Arrays.equals(result, expected) - case (result: Double, expected: Spread[Double]) => - expected.isWithin(result) + case (result: Double, expected: Spread[Double @unchecked]) => + expected.asInstanceOf[Spread[Double]].isWithin(result) case (result: Double, expected: Double) if result.isNaN && expected.isNaN => true case (result: Float, expected: Float) if result.isNaN && expected.isNaN => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 75d17417e5a0..64161bebdcbe 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.types._ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -59,4 +61,73 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } + + private val structOfString = new StructType().add("str", StringType) + private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) + private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) + private val mapOfString = MapType(StringType, StringType) + private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) + + testMurmur3Hash( + new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("float", FloatType) + .add("double", DoubleType) + .add("decimal", DecimalType.SYSTEM_DEFAULT) + .add("string", StringType) + .add("binary", BinaryType) + .add("date", DateType) + .add("timestamp", TimestampType) + .add("udt", new ExamplePointUDT)) + + testMurmur3Hash( + new StructType() + .add("arrayOfNull", arrayOfNull) + .add("arrayOfString", arrayOfString) + .add("arrayOfArrayOfString", ArrayType(arrayOfString)) + .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) + .add("arrayOfMap", ArrayType(mapOfString)) + .add("arrayOfStruct", ArrayType(structOfString)) + .add("arrayOfUDT", arrayOfUDT)) + + testMurmur3Hash( + new StructType() + .add("mapOfIntAndString", MapType(IntegerType, StringType)) + .add("mapOfStringAndArray", MapType(StringType, arrayOfString)) + .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType)) + .add("mapOfArray", MapType(arrayOfString, arrayOfString)) + .add("mapOfStringAndStruct", MapType(StringType, structOfString)) + .add("mapOfStructAndString", MapType(structOfString, StringType)) + .add("mapOfStruct", MapType(structOfString, structOfString))) + + testMurmur3Hash( + new StructType() + .add("structOfString", structOfString) + .add("structOfStructOfString", new StructType().add("struct", structOfString)) + .add("structOfArray", new StructType().add("array", arrayOfString)) + .add("structOfMap", new StructType().add("map", mapOfString)) + .add("structOfArrayAndMap", + new StructType().add("array", arrayOfString).add("map", mapOfString)) + .add("structOfUDT", structOfUDT)) + + private def testMurmur3Hash(inputSchema: StructType): Unit = { + val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get + val encoder = RowEncoder(inputSchema) + val seed = scala.util.Random.nextInt() + test(s"murmur3 hash: ${inputSchema.simpleString}") { + for (_ <- 1 to 10) { + val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] + val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { + case (value, dt) => Literal.create(value, dt) + } + checkEvaluation(Murmur3Hash(literals, seed), input.hashCode(seed)) + } + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala index 7ad8657bde12..b190d3a00dfb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/OrderingSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.math._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{Row, RandomDataGenerator} -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala index 0d329497758c..83838294a991 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlusSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.util.Random -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.{SpecificMutableRow, MutableRow, BoundReference} -import org.apache.spark.sql.types.{DataType, IntegerType} - import scala.collection.mutable + import org.scalatest.Assertions._ +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{BoundReference, MutableRow, SpecificMutableRow} +import org.apache.spark.sql.types.{DataType, IntegerType} + class HyperLogLogPlusPlusSuite extends SparkFunSuite { /** Create a HLL++ instance and an input and output buffer. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala index 796d60032e1a..f8342214d9ae 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala @@ -90,13 +90,13 @@ class GenerateUnsafeRowJoinerBitsetSuite extends SparkFunSuite { } private def createUnsafeRow(numFields: Int): UnsafeRow = { - val row = new UnsafeRow + val row = new UnsafeRow(numFields) val sizeInBytes = numFields * 8 + ((numFields + 63) / 64) * 8 // Allocate a larger buffer than needed and point the UnsafeRow to somewhere in the middle. // This way we can test the joiner when the input UnsafeRows are not the entire arrays. val offset = numFields * 8 val buf = new Array[Byte](sizeInBytes + offset) - row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, numFields, sizeInBytes) + row.pointTo(buf, Platform.BYTE_ARRAY_OFFSET + offset, sizeInBytes) row } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala index 2d080b95b129..37148a226f29 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/AggregateOptimizeSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LocalRelation, LogicalPlan} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index cde346e99eb1..000a3b7ecb7c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.SimpleCatalystConf import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { @@ -86,23 +86,27 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { checkCondition( ('a === 'b || 'b > 3) && ('a === 'b || 'a > 3) && ('a === 'b || 'a < 5), - ('a === 'b || 'b > 3 && 'a > 3 && 'a < 5)) + 'a === 'b || 'b > 3 && 'a > 3 && 'a < 5) } test("a && (!a || b)") { - checkCondition(('a && (!('a) || 'b )), ('a && 'b)) + checkCondition('a && (!'a || 'b ), 'a && 'b) - checkCondition(('a && ('b || !('a) )), ('a && 'b)) + checkCondition('a && ('b || !'a ), 'a && 'b) - checkCondition(((!('a) || 'b ) && 'a), ('b && 'a)) + checkCondition((!'a || 'b ) && 'a, 'b && 'a) - checkCondition((('b || !('a) ) && 'a), ('b && 'a)) + checkCondition(('b || !'a ) && 'a, 'b && 'a) } - test("!(a && b) , !(a || b)") { - checkCondition((!('a && 'b)), (!('a) || !('b))) + test("DeMorgan's law") { + checkCondition(!('a && 'b), !'a || !'b) + + checkCondition(!('a || 'b), !'a && !'b) + + checkCondition(!(('a && 'b) || ('c && 'd)), (!'a || !'b) && (!'c || !'d)) - checkCondition(!('a || 'b), (!('a) && !('b))) + checkCondition(!(('a || 'b) && ('c || 'd)), (!'a && !'b) || (!'c && !'d)) } private val caseInsensitiveAnalyzer = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 9bf61ae09178..81f392803561 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Explode import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.types.StringType class ColumnPruningSuite extends PlanTest { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala index 06c592f4905a..9fe2b2d1f48c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombiningLimitsSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.optimizer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ class CombiningLimitsSuite extends PlanTest { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala new file mode 100644 index 000000000000..10ed4e46ddd1 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ComputeCurrentTimeSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, CurrentDate, CurrentTimestamp, Literal} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.util.DateTimeUtils + +class ComputeCurrentTimeSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Seq(Batch("ComputeCurrentTime", Once, ComputeCurrentTime)) + } + + test("analyzer should replace current_timestamp with literals") { + val in = Project(Seq(Alias(CurrentTimestamp(), "a")(), Alias(CurrentTimestamp(), "b")()), + LocalRelation()) + + val min = System.currentTimeMillis() * 1000 + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = (System.currentTimeMillis() + 1) * 1000 + + val lits = new scala.collection.mutable.ArrayBuffer[Long] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Long] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } + + test("analyzer should replace current_date with literals") { + val in = Project(Seq(Alias(CurrentDate(), "a")(), Alias(CurrentDate(), "b")()), LocalRelation()) + + val min = DateTimeUtils.millisToDays(System.currentTimeMillis()) + val plan = Optimize.execute(in.analyze).asInstanceOf[Project] + val max = DateTimeUtils.millisToDays(System.currentTimeMillis()) + + val lits = new scala.collection.mutable.ArrayBuffer[Int] + plan.transformAllExpressions { case e: Literal => + lits += e.value.asInstanceOf[Int] + e + } + assert(lits.size == 2) + assert(lits(0) >= min && lits(0) <= max) + assert(lits(1) >= min && lits(1) <= max) + assert(lits(0) == lits(1)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala index 8aaefa84937c..48f9ac77b74c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ConstantFoldingSuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.analysis.{UnresolvedExtractValue, EliminateSubQueries} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedExtractValue} +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types._ -// For implicit conversions -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ - class ConstantFoldingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index fba4c5ca77d6..f9f3bd55aa57 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.{LeftOuter, LeftSemi, PlanTest, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.{LeftSemi, PlanTest, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types.IntegerType class FilterPushdownSuite extends PlanTest { @@ -75,8 +75,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a) - .select('a).analyze + .groupBy('a)('a).analyze comparePlans(optimized, correctAnswer) } @@ -91,8 +90,7 @@ class FilterPushdownSuite extends PlanTest { val correctAnswer = testRelation .select('a) - .groupBy('a)('a as 'c) - .select('c).analyze + .groupBy('a)('a as 'c).analyze comparePlans(optimized, correctAnswer) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index b3df487c84dc..741bc113cfcd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.catalyst.optimizer +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ -/* Implicit conversions */ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ - class LikeSimplificationSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index 48cab01ac100..3e384e473e5f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -18,17 +18,17 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.immutable.HashSet + import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, UnresolvedAttribute} +// For implicit conversions +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types._ -// For implicit conversions -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ - class OptimizeInSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala new file mode 100644 index 000000000000..7e3da6bea75e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule + +/** + * This is a test for SPARK-7727 if the Optimizer is kept being extendable + */ +class OptimizerExtendableSuite extends SparkFunSuite { + + /** + * Dummy rule for test batches + */ + object DummyRule extends Rule[LogicalPlan] { + def apply(p: LogicalPlan): LogicalPlan = p + } + + /** + * This class represents a dummy extended optimizer that takes the batches of the + * Optimizer and adds custom ones. + */ + class ExtendedOptimizer extends Optimizer { + + // rules set to DummyRule, would not be executed anyways + val myBatches: Seq[Batch] = { + Batch("once", Once, + DummyRule) :: + Batch("fixedPoint", FixedPoint(100), + DummyRule) :: Nil + } + + override def batches: Seq[Batch] = super.batches ++ myBatches + } + + test("Extending batches possible") { + // test simply instantiates the new extended optimizer + val extendedOptimizer = new ExtendedOptimizer() + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala index 1aa89991cc69..85b6530481b0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ProjectCollapsingSuite.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.Rand import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor - class ProjectCollapsingSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala index 1595ad932742..a498b463a69e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SetOperationPushDownSuite.scala @@ -18,11 +18,11 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.analysis.EliminateSubQueries +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.dsl.expressions._ class SetOperationPushDownSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala index 6b1e53cd42b2..41455221cfdc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyCaseConversionExpressionsSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.catalyst.optimizer +/* Implicit conversions */ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules._ -/* Implicit conversions */ -import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.dsl.plans._ - class SimplifyCaseConversionExpressionsSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 2efee1fc5470..f9874088b588 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation} import org.apache.spark.sql.catalyst.util._ /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 62d5f6ac7488..fb4f34d059b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{ExprId, AttributeReference} +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExprId} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 965bdb1515e5..6a188e7e5512 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.types.{IntegerType, StringType, NullType} +import org.apache.spark.sql.types.{IntegerType, NullType, StringType} case class Dummy(optKey: Option[Expression]) extends Expression with CodegenFallback { override def children: Seq[Expression] = optKey.toSeq diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala index 0ce5a2fb6950..6745b4b6c3c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/DateTimeUtilsSuite.scala @@ -22,8 +22,8 @@ import java.text.SimpleDateFormat import java.util.{Calendar, TimeZone} import org.apache.spark.SparkFunSuite -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.sql.catalyst.util.DateTimeUtils._ +import org.apache.spark.unsafe.types.UTF8String class DateTimeUtilsSuite extends SparkFunSuite { @@ -384,9 +384,6 @@ class DateTimeUtilsSuite extends SparkFunSuite { Timestamp.valueOf("1700-02-28 12:14:50.123456")).foreach { t => val us = fromJavaTimestamp(t) assert(toJavaTimestamp(us) === t) - assert(getHours(us) === t.getHours) - assert(getMinutes(us) === t.getMinutes) - assert(getSeconds(us) === t.getSeconds) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala index 4030a1b1df35..a0c1d97bfc3a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import org.json4s.jackson.JsonMethods.parse import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{MetadataBuilder, Metadata} +import org.apache.spark.sql.types.{Metadata, MetadataBuilder} class MetadataSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala index 50683947da22..e1675c95907a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/types/DecimalSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.types -import org.apache.spark.SparkFunSuite +import scala.language.postfixOps + import org.scalatest.PrivateMethodTester -import scala.language.postfixOps +import org.apache.spark.SparkFunSuite class DecimalSuite extends SparkFunSuite with PrivateMethodTester { /** Check that a Decimal has the given string representation, precision and scale */ diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 06841b094562..6db7a8a2dc52 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java index a2f99d566d47..6bf9d7bd0367 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMap.java @@ -61,7 +61,7 @@ public final class UnsafeFixedWidthAggregationMap { /** * Re-used pointer to the current aggregation buffer */ - private final UnsafeRow currentAggregationBuffer = new UnsafeRow(); + private final UnsafeRow currentAggregationBuffer; private final boolean enablePerfMetrics; @@ -98,6 +98,7 @@ public UnsafeFixedWidthAggregationMap( long pageSizeBytes, boolean enablePerfMetrics) { this.aggregationBufferSchema = aggregationBufferSchema; + this.currentAggregationBuffer = new UnsafeRow(aggregationBufferSchema.length()); this.groupingKeyProjection = UnsafeProjection.create(groupingKeySchema); this.groupingKeySchema = groupingKeySchema; this.map = @@ -147,7 +148,6 @@ public UnsafeRow getAggregationBufferFromUnsafeRow(UnsafeRow unsafeGroupingKeyRo currentAggregationBuffer.pointTo( address.getBaseObject(), address.getBaseOffset(), - aggregationBufferSchema.length(), loc.getValueLength() ); return currentAggregationBuffer; @@ -165,8 +165,8 @@ public KVIterator iterator() { private final BytesToBytesMap.MapIterator mapLocationIterator = map.destructiveIterator(); - private final UnsafeRow key = new UnsafeRow(); - private final UnsafeRow value = new UnsafeRow(); + private final UnsafeRow key = new UnsafeRow(groupingKeySchema.length()); + private final UnsafeRow value = new UnsafeRow(aggregationBufferSchema.length()); @Override public boolean next() { @@ -177,13 +177,11 @@ public boolean next() { key.pointTo( keyAddress.getBaseObject(), keyAddress.getBaseOffset(), - groupingKeySchema.length(), loc.getKeyLength() ); value.pointTo( valueAddress.getBaseObject(), valueAddress.getBaseOffset(), - aggregationBufferSchema.length(), loc.getValueLength() ); return true; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index 8c9b9c85e37f..0da26bf376a6 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -94,7 +94,7 @@ public UnsafeKVExternalSorter( // The only new memory we are allocating is the pointer/prefix array. BytesToBytesMap.MapIterator iter = map.iterator(); final int numKeyFields = keySchema.size(); - UnsafeRow row = new UnsafeRow(); + UnsafeRow row = new UnsafeRow(numKeyFields); while (iter.hasNext()) { final BytesToBytesMap.Location loc = iter.next(); final Object baseObject = loc.getKeyAddress().getBaseObject(); @@ -107,7 +107,7 @@ public UnsafeKVExternalSorter( long address = taskMemoryManager.encodePageNumberAndOffset(page, baseOffset - 8); // Compute prefix - row.pointTo(baseObject, baseOffset, numKeyFields, loc.getKeyLength()); + row.pointTo(baseObject, baseOffset, loc.getKeyLength()); final long prefix = prefixComputer.computePrefix(row); inMemSorter.insertRecord(address, prefix); @@ -194,12 +194,14 @@ public void cleanupResources() { private static final class KVComparator extends RecordComparator { private final BaseOrdering ordering; - private final UnsafeRow row1 = new UnsafeRow(); - private final UnsafeRow row2 = new UnsafeRow(); + private final UnsafeRow row1; + private final UnsafeRow row2; private final int numKeyFields; public KVComparator(BaseOrdering ordering, int numKeyFields) { this.numKeyFields = numKeyFields; + this.row1 = new UnsafeRow(numKeyFields); + this.row2 = new UnsafeRow(numKeyFields); this.ordering = ordering; } @@ -207,17 +209,15 @@ public KVComparator(BaseOrdering ordering, int numKeyFields) { public int compare(Object baseObj1, long baseOff1, Object baseObj2, long baseOff2) { // Note that since ordering doesn't need the total length of the record, we just pass -1 // into the row. - row1.pointTo(baseObj1, baseOff1 + 4, numKeyFields, -1); - row2.pointTo(baseObj2, baseOff2 + 4, numKeyFields, -1); + row1.pointTo(baseObj1, baseOff1 + 4, -1); + row2.pointTo(baseObj2, baseOff2 + 4, -1); return ordering.compare(row1, row2); } } public class KVSorterIterator extends KVIterator { - private UnsafeRow key = new UnsafeRow(); - private UnsafeRow value = new UnsafeRow(); - private final int numKeyFields = keySchema.size(); - private final int numValueFields = valueSchema.size(); + private UnsafeRow key = new UnsafeRow(keySchema.size()); + private UnsafeRow value = new UnsafeRow(valueSchema.size()); private final UnsafeSorterIterator underlying; private KVSorterIterator(UnsafeSorterIterator underlying) { @@ -237,8 +237,8 @@ public boolean next() throws IOException { // Note that recordLen = keyLen + valueLen + 4 bytes (for the keyLen itself) int keyLen = Platform.getInt(baseObj, recordOffset); int valueLen = recordLen - keyLen - 4; - key.pointTo(baseObj, recordOffset + 4, numKeyFields, keyLen); - value.pointTo(baseObj, recordOffset + 4 + keyLen, numValueFields, valueLen); + key.pointTo(baseObj, recordOffset + 4, keyLen); + value.pointTo(baseObj, recordOffset + 4 + keyLen, valueLen); return true; } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 842dcb8c93dc..6bcd155ccdc4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -19,7 +19,9 @@ package org.apache.spark.sql.execution.datasources.parquet; import java.io.ByteArrayInputStream; +import java.io.File; import java.io.IOException; +import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -36,6 +38,7 @@ import static org.apache.parquet.hadoop.ParquetInputFormat.getFilter; import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.RecordReader; @@ -56,9 +59,11 @@ import org.apache.parquet.hadoop.metadata.ParquetMetadata; import org.apache.parquet.hadoop.util.ConfigurationUtil; import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Types; +import org.apache.spark.sql.types.StructType; /** - * Base class for custom RecordReaaders for Parquet that directly materialize to `T`. + * Base class for custom RecordReaders for Parquet that directly materialize to `T`. * This class handles computing row groups, filtering on them, setting up the column readers, * etc. * This is heavily based on parquet-mr's RecordReader. @@ -69,7 +74,7 @@ public abstract class SpecificParquetRecordReaderBase extends RecordReader readSupport; + protected StructType sparkSchema; /** * The total number of rows this RecordReader will eventually read. The sum of the @@ -79,6 +84,7 @@ public abstract class SpecificParquetRecordReaderBase extends RecordReader fileMetadata = footer.getFileMetaData().getKeyValueMetaData(); - this.readSupport = getReadSupportInstance( - (Class>) getReadSupportClass(configuration)); + ReadSupport readSupport = getReadSupportInstance(getReadSupportClass(configuration)); ReadSupport.ReadContext readContext = readSupport.init(new InitContext( taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); this.requestedSchema = readContext.getRequestedSchema(); - this.fileSchema = fileSchema; + this.sparkSchema = new CatalystSchemaConverter(configuration).convert(requestedSchema); this.reader = new ParquetFileReader(configuration, file, blocks, requestedSchema.getColumns()); for (BlockMetaData block : blocks) { this.totalRowCount += block.getRowCount(); } } + /** + * Returns the list of files at 'path' recursively. This skips files that are ignored normally + * by MapReduce. + */ + public static List listDirectory(File path) throws IOException { + List result = new ArrayList(); + if (path.isDirectory()) { + for (File f: path.listFiles()) { + result.addAll(listDirectory(f)); + } + } else { + char c = path.getName().charAt(0); + if (c != '.' && c != '_') { + result.add(path.getAbsolutePath()); + } + } + return result; + } + + /** + * Initializes the reader to read the file at `path` with `columns` projected. If columns is + * null, all the columns are projected. + * + * This is exposed for testing to be able to create this reader without the rest of the Hadoop + * split machinery. It is not intended for general use and those not support all the + * configurations. + */ + protected void initialize(String path, List columns) throws IOException { + Configuration config = new Configuration(); + config.set("spark.sql.parquet.binaryAsString", "false"); + config.set("spark.sql.parquet.int96AsTimestamp", "false"); + config.set("spark.sql.parquet.writeLegacyFormat", "false"); + + this.file = new Path(path); + long length = FileSystem.get(config).getFileStatus(this.file).getLen(); + ParquetMetadata footer = readFooter(config, file, range(0, length)); + + List blocks = footer.getBlocks(); + this.fileSchema = footer.getFileMetaData().getSchema(); + + if (columns == null) { + this.requestedSchema = fileSchema; + } else { + Types.MessageTypeBuilder builder = Types.buildMessage(); + for (String s: columns) { + if (!fileSchema.containsField(s)) { + throw new IOException("Can only project existing columns. Unknown field: " + s + + " File schema:\n" + fileSchema); + } + builder.addFields(fileSchema.getType(s)); + } + this.requestedSchema = builder.named("spark_schema"); + } + this.sparkSchema = new CatalystSchemaConverter(config).convert(requestedSchema); + this.reader = new ParquetFileReader(config, file, blocks, requestedSchema.getColumns()); + for (BlockMetaData block : blocks) { + this.totalRowCount += block.getRowCount(); + } + } + @Override public Void getCurrentKey() throws IOException, InterruptedException { return null; @@ -218,8 +283,9 @@ private static Map> toSetMultiMap(Map map) { return Collections.unmodifiableMap(setMultiMap); } - private static Class getReadSupportClass(Configuration configuration) { - return ConfigurationUtil.getClassFromConfig(configuration, + @SuppressWarnings("unchecked") + private Class> getReadSupportClass(Configuration configuration) { + return (Class>) ConfigurationUtil.getClassFromConfig(configuration, ParquetInputFormat.READ_SUPPORT_CLASS, ReadSupport.class); } @@ -230,10 +296,9 @@ private static Class getReadSupportClass(Configuration configuration) { private static ReadSupport getReadSupportInstance( Class> readSupportClass){ try { - return readSupportClass.newInstance(); - } catch (InstantiationException e) { - throw new BadConfigurationException("could not instantiate read support class", e); - } catch (IllegalAccessException e) { + return readSupportClass.getConstructor().newInstance(); + } catch (InstantiationException | IllegalAccessException | + NoSuchMethodException | InvocationTargetException e) { throw new BadConfigurationException("could not instantiate read support class", e); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java index 0cc4566c9cdd..47818c0939f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/UnsafeRowParquetRecordReader.java @@ -21,35 +21,29 @@ import java.nio.ByteBuffer; import java.util.List; -import org.apache.spark.sql.catalyst.expressions.UnsafeRow; -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; -import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.Platform; -import org.apache.spark.unsafe.types.UTF8String; - -import static org.apache.parquet.column.ValuesType.DEFINITION_LEVEL; -import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; -import static org.apache.parquet.column.ValuesType.VALUES; - +import org.apache.hadoop.fs.Path; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.Preconditions; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; import org.apache.parquet.column.Encoding; -import org.apache.parquet.column.page.DataPage; -import org.apache.parquet.column.page.DataPageV1; -import org.apache.parquet.column.page.DataPageV2; -import org.apache.parquet.column.page.DictionaryPage; -import org.apache.parquet.column.page.PageReadStore; -import org.apache.parquet.column.page.PageReader; +import org.apache.parquet.column.page.*; import org.apache.parquet.column.values.ValuesReader; import org.apache.parquet.io.api.Binary; import org.apache.parquet.schema.OriginalType; import org.apache.parquet.schema.PrimitiveType; import org.apache.parquet.schema.Type; +import org.apache.spark.sql.catalyst.expressions.UnsafeRow; +import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; +import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.types.UTF8String; + +import static org.apache.parquet.column.ValuesType.*; + /** * A specialized RecordReader that reads into UnsafeRows directly using the Parquet column APIs. * @@ -128,14 +122,42 @@ public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttem public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) throws IOException, InterruptedException { super.initialize(inputSplit, taskAttemptContext); + initializeInternal(); + } + /** + * Utility API that will read all the data in path. This circumvents the need to create Hadoop + * objects to use this class. `columns` can contain the list of columns to project. + */ + @Override + public void initialize(String path, List columns) throws IOException { + super.initialize(path, columns); + initializeInternal(); + } + + @Override + public boolean nextKeyValue() throws IOException, InterruptedException { + if (batchIdx >= numBatched) { + if (!loadBatch()) return false; + } + ++batchIdx; + return true; + } + + @Override + public UnsafeRow getCurrentValue() throws IOException, InterruptedException { + return rows[batchIdx - 1]; + } + + @Override + public float getProgress() throws IOException, InterruptedException { + return (float) rowsReturned / totalRowCount; + } + + private void initializeInternal() throws IOException { /** * Check that the requested schema is supported. */ - if (requestedSchema.getFieldCount() == 0) { - // TODO: what does this mean? - throw new IOException("Empty request schema not supported."); - } int numVarLenFields = 0; originalTypes = new OriginalType[requestedSchema.getFieldCount()]; for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { @@ -181,34 +203,14 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont rowWriters = new UnsafeRowWriter[rows.length]; for (int i = 0; i < rows.length; ++i) { - rows[i] = new UnsafeRow(); + rows[i] = new UnsafeRow(requestedSchema.getFieldCount()); rowWriters[i] = new UnsafeRowWriter(); BufferHolder holder = new BufferHolder(rowByteSize); rowWriters[i].initialize(rows[i], holder, requestedSchema.getFieldCount()); - rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, requestedSchema.getFieldCount(), - holder.buffer.length); + rows[i].pointTo(holder.buffer, Platform.BYTE_ARRAY_OFFSET, holder.buffer.length); } } - @Override - public boolean nextKeyValue() throws IOException, InterruptedException { - if (batchIdx >= numBatched) { - if (!loadBatch()) return false; - } - ++batchIdx; - return true; - } - - @Override - public UnsafeRow getCurrentValue() throws IOException, InterruptedException { - return rows[batchIdx - 1]; - } - - @Override - public float getProgress() throws IOException, InterruptedException { - return (float) rowsReturned / totalRowCount; - } - /** * Decodes a batch of values into `rows`. This function is the hot path. */ @@ -261,9 +263,19 @@ private boolean loadBatch() throws IOException { case INT96: throw new IOException("Unsupported " + columnReaders[i].descriptor.getType()); } - numBatched = num; - batchIdx = 0; } + + numBatched = num; + batchIdx = 0; + + // Update the total row lengths if the schema contained variable length. We did not maintain + // this as we populated the columns. + if (containsVarLenFields) { + for (int i = 0; i < numBatched; ++i) { + rows[i].setTotalSize(rowWriters[i].holder().totalSize()); + } + } + return true; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 297ef2299cb3..e8c61d6e01dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression - import scala.language.implicitConversions -import org.apache.spark.annotation.Experimental import org.apache.spark.Logging -import org.apache.spark.sql.functions.lit +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.SqlParser._ import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.functions.lit import org.apache.spark.sql.types._ - private[sql] object Column { def apply(colName: String): Column = new Column(colName) @@ -130,8 +129,11 @@ class Column(protected[sql] val expr: Expression) extends Logging { // Leave an unaliased generator with an empty list of names since the analyzer will generate // the correct defaults after the nested expression's type has been resolved. case explode: Explode => MultiAlias(explode, Nil) + case jt: JsonTuple => MultiAlias(jt, Nil) + case func: UnresolvedFunction => UnresolvedAlias(func, Some(func.prettyString)) + case expr: Expression => Alias(expr, expr.prettyString)() } @@ -706,18 +708,6 @@ class Column(protected[sql] val expr: Expression) extends Logging { */ def mod(other: Any): Column = this % other - /** - * A boolean expression that is evaluated to true if the value of this expression is contained - * by the evaluated values of the arguments. - * - * @group expr_ops - * @since 1.3.0 - * @deprecated As of 1.5.0. Use isin. This will be removed in Spark 2.0. - */ - @deprecated("use isin. This will be removed in Spark 2.0.", "1.5.0") - @scala.annotation.varargs - def in(list: Any*): Column = isin(list : _*) - /** * A boolean expression that is evaluated to true if the value of this expression is contained * by the evaluated values of the arguments. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index d74131231499..60d2f05b8605 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -25,20 +25,18 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import com.fasterxml.jackson.core.JsonFactory -import org.apache.commons.lang3.StringUtils import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.python.PythonRDD import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection, SqlParser} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, QueryExecution, Queryable, SQLExecution} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, FileRelation, LogicalRDD, Queryable, QueryExecution, SQLExecution} import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation} import org.apache.spark.sql.execution.datasources.json.JacksonGenerator import org.apache.spark.sql.sources.HadoopFsRelation @@ -455,7 +453,8 @@ class DataFrame private[sql]( // Analyze the self join. The assumption is that the analyzer will disambiguate left vs right // by creating a new instance for one of the branch. val joined = sqlContext.executePlan( - Join(logicalPlan, right.logicalPlan, JoinType(joinType), None)).analyzed.asInstanceOf[Join] + Join(logicalPlan, right.logicalPlan, joinType = JoinType(joinType), None)) + .analyzed.asInstanceOf[Join] val condition = usingColumns.map { col => catalyst.expressions.EqualTo( @@ -473,15 +472,15 @@ class DataFrame private[sql]( usingColumns.map(col => withPlan(joined.right).resolve(col)) case FullOuter => usingColumns.map { col => - val leftCol = withPlan(joined.left).resolve(col) - val rightCol = withPlan(joined.right).resolve(col) + val leftCol = withPlan(joined.left).resolve(col).toAttribute.withNullability(true) + val rightCol = withPlan(joined.right).resolve(col).toAttribute.withNullability(true) Alias(Coalesce(Seq(leftCol, rightCol)), col)() } } // The nullability of output of joined could be different than original column, // so we can only compare them by exprId - val joinRefs = condition.map(_.references.toSeq.map(_.exprId)).getOrElse(Nil) - val resultCols = joinedCols ++ joined.output.filterNot(e => joinRefs.contains(e.exprId)) + val joinRefs = AttributeSet(condition.toSeq.flatMap(_.references)) + val resultCols = joinedCols ++ joined.output.filterNot(joinRefs.contains(_)) withPlan { Project( resultCols, @@ -1063,10 +1062,15 @@ class DataFrame private[sql]( * @since 1.4.0 */ def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] = { + // It is possible that the underlying dataframe doesn't guarantee the ordering of rows in its + // constituent partitions each time a split is materialized which could result in + // overlapping splits. To prevent this, we explicitly sort each input partition to make the + // ordering deterministic. + val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan) val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => - new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, logicalPlan)) + new DataFrame(sqlContext, Sample(x(0), x(1), withReplacement = false, seed, sorted)) }.toArray } @@ -1172,13 +1176,17 @@ class DataFrame private[sql]( */ def withColumn(colName: String, col: Column): DataFrame = { val resolver = sqlContext.analyzer.resolver - val replaced = schema.exists(f => resolver(f.name, colName)) - if (replaced) { - val colNames = schema.map { field => - val name = field.name - if (resolver(name, colName)) col.as(colName) else Column(name) + val output = queryExecution.analyzed.output + val shouldReplace = output.exists(f => resolver(f.name, colName)) + if (shouldReplace) { + val columns = output.map { field => + if (resolver(field.name, colName)) { + col.as(colName) + } else { + Column(field) + } } - select(colNames : _*) + select(columns : _*) } else { select(Column("*"), col.as(colName)) } @@ -1189,13 +1197,17 @@ class DataFrame private[sql]( */ private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { val resolver = sqlContext.analyzer.resolver - val replaced = schema.exists(f => resolver(f.name, colName)) - if (replaced) { - val colNames = schema.map { field => - val name = field.name - if (resolver(name, colName)) col.as(colName, metadata) else Column(name) + val output = queryExecution.analyzed.output + val shouldReplace = output.exists(f => resolver(f.name, colName)) + if (shouldReplace) { + val columns = output.map { field => + if (resolver(field.name, colName)) { + col.as(colName, metadata) + } else { + Column(field) + } } - select(colNames : _*) + select(columns : _*) } else { select(Column("*"), col.as(colName, metadata)) } @@ -1743,344 +1755,6 @@ class DataFrame private[sql]( } } - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // Deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - /** - * @deprecated As of 1.3.0, replaced by `toDF()`. This will be removed in Spark 2.0. - */ - @deprecated("Use toDF. This will be removed in Spark 2.0.", "1.3.0") - def toSchemaRDD: DataFrame = this - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * This will run a `CREATE TABLE` and a bunch of `INSERT INTO` statements. - * If you pass `true` for `allowExisting`, it will drop any table with the - * given name; if you pass `false`, it will throw if the table already - * exists. - * @group output - * @deprecated As of 1.340, replaced by `write().jdbc()`. This will be removed in Spark 2.0. - */ - @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") - def createJDBCTable(url: String, table: String, allowExisting: Boolean): Unit = { - val w = if (allowExisting) write.mode(SaveMode.Overwrite) else write - w.jdbc(url, table, new Properties) - } - - /** - * Save this [[DataFrame]] to a JDBC database at `url` under the table name `table`. - * Assumes the table already exists and has a compatible schema. If you - * pass `true` for `overwrite`, it will `TRUNCATE` the table before - * performing the `INSERT`s. - * - * The table must already exist on the database. It must have a schema - * that is compatible with the schema of this RDD; inserting the rows of - * the RDD in order via the simple statement - * `INSERT INTO table VALUES (?, ?, ..., ?)` should not fail. - * @group output - * @deprecated As of 1.4.0, replaced by `write().jdbc()`. This will be removed in Spark 2.0. - */ - @deprecated("Use write.jdbc(). This will be removed in Spark 2.0.", "1.4.0") - def insertIntoJDBC(url: String, table: String, overwrite: Boolean): Unit = { - val w = if (overwrite) write.mode(SaveMode.Overwrite) else write.mode(SaveMode.Append) - w.jdbc(url, table, new Properties) - } - - /** - * Saves the contents of this [[DataFrame]] as a parquet file, preserving the schema. - * Files that are written out using this method can be read back in as a [[DataFrame]] - * using the `parquetFile` function in [[SQLContext]]. - * @group output - * @deprecated As of 1.4.0, replaced by `write().parquet()`. This will be removed in Spark 2.0. - */ - @deprecated("Use write.parquet(path). This will be removed in Spark 2.0.", "1.4.0") - def saveAsParquetFile(path: String): Unit = { - write.format("parquet").mode(SaveMode.ErrorIfExists).save(path) - } - - /** - * Creates a table from the the contents of this DataFrame. - * It will use the default data source configured by spark.sql.sources.default. - * This will fail if the table already exists. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.saveAsTable(tableName). This will be removed in Spark 2.0.", "1.4.0") - def saveAsTable(tableName: String): Unit = { - write.mode(SaveMode.ErrorIfExists).saveAsTable(tableName) - } - - /** - * Creates a table from the the contents of this DataFrame, using the default data source - * configured by spark.sql.sources.default and [[SaveMode.ErrorIfExists]] as the save mode. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.mode(mode).saveAsTable(tableName). This will be removed in Spark 2.0.", - "1.4.0") - def saveAsTable(tableName: String, mode: SaveMode): Unit = { - write.mode(mode).saveAsTable(tableName) - } - - /** - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source and a set of options, - * using [[SaveMode.ErrorIfExists]] as the save mode. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().format(source).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).saveAsTable(tableName). This will be removed in Spark 2.0.", - "1.4.0") - def saveAsTable(tableName: String, source: String): Unit = { - write.format(source).saveAsTable(tableName) - } - - /** - * :: Experimental :: - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, and a set of options. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by `write().mode(mode).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).saveAsTable(tableName). " + - "This will be removed in Spark 2.0.", "1.4.0") - def saveAsTable(tableName: String, source: String, mode: SaveMode): Unit = { - write.format(source).mode(mode).saveAsTable(tableName) - } - - /** - * Creates a table at the given path from the the contents of this DataFrame - * based on a given data source, [[SaveMode]] specified by mode, and a set of options. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + - "This will be removed in Spark 2.0.", "1.4.0") - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: java.util.Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).saveAsTable(tableName) - } - - /** - * (Scala-specific) - * Creates a table from the the contents of this DataFrame based on a given data source, - * [[SaveMode]] specified by mode, and a set of options. - * - * Note that this currently only works with DataFrames that are created from a HiveContext as - * there is no notion of a persisted catalog in a standard SQL context. Instead you can write - * an RDD out to a parquet file, and then register that file as a table. This "table" can then - * be the target of an `insertInto`. - * - * When the DataFrame is created from a non-partitioned [[HadoopFsRelation]] with a single input - * path, and the data source provider can be mapped to an existing Hive builtin SerDe (i.e. ORC - * and Parquet), the table is persisted in a Hive compatible format, which means other systems - * like Hive will be able to read this table. Otherwise, the table is persisted in a Spark SQL - * specific format. - * - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).options(options).saveAsTable(tableName). " + - "This will be removed in Spark 2.0.", "1.4.0") - def saveAsTable( - tableName: String, - source: String, - mode: SaveMode, - options: Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).saveAsTable(tableName) - } - - /** - * Saves the contents of this DataFrame to the given path, - * using the default data source configured by spark.sql.sources.default and - * [[SaveMode.ErrorIfExists]] as the save mode. - * @group output - * @deprecated As of 1.4.0, replaced by `write().save(path)`. This will be removed in Spark 2.0. - */ - @deprecated("Use write.save(path). This will be removed in Spark 2.0.", "1.4.0") - def save(path: String): Unit = { - write.save(path) - } - - /** - * Saves the contents of this DataFrame to the given path and [[SaveMode]] specified by mode, - * using the default data source configured by spark.sql.sources.default. - * @group output - * @deprecated As of 1.4.0, replaced by `write().mode(mode).save(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.mode(mode).save(path). This will be removed in Spark 2.0.", "1.4.0") - def save(path: String, mode: SaveMode): Unit = { - write.mode(mode).save(path) - } - - /** - * Saves the contents of this DataFrame to the given path based on the given data source, - * using [[SaveMode.ErrorIfExists]] as the save mode. - * @group output - * @deprecated As of 1.4.0, replaced by `write().format(source).save(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).save(path). This will be removed in Spark 2.0.", "1.4.0") - def save(path: String, source: String): Unit = { - write.format(source).save(path) - } - - /** - * Saves the contents of this DataFrame to the given path based on the given data source and - * [[SaveMode]] specified by mode. - * @group output - * @deprecated As of 1.4.0, replaced by `write().format(source).mode(mode).save(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).save(path). " + - "This will be removed in Spark 2.0.", "1.4.0") - def save(path: String, source: String, mode: SaveMode): Unit = { - write.format(source).mode(mode).save(path) - } - - /** - * Saves the contents of this DataFrame based on the given data source, - * [[SaveMode]] specified by mode, and a set of options. - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).save(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).options(options).save(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def save( - source: String, - mode: SaveMode, - options: java.util.Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).save() - } - - /** - * (Scala-specific) - * Saves the contents of this DataFrame based on the given data source, - * [[SaveMode]] specified by mode, and a set of options - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().format(source).mode(mode).options(options).save(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.format(source).mode(mode).options(options).save(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def save( - source: String, - mode: SaveMode, - options: Map[String, String]): Unit = { - write.format(source).mode(mode).options(options).save() - } - - /** - * Adds the rows from this RDD to the specified table, optionally overwriting the existing data. - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.mode(SaveMode.Append|SaveMode.Overwrite).saveAsTable(tableName). " + - "This will be removed in Spark 2.0.", "1.4.0") - def insertInto(tableName: String, overwrite: Boolean): Unit = { - write.mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append).insertInto(tableName) - } - - /** - * Adds the rows from this RDD to the specified table. - * Throws an exception if the table already exists. - * @group output - * @deprecated As of 1.4.0, replaced by - * `write().mode(SaveMode.Append).saveAsTable(tableName)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use write.mode(SaveMode.Append).saveAsTable(tableName). " + - "This will be removed in Spark 2.0.", "1.4.0") - def insertInto(tableName: String): Unit = { - write.mode(SaveMode.Append).insertInto(tableName) - } - - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // End of deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - /** * Wrap a DataFrame action to track all Spark jobs in the body so that we can connect them with * an execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 3ed1e55adec6..d948e4894253 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -30,10 +30,10 @@ import org.apache.spark.api.java.JavaRDD import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.SqlParser +import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.{JDBCPartition, JDBCPartitioningInfo, JDBCRelation} import org.apache.spark.sql.execution.datasources.json.JSONRelation import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.types.StructType /** @@ -98,17 +98,6 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { this } - /** - * Loads input in as a [[DataFrame]], for data sources that require a path (e.g. data backed by - * a local or distributed file system). - * - * @since 1.4.0 - */ - // TODO: Remove this one in Spark 2.0. - def load(path: String): DataFrame = { - option("path", path).load() - } - /** * Loads input in as a [[DataFrame]], for data sources that don't require a path (e.g. external * key-value stores). @@ -125,6 +114,16 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { DataFrame(sqlContext, LogicalRelation(resolved.relation)) } + /** + * Loads input in as a [[DataFrame]], for data sources that require a path (e.g. data backed by + * a local or distributed file system). + * + * @since 1.4.0 + */ + def load(path: String): DataFrame = { + option("path", path).load() + } + /** * Loads input in as a [[DataFrame]], for data sources that support multiple paths. * Only works if the source is a HadoopFsRelationProvider. @@ -154,13 +153,14 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash * your external database systems. * - * @param url JDBC database url of the form `jdbc:subprotocol:subname` + * @param url JDBC database url of the form `jdbc:subprotocol:subname`. * @param table Name of the table in the external database. * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions + * @param lowerBound the minimum value of `columnName` used to decide partition stride. + * @param upperBound the maximum value of `columnName` used to decide partition stride. + * @param numPartitions the number of partitions. This, along with `lowerBound` (inclusive), + * `upperBound` (exclusive), form partition strides for generated WHERE + * clause expressions used to split the column `columnName` evenly. * @param connectionProperties JDBC database connection arguments, a list of arbitrary string * tag/value. Normally at least a "user" and "password" property * should be included. @@ -257,6 +257,8 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { * *
  • `allowNumericLeadingZeros` (default `false`): allows leading zeros in numbers * (e.g. 00012)
  • + *
  • `allowBackslashEscapingAnyCharacter` (default `false`): allows accepting quoting of all + * character using backslash quoting mechanism
  • * * @since 1.6.0 */ @@ -339,7 +341,7 @@ class DataFrameReader private[sql](sqlContext: SQLContext) extends Logging { } /** - * Loads a text file and returns a [[DataFrame]] with a single string column named "text". + * Loads a text file and returns a [[DataFrame]] with a single string column named "value". * Each line in the text file is a new row in the resulting DataFrame. For example: * {{{ * // Scala: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 69c984717526..e66aa5f94718 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.{util => ju, lang => jl} +import java.{lang => jl, util => ju} import scala.collection.JavaConverters._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 03867beb7822..00f9817b5397 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -23,13 +23,12 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} -import org.apache.spark.sql.catalyst.plans.logical.{Project, InsertIntoTable} +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} +import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, ResolvedDataSource} import org.apache.spark.sql.sources.HadoopFsRelation - /** * :: Experimental :: * Interface used to write a [[DataFrame]] to external storage systems (e.g. file systems, @@ -119,7 +118,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { * Partitions the output by the given columns on the file system. If specified, the output is * laid out on the file system similar to Hive's partitioning scheme. * - * This is only applicable for Parquet at the moment. + * This was initially applicable for Parquet but in 1.5+ covers JSON, text, ORC and avro as well. * * @since 1.4.0 */ @@ -129,6 +128,34 @@ final class DataFrameWriter private[sql](df: DataFrame) { this } + /** + * Buckets the output by the given columns. If specified, the output is laid out on the file + * system similar to Hive's bucketing scheme. + * + * This is applicable for Parquet, JSON and ORC. + * + * @since 2.0 + */ + @scala.annotation.varargs + def bucketBy(numBuckets: Int, colName: String, colNames: String*): DataFrameWriter = { + this.numBuckets = Option(numBuckets) + this.bucketColumnNames = Option(colName +: colNames) + this + } + + /** + * Sorts the output in each bucket by the given columns. + * + * This is applicable for Parquet, JSON and ORC. + * + * @since 2.0 + */ + @scala.annotation.varargs + def sortBy(colName: String, colNames: String*): DataFrameWriter = { + this.sortColumnNames = Option(colName +: colNames) + this + } + /** * Saves the content of the [[DataFrame]] at the specified path. * @@ -145,10 +172,12 @@ final class DataFrameWriter private[sql](df: DataFrame) { * @since 1.4.0 */ def save(): Unit = { + assertNotBucketed() ResolvedDataSource( df.sqlContext, source, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + getBucketSpec, mode, extraOptions.toMap, df) @@ -167,6 +196,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } private def insertInto(tableIdent: TableIdentifier): Unit = { + assertNotBucketed() val partitions = normalizedParCols.map(_.map(col => col -> (None: Option[String])).toMap) val overwrite = mode == SaveMode.Overwrite @@ -189,13 +219,47 @@ final class DataFrameWriter private[sql](df: DataFrame) { ifNotExists = false)).toRdd } - private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { parCols => - parCols.map { col => - df.logicalPlan.output - .map(_.name) - .find(df.sqlContext.analyzer.resolver(_, col)) - .getOrElse(throw new AnalysisException(s"Partition column $col not found in existing " + - s"columns (${df.logicalPlan.output.map(_.name).mkString(", ")})")) + private def normalizedParCols: Option[Seq[String]] = partitioningColumns.map { cols => + cols.map(normalize(_, "Partition")) + } + + private def normalizedBucketColNames: Option[Seq[String]] = bucketColumnNames.map { cols => + cols.map(normalize(_, "Bucketing")) + } + + private def normalizedSortColNames: Option[Seq[String]] = sortColumnNames.map { cols => + cols.map(normalize(_, "Sorting")) + } + + private def getBucketSpec: Option[BucketSpec] = { + if (sortColumnNames.isDefined) { + require(numBuckets.isDefined, "sortBy must be used together with bucketBy") + } + + for { + n <- numBuckets + } yield { + require(n > 0 && n < 100000, "Bucket number must be greater than 0 and less than 100000.") + BucketSpec(n, normalizedBucketColNames.get, normalizedSortColNames.getOrElse(Nil)) + } + } + + /** + * The given column name may not be equal to any of the existing column names if we were in + * case-insensitive context. Normalize the given column name to the real one so that we don't + * need to care about case sensitivity afterwards. + */ + private def normalize(columnName: String, columnType: String): String = { + val validColumnNames = df.logicalPlan.output.map(_.name) + validColumnNames.find(df.sqlContext.analyzer.resolver(_, columnName)) + .getOrElse(throw new AnalysisException(s"$columnType column $columnName not found in " + + s"existing columns (${validColumnNames.mkString(", ")})")) + } + + private def assertNotBucketed(): Unit = { + if (numBuckets.isDefined || sortColumnNames.isDefined) { + throw new IllegalArgumentException( + "Currently we don't support writing bucketed data to this data source.") } } @@ -245,6 +309,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { source, temporary = false, partitioningColumns.map(_.toArray).getOrElse(Array.empty[String]), + getBucketSpec, mode, extraOptions.toMap, df.logicalPlan) @@ -275,7 +340,7 @@ final class DataFrameWriter private[sql](df: DataFrame) { } // connectionProperties should override settings in extraOptions props.putAll(connectionProperties) - val conn = JdbcUtils.createConnection(url, props) + val conn = JdbcUtils.createConnectionFactory(url, props)() try { var tableExists = JdbcUtils.tableExists(conn, url, table) @@ -297,7 +362,12 @@ final class DataFrameWriter private[sql](df: DataFrame) { if (!tableExists) { val schema = JdbcUtils.schemaString(df, url) val sql = s"CREATE TABLE $table ($schema)" - conn.createStatement.executeUpdate(sql) + val statement = conn.createStatement + try { + statement.executeUpdate(sql) + } finally { + statement.close() + } } } finally { conn.close() @@ -368,4 +438,9 @@ final class DataFrameWriter private[sql](df: DataFrame) { private var partitioningColumns: Option[Seq[String]] = None + private var bucketColumnNames: Option[Seq[String]] = None + + private var numBuckets: Option[Int] = None + + private var sortColumnNames: Option[Seq[String]] = None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 79b4244ac0cd..42f01e9359c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql import scala.collection.JavaConverters._ +import org.apache.spark.Logging import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} @@ -64,7 +65,7 @@ import org.apache.spark.util.Utils class Dataset[T] private[sql]( @transient override val sqlContext: SQLContext, @transient override val queryExecution: QueryExecution, - tEncoder: Encoder[T]) extends Queryable with Serializable { + tEncoder: Encoder[T]) extends Queryable with Serializable with Logging { /** * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is @@ -450,7 +451,7 @@ class Dataset[T] private[sql]( */ @scala.annotation.varargs def groupBy(cols: Column*): GroupedDataset[Row, T] = { - val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias) + val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias(_)) val withKey = Project(withKeyColumns, logicalPlan) val executed = sqlContext.executePlan(withKey) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 13341a88a6b7..c74ef2c03541 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -21,13 +21,12 @@ import scala.collection.JavaConverters._ import scala.language.implicitConversions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star} +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Pivot} import org.apache.spark.sql.types.NumericType - /** * :: Experimental :: * A set of methods for aggregations on a [[DataFrame]], created by [[DataFrame.groupBy]]. @@ -58,10 +57,10 @@ class GroupedData protected[sql]( df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) + df.sqlContext, Aggregate(Seq(Rollup(groupingExprs)), aliasedAgg, df.logicalPlan)) case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) + df.sqlContext, Aggregate(Seq(Cube(groupingExprs)), aliasedAgg, df.logicalPlan)) case GroupedData.PivotType(pivotCol, values) => val aliasedGrps = groupingExprs.map(alias) DataFrame( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 4bf0b256fcb4..a819ddceb1b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -21,8 +21,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, OuterScopes} -import org.apache.spark.sql.catalyst.expressions.{Alias, CreateStruct, Attribute} +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.Aggregator diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala index 3d819262859f..7976795ff591 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala @@ -25,6 +25,8 @@ import scala.collection.JavaConverters._ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.parser.ParserConf +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -114,6 +116,25 @@ private[spark] object SQLConf { } }, _.toString, doc, isPublic) + def longMemConf( + key: String, + defaultValue: Option[Long] = None, + doc: String = "", + isPublic: Boolean = true): SQLConfEntry[Long] = + SQLConfEntry(key, defaultValue, { v => + try { + v.toLong + } catch { + case _: NumberFormatException => + try { + Utils.byteStringAsBytes(v) + } catch { + case _: NumberFormatException => + throw new IllegalArgumentException(s"$key should be long, but was $v") + } + } + }, _.toString, doc, isPublic) + def doubleConf( key: String, defaultValue: Option[Double] = None, @@ -234,7 +255,7 @@ private[spark] object SQLConf { doc = "The default number of partitions to use when shuffling data for joins or aggregations.") val SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE = - longConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize", + longMemConf("spark.sql.adaptive.shuffle.targetPostShuffleInputSize", defaultValue = Some(64 * 1024 * 1024), doc = "The target post-shuffle input size in bytes of a task.") @@ -334,7 +355,8 @@ private[spark] object SQLConf { val HIVE_VERIFY_PARTITION_PATH = booleanConf("spark.sql.hive.verifyPartitionPath", defaultValue = Some(false), - doc = "") + doc = "When true, check all the partition paths under the table\'s root directory " + + "when reading data stored in HDFS.") val HIVE_METASTORE_PARTITION_PRUNING = booleanConf("spark.sql.hive.metastorePartitionPruning", defaultValue = Some(false), @@ -352,7 +374,7 @@ private[spark] object SQLConf { val COLUMN_NAME_OF_CORRUPT_RECORD = stringConf("spark.sql.columnNameOfCorruptRecord", defaultValue = Some("_corrupt_record"), - doc = "") + doc = "The name of internal column for storing raw/un-parsed JSON records that fail to parse.") val BROADCAST_TIMEOUT = intConf("spark.sql.broadcastTimeout", defaultValue = Some(5 * 60), @@ -413,7 +435,8 @@ private[spark] object SQLConf { val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = intConf( key = "spark.sql.sources.parallelPartitionDiscovery.threshold", defaultValue = Some(32), - doc = "") + doc = "The degree of parallelism for schema merging and partition discovery of " + + "Parquet data sources.") // Whether to perform eager analysis when constructing a dataframe. // Set to false when debugging requires the ability to look at invalid query plans. @@ -449,6 +472,19 @@ private[spark] object SQLConf { doc = "When true, we could use `datasource`.`path` as table in SQL query" ) + val PARSER_SUPPORT_QUOTEDID = booleanConf("spark.sql.parser.supportQuotedIdentifiers", + defaultValue = Some(true), + isPublic = false, + doc = "Whether to use quoted identifier.\n false: default(past) behavior. Implies only" + + "alphaNumeric and underscore are valid characters in identifiers.\n" + + " true: implies column names can contain any character.") + + val PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS = booleanConf( + "spark.sql.parser.supportSQL11ReservedKeywords", + defaultValue = Some(false), + isPublic = false, + doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.") + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -469,7 +505,7 @@ private[spark] object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -private[sql] class SQLConf extends Serializable with CatalystConf { +private[sql] class SQLConf extends Serializable with CatalystConf with ParserConf { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -567,6 +603,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf { private[spark] def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + def supportQuotedId: Boolean = getConf(PARSER_SUPPORT_QUOTEDID) + + def supportSQL11ReservedKeywords: Boolean = getConf(PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index db286ea8700b..e827427c19e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -26,28 +26,28 @@ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal +import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} +import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.sql.SQLConf.SQLConfEntry +import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.sql.{execution => sparkexecution} import org.apache.spark.util.Utils -import org.apache.spark.{SparkContext, SparkException} /** * The entry point for working with structured data (rows and columns) in Spark. Allows the @@ -785,9 +785,20 @@ class SQLContext private[sql]( */ @Experimental def range(start: Long, end: Long): DataFrame = { - createDataFrame( - sparkContext.range(start, end).map(Row(_)), - StructType(StructField("id", LongType, nullable = false) :: Nil)) + range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) + } + + /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with an step value. + * + * @since 2.0.0 + * @group dataframe + */ + @Experimental + def range(start: Long, end: Long, step: Long): DataFrame = { + range(start, end, step, numPartitions = sparkContext.defaultParallelism) } /** @@ -801,9 +812,7 @@ class SQLContext private[sql]( */ @Experimental def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { - createDataFrame( - sparkContext.range(start, end, step, numPartitions).map(Row(_)), - StructType(StructField("id", LongType, nullable = false) :: Nil)) + DataFrame(this, Range(start, end, step, numPartitions)) } /** @@ -879,9 +888,6 @@ class SQLContext private[sql]( }.toArray } - @deprecated("use org.apache.spark.sql.SparkPlanner", "1.6.0") - protected[sql] class SparkPlanner extends sparkexecution.SparkPlanner(this) - @transient protected[sql] val planner: sparkexecution.SparkPlanner = new sparkexecution.SparkPlanner(this) @@ -895,15 +901,10 @@ class SQLContext private[sql]( @transient protected[sql] val prepareForExecution = new RuleExecutor[SparkPlan] { val batches = Seq( - Batch("Add exchange", Once, EnsureRequirements(self)), - Batch("Add row converters", Once, EnsureRowFormats) + Batch("Add exchange", Once, EnsureRequirements(self)) ) } - @deprecated("use org.apache.spark.sql.QueryExecution", "1.6.0") - protected[sql] class QueryExecution(logical: LogicalPlan) - extends sparkexecution.QueryExecution(this, logical) - /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. @@ -944,301 +945,6 @@ class SQLContext private[sql]( } } - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // Deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. - */ - @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") - def applySchema(rowRDD: RDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. - */ - @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") - def applySchema(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. - */ - @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") - def applySchema(rdd: RDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) - } - - /** - * @deprecated As of 1.3.0, replaced by `createDataFrame()`. This will be removed in Spark 2.0. - */ - @deprecated("Use createDataFrame. This will be removed in Spark 2.0.", "1.3.0") - def applySchema(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd, beanClass) - } - - /** - * Loads a Parquet file, returning the result as a [[DataFrame]]. This function returns an empty - * [[DataFrame]] if no paths are passed in. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().parquet()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.parquet(). This will be removed in Spark 2.0.", "1.4.0") - @scala.annotation.varargs - def parquetFile(paths: String*): DataFrame = { - if (paths.isEmpty) { - emptyDataFrame - } else { - read.parquet(paths : _*) - } - } - - /** - * Loads a JSON file (one object per line), returning the result as a [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonFile(path: String): DataFrame = { - read.json(path) - } - - /** - * Loads a JSON file (one object per line) and applies the given schema, - * returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonFile(path: String, schema: StructType): DataFrame = { - read.schema(schema).json(path) - } - - /** - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonFile(path: String, samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(path) - } - - /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: RDD[String]): DataFrame = read.json(json) - - /** - * Loads an RDD[String] storing JSON objects (one object per record), returning the result as a - * [[DataFrame]]. - * It goes through the entire dataset once to determine the schema. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: JavaRDD[String]): DataFrame = read.json(json) - - /** - * Loads an RDD[String] storing JSON objects (one object per record) and applies the given schema, - * returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: RDD[String], schema: StructType): DataFrame = { - read.schema(schema).json(json) - } - - /** - * Loads an JavaRDD storing JSON objects (one object per record) and applies the given - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: JavaRDD[String], schema: StructType): DataFrame = { - read.schema(schema).json(json) - } - - /** - * Loads an RDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(json) - } - - /** - * Loads a JavaRDD[String] storing JSON objects (one object per record) inferring the - * schema, returning the result as a [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().json()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.json(). This will be removed in Spark 2.0.", "1.4.0") - def jsonRDD(json: JavaRDD[String], samplingRatio: Double): DataFrame = { - read.option("samplingRatio", samplingRatio.toString).json(json) - } - - /** - * Returns the dataset stored at path as a DataFrame, - * using the default data source configured by spark.sql.sources.default. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().load(path)`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.load(path). This will be removed in Spark 2.0.", "1.4.0") - def load(path: String): DataFrame = { - read.load(path) - } - - /** - * Returns the dataset stored at path as a DataFrame, using the given data source. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).load(path)`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use read.format(source).load(path). This will be removed in Spark 2.0.", "1.4.0") - def load(path: String, source: String): DataFrame = { - read.format(source).load(path) - } - - /** - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. - * This will be removed in Spark 2.0. - */ - @deprecated("Use read.format(source).options(options).load(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def load(source: String, options: java.util.Map[String, String]): DataFrame = { - read.options(options).format(source).load() - } - - /** - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by `read().format(source).options(options).load()`. - */ - @deprecated("Use read.format(source).options(options).load(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def load(source: String, options: Map[String, String]): DataFrame = { - read.options(options).format(source).load() - } - - /** - * (Java-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by - * `read().format(source).schema(schema).options(options).load()`. - */ - @deprecated("Use read.format(source).schema(schema).options(options).load(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def load(source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = - { - read.format(source).schema(schema).options(options).load() - } - - /** - * (Scala-specific) Returns the dataset specified by the given data source and - * a set of options as a DataFrame, using the given schema as the schema of the DataFrame. - * - * @group genericdata - * @deprecated As of 1.4.0, replaced by - * `read().format(source).schema(schema).options(options).load()`. - */ - @deprecated("Use read.format(source).schema(schema).options(options).load(). " + - "This will be removed in Spark 2.0.", "1.4.0") - def load(source: String, schema: StructType, options: Map[String, String]): DataFrame = { - read.format(source).schema(schema).options(options).load() - } - - /** - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") - def jdbc(url: String, table: String): DataFrame = { - read.jdbc(url, table, new Properties) - } - - /** - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. Partitions of the table will be retrieved in parallel based on the parameters - * passed to this function. - * - * @param columnName the name of a column of integral type that will be used for partitioning. - * @param lowerBound the minimum value of `columnName` used to decide partition stride - * @param upperBound the maximum value of `columnName` used to decide partition stride - * @param numPartitions the number of partitions. the range `minValue`-`maxValue` will be split - * evenly into this many partitions - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") - def jdbc( - url: String, - table: String, - columnName: String, - lowerBound: Long, - upperBound: Long, - numPartitions: Int): DataFrame = { - read.jdbc(url, table, columnName, lowerBound, upperBound, numPartitions, new Properties) - } - - /** - * Construct a [[DataFrame]] representing the database table accessible via JDBC URL - * url named table. The theParts parameter gives a list expressions - * suitable for inclusion in WHERE clauses; each one defines one partition - * of the [[DataFrame]]. - * - * @group specificdata - * @deprecated As of 1.4.0, replaced by `read().jdbc()`. This will be removed in Spark 2.0. - */ - @deprecated("Use read.jdbc(). This will be removed in Spark 2.0.", "1.4.0") - def jdbc(url: String, table: String, theParts: Array[String]): DataFrame = { - read.jdbc(url, table, theParts, new Properties) - } - - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - // End of deprecated methods - //////////////////////////////////////////////////////////////////////////// - //////////////////////////////////////////////////////////////////////////// - - // Register a succesfully instantiatd context to the singleton. This should be at the end of // the class definition so that the singleton is updated only if there is no exception in the // construction of the instance. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 6735d02954b8..ab414799f1a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -21,11 +21,10 @@ import scala.language.implicitConversions import scala.reflect.runtime.universe.TypeTag import org.apache.spark.rdd.RDD -import org.apache.spark.sql.types._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.types.StructField +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index 051694c0d43a..f87a88d49744 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql -import java.util.{List => JList, Map => JMap} - import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -69,7 +67,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { udaf } - // scalastyle:off + // scalastyle:off line.size.limit /* register 0-22 were generated by this script @@ -85,8 +83,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try($inputTypes).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try($inputTypes).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) }""") @@ -102,7 +100,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { | * Register a user-defined function with ${i} arguments. | * @since 1.3.0 | */ - |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType) = { + |def register(name: String, f: UDF$i[$extTypeArgs, _], returnType: DataType): Unit = { | functionRegistry.registerFunction( | name, | (e: Seq[Expression]) => ScalaUDF(f$anyCast.call($anyParams), returnType, e)) @@ -117,8 +115,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag](name: String, func: Function0[RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -130,8 +128,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag](name: String, func: Function1[A1, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -143,8 +141,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag](name: String, func: Function2[A1, A2, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -156,8 +154,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](name: String, func: Function3[A1, A2, A3, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -169,8 +167,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](name: String, func: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -182,8 +180,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](name: String, func: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -195,8 +193,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](name: String, func: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -208,8 +206,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](name: String, func: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -221,8 +219,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](name: String, func: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -234,8 +232,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](name: String, func: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -247,8 +245,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](name: String, func: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -260,8 +258,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag](name: String, func: Function11[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -273,8 +271,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag](name: String, func: Function12[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -286,8 +284,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag](name: String, func: Function13[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -299,8 +297,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag](name: String, func: Function14[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -312,8 +310,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag](name: String, func: Function15[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -325,8 +323,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag](name: String, func: Function16[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -338,8 +336,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag](name: String, func: Function17[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -351,8 +349,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag](name: String, func: Function18[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -364,8 +362,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag](name: String, func: Function19[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -377,8 +375,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag](name: String, func: Function20[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -390,8 +388,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag](name: String, func: Function21[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -403,8 +401,8 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { */ def register[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag, A11: TypeTag, A12: TypeTag, A13: TypeTag, A14: TypeTag, A15: TypeTag, A16: TypeTag, A17: TypeTag, A18: TypeTag, A19: TypeTag, A20: TypeTag, A21: TypeTag, A22: TypeTag](name: String, func: Function22[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12, A13, A14, A15, A16, A17, A18, A19, A20, A21, A22, RT]): UserDefinedFunction = { val dataType = ScalaReflection.schemaFor[RT].dataType - val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).getOrElse(Nil) - def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes) + val inputTypes = Try(ScalaReflection.schemaFor[A1].dataType :: ScalaReflection.schemaFor[A2].dataType :: ScalaReflection.schemaFor[A3].dataType :: ScalaReflection.schemaFor[A4].dataType :: ScalaReflection.schemaFor[A5].dataType :: ScalaReflection.schemaFor[A6].dataType :: ScalaReflection.schemaFor[A7].dataType :: ScalaReflection.schemaFor[A8].dataType :: ScalaReflection.schemaFor[A9].dataType :: ScalaReflection.schemaFor[A10].dataType :: ScalaReflection.schemaFor[A11].dataType :: ScalaReflection.schemaFor[A12].dataType :: ScalaReflection.schemaFor[A13].dataType :: ScalaReflection.schemaFor[A14].dataType :: ScalaReflection.schemaFor[A15].dataType :: ScalaReflection.schemaFor[A16].dataType :: ScalaReflection.schemaFor[A17].dataType :: ScalaReflection.schemaFor[A18].dataType :: ScalaReflection.schemaFor[A19].dataType :: ScalaReflection.schemaFor[A20].dataType :: ScalaReflection.schemaFor[A21].dataType :: ScalaReflection.schemaFor[A22].dataType :: Nil).toOption + def builder(e: Seq[Expression]) = ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil)) functionRegistry.registerFunction(name, builder) UserDefinedFunction(func, dataType, inputTypes) } @@ -416,7 +414,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 1 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF1[_, _], returnType: DataType) = { + def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF1[Any, Any]].call(_: Any), returnType, e)) @@ -426,7 +424,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 2 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF2[_, _, _], returnType: DataType) = { + def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any), returnType, e)) @@ -436,7 +434,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 3 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF3[_, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any), returnType, e)) @@ -446,7 +444,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 4 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -456,7 +454,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 5 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -466,7 +464,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 6 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -476,7 +474,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 7 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -486,7 +484,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 8 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -496,7 +494,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 9 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -506,7 +504,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 10 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -516,7 +514,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 11 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF11[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -526,7 +524,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 12 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF12[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -536,7 +534,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 13 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -546,7 +544,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 14 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -556,7 +554,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 15 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -566,7 +564,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 16 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -576,7 +574,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 17 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -586,7 +584,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 18 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -596,7 +594,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 19 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -606,7 +604,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 20 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -616,7 +614,7 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 21 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) @@ -626,11 +624,12 @@ class UDFRegistration private[sql] (sqlContext: SQLContext) extends Logging { * Register a user-defined function with 22 arguments. * @since 1.3.0 */ - def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType) = { + def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { functionRegistry.registerFunction( name, (e: Seq[Expression]) => ScalaUDF(f.asInstanceOf[UDF22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any), returnType, e)) } - // scalastyle:on + // scalastyle:on line.size.limit + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala index 0f8cd280b5ac..2fb3bf07aa60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UserDefinedFunction.scala @@ -44,10 +44,10 @@ import org.apache.spark.sql.types.DataType case class UserDefinedFunction protected[sql] ( f: AnyRef, dataType: DataType, - inputTypes: Seq[DataType] = Nil) { + inputTypes: Option[Seq[DataType]]) { def apply(exprs: Column*): Column = { - Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes)) + Column(ScalaUDF(f, dataType, exprs.map(_.expr), inputTypes.getOrElse(Nil))) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index b3f134614c6b..d912aeb70d51 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -19,20 +19,20 @@ package org.apache.spark.sql.api.r import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import scala.util.matching.Regex + import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} import org.apache.spark.api.r.SerDe import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression, GenericRowWithSchema} +import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SaveMode, SQLContext} +import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, GenericRowWithSchema, NamedExpression} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, GroupedData, Row, SQLContext, SaveMode} - -import scala.util.matching.Regex private[r] object SQLUtils { SerDe.registerSqlSerDe((readSqlObject, writeSqlObject)) def createSQLContext(jsc: JavaSparkContext): SQLContext = { - new SQLContext(jsc) + SQLContext.getOrCreate(jsc.sc) } def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala index 663bc904f39c..33475bea9af4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, SortOrder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 62cbc518e02a..6b100577077c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors.attachTree import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.util.MutablePair @@ -50,26 +49,14 @@ case class Exchange( case None => "" } - val simpleNodeName = if (tungstenMode) "TungstenExchange" else "Exchange" + val simpleNodeName = "Exchange" s"$simpleNodeName$extraInfo" } - /** - * Returns true iff we can support the data type, and we are not doing range partitioning. - */ - private lazy val tungstenMode: Boolean = !newPartitioning.isInstanceOf[RangePartitioning] - override def outputPartitioning: Partitioning = newPartitioning override def output: Seq[Attribute] = child.output - // This setting is somewhat counterintuitive: - // If the schema works with UnsafeRow, then we tell the planner that we don't support safe row, - // so the planner inserts a converter to convert data into UnsafeRow if needed. - override def outputsUnsafeRows: Boolean = tungstenMode - override def canProcessSafeRows: Boolean = !tungstenMode - override def canProcessUnsafeRows: Boolean = tungstenMode - /** * Determines whether records must be defensively copied before being sent to the shuffle. * Several of Spark's shuffle components will buffer deserialized Java objects in memory. The @@ -130,15 +117,7 @@ case class Exchange( } } - @transient private lazy val sparkConf = child.sqlContext.sparkContext.getConf - - private val serializer: Serializer = { - if (tungstenMode) { - new UnsafeRowSerializer(child.output.size) - } else { - new SparkSqlSerializer(sparkConf) - } - } + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) override protected def doPrepare(): Unit = { // If an ExchangeCoordinator is needed, we register this Exchange operator @@ -488,6 +467,12 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { + case operator @ Exchange(partitioning, child, _) => + child.children match { + case Exchange(childPartitioning, baseChild, _)::Nil => + if (childPartitioning.guarantees(partitioning)) child else operator + case _ => operator + } case operator: SparkPlan => ensureDistributionAndOrdering(operator) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala index 827fdd278460..07015e5a5aae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution -import java.util.{Map => JMap, HashMap => JHashMap} +import java.util.{HashMap => JHashMap, Map => JMap} import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{Logging, SimpleFutureAction, ShuffleDependency, MapOutputStatistics} +import org.apache.spark.{Logging, MapOutputStatistics, ShuffleDependency, SimpleFutureAction} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index b8a43025882e..569a21feaa8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -18,14 +18,13 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation -import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow, UnsafeProjection} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation} +import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} import org.apache.spark.sql.types.DataType -import org.apache.spark.sql.{Row, SQLContext} - object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -74,9 +73,7 @@ private[sql] case class LogicalRDD( override def children: Seq[LogicalPlan] = Nil - override protected final def otherCopyArgs: Seq[AnyRef] = { - sqlContext :: Nil - } + override protected final def otherCopyArgs: Seq[AnyRef] = sqlContext :: Nil override def newInstance(): LogicalRDD.this.type = LogicalRDD(output.map(_.newInstance()), rdd)(sqlContext).asInstanceOf[this.type] @@ -86,6 +83,8 @@ private[sql] case class LogicalRDD( case _ => false } + override def producedAttributes: AttributeSet = outputSet + @transient override lazy val statistics: Statistics = Statistics( // TODO: Instead of returning a default value here, find a way to return a meaningful size // estimate for RDDs. See PR 1238 for more discussions. @@ -99,10 +98,19 @@ private[sql] case class PhysicalRDD( rdd: RDD[InternalRow], override val nodeName: String, override val metadata: Map[String, String] = Map.empty, - override val outputsUnsafeRows: Boolean = false) + isUnsafeRow: Boolean = false) extends LeafNode { - protected override def doExecute(): RDD[InternalRow] = rdd + protected override def doExecute(): RDD[InternalRow] = { + if (isUnsafeRow) { + rdd + } else { + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map(proj) + } + } + } override def simpleString: String = { val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index 91530bd63798..c3683cc4e7aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -41,20 +41,11 @@ case class Expand( // as UNKNOWN partitioning override def outputPartitioning: Partitioning = UnknownPartitioning(0) - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) - private[this] val projection = { - if (outputsUnsafeRows) { - (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) - } else { - (exprs: Seq[Expression]) => newMutableProjection(exprs, child.output)() - } - } + private[this] val projection = + (exprs: Seq[Expression]) => UnsafeProjection.create(exprs, child.output) protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { child.execute().mapPartitions { iter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 54b8cb58285c..4db88a09d815 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -54,6 +54,8 @@ case class Generate( child: SparkPlan) extends UnaryNode { + override def expressions: Seq[Expression] = generator :: Nil + val boundGenerator = BindReferences.bindReference(generator, child.output) protected override def doExecute(): RDD[InternalRow] = { @@ -62,6 +64,7 @@ case class Generate( child.execute().mapPartitionsInternal { iter => val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null)) val joinedRow = new JoinedRow + val proj = UnsafeProjection.create(output, output) iter.flatMap { row => // we should always set the left (child output) @@ -75,13 +78,14 @@ case class Generate( } ++ LazyIterator(() => boundGenerator.terminate()).map { row => // we leave the left side as the last element of its child output // keep it the same as Hive does - joinedRow.withRight(row) + proj(joinedRow.withRight(row)) } } } else { child.execute().mapPartitionsInternal { iter => - iter.flatMap(row => boundGenerator.eval(row)) ++ - LazyIterator(() => boundGenerator.terminate()) + val proj = UnsafeProjection.create(output, output) + (iter.flatMap(row => boundGenerator.eval(row)) ++ + LazyIterator(() => boundGenerator.terminate())).map(proj) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala index 6a8850129f1a..ef84992e6979 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GroupedIterator.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateOrdering} -import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder, Ascending, Expression} +import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Expression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateOrdering, GenerateUnsafeProjection} object GroupedIterator { def apply( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala index ba7f6287ac6c..59057bf9666e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScan.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection} /** @@ -29,15 +29,20 @@ private[sql] case class LocalTableScan( output: Seq[Attribute], rows: Seq[InternalRow]) extends LeafNode { - private lazy val rdd = sqlContext.sparkContext.parallelize(rows) + private val unsafeRows: Array[InternalRow] = { + val proj = UnsafeProjection.create(output, output) + rows.map(r => proj(r).copy()).toArray + } + + private lazy val rdd = sqlContext.sparkContext.parallelize(unsafeRows) protected override def doExecute(): RDD[InternalRow] = rdd override def executeCollect(): Array[InternalRow] = { - rows.toArray + unsafeRows } override def executeTake(limit: Int): Array[InternalRow] = { - rows.take(limit).toArray + unsafeRows.take(limit) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala index b397d42612cf..38263af0f7e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Queryable.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils + import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types.StructType @@ -31,7 +32,20 @@ private[sql] trait Queryable { override def toString: String = { try { - schema.map(f => s"${f.name}: ${f.dataType.simpleString}").mkString("[", ", ", "]") + val builder = new StringBuilder + val fields = schema.take(2).map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append("[") + builder.append(fields.mkString(", ")) + if (schema.length > 2) { + if (schema.length - fields.size == 1) { + builder.append(" ... 1 more field") + } else { + builder.append(" ... " + (schema.length - 2) + " more fields") + } + } + builder.append("]").toString() } catch { case NonFatal(e) => s"Invalid tree; ${e.getMessage}:\n$queryExecution" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala index 34971986261c..0a11b16d0ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SQLExecution.scala @@ -21,8 +21,8 @@ import java.util.concurrent.atomic.AtomicLong import org.apache.spark.SparkContext import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionStart, - SparkListenerSQLExecutionEnd} +import org.apache.spark.sql.execution.ui.{SparkListenerSQLExecutionEnd, + SparkListenerSQLExecutionStart} import org.apache.spark.util.Utils private[sql] object SQLExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 24207cb46fd2..73dc8cb98447 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -39,10 +39,6 @@ case class Sort( testSpillFrequency: Int = 0) extends UnaryNode { - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - override def output: Seq[Attribute] = child.output override def outputOrdering: Seq[SortOrder] = sortOrder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala index e17b50edc62d..909f124d2c9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SortPrefixUtils.scala @@ -21,8 +21,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ -import org.apache.spark.util.collection.unsafe.sort.{PrefixComparators, PrefixComparator} - +import org.apache.spark.util.collection.unsafe.sort.{PrefixComparator, PrefixComparators} object SortPrefixUtils { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index ec98f8104134..2355de3d0586 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -97,17 +97,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ /** Specifies sort order for each partition requirements on the input data for this operator. */ def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil) - /** Specifies whether this operator outputs UnsafeRows */ - def outputsUnsafeRows: Boolean = false - - /** Specifies whether this operator is capable of processing UnsafeRows */ - def canProcessUnsafeRows: Boolean = false - - /** - * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows - * that are not UnsafeRows). - */ - def canProcessSafeRows: Boolean = true /** * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute @@ -115,18 +104,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * Concrete implementations of SparkPlan should override doExecute instead. */ final def execute(): RDD[InternalRow] = { - if (children.nonEmpty) { - val hasUnsafeInputs = children.exists(_.outputsUnsafeRows) - val hasSafeInputs = children.exists(!_.outputsUnsafeRows) - assert(!(hasSafeInputs && hasUnsafeInputs), - "Child operators should output rows in the same format") - assert(canProcessSafeRows || canProcessUnsafeRows, - "Operator must be able to process at least one row format") - assert(!hasSafeInputs || canProcessSafeRows, - "Operator will receive safe rows as input but cannot process safe rows") - assert(!hasUnsafeInputs || canProcessUnsafeRows, - "Operator will receive unsafe rows as input but cannot process unsafe rows") - } RDDOperationScope.withScope(sparkContext, nodeName, false, true) { prepare() doExecute() @@ -192,7 +169,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ while (buf.size < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1 + var numPartsToTry = 1L if (partsScanned > 0) { // If we didn't find any rows after the first iteration, just try all partitions next. // Otherwise, interpolate the number of partitions we need to try, but overestimate it @@ -206,13 +183,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ numPartsToTry = math.max(0, numPartsToTry) // guard against negative num of partitions val left = n - buf.size - val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts) + val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt) val sc = sqlContext.sparkContext - val res = - sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) + val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p) res.foreach(buf ++= _.take(n - buf.size)) - partsScanned += numPartsToTry + partsScanned += p.size } buf.toArray @@ -279,6 +255,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[sql] trait LeafNode extends SparkPlan { override def children: Seq[SparkPlan] = Nil + override def producedAttributes: AttributeSet = outputSet } private[sql] trait UnaryNode extends SparkPlan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala new file mode 100644 index 000000000000..a322688a259e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation +import org.apache.spark.sql.catalyst.parser.{ASTNode, ParserConf, SimpleParserConf} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.{CatalystQl, TableIdentifier} + +private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { + /** Check if a command should not be explained. */ + protected def isNoExplainCommand(command: String): Boolean = "TOK_DESCTABLE" == command + + protected override def nodeToPlan(node: ASTNode): LogicalPlan = { + node match { + // Just fake explain for any of the native commands. + case Token("TOK_EXPLAIN", explainArgs) if isNoExplainCommand(explainArgs.head.text) => + ExplainCommand(OneRowRelation) + + case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.text => + val Some(crtTbl) :: _ :: extended :: Nil = + getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) + ExplainCommand(nodeToPlan(crtTbl), extended = extended.isDefined) + + case Token("TOK_EXPLAIN", explainArgs) => + // Ignore FORMATTED if present. + val Some(query) :: _ :: extended :: Nil = + getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) + ExplainCommand(nodeToPlan(query), extended = extended.isDefined) + + case Token("TOK_DESCTABLE", describeArgs) => + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL + val Some(tableType) :: formatted :: extended :: pretty :: Nil = + getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) + if (formatted.isDefined || pretty.isDefined) { + // FORMATTED and PRETTY are not supported and this statement will be treated as + // a Hive native command. + nodeToDescribeFallback(node) + } else { + tableType match { + case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts :: Nil) :: Nil) => + nameParts match { + case Token(".", dbName :: tableName :: Nil) => + // It is describing a table with the format like "describe db.table". + // TODO: Actually, a user may mean tableName.columnName. Need to resolve this + // issue. + val tableIdent = extractTableIdent(nameParts) + datasources.DescribeCommand( + UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) + case Token(".", dbName :: tableName :: colName :: Nil) => + // It is describing a column with the format like "describe db.table column". + nodeToDescribeFallback(node) + case tableName => + // It is describing a table with the format like "describe table". + datasources.DescribeCommand( + UnresolvedRelation(TableIdentifier(tableName.text), None), + isExtended = extended.isDefined) + } + // All other cases. + case _ => nodeToDescribeFallback(node) + } + } + + case _ => + super.nodeToPlan(node) + } + } + + protected def nodeToDescribeFallback(node: ASTNode): LogicalPlan = noParseRule("Describe", node) +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala index 45a8e0324826..c590f7c6c3e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala @@ -22,16 +22,15 @@ import java.util.{HashMap => JavaHashMap} import scala.reflect.ClassTag -import com.esotericsoftware.kryo.io.{Input, Output} import com.esotericsoftware.kryo.{Kryo, Serializer} +import com.esotericsoftware.kryo.io.{Input, Output} import com.twitter.chill.ResourcePool +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.{KryoSerializer, SerializerInstance} import org.apache.spark.sql.types.Decimal import org.apache.spark.util.MutablePair -import org.apache.spark.{SparkConf, SparkEnv} - private[sql] class SparkSqlSerializer(conf: SparkConf) extends KryoSerializer(conf) { override def newKryo(): Kryo = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 688555cf136e..482130a18d93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.{execution, Strategy} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression @@ -24,10 +25,9 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} -import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} -import org.apache.spark.sql.{Strategy, execution} private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { self: SparkPlanner => @@ -358,6 +358,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil + case r @ logical.Range(start, end, step, numSlices, output) => + execution.Range(start, step, numSlices, r.numElements, output) :: Nil case logical.RepartitionByExpression(expressions, child, nPartitions) => execution.Exchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil @@ -380,13 +382,12 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case c: CreateTableUsing if c.temporary && c.allowExisting => sys.error("allowExisting should be set to false when creating a temporary table.") - case CreateTableUsingAsSelect(tableIdent, provider, true, partitionsCols, mode, opts, query) - if partitionsCols.nonEmpty => + case c: CreateTableUsingAsSelect if c.temporary && c.partitionColumns.nonEmpty => sys.error("Cannot create temporary partitioned table.") - case CreateTableUsingAsSelect(tableIdent, provider, true, _, mode, opts, query) => + case c: CreateTableUsingAsSelect if c.temporary => val cmd = CreateTempTableUsingAsSelect( - tableIdent, provider, Array.empty[String], mode, opts, query) + c.tableIdent, c.provider, Array.empty[String], c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case c: CreateTableUsingAsSelect if !c.temporary => sys.error("Tables created with SQLContext must be TEMPORARY. Use a HiveContext instead.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala index 7e981268de39..a23ebec95333 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/UnsafeRowSerializer.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import com.google.common.io.ByteStreams -import org.apache.spark.serializer.{SerializationStream, DeserializationStream, SerializerInstance, Serializer} +import org.apache.spark.serializer.{DeserializationStream, SerializationStream, Serializer, SerializerInstance} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.unsafe.Platform @@ -94,7 +94,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst private[this] val dIn: DataInputStream = new DataInputStream(new BufferedInputStream(in)) // 1024 is a default buffer size; this buffer will grow to accommodate larger rows private[this] var rowBuffer: Array[Byte] = new Array[Byte](1024) - private[this] var row: UnsafeRow = new UnsafeRow() + private[this] var row: UnsafeRow = new UnsafeRow(numFields) private[this] var rowTuple: (Int, UnsafeRow) = (0, row) private[this] val EOF: Int = -1 @@ -117,7 +117,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) rowSize = readSize() if (rowSize == EOF) { // We are returning the last row in this stream dIn.close() @@ -152,7 +152,7 @@ private class UnsafeRowSerializerInstance(numFields: Int) extends SerializerInst rowBuffer = new Array[Byte](rowSize) } ByteStreams.readFully(dIn, rowBuffer, 0, rowSize) - row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, numFields, rowSize) + row.pointTo(rowBuffer, Platform.BYTE_ARRAY_OFFSET, rowSize) row.asInstanceOf[T] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index b1280c32a6a4..be885397a7d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -17,12 +17,19 @@ package org.apache.spark.sql.execution +import java.util + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.IntegerType -import org.apache.spark.rdd.RDD -import org.apache.spark.util.collection.CompactBuffer +import org.apache.spark.util.collection.unsafe.sort.{UnsafeExternalSorter, UnsafeSorterIterator} +import org.apache.spark.{SparkEnv, TaskContext} /** * This class calculates and outputs (windowed) aggregates over the rows in a single (sorted) @@ -42,6 +49,8 @@ import org.apache.spark.util.collection.CompactBuffer * - Moving frame: Every time we move to a new row to process, we remove some rows from the frame * and we add some rows to the frame. Examples are: * 1 PRECEDING AND CURRENT ROW and 1 FOLLOWING AND 2 FOLLOWING. + * - Offset frame: The frame consist of one row, which is an offset number of rows away from the + * current row. Only [[OffsetWindowFunction]]s can be processed in an offset frame. * * Different frame boundaries can be used in Growing, Shrinking and Moving frames. A frame * boundary can be either Row or Range based: @@ -95,8 +104,6 @@ case class Window( override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def canProcessUnsafeRows: Boolean = true - /** * Create a bound ordering object for a given frame type and offset. A bound ordering object is * used to determine which input row lies within the frame boundaries of an output row. @@ -122,12 +129,10 @@ case class Window( // Create the projection which returns the current 'value'. val current = newMutableProjection(expr :: Nil, child.output)() // Flip the sign of the offset when processing the order is descending - val boundOffset = - if (sortExpr.direction == Descending) { - -offset - } else { - offset - } + val boundOffset = sortExpr.direction match { + case Descending => -offset + case Ascending => offset + } // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) val bound = newMutableProjection(boundExpr :: Nil, child.output)() @@ -149,43 +154,102 @@ case class Window( } /** - * Create a frame processor. - * - * This method uses Code Generation. It can only be used on the executor side. - * - * @param frame boundaries. - * @param functions to process in the frame. - * @param ordinal at which the processor starts writing to the output. - * @return a frame processor. + * Collection containing an entry for each window frame to process. Each entry contains a frames' + * WindowExpressions and factory function for the WindowFrameFunction. */ - private[this] def createFrameProcessor( - frame: WindowFrame, - functions: Array[WindowFunction], - ordinal: Int): WindowFunctionFrame = frame match { - // Growing Frame. - case SpecifiedWindowFrame(frameType, UnboundedPreceding, FrameBoundaryExtractor(high)) => - val uBoundOrdering = createBoundOrdering(frameType, high) - new UnboundedPrecedingWindowFunctionFrame(ordinal, functions, uBoundOrdering) - - // Shrinking Frame. - case SpecifiedWindowFrame(frameType, FrameBoundaryExtractor(low), UnboundedFollowing) => - val lBoundOrdering = createBoundOrdering(frameType, low) - new UnboundedFollowingWindowFunctionFrame(ordinal, functions, lBoundOrdering) - - // Moving Frame. - case SpecifiedWindowFrame(frameType, - FrameBoundaryExtractor(low), FrameBoundaryExtractor(high)) => - val lBoundOrdering = createBoundOrdering(frameType, low) - val uBoundOrdering = createBoundOrdering(frameType, high) - new SlidingWindowFunctionFrame(ordinal, functions, lBoundOrdering, uBoundOrdering) - - // Entire Partition Frame. - case SpecifiedWindowFrame(_, UnboundedPreceding, UnboundedFollowing) => - new UnboundedWindowFunctionFrame(ordinal, functions) - - // Error - case fr => - sys.error(s"Unsupported Frame $fr for functions: $functions") + private[this] lazy val windowFrameExpressionFactoryPairs = { + type FrameKey = (String, FrameType, Option[Int], Option[Int]) + type ExpressionBuffer = mutable.Buffer[Expression] + val framedFunctions = mutable.Map.empty[FrameKey, (ExpressionBuffer, ExpressionBuffer)] + + // Add a function and its function to the map for a given frame. + def collect(tpe: String, fr: SpecifiedWindowFrame, e: Expression, fn: Expression): Unit = { + val key = (tpe, fr.frameType, FrameBoundary(fr.frameStart), FrameBoundary(fr.frameEnd)) + val (es, fns) = framedFunctions.getOrElseUpdate( + key, (ArrayBuffer.empty[Expression], ArrayBuffer.empty[Expression])) + es.append(e) + fns.append(fn) + } + + // Collect all valid window functions and group them by their frame. + windowExpression.foreach { x => + x.foreach { + case e @ WindowExpression(function, spec) => + val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] + function match { + case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f) + case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) + case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) + case f => sys.error(s"Unsupported window function: $f") + } + case _ => + } + } + + // Map the groups to a (unbound) expression and frame factory pair. + var numExpressions = 0 + framedFunctions.toSeq.map { + case (key, (expressions, functionSeq)) => + val ordinal = numExpressions + val functions = functionSeq.toArray + + // Construct an aggregate processor if we need one. + def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection) + + // Create the factory + val factory = key match { + // Offset Frame + case ("OFFSET", RowFrame, Some(offset), Some(h)) if offset == h => + target: MutableRow => + new OffsetWindowFunctionFrame( + target, + ordinal, + functions, + child.output, + newMutableProjection, + offset) + + // Growing Frame. + case ("AGGREGATE", frameType, None, Some(high)) => + target: MutableRow => { + new UnboundedPrecedingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, high)) + } + + // Shrinking Frame. + case ("AGGREGATE", frameType, Some(low), None) => + target: MutableRow => { + new UnboundedFollowingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low)) + } + + // Moving Frame. + case ("AGGREGATE", frameType, Some(low), Some(high)) => + target: MutableRow => { + new SlidingWindowFunctionFrame( + target, + processor, + createBoundOrdering(frameType, low), + createBoundOrdering(frameType, high)) + } + + // Entire Partition Frame. + case ("AGGREGATE", frameType, None, None) => + target: MutableRow => { + new UnboundedWindowFunctionFrame(target, processor) + } + } + + // Keep track of the number of expressions. This is a side-effect in a map... + numExpressions += expressions.size + + // Create the Frame Expression - Factory pair. + (expressions, factory) + } } /** @@ -197,107 +261,119 @@ case class Window( * @return the final resulting projection. */ private[this] def createResultProjection( - expressions: Seq[Expression]): MutableProjection = { + expressions: Seq[Expression]): UnsafeProjection = { val references = expressions.zipWithIndex.map{ case (e, i) => // Results of window expressions will be on the right side of child's output BoundReference(child.output.size + i, e.dataType, e.nullable) } val unboundToRefMap = expressions.zip(references).toMap val patchedWindowExpression = windowExpression.map(_.transform(unboundToRefMap)) - newMutableProjection( + UnsafeProjection.create( projectList ++ patchedWindowExpression, - child.output)() + child.output) } protected override def doExecute(): RDD[InternalRow] = { - // Prepare processing. - // Group the window expression by their processing frame. - val windowExprs = windowExpression.flatMap { - _.collect { - case e: WindowExpression => e - } - } - - // Create Frame processor factories and order the unbound window expressions by the frame they - // are processed in; this is the order in which their results will be written to window - // function result buffer. - val framedWindowExprs = windowExprs.groupBy(_.windowSpec.frameSpecification) - val factories = Array.ofDim[() => WindowFunctionFrame](framedWindowExprs.size) - val unboundExpressions = scala.collection.mutable.Buffer.empty[Expression] - framedWindowExprs.zipWithIndex.foreach { - case ((frame, unboundFrameExpressions), index) => - // Track the ordinal. - val ordinal = unboundExpressions.size - - // Track the unbound expressions - unboundExpressions ++= unboundFrameExpressions - - // Bind the expressions. - val functions = unboundFrameExpressions.map { e => - BindReferences.bindReference(e.windowFunction, child.output) - }.toArray - - // Create the frame processor factory. - factories(index) = () => createFrameProcessor(frame, functions, ordinal) - } + // Unwrap the expressions and factories from the map. + val expressions = windowFrameExpressionFactoryPairs.flatMap(_._1) + val factories = windowFrameExpressionFactoryPairs.map(_._2).toArray // Start processing. child.execute().mapPartitions { stream => new Iterator[InternalRow] { // Get all relevant projections. - val result = createResultProjection(unboundExpressions) + val result = createResultProjection(expressions) val grouping = UnsafeProjection.create(partitionSpec, child.output) // Manage the stream and the grouping. - var nextRow: InternalRow = EmptyRow - var nextGroup: InternalRow = EmptyRow + var nextRow: UnsafeRow = null + var nextGroup: UnsafeRow = null var nextRowAvailable: Boolean = false private[this] def fetchNextRow() { nextRowAvailable = stream.hasNext if (nextRowAvailable) { - nextRow = stream.next() + nextRow = stream.next().asInstanceOf[UnsafeRow] nextGroup = grouping(nextRow) } else { - nextRow = EmptyRow - nextGroup = EmptyRow + nextRow = null + nextGroup = null } } fetchNextRow() // Manage the current partition. - var rows: CompactBuffer[InternalRow] = _ - val frames: Array[WindowFunctionFrame] = factories.map(_()) + val rows = ArrayBuffer.empty[UnsafeRow] + val inputFields = child.output.length + var sorter: UnsafeExternalSorter = null + var rowBuffer: RowBuffer = null + val windowFunctionResult = new SpecificMutableRow(expressions.map(_.dataType)) + val frames = factories.map(_(windowFunctionResult)) val numFrames = frames.length private[this] def fetchNextPartition() { // Collect all the rows in the current partition. // Before we start to fetch new input rows, make a copy of nextGroup. val currentGroup = nextGroup.copy() - rows = new CompactBuffer + + // clear last partition + if (sorter != null) { + // the last sorter of this task will be cleaned up via task completion listener + sorter.cleanupResources() + sorter = null + } else { + rows.clear() + } + while (nextRowAvailable && nextGroup == currentGroup) { - rows += nextRow.copy() + if (sorter == null) { + rows += nextRow.copy() + + if (rows.length >= 4096) { + // We will not sort the rows, so prefixComparator and recordComparator are null. + sorter = UnsafeExternalSorter.create( + TaskContext.get().taskMemoryManager(), + SparkEnv.get.blockManager, + TaskContext.get(), + null, + null, + 1024, + SparkEnv.get.memoryManager.pageSizeBytes) + rows.foreach { r => + sorter.insertRecord(r.getBaseObject, r.getBaseOffset, r.getSizeInBytes, 0) + } + rows.clear() + } + } else { + sorter.insertRecord(nextRow.getBaseObject, nextRow.getBaseOffset, + nextRow.getSizeInBytes, 0) + } fetchNextRow() } + if (sorter != null) { + rowBuffer = new ExternalRowBuffer(sorter, inputFields) + } else { + rowBuffer = new ArrayRowBuffer(rows) + } // Setup the frames. var i = 0 while (i < numFrames) { - frames(i).prepare(rows) + frames(i).prepare(rowBuffer.copy()) i += 1 } // Setup iteration rowIndex = 0 - rowsSize = rows.size + rowsSize = rowBuffer.size() } // Iteration var rowIndex = 0 - var rowsSize = 0 + var rowsSize = 0L + override final def hasNext: Boolean = rowIndex < rowsSize || nextRowAvailable val join = new JoinedRow - val windowFunctionResult = new GenericMutableRow(unboundExpressions.size) override final def next(): InternalRow = { // Load the next partition if we need to. if (rowIndex >= rowsSize && nextRowAvailable) { @@ -307,13 +383,14 @@ case class Window( if (rowIndex < rowsSize) { // Get the results for the window frames. var i = 0 + val current = rowBuffer.next() while (i < numFrames) { - frames(i).write(windowFunctionResult) + frames(i).write(rowIndex, current) i += 1 } // 'Merge' the input row with the window function result - join(rows(rowIndex), windowFunctionResult) + join(current, windowFunctionResult) rowIndex += 1 // Return the projection. @@ -329,14 +406,18 @@ case class Window( * Function for comparing boundary values. */ private[execution] abstract class BoundOrdering { - def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int + def compare(inputRow: InternalRow, inputIndex: Int, outputRow: InternalRow, outputIndex: Int): Int } /** * Compare the input index to the bound of the output index. */ private[execution] final case class RowBoundOrdering(offset: Int) extends BoundOrdering { - override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int = + override def compare( + inputRow: InternalRow, + inputIndex: Int, + outputRow: InternalRow, + outputIndex: Int): Int = inputIndex - (outputIndex + offset) } @@ -347,148 +428,198 @@ private[execution] final case class RangeBoundOrdering( ordering: Ordering[InternalRow], current: Projection, bound: Projection) extends BoundOrdering { - override def compare(input: Seq[InternalRow], inputIndex: Int, outputIndex: Int): Int = - ordering.compare(current(input(inputIndex)), bound(input(outputIndex))) + override def compare( + inputRow: InternalRow, + inputIndex: Int, + outputRow: InternalRow, + outputIndex: Int): Int = + ordering.compare(current(inputRow), bound(outputRow)) } /** - * A window function calculates the results of a number of window functions for a window frame. - * Before use a frame must be prepared by passing it all the rows in the current partition. After - * preparation the update method can be called to fill the output rows. - * - * TODO How to improve performance? A few thoughts: - * - Window functions are expensive due to its distribution and ordering requirements. - * Unfortunately it is up to the Spark engine to solve this. Improvements in the form of project - * Tungsten are on the way. - * - The window frame processing bit can be improved though. But before we start doing that we - * need to see how much of the time and resources are spent on partitioning and ordering, and - * how much time and resources are spent processing the partitions. There are a couple ways to - * improve on the current situation: - * - Reduce memory footprint by performing streaming calculations. This can only be done when - * there are no Unbound/Unbounded Following calculations present. - * - Use Tungsten style memory usage. - * - Use code generation in general, and use the approach to aggregation taken in the - * GeneratedAggregate class in specific. - * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. - */ -private[execution] abstract class WindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) { - - // Make sure functions are initialized. - functions.foreach(_.init()) + * The interface of row buffer for a partition + */ +private[execution] abstract class RowBuffer { - /** Number of columns the window function frame is managing */ - val numColumns = functions.length + /** Number of rows. */ + def size(): Int - /** - * Create a fresh thread safe copy of the frame. - * - * @return the copied frame. - */ - def copy: WindowFunctionFrame - - /** - * Create new instances of the functions. - * - * @return an array containing copies of the current window functions. - */ - protected final def copyFunctions: Array[WindowFunction] = functions.map(_.newInstance()) + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow - /** - * Prepare the frame for calculating the results for a partition. - * - * @param rows to calculate the frame results for. - */ - def prepare(rows: CompactBuffer[InternalRow]): Unit + /** Skip the next `n` rows. */ + def skip(n: Int): Unit - /** - * Write the result for the current row to the given target row. - * - * @param target row to write the result for the current row to. - */ - def write(target: GenericMutableRow): Unit + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer +} - /** Reset the current window functions. */ - protected final def reset(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).reset() - i += 1 +/** + * A row buffer based on ArrayBuffer (the number of rows is limited) + */ +private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { + + private[this] var cursor: Int = -1 + + /** Number of rows. */ + def size(): Int = buffer.length + + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow = { + cursor += 1 + if (cursor < buffer.length) { + buffer(cursor) + } else { + null } } - /** Prepare an input row for processing. */ - protected final def prepare(input: InternalRow): Array[AnyRef] = { - val prepared = new Array[AnyRef](numColumns) - var i = 0 - while (i < numColumns) { - prepared(i) = functions(i).prepareInputParameters(input) - i += 1 - } - prepared + /** Skip the next `n` rows. */ + def skip(n: Int): Unit = { + cursor += n } - /** Evaluate a prepared buffer (iterator). */ - protected final def evaluatePrepared(iterator: java.util.Iterator[Array[AnyRef]]): Unit = { - reset() - while (iterator.hasNext) { - val prepared = iterator.next() - var i = 0 - while (i < numColumns) { - functions(i).update(prepared(i)) - i += 1 - } + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer = { + new ArrayRowBuffer(buffer) + } +} + +/** + * An external buffer of rows based on UnsafeExternalSorter + */ +private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) + extends RowBuffer { + + private[this] val iter: UnsafeSorterIterator = sorter.getIterator + + private[this] val currentRow = new UnsafeRow(numFields) + + /** Number of rows. */ + def size(): Int = iter.getNumRecords() + + /** Return next row in the buffer, null if no more left. */ + def next(): InternalRow = { + if (iter.hasNext) { + iter.loadNext() + currentRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) + currentRow + } else { + null } - evaluate() } - /** Evaluate a prepared buffer (array). */ - protected final def evaluatePrepared(prepared: Array[Array[AnyRef]], - fromIndex: Int, toIndex: Int): Unit = { + /** Skip the next `n` rows. */ + def skip(n: Int): Unit = { var i = 0 - while (i < numColumns) { - val function = functions(i) - function.reset() - var j = fromIndex - while (j < toIndex) { - function.update(prepared(j)(i)) - j += 1 - } - function.evaluate() + while (i < n && iter.hasNext) { + iter.loadNext() i += 1 } } - /** Update an array of window functions. */ - protected final def update(input: InternalRow): Unit = { - var i = 0 - while (i < numColumns) { - val aggregate = functions(i) - val preparedInput = aggregate.prepareInputParameters(input) - aggregate.update(preparedInput) - i += 1 + /** Return a new RowBuffer that has the same rows. */ + def copy(): RowBuffer = { + new ExternalRowBuffer(sorter, numFields) + } +} + +/** + * A window function calculates the results of a number of window functions for a window frame. + * Before use a frame must be prepared by passing it all the rows in the current partition. After + * preparation the update method can be called to fill the output rows. + */ +private[execution] abstract class WindowFunctionFrame { + /** + * Prepare the frame for calculating the results for a partition. + * + * @param rows to calculate the frame results for. + */ + def prepare(rows: RowBuffer): Unit + + /** + * Write the current results to the target row. + */ + def write(index: Int, current: InternalRow): Unit +} + +/** + * The offset window frame calculates frames containing LEAD/LAG statements. + * + * @param target to write results to. + * @param expressions to shift a number of rows. + * @param inputSchema required for creating a projection. + * @param newMutableProjection function used to create the projection. + * @param offset by which rows get moved within a partition. + */ +private[execution] final class OffsetWindowFunctionFrame( + target: MutableRow, + ordinal: Int, + expressions: Array[Expression], + inputSchema: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection, + offset: Int) extends WindowFunctionFrame { + + /** Rows of the partition currently being processed. */ + private[this] var input: RowBuffer = null + + /** Index of the input row currently used for output. */ + private[this] var inputIndex = 0 + + /** Row used when there is no valid input. */ + private[this] val emptyRow = new GenericInternalRow(inputSchema.size) + + /** Row used to combine the offset and the current row. */ + private[this] val join = new JoinedRow + + /** Create the projection. */ + private[this] val projection = { + // Collect the expressions and bind them. + val inputAttrs = inputSchema.map(_.withNullability(true)) + val numInputAttributes = inputAttrs.size + val boundExpressions = Seq.fill(ordinal)(NoOp) ++ expressions.toSeq.map { + case e: OffsetWindowFunction => + val input = BindReferences.bindReference(e.input, inputAttrs) + if (e.default == null || e.default.foldable && e.default.eval() == null) { + // Without default value. + input + } else { + // With default value. + val default = BindReferences.bindReference(e.default, inputAttrs).transform { + // Shift the input reference to its default version. + case BoundReference(o, dataType, nullable) => + BoundReference(o + numInputAttributes, dataType, nullable) + } + org.apache.spark.sql.catalyst.expressions.Coalesce(input :: default :: Nil) + } + case e => + BindReferences.bindReference(e, inputAttrs) } + + // Create the projection. + newMutableProjection(boundExpressions, Nil)().target(target) } - /** Evaluate the window functions. */ - protected final def evaluate(): Unit = { - var i = 0 - while (i < numColumns) { - functions(i).evaluate() - i += 1 + override def prepare(rows: RowBuffer): Unit = { + input = rows + // drain the first few rows if offset is larger than zero + inputIndex = 0 + while (inputIndex < offset) { + input.next() + inputIndex += 1 } + inputIndex = offset } - /** Fill a target row with the current window function results. */ - protected final def fill(target: GenericMutableRow, rowIndex: Int): Unit = { - var i = 0 - while (i < numColumns) { - target.update(ordinal + i, functions(i).get(rowIndex)) - i += 1 + override def write(index: Int, current: InternalRow): Unit = { + if (inputIndex >= 0 && inputIndex < input.size) { + val r = input.next() + join(r, current) + } else { + join(emptyRow, current) } + projection(join) + inputIndex += 1 } } @@ -496,19 +627,25 @@ private[execution] abstract class WindowFunctionFrame( * The sliding window frame calculates frames with the following SQL form: * ... BETWEEN 1 PRECEDING AND 1 FOLLOWING * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class SlidingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], + target: MutableRow, + processor: AggregateProcessor, lbound: BoundOrdering, - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: RowBuffer = null + + /** The next row from `input`. */ + private[this] var nextRow: InternalRow = null + + /** The rows within current sliding window. */ + private[this] val buffer = new util.ArrayDeque[InternalRow]() /** Index of the first input row with a value greater than the upper bound of the current * output row. */ @@ -518,56 +655,46 @@ private[execution] final class SlidingWindowFunctionFrame( * current output row. */ private[this] var inputLowIndex = 0 - /** Buffer used for storing prepared input for the window functions. */ - private[this] val buffer = new java.util.ArrayDeque[Array[AnyRef]] - - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 - /** Prepare the frame for calculating a new partition. Reset all variables. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: RowBuffer): Unit = { input = rows + nextRow = rows.next() inputHighIndex = 0 inputLowIndex = 0 - outputIndex = 0 buffer.clear() } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - var bufferUpdated = outputIndex == 0 + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 // Add all rows to the buffer for which the input row value is equal to or less than // the output row upper bound. - while (inputHighIndex < input.size && - ubound.compare(input, inputHighIndex, outputIndex) <= 0) { - buffer.offer(prepare(input(inputHighIndex))) + while (nextRow != null && ubound.compare(nextRow, inputHighIndex, current, index) <= 0) { + buffer.add(nextRow.copy()) + nextRow = input.next() inputHighIndex += 1 bufferUpdated = true } // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. - while (inputLowIndex < inputHighIndex && - lbound.compare(input, inputLowIndex, outputIndex) < 0) { - buffer.pop() + while (!buffer.isEmpty && lbound.compare(buffer.peek(), inputLowIndex, current, index) < 0) { + buffer.remove() inputLowIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer.iterator()) - fill(target, outputIndex) + processor.initialize(input.size) + val iter = buffer.iterator() + while (iter.hasNext) { + processor.update(iter.next()) + } + processor.evaluate(target) } - - // Move to the next row. - outputIndex += 1 } - - /** Copy the frame. */ - override def copy: SlidingWindowFunctionFrame = - new SlidingWindowFunctionFrame(ordinal, copyFunctions, lbound, ubound) } /** @@ -578,36 +705,30 @@ private[execution] final class SlidingWindowFunctionFrame( * Its results are the same for each and every row in the partition. This class can be seen as a * special case of a sliding window, but is optimized for the unbound case. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. */ private[execution] final class UnboundedWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction]) extends WindowFunctionFrame(ordinal, functions) { - - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 + target: MutableRow, + processor: AggregateProcessor) extends WindowFunctionFrame { /** Prepare the frame for calculating a new partition. Process all rows eagerly. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() - outputIndex = 0 - val iterator = rows.iterator - while (iterator.hasNext) { - update(iterator.next()) + override def prepare(rows: RowBuffer): Unit = { + val size = rows.size() + processor.initialize(size) + var i = 0 + while (i < size) { + processor.update(rows.next()) + i += 1 } - evaluate() } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - fill(target, outputIndex) - outputIndex += 1 + override def write(index: Int, current: InternalRow): Unit = { + // Unfortunately we cannot assume that evaluation is deterministic. So we need to re-evaluate + // for each row. + processor.evaluate(target) } - - /** Copy the frame. */ - override def copy: UnboundedWindowFunctionFrame = - new UnboundedWindowFunctionFrame(ordinal, copyFunctions) } /** @@ -620,58 +741,51 @@ private[execution] final class UnboundedWindowFunctionFrame( * is not the case when there is no lower bound, given the additive nature of most aggregates * streaming updates and partial evaluation suffice and no buffering is needed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param ubound comparator used to identify the upper bound of an output row. */ private[execution] final class UnboundedPrecedingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - ubound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { + target: MutableRow, + processor: AggregateProcessor, + ubound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: RowBuffer = null + + /** The next row from `input`. */ + private[this] var nextRow: InternalRow = null /** Index of the first input row with a value greater than the upper bound of the current - * output row. */ + * output row. */ private[this] var inputIndex = 0 - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 - /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { - reset() + override def prepare(rows: RowBuffer): Unit = { input = rows + nextRow = rows.next() inputIndex = 0 - outputIndex = 0 + processor.initialize(input.size) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - var bufferUpdated = outputIndex == 0 + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 // Add all rows to the aggregates for which the input row value is equal to or less than // the output row upper bound. - while (inputIndex < input.size && ubound.compare(input, inputIndex, outputIndex) <= 0) { - update(input(inputIndex)) + while (nextRow != null && ubound.compare(nextRow, inputIndex, current, index) <= 0) { + processor.update(nextRow) + nextRow = input.next() inputIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluate() - fill(target, outputIndex) + processor.evaluate(target) } - - // Move to the next row. - outputIndex += 1 } - - /** Copy the frame. */ - override def copy: UnboundedPrecedingWindowFunctionFrame = - new UnboundedPrecedingWindowFunctionFrame(ordinal, copyFunctions, ubound) } /** @@ -686,65 +800,183 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( * buffer and must do full recalculation after each row. Reverse iteration would be possible, if * the communitativity of the used window functions can be guaranteed. * - * @param ordinal of the first column written by this frame. - * @param functions to calculate the row values with. + * @param target to write results to. + * @param processor to calculate the row values with. * @param lbound comparator used to identify the lower bound of an output row. */ private[execution] final class UnboundedFollowingWindowFunctionFrame( - ordinal: Int, - functions: Array[WindowFunction], - lbound: BoundOrdering) extends WindowFunctionFrame(ordinal, functions) { - - /** Buffer used for storing prepared input for the window functions. */ - private[this] var buffer: Array[Array[AnyRef]] = _ + target: MutableRow, + processor: AggregateProcessor, + lbound: BoundOrdering) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ - private[this] var input: CompactBuffer[InternalRow] = null + private[this] var input: RowBuffer = null /** Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. */ + * current output row. */ private[this] var inputIndex = 0 - /** Index of the row we are currently writing. */ - private[this] var outputIndex = 0 - /** Prepare the frame for calculating a new partition. */ - override def prepare(rows: CompactBuffer[InternalRow]): Unit = { + override def prepare(rows: RowBuffer): Unit = { input = rows inputIndex = 0 - outputIndex = 0 - val size = input.size - buffer = Array.ofDim(size) - var i = 0 - while (i < size) { - buffer(i) = prepare(input(i)) - i += 1 - } - evaluatePrepared(buffer, 0, buffer.length) } /** Write the frame columns for the current row to the given target row. */ - override def write(target: GenericMutableRow): Unit = { - var bufferUpdated = outputIndex == 0 + override def write(index: Int, current: InternalRow): Unit = { + var bufferUpdated = index == 0 + + // Duplicate the input to have a new iterator + val tmp = input.copy() // Drop all rows from the buffer for which the input row value is smaller than // the output row lower bound. - while (inputIndex < input.size && lbound.compare(input, inputIndex, outputIndex) < 0) { + tmp.skip(inputIndex) + var nextRow = tmp.next() + while (nextRow != null && lbound.compare(nextRow, inputIndex, current, index) < 0) { + nextRow = tmp.next() inputIndex += 1 bufferUpdated = true } // Only recalculate and update when the buffer changes. if (bufferUpdated) { - evaluatePrepared(buffer, inputIndex, buffer.length) - fill(target, outputIndex) + processor.initialize(input.size) + while (nextRow != null) { + processor.update(nextRow) + nextRow = tmp.next() + } + processor.evaluate(target) } + } +} - // Move to the next row. - outputIndex += 1 +/** + * This class prepares and manages the processing of a number of [[AggregateFunction]]s within a + * single frame. The [[WindowFunctionFrame]] takes care of processing the frame in the correct way, + * this reduces the processing of a [[AggregateWindowFunction]] to processing the underlying + * [[AggregateFunction]]. All [[AggregateFunction]]s are processed in [[Complete]] mode. + * + * [[SizeBasedWindowFunction]]s are initialized in a slightly different way. These functions + * require the size of the partition processed, this value is exposed to them when the processor is + * constructed. + * + * Processing of distinct aggregates is currently not supported. + * + * The implementation is split into an object which takes care of construction, and a the actual + * processor class. + */ +private[execution] object AggregateProcessor { + def apply(functions: Array[Expression], + ordinal: Int, + inputAttributes: Seq[Attribute], + newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): + AggregateProcessor = { + val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] + val initialValues = mutable.Buffer.empty[Expression] + val updateExpressions = mutable.Buffer.empty[Expression] + val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) + val imperatives = mutable.Buffer.empty[ImperativeAggregate] + + // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to + // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. + val trackPartitionSize = functions.exists(_.isInstanceOf[SizeBasedWindowFunction]) + if (trackPartitionSize) { + aggBufferAttributes += SizeBasedWindowFunction.n + initialValues += NoOp + updateExpressions += NoOp + } + + // Add an AggregateFunction to the AggregateProcessor. + functions.foreach { + case agg: DeclarativeAggregate => + aggBufferAttributes ++= agg.aggBufferAttributes + initialValues ++= agg.initialValues + updateExpressions ++= agg.updateExpressions + evaluateExpressions += agg.evaluateExpression + case agg: ImperativeAggregate => + val offset = aggBufferAttributes.size + val imperative = BindReferences.bindReference(agg + .withNewInputAggBufferOffset(offset) + .withNewMutableAggBufferOffset(offset), + inputAttributes) + imperatives += imperative + aggBufferAttributes ++= imperative.aggBufferAttributes + val noOps = Seq.fill(imperative.aggBufferAttributes.size)(NoOp) + initialValues ++= noOps + updateExpressions ++= noOps + evaluateExpressions += imperative + case other => + sys.error(s"Unsupported Aggregate Function: $other") + } + + // Create the projections. + val initialProjection = newMutableProjection( + initialValues, + Seq(SizeBasedWindowFunction.n))() + val updateProjection = newMutableProjection( + updateExpressions, + aggBufferAttributes ++ inputAttributes)() + val evaluateProjection = newMutableProjection( + evaluateExpressions, + aggBufferAttributes)() + + // Create the processor + new AggregateProcessor( + aggBufferAttributes.toArray, + initialProjection, + updateProjection, + evaluateProjection, + imperatives.toArray, + trackPartitionSize) + } +} + +/** + * This class manages the processing of a number of aggregate functions. See the documentation of + * the object for more information. + */ +private[execution] final class AggregateProcessor( + private[this] val bufferSchema: Array[AttributeReference], + private[this] val initialProjection: MutableProjection, + private[this] val updateProjection: MutableProjection, + private[this] val evaluateProjection: MutableProjection, + private[this] val imperatives: Array[ImperativeAggregate], + private[this] val trackPartitionSize: Boolean) { + + private[this] val join = new JoinedRow + private[this] val numImperatives = imperatives.length + private[this] val buffer = new SpecificMutableRow(bufferSchema.toSeq.map(_.dataType)) + initialProjection.target(buffer) + updateProjection.target(buffer) + + /** Create the initial state. */ + def initialize(size: Int): Unit = { + // Some initialization expressions are dependent on the partition size so we have to + // initialize the size before initializing all other fields, and we have to pass the buffer to + // the initialization projection. + if (trackPartitionSize) { + buffer.setInt(0, size) + } + initialProjection(buffer) + var i = 0 + while (i < numImperatives) { + imperatives(i).initialize(buffer) + i += 1 + } + } + + /** Update the buffer. */ + def update(input: InternalRow): Unit = { + updateProjection(join(buffer, input)) + var i = 0 + while (i < numImperatives) { + imperatives(i).update(buffer, input) + i += 1 + } } - /** Copy the frame. */ - override def copy: UnboundedFollowingWindowFunctionFrame = - new UnboundedFollowingWindowFunctionFrame(ordinal, copyFunctions, lbound) + /** Evaluate buffer. */ + def evaluate(target: MutableRow): Unit = + evaluateProjection.target(target)(buffer) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index c5470a6989de..1d56592c40b9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.physical.{UnspecifiedDistribution, ClusteredDistribution, AllTuples, Distribution} +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -36,14 +36,19 @@ case class SortBasedAggregate( child: SparkPlan) extends UnaryNode { + private[this] val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } + + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + override private[sql] lazy val metrics = Map( "numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"), "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) override def requiredChildDistribution: List[Distribution] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index ac920aa8bc7f..6501634ff998 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -87,6 +87,10 @@ class SortBasedAggregationIterator( // The aggregation buffer used by the sort-based aggregation. private[this] val sortBasedAggregationBuffer: MutableRow = newBuffer + // An SafeProjection to turn UnsafeRow into GenericInternalRow, because UnsafeRow can't be + // compared to MutableRow (aggregation buffer) directly. + private[this] val safeProj: Projection = FromUnsafeProjection(valueAttributes.map(_.dataType)) + protected def initialize(): Unit = { if (inputIterator.hasNext) { initializeBuffer(sortBasedAggregationBuffer) @@ -110,7 +114,7 @@ class SortBasedAggregationIterator( // We create a variable to track if we see the next group. var findNextPartition = false // firstRowInNextGroup is the first row of this group. We first process it. - processRow(sortBasedAggregationBuffer, firstRowInNextGroup) + processRow(sortBasedAggregationBuffer, safeProj(firstRowInNextGroup)) // The search will stop when we see the next group or there is no // input row left in the iter. @@ -122,7 +126,7 @@ class SortBasedAggregationIterator( // Check if the current row belongs the current input row. if (currentGroupingKey == groupingKey) { - processRow(sortBasedAggregationBuffer, currentRow) + processRow(sortBasedAggregationBuffer, safeProj(currentRow)) } else { // We find a new group. findNextPartition = true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index b8849c827048..a9cf04388d2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -23,8 +23,8 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType case class TungstenAggregate( @@ -49,12 +49,13 @@ case class TungstenAggregate( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) + override def producedAttributes: AttributeSet = + AttributeSet(aggregateAttributes) ++ + AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++ + AttributeSet(aggregateBufferAttributes) + override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { case Some(exprs) if exprs.length == 0 => AllTuples :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 582fdbe54706..41799c596b6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,15 +17,15 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.{InternalAccumulator, Logging, TaskContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner -import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, UnsafeKVExternalSorter} +import org.apache.spark.sql.execution.metric.LongSQLMetric import org.apache.spark.sql.types.StructType import org.apache.spark.unsafe.KVIterator -import org.apache.spark.{InternalAccumulator, Logging, TaskContext} /** * An iterator used to evaluate aggregate functions. It operates on [[UnsafeRow]]s. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index a9719128a626..1df38f7ff59c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -21,11 +21,11 @@ import scala.language.existentials import org.apache.spark.Logging import org.apache.spark.sql.Encoder -import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{OuterScopes, encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ object TypedAggregateExpression { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index c0d00104e8bf..5a19920add71 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.Logging import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, InterpretedMutableProjection, MutableRow} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, ImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.expressions.{MutableRow, InterpretedMutableProjection, AttributeReference, Expression} -import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, AggregateFunction} import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b3e4688557ba..95bef683238a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.{HashPartitioner, SparkEnv} import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD, ShuffledRDD} import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow @@ -25,20 +26,15 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.LongType import org.apache.spark.util.MutablePair import org.apache.spark.util.random.PoissonSampler -import org.apache.spark.{HashPartitioner, SparkEnv} - case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { override private[sql] lazy val metrics = Map( "numRows" -> SQLMetrics.createLongMetric(sparkContext, "number of rows")) - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - override def output: Seq[Attribute] = projectList.map(_.toAttribute) protected override def doExecute(): RDD[InternalRow] = { @@ -79,12 +75,6 @@ case class Filter(condition: Expression, child: SparkPlan) extends UnaryNode { } override def outputOrdering: Seq[SortOrder] = child.outputOrdering - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - - override def canProcessUnsafeRows: Boolean = true - - override def canProcessSafeRows: Boolean = true } /** @@ -107,10 +97,6 @@ case class Sample( { override def output: Seq[Attribute] = child.output - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, @@ -126,6 +112,65 @@ case class Sample( } } +case class Range( + start: Long, + step: Long, + numSlices: Int, + numElements: BigInt, + output: Seq[Attribute]) + extends LeafNode { + + protected override def doExecute(): RDD[InternalRow] = { + sqlContext + .sparkContext + .parallelize(0 until numSlices, numSlices) + .mapPartitionsWithIndex((i, _) => { + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue + } + val safePartitionStart = getSafeMargin(partitionStart) + val safePartitionEnd = getSafeMargin(partitionEnd) + val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize + val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1) + + new Iterator[InternalRow] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + + override def hasNext = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next() = { + val ret = number + number += step + if (number < ret ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step + // back, we are pretty sure that we have an overflow. + overflow = true + } + + unsafeRow.setLong(0, ret) + unsafeRow + } + } + }) + } +} + /** * Union two plans, without a distinct. This is UNION ALL in SQL. */ @@ -137,9 +182,6 @@ case class Union(children: Seq[SparkPlan]) extends SparkPlan { } } } - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true protected override def doExecute(): RDD[InternalRow] = sparkContext.union(children.map(_.execute())) } @@ -206,12 +248,14 @@ case class TakeOrderedAndProject( // and this ordering needs to be created on the driver in order to be passed into Spark core code. private val ord: InterpretedOrdering = new InterpretedOrdering(sortOrder, child.output) - // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable. - @transient private val projection = projectList.map(new InterpretedProjection(_, child.output)) - private def collectData(): Array[InternalRow] = { val data = child.execute().map(_.copy()).takeOrdered(limit)(ord) - projection.map(data.map(_)).getOrElse(data) + if (projectList.isDefined) { + val proj = UnsafeProjection.create(projectList.get, child.output) + data.map(r => proj(r).copy()) + } else { + data + } } override def executeCollect(): Array[InternalRow] = { @@ -249,10 +293,6 @@ case class Coalesce(numPartitions: Int, child: SparkPlan) extends UnaryNode { protected override def doExecute(): RDD[InternalRow] = { child.execute().coalesce(numPartitions, shuffle = false) } - - override def outputsUnsafeRows: Boolean = child.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -265,10 +305,6 @@ case class Except(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).subtract(right.execute().map(_.copy())) } - - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -281,10 +317,6 @@ case class Intersect(left: SparkPlan, right: SparkPlan) extends BinaryNode { protected override def doExecute(): RDD[InternalRow] = { left.execute().map(_.copy()).intersection(right.execute().map(_.copy())) } - - override def outputsUnsafeRows: Boolean = children.exists(_.outputsUnsafeRows) - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true } /** @@ -307,6 +339,7 @@ case class MapPartitions[T, U]( uEncoder: ExpressionEncoder[U], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = outputSet override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => @@ -325,10 +358,7 @@ case class AppendColumns[T, U]( uEncoder: ExpressionEncoder[U], newColumns: Seq[Attribute], child: SparkPlan) extends UnaryNode { - - // We are using an unsafe combiner. - override def canProcessSafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true + override def producedAttributes: AttributeSet = AttributeSet(newColumns) override def output: Seq[Attribute] = child.output ++ newColumns @@ -357,6 +387,7 @@ case class MapGroups[K, T, U]( groupingAttributes: Seq[Attribute], output: Seq[Attribute], child: SparkPlan) extends UnaryNode { + override def producedAttributes: AttributeSet = outputSet override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -396,6 +427,7 @@ case class CoGroup[Key, Left, Right, Result]( rightGroup: Seq[Attribute], left: SparkPlan, right: SparkPlan) extends BinaryNode { + override def producedAttributes: AttributeSet = outputSet override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala index c52ee9ffd6d2..5d4476989a36 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, GenericInternalRow} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index c9f2329db4b6..9c908b2877e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -574,11 +574,10 @@ private[columnar] case class STRUCT(dataType: StructType) assert(buffer.hasArray) val cursor = buffer.position() buffer.position(cursor + sizeInBytes) - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(numOfFields) unsafeRow.pointTo( buffer.array(), Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, - numOfFields, sizeInBytes) unsafeRow } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index eaafc96e4d2e..55e2c0ed7002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator, UnsafeRowWriter} import org.apache.spark.sql.types._ /** @@ -131,7 +131,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera private ByteOrder nativeOrder = null; private byte[][] buffers = null; - private UnsafeRow unsafeRow = new UnsafeRow(); + private UnsafeRow unsafeRow = new UnsafeRow($numFields); private BufferHolder bufferHolder = new BufferHolder(); private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); private MutableUnsafeRow mutableRow = null; @@ -183,7 +183,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera bufferHolder.reset(); rowWriter.initialize(bufferHolder, $numFields); ${extractors.mkString("\n")} - unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize()); + unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()); return unsafeRow; } }""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala index 3c5a8cb2aa93..9084b74d1a74 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{Accumulable, Accumulator, Accumulators} import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow @@ -27,10 +28,9 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} +import org.apache.spark.sql.execution.{LeafNode, SparkPlan} import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel -import org.apache.spark.{Accumulable, Accumulator, Accumulators} private[sql] object InMemoryRelation { def apply( @@ -39,9 +39,7 @@ private[sql] object InMemoryRelation { storageLevel: StorageLevel, child: SparkPlan, tableName: Option[String]): InMemoryRelation = - new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, - if (child.outputsUnsafeRows) child else ConvertToUnsafe(child), - tableName)() + new InMemoryRelation(child.output, useCompression, batchSize, storageLevel, child, tableName)() } /** @@ -61,11 +59,13 @@ private[sql] case class InMemoryRelation( storageLevel: StorageLevel, @transient child: SparkPlan, tableName: Option[String])( - @transient private var _cachedColumnBuffers: RDD[CachedBatch] = null, - @transient private var _statistics: Statistics = null, - private var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) + @transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null, + @transient private[sql] var _statistics: Statistics = null, + private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null) extends LogicalPlan with MultiInstanceRelation { + override def producedAttributes: AttributeSet = outputSet + private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = if (_batchStats == null) { child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow]) @@ -224,8 +224,6 @@ private[sql] case class InMemoryColumnarTableScan( // The cached version does not change the outputOrdering of the original SparkPlan. override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering - override def outputsUnsafeRows: Boolean = true - private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) // Returned filter predicate should return false iff it is impossible for the input expression diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala index 8d99546924de..2465633162c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessor.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.columnar -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.expressions.MutableRow diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 920381f9c63d..b90d00b15b18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.{ByteBuffer, ByteOrder} + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.MutableRow import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala index 24a79f289aa8..2e2ce88211a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/commands.scala @@ -21,13 +21,13 @@ import java.util.NoSuchElementException import org.apache.spark.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters} +import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SQLConf, SQLContext} /** * A logical command that is executed for its side-effects. `RunnableCommand`s are @@ -148,8 +148,6 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) - (keyValueOutput, runFunc) - case Some((SQLConf.Deprecated.SORTMERGE_JOIN, Some(value))) => val runFunc = (sqlContext: SQLContext) => { logWarning( @@ -232,7 +230,7 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm case class ExplainCommand( logicalPlan: LogicalPlan, override val output: Seq[Attribute] = - Seq(AttributeReference("plan", StringType, nullable = false)()), + Seq(AttributeReference("plan", StringType, nullable = true)()), extended: Boolean = false) extends RunnableCommand { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala index f22508b21090..d8d21b06b8b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DDLParser.scala @@ -22,7 +22,7 @@ import scala.util.matching.Regex import org.apache.spark.Logging import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.{TableIdentifier, AbstractSparkSQLParser} +import org.apache.spark.sql.catalyst.{AbstractSparkSQLParser, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.DataTypeParser @@ -109,6 +109,7 @@ class DDLParser(parseQuery: String => LogicalPlan) provider, temp.isDefined, Array.empty[String], + bucketSpec = None, mode, options, queryPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 8a15a51d825e..1d6290e027f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -19,22 +19,22 @@ package org.apache.spark.sql.execution.datasources import scala.collection.mutable.ArrayBuffer +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{MapPartitionsRDD, RDD, UnionRDD} +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, expressions} import org.apache.spark.sql.execution.PhysicalRDD.{INPUT_PATHS, PUSHED_FILTERS} import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql.{SaveMode, Strategy, execution, sources, _} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.{Logging, TaskContext} /** * A Strategy for planning scans over data sources defined using the sources API. @@ -77,7 +77,8 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val pushedFilters = filters.filter(_.references.intersect(partitionColumns).isEmpty) // Predicates with both partition keys and attributes - val combineFilters = filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet + val partitionAndNormalColumnFilters = + filters.toSet -- partitionFilters.toSet -- pushedFilters.toSet val selectedPartitions = prunePartitions(partitionFilters, t.partitionSpec).toArray @@ -88,16 +89,33 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { s"Selected $selected partitions out of $total, pruned $percentPruned% partitions." } + // need to add projections from "partitionAndNormalColumnAttrs" in if it is not empty + val partitionAndNormalColumnAttrs = AttributeSet(partitionAndNormalColumnFilters) + val partitionAndNormalColumnProjs = if (partitionAndNormalColumnAttrs.isEmpty) { + projects + } else { + (partitionAndNormalColumnAttrs ++ projects).toSeq + } + val scan = buildPartitionedTableScan( l, - projects, + partitionAndNormalColumnProjs, pushedFilters, t.partitionSpec.partitionColumns, selectedPartitions) - combineFilters - .reduceLeftOption(expressions.And) - .map(execution.Filter(_, scan)).getOrElse(scan) :: Nil + // Add a Projection to guarantee the original projection: + // this is because "partitionAndNormalColumnAttrs" may be different + // from the original "projects", in elements or their ordering + + partitionAndNormalColumnFilters.reduceLeftOption(expressions.And).map(cf => + if (projects.isEmpty || projects == partitionAndNormalColumnProjs) { + // if the original projection is empty, no need for the additional Project either + execution.Filter(cf, scan) + } else { + execution.Project(projects, execution.Filter(cf, scan)) + } + ).getOrElse(scan) :: Nil // Scanning non-partitioned HadoopFsRelation case PhysicalOperation(projects, filters, l @ LogicalRelation(t: HadoopFsRelation, _)) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala index 735d52f80886..7a8691e7cb9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelation.scala @@ -22,6 +22,7 @@ import java.io.IOException import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat + import org.apache.spark._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -93,7 +94,7 @@ private[sql] case class InsertIntoHadoopFsRelation( val isAppend = pathExists && (mode == SaveMode.Append) if (doInsertion) { - val job = new Job(hadoopConf) + val job = Job.getInstance(hadoopConf) job.setOutputKeyClass(classOf[Void]) job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, qualifiedOutputPath) @@ -124,7 +125,7 @@ private[sql] case class InsertIntoHadoopFsRelation( |Actual: ${partitionColumns.mkString(", ")} """.stripMargin) - val writerContainer = if (partitionColumns.isEmpty) { + val writerContainer = if (partitionColumns.isEmpty && relation.bucketSpec.isEmpty) { new DefaultWriterContainer(relation, job, isAppend) } else { val output = df.queryExecution.executedPlan.output diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala index 86a306b8f941..ece9b8a9a917 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ResolvedDataSource.scala @@ -21,14 +21,14 @@ import java.util.ServiceLoader import scala.collection.JavaConverters._ import scala.language.{existentials, implicitConversions} -import scala.util.{Success, Failure, Try} +import scala.util.{Failure, Success, Try} import org.apache.hadoop.fs.Path import org.apache.hadoop.util.StringUtils import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.sql.{DataFrame, SaveMode, AnalysisException, SQLContext} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{CalendarIntervalType, StructType} import org.apache.spark.util.Utils @@ -57,24 +57,38 @@ object ResolvedDataSource extends Logging { val serviceLoader = ServiceLoader.load(classOf[DataSourceRegister], loader) serviceLoader.asScala.filter(_.shortName().equalsIgnoreCase(provider)).toList match { - /** the provider format did not match any given registered aliases */ - case Nil => Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { - case Success(dataSource) => dataSource - case Failure(error) => - if (provider.startsWith("org.apache.spark.sql.hive.orc")) { - throw new ClassNotFoundException( - "The ORC data source must be used with Hive support enabled.", error) - } else { - throw new ClassNotFoundException( - s"Failed to load class for data source: $provider.", error) - } - } - /** there is exactly one registered alias */ - case head :: Nil => head.getClass - /** There are multiple registered aliases for the input */ - case sources => sys.error(s"Multiple sources found for $provider, " + - s"(${sources.map(_.getClass.getName).mkString(", ")}), " + - "please specify the fully qualified class name.") + // the provider format did not match any given registered aliases + case Nil => + Try(loader.loadClass(provider)).orElse(Try(loader.loadClass(provider2))) match { + case Success(dataSource) => + // Found the data source using fully qualified path + dataSource + case Failure(error) => + if (provider.startsWith("org.apache.spark.sql.hive.orc")) { + throw new ClassNotFoundException( + "The ORC data source must be used with Hive support enabled.", error) + } else { + if (provider == "avro" || provider == "com.databricks.spark.avro") { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please use Spark package " + + "http://spark-packages.org/package/databricks/spark-avro", + error) + } else { + throw new ClassNotFoundException( + s"Failed to find data source: $provider. Please find packages at " + + "http://spark-packages.org", + error) + } + } + } + case head :: Nil => + // there is exactly one registered alias + head.getClass + case sources => + // There are multiple registered aliases for the input + sys.error(s"Multiple sources found for $provider " + + s"(${sources.map(_.getClass.getName).mkString(", ")}), " + + "please specify the fully qualified class name.") } } @@ -196,6 +210,7 @@ object ResolvedDataSource extends Logging { sqlContext: SQLContext, provider: String, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], data: DataFrame): ResolvedDataSource = { @@ -230,6 +245,7 @@ object ResolvedDataSource extends Logging { Array(outputPath.toString), Some(dataSchema.asNullable), Some(partitionColumnsSchema(data.schema, partitionColumns, caseSensitive)), + bucketSpec, caseInsensitiveOptions) // For partitioned relation r, r.schema's column ordering can be different from the column diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index eea780cbaa7e..d45d2db62f3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -26,16 +26,16 @@ import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} +import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} + +import org.apache.spark.{Partition => SparkPartition, _} import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.executor.DataReadMethod -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql.{SQLConf, SQLContext} import org.apache.spark.sql.execution.datasources.parquet.UnsafeRowParquetRecordReader import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} -import org.apache.spark.{Partition => SparkPartition, _} - private[spark] class SqlNewHadoopPartition( rddId: Int, @@ -68,16 +68,14 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( initLocalJobFuncOpt: Option[Job => Unit], inputFormatClass: Class[_ <: InputFormat[Void, V]], valueClass: Class[V]) - extends RDD[V](sqlContext.sparkContext, Nil) - with SparkHadoopMapReduceUtil - with Logging { + extends RDD[V](sqlContext.sparkContext, Nil) with Logging { protected def getJob(): Job = { - val conf: Configuration = broadcastedConf.value.value + val conf = broadcastedConf.value.value // "new Job" will make a copy of the conf. Then, it is // safe to mutate conf properties with initLocalJobFuncOpt // and initDriverSideJobFuncOpt. - val newJob = new Job(conf) + val newJob = Job.getInstance(conf) initLocalJobFuncOpt.map(f => f(newJob)) newJob } @@ -87,7 +85,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( if (isDriverSide) { initDriverSideJobFuncOpt.map(f => f(job)) } - SparkHadoopUtil.get.getConfigurationFromJobContext(job) + job.getConfiguration } private val jobTrackerId: String = { @@ -110,7 +108,7 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - val jobContext = newJobContext(conf, jobId) + val jobContext = new JobContextImpl(conf, jobId) val rawSplits = inputFormat.getSplits(jobContext).toArray val result = new Array[SparkPartition](rawSplits.size) for (i <- 0 until rawSplits.size) { @@ -154,8 +152,8 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( configurable.setConf(conf) case _ => } - val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0) - val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId) + val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) + val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) private[this] var reader: RecordReader[Void, V] = null /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index ad5536725889..4f8524f4b967 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -24,16 +24,16 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter} +import org.apache.hadoop.mapreduce.task.TaskAttemptContextImpl + import org.apache.spark._ -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.mapred.SparkHadoopMapRedUtil -import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory} -import org.apache.spark.sql.types.{StructType, StringType} +import org.apache.spark.sql.types.{IntegerType, StructType, StringType} import org.apache.spark.util.SerializableConfiguration @@ -41,14 +41,12 @@ private[sql] abstract class BaseWriterContainer( @transient val relation: HadoopFsRelation, @transient private val job: Job, isAppend: Boolean) - extends SparkHadoopMapReduceUtil - with Logging - with Serializable { + extends Logging with Serializable { protected val dataSchema = relation.dataSchema protected val serializableConf = - new SerializableConfiguration(SparkHadoopUtil.get.getConfigurationFromJobContext(job)) + new SerializableConfiguration(job.getConfiguration) // This UUID is used to avoid output file name collision between different appending write jobs. // These jobs may belong to different SparkContext instances. Concrete data source implementations @@ -90,8 +88,7 @@ private[sql] abstract class BaseWriterContainer( // This UUID is sent to executor side together with the serialized `Configuration` object within // the `Job` instance. `OutputWriters` on the executor side should use this UUID to generate // unique task output files. - SparkHadoopUtil.get.getConfigurationFromJobContext(job). - set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) + job.getConfiguration.set("spark.sql.sources.writeJobUUID", uniqueWriteJobId.toString) // Order of the following two lines is important. For Hadoop 1, TaskAttemptContext constructor // clones the Configuration object passed in. If we initialize the TaskAttemptContext first, @@ -101,7 +98,7 @@ private[sql] abstract class BaseWriterContainer( // committer, since their initialization involve the job configuration, which can be potentially // decorated in `prepareJobForWrite`. outputWriterFactory = relation.prepareJobForWrite(job) - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId) outputFormatClass = job.getOutputFormatClass outputCommitter = newOutputCommitter(taskAttemptContext) @@ -111,7 +108,7 @@ private[sql] abstract class BaseWriterContainer( def executorSideSetup(taskContext: TaskContext): Unit = { setupIDs(taskContext.stageId(), taskContext.partitionId(), taskContext.attemptNumber()) setupConf() - taskAttemptContext = newTaskAttemptContext(serializableConf.value, taskAttemptId) + taskAttemptContext = new TaskAttemptContextImpl(serializableConf.value, taskAttemptId) outputCommitter = newOutputCommitter(taskAttemptContext) outputCommitter.setupTask(taskAttemptContext) } @@ -124,9 +121,9 @@ private[sql] abstract class BaseWriterContainer( } } - protected def newOutputWriter(path: String): OutputWriter = { + protected def newOutputWriter(path: String, bucketId: Option[Int] = None): OutputWriter = { try { - outputWriterFactory.newInstance(path, dataSchema, taskAttemptContext) + outputWriterFactory.newInstance(path, bucketId, dataSchema, taskAttemptContext) } catch { case e: org.apache.hadoop.fs.FileAlreadyExistsException => if (outputCommitter.isInstanceOf[parquet.DirectParquetOutputCommitter]) { @@ -166,7 +163,7 @@ private[sql] abstract class BaseWriterContainer( "because spark.speculation is configured to be true.") defaultOutputCommitter } else { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val committerClass = configuration.getClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, null, classOf[OutputCommitter]) @@ -201,10 +198,8 @@ private[sql] abstract class BaseWriterContainer( private def setupIDs(jobId: Int, splitId: Int, attemptId: Int): Unit = { this.jobId = SparkHadoopWriter.createJobID(new Date, jobId) - this.taskId = new TaskID(this.jobId, true, splitId) - // scalastyle:off jobcontext + this.taskId = new TaskID(this.jobId, TaskType.MAP, splitId) this.taskAttemptId = new TaskAttemptID(taskId, attemptId) - // scalastyle:on jobcontext } private def setupConf(): Unit = { @@ -250,7 +245,7 @@ private[sql] class DefaultWriterContainer( def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { executorSideSetup(taskContext) - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) + val configuration = taskAttemptContext.getConfiguration configuration.set("spark.sql.sources.output.path", outputPath) val writer = newOutputWriter(getWorkPath) writer.initConverter(dataSchema) @@ -317,26 +312,148 @@ private[sql] class DynamicPartitionWriterContainer( isAppend: Boolean) extends BaseWriterContainer(relation, job, isAppend) { - def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { - val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] - executorSideSetup(taskContext) + private val bucketSpec = relation.bucketSpec - var outputWritersCleared = false + private val bucketColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { + spec => spec.bucketColumnNames.map(c => inputSchema.find(_.name == c).get) + } - // Returns the partition key given an input row - val getPartitionKey = UnsafeProjection.create(partitionColumns, inputSchema) - // Returns the data columns to be written given an input row - val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) + private val sortColumns: Seq[Attribute] = bucketSpec.toSeq.flatMap { + spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get) + } - // Expressions that given a partition key build a string like: col1=val/col2=val/... - val partitionStringExpression = partitionColumns.zipWithIndex.flatMap { case (c, i) => + private def bucketIdExpression: Option[Expression] = for { + BucketSpec(numBuckets, _, _) <- bucketSpec + } yield Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets)) + + // Expressions that given a partition key build a string like: col1=val/col2=val/... + private def partitionStringExpression: Seq[Expression] = { + partitionColumns.zipWithIndex.flatMap { case (c, i) => val escaped = ScalaUDF( - PartitioningUtils.escapePathName _, StringType, Seq(Cast(c, StringType)), Seq(StringType)) + PartitioningUtils.escapePathName _, + StringType, + Seq(Cast(c, StringType)), + Seq(StringType)) val str = If(IsNull(c), Literal(defaultPartitionName), escaped) val partitionName = Literal(c.name + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } + } + + private def getBucketIdFromKey(key: InternalRow): Option[Int] = { + if (bucketSpec.isDefined) { + Some(key.getInt(partitionColumns.length)) + } else { + None + } + } + + private def sameBucket(key1: UnsafeRow, key2: UnsafeRow): Boolean = { + val bucketIdIndex = partitionColumns.length + if (key1.getInt(bucketIdIndex) != key2.getInt(bucketIdIndex)) { + false + } else { + var i = partitionColumns.length - 1 + while (i >= 0) { + val dt = partitionColumns(i).dataType + if (key1.get(i, dt) != key2.get(i, dt)) return false + i -= 1 + } + true + } + } + + private def sortBasedWrite( + sorter: UnsafeKVExternalSorter, + iterator: Iterator[InternalRow], + getSortingKey: UnsafeProjection, + getOutputRow: UnsafeProjection, + getPartitionString: UnsafeProjection, + outputWriters: java.util.HashMap[InternalRow, OutputWriter]): Unit = { + while (iterator.hasNext) { + val currentRow = iterator.next() + sorter.insertKV(getSortingKey(currentRow), getOutputRow(currentRow)) + } + + logInfo(s"Sorting complete. Writing out partition files one at a time.") + + val needNewWriter: (UnsafeRow, UnsafeRow) => Boolean = if (sortColumns.isEmpty) { + (key1, key2) => key1 != key2 + } else { + (key1, key2) => key1 == null || !sameBucket(key1, key2) + } + + val sortedIterator = sorter.sortedIterator() + var currentKey: UnsafeRow = null + var currentWriter: OutputWriter = null + try { + while (sortedIterator.next()) { + if (needNewWriter(currentKey, sortedIterator.getKey)) { + if (currentWriter != null) { + currentWriter.close() + } + currentKey = sortedIterator.getKey.copy() + logDebug(s"Writing partition: $currentKey") + + // Either use an existing file from before, or open a new one. + currentWriter = outputWriters.remove(currentKey) + if (currentWriter == null) { + currentWriter = newOutputWriter(currentKey, getPartitionString) + } + } + + currentWriter.writeInternal(sortedIterator.getValue) + } + } finally { + if (currentWriter != null) { currentWriter.close() } + } + } + + /** + * Open and returns a new OutputWriter given a partition key and optional bucket id. + * If bucket id is specified, we will append it to the end of the file name, but before the + * file extension, e.g. part-r-00009-ea518ad4-455a-4431-b471-d24e03814677-00002.gz.parquet + */ + private def newOutputWriter( + key: InternalRow, + getPartitionString: UnsafeProjection): OutputWriter = { + val configuration = taskAttemptContext.getConfiguration + val path = if (partitionColumns.nonEmpty) { + val partitionPath = getPartitionString(key).getString(0) + configuration.set( + "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) + new Path(getWorkPath, partitionPath).toString + } else { + configuration.set("spark.sql.sources.output.path", outputPath) + getWorkPath + } + val bucketId = getBucketIdFromKey(key) + val newWriter = super.newOutputWriter(path, bucketId) + newWriter.initConverter(dataSchema) + newWriter + } + + def writeRows(taskContext: TaskContext, iterator: Iterator[InternalRow]): Unit = { + val outputWriters = new java.util.HashMap[InternalRow, OutputWriter] + executorSideSetup(taskContext) + + var outputWritersCleared = false + + // We should first sort by partition columns, then bucket id, and finally sorting columns. + val getSortingKey = + UnsafeProjection.create(partitionColumns ++ bucketIdExpression ++ sortColumns, inputSchema) + + val sortingKeySchema = if (bucketSpec.isEmpty) { + StructType.fromAttributes(partitionColumns) + } else { // If it's bucketed, we should also consider bucket id as part of the key. + val fields = StructType.fromAttributes(partitionColumns) + .add("bucketId", IntegerType, nullable = false) ++ StructType.fromAttributes(sortColumns) + StructType(fields) + } + + // Returns the data columns to be written given an input row + val getOutputRow = UnsafeProjection.create(dataColumns, inputSchema) // Returns the partition path given a partition key. val getPartitionString = @@ -344,22 +461,34 @@ private[sql] class DynamicPartitionWriterContainer( // If anything below fails, we should abort the task. try { - // This will be filled in if we have to fall back on sorting. - var sorter: UnsafeKVExternalSorter = null + // If there is no sorting columns, we set sorter to null and try the hash-based writing first, + // and fill the sorter if there are too many writers and we need to fall back on sorting. + // If there are sorting columns, then we have to sort the data anyway, and no need to try the + // hash-based writing first. + var sorter: UnsafeKVExternalSorter = if (sortColumns.nonEmpty) { + new UnsafeKVExternalSorter( + sortingKeySchema, + StructType.fromAttributes(dataColumns), + SparkEnv.get.blockManager, + TaskContext.get().taskMemoryManager().pageSizeBytes) + } else { + null + } while (iterator.hasNext && sorter == null) { val inputRow = iterator.next() - val currentKey = getPartitionKey(inputRow) + // When we reach here, the `sortColumns` must be empty, so the sorting key is hashing key. + val currentKey = getSortingKey(inputRow) var currentWriter = outputWriters.get(currentKey) if (currentWriter == null) { if (outputWriters.size < maxOpenFiles) { - currentWriter = newOutputWriter(currentKey) + currentWriter = newOutputWriter(currentKey, getPartitionString) outputWriters.put(currentKey.copy(), currentWriter) currentWriter.writeInternal(getOutputRow(inputRow)) } else { logInfo(s"Maximum partitions reached, falling back on sorting.") sorter = new UnsafeKVExternalSorter( - StructType.fromAttributes(partitionColumns), + sortingKeySchema, StructType.fromAttributes(dataColumns), SparkEnv.get.blockManager, TaskContext.get().taskMemoryManager().pageSizeBytes) @@ -371,39 +500,15 @@ private[sql] class DynamicPartitionWriterContainer( } // If the sorter is not null that means that we reached the maxFiles above and need to finish - // using external sort. + // using external sort, or there are sorting columns and we need to sort the whole data set. if (sorter != null) { - while (iterator.hasNext) { - val currentRow = iterator.next() - sorter.insertKV(getPartitionKey(currentRow), getOutputRow(currentRow)) - } - - logInfo(s"Sorting complete. Writing out partition files one at a time.") - - val sortedIterator = sorter.sortedIterator() - var currentKey: InternalRow = null - var currentWriter: OutputWriter = null - try { - while (sortedIterator.next()) { - if (currentKey != sortedIterator.getKey) { - if (currentWriter != null) { - currentWriter.close() - } - currentKey = sortedIterator.getKey.copy() - logDebug(s"Writing partition: $currentKey") - - // Either use an existing file from before, or open a new one. - currentWriter = outputWriters.remove(currentKey) - if (currentWriter == null) { - currentWriter = newOutputWriter(currentKey) - } - } - - currentWriter.writeInternal(sortedIterator.getValue) - } - } finally { - if (currentWriter != null) { currentWriter.close() } - } + sortBasedWrite( + sorter, + iterator, + getSortingKey, + getOutputRow, + getPartitionString, + outputWriters) } commitTask() @@ -414,18 +519,6 @@ private[sql] class DynamicPartitionWriterContainer( throw new SparkException("Task failed while writing rows.", cause) } - /** Open and returns a new OutputWriter given a partition key. */ - def newOutputWriter(key: InternalRow): OutputWriter = { - val partitionPath = getPartitionString(key).getString(0) - val path = new Path(getWorkPath, partitionPath) - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(taskAttemptContext) - configuration.set( - "spark.sql.sources.output.path", new Path(outputPath, partitionPath).toString) - val newWriter = super.newOutputWriter(path.toString) - newWriter.initConverter(dataSchema) - newWriter - } - def clearOutputWriters(): Unit = { if (!outputWritersCleared) { outputWriters.asScala.values.foreach(_.close()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala new file mode 100644 index 000000000000..82287c896713 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/bucket.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import org.apache.hadoop.mapreduce.TaskAttemptContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory, HadoopFsRelationProvider, HadoopFsRelation} +import org.apache.spark.sql.types.StructType + +/** + * A container for bucketing information. + * Bucketing is a technology for decomposing data sets into more manageable parts, and the number + * of buckets is fixed so it does not fluctuate with data. + * + * @param numBuckets number of buckets. + * @param bucketColumnNames the names of the columns that used to generate the bucket id. + * @param sortColumnNames the names of the columns that used to sort data in each bucket. + */ +private[sql] case class BucketSpec( + numBuckets: Int, + bucketColumnNames: Seq[String], + sortColumnNames: Seq[String]) + +private[sql] trait BucketedHadoopFsRelationProvider extends HadoopFsRelationProvider { + final override def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + parameters: Map[String, String]): HadoopFsRelation = + // TODO: throw exception here as we won't call this method during execution, after bucketed read + // support is finished. + createRelation(sqlContext, paths, dataSchema, partitionColumns, bucketSpec = None, parameters) +} + +private[sql] abstract class BucketedOutputWriterFactory extends OutputWriterFactory { + final override def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = + throw new UnsupportedOperationException("use bucket version") +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index e7deeff13dc4..0897fcadbc01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Returned for the "DESCRIBE [EXTENDED] [dbName.]tableName" command. @@ -42,7 +42,7 @@ case class DescribeCommand( new MetadataBuilder().putString("comment", "name of the column").build())(), AttributeReference("data_type", StringType, nullable = false, new MetadataBuilder().putString("comment", "data type of the column").build())(), - AttributeReference("comment", StringType, nullable = false, + AttributeReference("comment", StringType, nullable = true, new MetadataBuilder().putString("comment", "comment of the column").build())() ) } @@ -76,6 +76,7 @@ case class CreateTableUsingAsSelect( provider: String, temporary: Boolean, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], child: LogicalPlan) extends UnaryNode { @@ -109,7 +110,14 @@ case class CreateTempTableUsingAsSelect( override def run(sqlContext: SQLContext): Seq[Row] = { val df = DataFrame(sqlContext, query) - val resolved = ResolvedDataSource(sqlContext, provider, partitionColumns, mode, options, df) + val resolved = ResolvedDataSource( + sqlContext, + provider, + partitionColumns, + bucketSpec = None, + mode, + options, + df) sqlContext.catalog.registerTable( tableIdent, DataFrame(sqlContext, LogicalRelation(resolved.relation)).logicalPlan) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala index f522303be94a..4dcd261f5cbe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DefaultSource.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.util.Properties import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.sources.{BaseRelation, RelationProvider, DataSourceRegister} +import org.apache.spark.sql.sources.{BaseRelation, DataSourceRegister, RelationProvider} class DefaultSource extends RelationProvider with DataSourceRegister { @@ -31,15 +31,12 @@ class DefaultSource extends RelationProvider with DataSourceRegister { sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = { val url = parameters.getOrElse("url", sys.error("Option 'url' not specified")) - val driver = parameters.getOrElse("driver", null) val table = parameters.getOrElse("dbtable", sys.error("Option 'dbtable' not specified")) val partitionColumn = parameters.getOrElse("partitionColumn", null) val lowerBound = parameters.getOrElse("lowerBound", null) val upperBound = parameters.getOrElse("upperBound", null) val numPartitions = parameters.getOrElse("numPartitions", null) - if (driver != null) DriverRegistry.register(driver) - if (partitionColumn != null && (lowerBound == null || upperBound == null || numPartitions == null)) { sys.error("Partitioning incompletely specified") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala index 7ccd61ed469e..65af397451c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/DriverRegistry.scala @@ -51,10 +51,5 @@ object DriverRegistry extends Logging { } } } - - def getDriverClassName(url: String): String = DriverManager.getDriver(url) match { - case wrapper: DriverWrapper => wrapper.wrapped.getClass.getCanonicalName - case driver => driver.getClass.getCanonicalName - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala index 2d38562e0901..d867e144e517 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRDD.scala @@ -17,22 +17,22 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, Date, DriverManager, ResultSet, ResultSetMetaData, SQLException, Timestamp} +import java.sql.{Connection, Date, ResultSet, ResultSetMetaData, SQLException, Timestamp} import java.util.Properties import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils +import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow -import org.apache.spark.sql.catalyst.util.{GenericArrayData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, GenericArrayData} import org.apache.spark.sql.jdbc.JdbcDialects import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} /** * Data corresponding to one partition of a JDBCRDD. @@ -41,7 +41,6 @@ private[sql] case class JDBCPartition(whereClause: String, idx: Int) extends Par override def index: Int = idx } - private[sql] object JDBCRDD extends Logging { /** @@ -120,32 +119,37 @@ private[sql] object JDBCRDD extends Logging { */ def resolveTable(url: String, table: String, properties: Properties): StructType = { val dialect = JdbcDialects.get(url) - val conn: Connection = getConnector(properties.getProperty("driver"), url, properties)() + val conn: Connection = JdbcUtils.createConnectionFactory(url, properties)() try { - val rs = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0").executeQuery() + val statement = conn.prepareStatement(s"SELECT * FROM $table WHERE 1=0") try { - val rsmd = rs.getMetaData - val ncols = rsmd.getColumnCount - val fields = new Array[StructField](ncols) - var i = 0 - while (i < ncols) { - val columnName = rsmd.getColumnLabel(i + 1) - val dataType = rsmd.getColumnType(i + 1) - val typeName = rsmd.getColumnTypeName(i + 1) - val fieldSize = rsmd.getPrecision(i + 1) - val fieldScale = rsmd.getScale(i + 1) - val isSigned = rsmd.isSigned(i + 1) - val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls - val metadata = new MetadataBuilder().putString("name", columnName) - val columnType = - dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( - getCatalystType(dataType, fieldSize, fieldScale, isSigned)) - fields(i) = StructField(columnName, columnType, nullable, metadata.build()) - i = i + 1 + val rs = statement.executeQuery() + try { + val rsmd = rs.getMetaData + val ncols = rsmd.getColumnCount + val fields = new Array[StructField](ncols) + var i = 0 + while (i < ncols) { + val columnName = rsmd.getColumnLabel(i + 1) + val dataType = rsmd.getColumnType(i + 1) + val typeName = rsmd.getColumnTypeName(i + 1) + val fieldSize = rsmd.getPrecision(i + 1) + val fieldScale = rsmd.getScale(i + 1) + val isSigned = rsmd.isSigned(i + 1) + val nullable = rsmd.isNullable(i + 1) != ResultSetMetaData.columnNoNulls + val metadata = new MetadataBuilder().putString("name", columnName) + val columnType = + dialect.getCatalystType(dataType, typeName, fieldSize, metadata).getOrElse( + getCatalystType(dataType, fieldSize, fieldScale, isSigned)) + fields(i) = StructField(columnName, columnType, nullable, metadata.build()) + i = i + 1 + } + return new StructType(fields) + } finally { + rs.close() } - return new StructType(fields) } finally { - rs.close() + statement.close() } } finally { conn.close() @@ -163,40 +167,73 @@ private[sql] object JDBCRDD extends Logging { * @return A Catalyst schema corresponding to columns in the given order. */ private def pruneSchema(schema: StructType, columns: Array[String]): StructType = { - val fieldMap = Map(schema.fields map { x => x.metadata.getString("name") -> x }: _*) - new StructType(columns map { name => fieldMap(name) }) + val fieldMap = Map(schema.fields.map(x => x.metadata.getString("name") -> x): _*) + new StructType(columns.map(name => fieldMap(name))) } /** - * Given a driver string and an url, return a function that loads the - * specified driver string then returns a connection to the JDBC url. - * getConnector is run on the driver code, while the function it returns - * is run on the executor. - * - * @param driver - The class name of the JDBC driver for the given url, or null if the class name - * is not necessary. - * @param url - The JDBC url to connect to. - * - * @return A function that loads the driver and connects to the url. + * Converts value to SQL expression. */ - def getConnector(driver: String, url: String, properties: Properties): () => Connection = { - () => { - try { - if (driver != null) DriverRegistry.register(driver) - } catch { - case e: ClassNotFoundException => - logWarning(s"Couldn't find class $driver", e) - } - DriverManager.getConnection(url, properties) - } + private def compileValue(value: Any): Any = value match { + case stringValue: String => s"'${escapeSql(stringValue)}'" + case timestampValue: Timestamp => "'" + timestampValue + "'" + case dateValue: Date => "'" + dateValue + "'" + case arrayValue: Array[Any] => arrayValue.map(compileValue).mkString(", ") + case _ => value } + private def escapeSql(value: String): String = + if (value == null) null else StringUtils.replace(value, "'", "''") + + /** + * Turns a single Filter into a String representing a SQL expression. + * Returns None for an unhandled filter. + */ + private def compileFilter(f: Filter): Option[String] = { + Option(f match { + case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" + case EqualNullSafe(attr, value) => + s"(NOT ($attr != ${compileValue(value)} OR $attr IS NULL OR " + + s"${compileValue(value)} IS NULL) OR ($attr IS NULL AND ${compileValue(value)} IS NULL))" + case LessThan(attr, value) => s"$attr < ${compileValue(value)}" + case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" + case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" + case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" + case IsNull(attr) => s"$attr IS NULL" + case IsNotNull(attr) => s"$attr IS NOT NULL" + case StringStartsWith(attr, value) => s"${attr} LIKE '${value}%'" + case StringEndsWith(attr, value) => s"${attr} LIKE '%${value}'" + case StringContains(attr, value) => s"${attr} LIKE '%${value}%'" + case In(attr, value) => s"$attr IN (${compileValue(value)})" + case Not(f) => compileFilter(f).map(p => s"(NOT ($p))").getOrElse(null) + case Or(f1, f2) => + // We can't compile Or filter unless both sub-filters are compiled successfully. + // It applies too for the following And filter. + // If we can make sure compileFilter supports all filters, we can remove this check. + val or = Seq(f1, f2).map(compileFilter(_)).flatten + if (or.size == 2) { + or.map(p => s"($p)").mkString(" OR ") + } else { + null + } + case And(f1, f2) => + val and = Seq(f1, f2).map(compileFilter(_)).flatten + if (and.size == 2) { + and.map(p => s"($p)").mkString(" AND ") + } else { + null + } + case _ => null + }) + } + + + /** * Build and return JDBCRDD from the given information. * * @param sc - Your SparkContext. * @param schema - The Catalyst schema of the underlying database table. - * @param driver - The class name of the JDBC driver for the given url. * @param url - The JDBC url to connect to. * @param fqTable - The fully-qualified table name (or paren'd SQL query) to use. * @param requiredColumns - The names of the columns to SELECT. @@ -209,7 +246,6 @@ private[sql] object JDBCRDD extends Logging { def scanTable( sc: SparkContext, schema: StructType, - driver: String, url: String, properties: Properties, fqTable: String, @@ -220,7 +256,7 @@ private[sql] object JDBCRDD extends Logging { val quotedColumns = requiredColumns.map(colName => dialect.quoteIdentifier(colName)) new JDBCRDD( sc, - getConnector(driver, url, properties), + JdbcUtils.createConnectionFactory(url, properties), pruneSchema(schema, requiredColumns), fqTable, quotedColumns, @@ -262,57 +298,24 @@ private[sql] class JDBCRDD( if (sb.length == 0) "1" else sb.substring(1) } - /** - * Converts value to SQL expression. - */ - private def compileValue(value: Any): Any = value match { - case stringValue: String => s"'${escapeSql(stringValue)}'" - case timestampValue: Timestamp => "'" + timestampValue + "'" - case dateValue: Date => "'" + dateValue + "'" - case _ => value - } - - private def escapeSql(value: String): String = - if (value == null) null else StringUtils.replace(value, "'", "''") - - /** - * Turns a single Filter into a String representing a SQL expression. - * Returns null for an unhandled filter. - */ - private def compileFilter(f: Filter): String = f match { - case EqualTo(attr, value) => s"$attr = ${compileValue(value)}" - case Not(EqualTo(attr, value)) => s"$attr != ${compileValue(value)}" - case LessThan(attr, value) => s"$attr < ${compileValue(value)}" - case GreaterThan(attr, value) => s"$attr > ${compileValue(value)}" - case LessThanOrEqual(attr, value) => s"$attr <= ${compileValue(value)}" - case GreaterThanOrEqual(attr, value) => s"$attr >= ${compileValue(value)}" - case IsNull(attr) => s"$attr IS NULL" - case IsNotNull(attr) => s"$attr IS NOT NULL" - case _ => null - } - /** * `filters`, but as a WHERE clause suitable for injection into a SQL query. */ - private val filterWhereClause: String = { - val filterStrings = filters map compileFilter filter (_ != null) - if (filterStrings.size > 0) { - val sb = new StringBuilder("WHERE ") - filterStrings.foreach(x => sb.append(x).append(" AND ")) - sb.substring(0, sb.length - 5) - } else "" - } + private val filterWhereClause: String = + filters.map(JDBCRDD.compileFilter).flatten.mkString(" AND ") /** * A WHERE clause representing both `filters`, if any, and the current partition. */ private def getWhereClause(part: JDBCPartition): String = { if (part.whereClause != null && filterWhereClause.length > 0) { - filterWhereClause + " AND " + part.whereClause + "WHERE " + filterWhereClause + " AND " + part.whereClause } else if (part.whereClause != null) { "WHERE " + part.whereClause + } else if (filterWhereClause.length > 0) { + "WHERE " + filterWhereClause } else { - filterWhereClause + "" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index f9300dc2cb52..572be823ca87 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -23,9 +23,9 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.Partition import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SaveMode, SQLContext} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, Row, SQLContext, SaveMode} /** * Instructions on how to partition the table among workers. @@ -91,12 +91,10 @@ private[sql] case class JDBCRelation( override val schema: StructType = JDBCRDD.resolveTable(url, table, properties) override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { - val driver: String = DriverRegistry.getDriverClassName(url) // Rely on a type erasure hack to pass RDD[InternalRow] back as RDD[Row] JDBCRDD.scanTable( sqlContext.sparkContext, schema, - driver, url, properties, table, @@ -110,4 +108,9 @@ private[sql] case class JDBCRelation( .mode(if (overwrite) SaveMode.Overwrite else SaveMode.Append) .jdbc(url, table, properties) } + + override def toString: String = { + // credentials should not be included in the plan output, table information is sufficient. + s"JDBCRelation(${table})" + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 252f1cfd5d9c..69ba84646f08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -17,16 +17,17 @@ package org.apache.spark.sql.execution.datasources.jdbc -import java.sql.{Connection, PreparedStatement} +import java.sql.{Connection, Driver, DriverManager, PreparedStatement} import java.util.Properties +import scala.collection.JavaConverters._ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcType, JdbcDialects} -import org.apache.spark.sql.types._ import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.jdbc.{JdbcDialect, JdbcDialects, JdbcType} +import org.apache.spark.sql.types._ /** * Util functions for JDBC tables. @@ -34,10 +35,31 @@ import org.apache.spark.sql.{DataFrame, Row} object JdbcUtils extends Logging { /** - * Establishes a JDBC connection. + * Returns a factory for creating connections to the given JDBC URL. + * + * @param url the JDBC url to connect to. + * @param properties JDBC connection properties. */ - def createConnection(url: String, connectionProperties: Properties): Connection = { - JDBCRDD.getConnector(connectionProperties.getProperty("driver"), url, connectionProperties)() + def createConnectionFactory(url: String, properties: Properties): () => Connection = { + val userSpecifiedDriverClass = Option(properties.getProperty("driver")) + userSpecifiedDriverClass.foreach(DriverRegistry.register) + // Performing this part of the logic on the driver guards against the corner-case where the + // driver returned for a URL is different on the driver and executors due to classpath + // differences. + val driverClass: String = userSpecifiedDriverClass.getOrElse { + DriverManager.getDriver(url).getClass.getCanonicalName + } + () => { + userSpecifiedDriverClass.foreach(DriverRegistry.register) + val driver: Driver = DriverManager.getDrivers.asScala.collectFirst { + case d: DriverWrapper if d.wrapped.getClass.getCanonicalName == driverClass => d + case d if d.getClass.getCanonicalName == driverClass => d + }.getOrElse { + throw new IllegalStateException( + s"Did not find registered driver with class $driverClass") + } + driver.connect(url, properties) + } } /** @@ -49,28 +71,36 @@ object JdbcUtils extends Logging { // Somewhat hacky, but there isn't a good way to identify whether a table exists for all // SQL database systems using JDBC meta data calls, considering "table" could also include // the database name. Query used to find table exists can be overriden by the dialects. - Try(conn.prepareStatement(dialect.getTableExistsQuery(table)).executeQuery()).isSuccess + Try { + val statement = conn.prepareStatement(dialect.getTableExistsQuery(table)) + try { + statement.executeQuery() + } finally { + statement.close() + } + }.isSuccess } /** * Drops a table from the JDBC database. */ def dropTable(conn: Connection, table: String): Unit = { - conn.createStatement.executeUpdate(s"DROP TABLE $table") + val statement = conn.createStatement + try { + statement.executeUpdate(s"DROP TABLE $table") + } finally { + statement.close() + } } /** * Returns a PreparedStatement that inserts a row into table via conn. */ def insertStatement(conn: Connection, table: String, rddSchema: StructType): PreparedStatement = { - val sql = new StringBuilder(s"INSERT INTO $table VALUES (") - var fieldsLeft = rddSchema.fields.length - while (fieldsLeft > 0) { - sql.append("?") - if (fieldsLeft > 1) sql.append(", ") else sql.append(")") - fieldsLeft = fieldsLeft - 1 - } - conn.prepareStatement(sql.toString()) + val columns = rddSchema.fields.map(_.name).mkString(",") + val placeholders = rddSchema.fields.map(_ => "?").mkString(",") + val sql = s"INSERT INTO $table ($columns) VALUES ($placeholders)" + conn.prepareStatement(sql) } /** @@ -234,15 +264,14 @@ object JdbcUtils extends Logging { df: DataFrame, url: String, table: String, - properties: Properties = new Properties()) { + properties: Properties) { val dialect = JdbcDialects.get(url) val nullTypes: Array[Int] = df.schema.fields.map { field => getJdbcType(field.dataType, dialect).jdbcNullType } val rddSchema = df.schema - val driver: String = DriverRegistry.getDriverClassName(url) - val getConnection: () => Connection = JDBCRDD.getConnector(driver, url, properties) + val getConnection: () => Connection = createConnectionFactory(url, properties) val batchSize = properties.getProperty("batchsize", "1000").toInt df.foreachPartition { iterator => savePartition(getConnection, table, iterator, rddSchema, nullTypes, batchSize, dialect) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala index 59ba4ae2cba0..563c3903daa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/InferSchema.scala @@ -169,7 +169,6 @@ private[json] object InferSchema { None } - case NullType => Some(StringType) case other => Some(other) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala index c132ead20e7d..aee9cf2bdbca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONOptions.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.datasources.json -import com.fasterxml.jackson.core.{JsonParser, JsonFactory} +import com.fasterxml.jackson.core.{JsonFactory, JsonParser} /** * Options for the JSON data source. @@ -31,7 +31,8 @@ case class JSONOptions( allowUnquotedFieldNames: Boolean = false, allowSingleQuotes: Boolean = true, allowNumericLeadingZeros: Boolean = false, - allowNonNumericNumbers: Boolean = false) { + allowNonNumericNumbers: Boolean = false, + allowBackslashEscapingAnyCharacter: Boolean = false) { /** Sets config options on a Jackson [[JsonFactory]]. */ def setJacksonOptions(factory: JsonFactory): Unit = { @@ -40,6 +41,8 @@ case class JSONOptions( factory.configure(JsonParser.Feature.ALLOW_SINGLE_QUOTES, allowSingleQuotes) factory.configure(JsonParser.Feature.ALLOW_NUMERIC_LEADING_ZEROS, allowNumericLeadingZeros) factory.configure(JsonParser.Feature.ALLOW_NON_NUMERIC_NUMBERS, allowNonNumericNumbers) + factory.configure(JsonParser.Feature.ALLOW_BACKSLASH_ESCAPING_ANY_CHARACTER, + allowBackslashEscapingAnyCharacter) } } @@ -59,6 +62,8 @@ object JSONOptions { allowNumericLeadingZeros = parameters.get("allowNumericLeadingZeros").map(_.toBoolean).getOrElse(false), allowNonNumericNumbers = - parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true) + parameters.get("allowNonNumericNumbers").map(_.toBoolean).getOrElse(true), + allowBackslashEscapingAnyCharacter = + parameters.get("allowBackslashEscapingAnyCharacter").map(_.toBoolean).getOrElse(false) ) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala index 3e61ba35bea8..8a6fa4aeebc0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JSONRelation.scala @@ -24,25 +24,23 @@ import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{LongWritable, NullWritable, Text} import org.apache.hadoop.mapred.{JobConf, TextInputFormat} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "json" @@ -51,6 +49,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { new JSONRelation( @@ -58,6 +57,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { maybeDataSchema = dataSchema, maybePartitionSpec = None, userDefinedPartitionColumns = partitionColumns, + bucketSpec = bucketSpec, paths = paths, parameters = parameters)(sqlContext) } @@ -68,6 +68,7 @@ private[sql] class JSONRelation( val maybeDataSchema: Option[StructType], val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + override val bucketSpec: Option[BucketSpec] = None, override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) @@ -89,8 +90,8 @@ private[sql] class JSONRelation( override val needConversion: Boolean = false private def createBaseRdd(inputPaths: Array[FileStatus]): RDD[String] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration val paths = inputPaths.map(_.getPath) @@ -160,13 +161,14 @@ private[sql] class JSONRelation( partitionColumns) } - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - new OutputWriterFactory { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { + new BucketedOutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new JsonOutputWriter(path, dataSchema, context) + new JsonOutputWriter(path, bucketId, dataSchema, context) } } } @@ -174,9 +176,10 @@ private[sql] class JSONRelation( private[json] class JsonOutputWriter( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with Logging { + extends OutputWriter with Logging { private[this] val writer = new CharArrayWriter() // create the Generator without separator inserted between 2 records @@ -186,11 +189,12 @@ private[json] class JsonOutputWriter( private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } }.getRecordWriter(context) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala index 3f34520afe6b..078e1cbec577 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonGenerator.scala @@ -17,14 +17,13 @@ package org.apache.spark.sql.execution.datasources.json -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.{MapData, ArrayData, DateTimeUtils} - import scala.collection.Map import com.fasterxml.jackson.core._ import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.util.{ArrayData, DateTimeUtils, MapData} import org.apache.spark.sql.types._ private[sql] object JacksonGenerator { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala index 55a1c24e9e00..2e3fe3da1538 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/json/JacksonParser.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.ByteArrayOutputStream + import scala.collection.mutable.ArrayBuffer import com.fasterxml.jackson.core._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala index a958373eb769..c3b7483e80ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystReadSupport.scala @@ -22,11 +22,11 @@ import java.util.{Map => JMap} import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration -import org.apache.parquet.hadoop.api.ReadSupport.ReadContext import org.apache.parquet.hadoop.api.{InitContext, ReadSupport} +import org.apache.parquet.hadoop.api.ReadSupport.ReadContext import org.apache.parquet.io.api.RecordMaterializer -import org.apache.parquet.schema.Type.Repetition import org.apache.parquet.schema._ +import org.apache.parquet.schema.Type.Repetition import org.apache.spark.Logging import org.apache.spark.deploy.SparkHadoopUtil @@ -58,9 +58,7 @@ private[parquet] class CatalystReadSupport extends ReadSupport[InternalRow] with */ override def init(context: InitContext): ReadContext = { catalystRequestedSchema = { - // scalastyle:off jobcontext val conf = context.getConfiguration - // scalastyle:on jobcontext val schemaString = conf.get(CatalystReadSupport.SPARK_ROW_REQUESTED_SCHEMA) assert(schemaString != null, "Parquet requested schema not set.") StructType.fromString(schemaString) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala index 8851bc23cd05..42d89f4bf81d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystRowConverter.scala @@ -25,14 +25,14 @@ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} -import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{DOUBLE, INT32, INT64, BINARY, FIXED_LEN_BYTE_ARRAY} import org.apache.parquet.schema.{GroupType, MessageType, PrimitiveType, Type} +import org.apache.parquet.schema.OriginalType.{INT_32, LIST, UTF8} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, DOUBLE, FIXED_LEN_BYTE_ARRAY, INT32, INT64} import org.apache.spark.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData, DateTimeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala index 5f9f9083098a..fb97a03df60f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystSchemaConverter.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.datasources.parquet import scala.collection.JavaConverters._ import org.apache.hadoop.conf.Configuration +import org.apache.parquet.schema._ import org.apache.parquet.schema.OriginalType._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ import org.apache.parquet.schema.Type.Repetition._ -import org.apache.parquet.schema._ -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, maxPrecisionForBytes} -import org.apache.spark.sql.types._ import org.apache.spark.sql.{AnalysisException, SQLConf} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{maxPrecisionForBytes, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} +import org.apache.spark.sql.types._ /** * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]] and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala index 6862dea5e6c3..e78afa5ae6d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/CatalystWriteSupport.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64, minBytesForPrecision} +import org.apache.spark.sql.execution.datasources.parquet.CatalystSchemaConverter.{minBytesForPrecision, MAX_PRECISION_FOR_INT32, MAX_PRECISION_FOR_INT64} import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala index 1a4e99ff10af..ecadb9e7c6ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/DirectParquetOutputCommitter.scala @@ -19,11 +19,11 @@ package org.apache.spark.sql.execution.datasources.parquet import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter import org.apache.parquet.Log -import org.apache.parquet.hadoop.util.ContextUtil import org.apache.parquet.hadoop.{ParquetFileReader, ParquetFileWriter, ParquetOutputCommitter, ParquetOutputFormat} +import org.apache.parquet.hadoop.util.ContextUtil /** * An output committer for writing Parquet files. In stead of writing to the `_temporary` folder @@ -54,11 +54,7 @@ private[datasources] class DirectParquetOutputCommitter( override def setupTask(taskContext: TaskAttemptContext): Unit = {} override def commitJob(jobContext: JobContext) { - val configuration = { - // scalastyle:off jobcontext - ContextUtil.getConfiguration(jobContext) - // scalastyle:on jobcontext - } + val configuration = ContextUtil.getConfiguration(jobContext) val fileSystem = outputPath.getFileSystem(configuration) if (configuration.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, true)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala index 07714329370a..e9b734b0abf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilters.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.Serializable -import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.filter2.predicate._ +import org.apache.parquet.filter2.predicate.FilterApi._ import org.apache.parquet.io.api.Binary import org.apache.parquet.schema.OriginalType import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName @@ -256,8 +256,21 @@ private[sql] object ParquetFilters { case sources.GreaterThanOrEqual(name, value) => makeGtEq.lift(dataTypeOf(name)).map(_(name, value)) + case sources.In(name, valueSet) => + makeInSet.lift(dataTypeOf(name)).map(_(name, valueSet.toSet)) + case sources.And(lhs, rhs) => - (createFilter(schema, lhs) ++ createFilter(schema, rhs)).reduceOption(FilterApi.and) + // At here, it is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // Pushing one side of AND down is only safe to do at the top level. + // You can see ParquetRelation's initializeLocalJobFunc method as an example. + for { + lhsFilter <- createFilter(schema, lhs) + rhsFilter <- createFilter(schema, rhs) + } yield FilterApi.and(lhsFilter, rhsFilter) case sources.Or(lhs, rhs) => for { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 1af2a394f399..ca8d01009040 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.Writable import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.task.JobContextImpl import org.apache.parquet.filter2.predicate.FilterApi import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.metadata.CompressionCodecName @@ -40,18 +41,17 @@ import org.apache.parquet.{Log => ApacheParquetLog} import org.slf4j.bridge.SLF4JBridgeHandler import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{RDD, SqlNewHadoopPartition, SqlNewHadoopRDD} import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.catalyst.util.LegacyTypeStringParser +import org.apache.spark.sql.catalyst.{InternalRow, SqlParser, TableIdentifier} +import org.apache.spark.sql.execution.datasources.{PartitionSpec, _} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.{Logging, Partition => SparkPartition, SparkException} - -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "parquet" @@ -60,13 +60,17 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc paths: Array[String], schema: Option[StructType], partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { - new ParquetRelation(paths, schema, None, partitionColumns, parameters)(sqlContext) + new ParquetRelation(paths, schema, None, partitionColumns, bucketSpec, parameters)(sqlContext) } } // NOTE: This class is instantiated and used on executor side only, no need to be serializable. -private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext) +private[sql] class ParquetOutputWriter( + path: String, + bucketId: Option[Int], + context: TaskAttemptContext) extends OutputWriter { private val recordWriter: RecordWriter[Void, InternalRow] = { @@ -82,11 +86,12 @@ private[sql] class ParquetOutputWriter(path: String, context: TaskAttemptContext // `FileOutputCommitter.getWorkPath()`, which points to the base directory of all // partitions in the case of dynamic partitioning. override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId - new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$bucketString$extension") } } } @@ -107,6 +112,7 @@ private[sql] class ParquetRelation( // This is for metastore conversion. private val maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + override val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) @@ -123,6 +129,7 @@ private[sql] class ParquetRelation( maybeDataSchema, maybePartitionSpec, maybePartitionSpec.map(_.partitionColumns), + None, parameters)(sqlContext) } @@ -140,6 +147,12 @@ private[sql] class ParquetRelation( .get(ParquetRelation.METASTORE_SCHEMA) .map(DataType.fromJson(_).asInstanceOf[StructType]) + // If this relation is converted from a Hive metastore table, this method returns the name of the + // original Hive metastore table. + private[sql] def metastoreTableName: Option[TableIdentifier] = { + parameters.get(ParquetRelation.METASTORE_TABLE_NAME).map(SqlParser.parseTableIdentifier) + } + private lazy val metadataCache: MetadataCache = { val meta = new MetadataCache meta.refresh() @@ -216,12 +229,8 @@ private[sql] class ParquetRelation( override def sizeInBytes: Long = metadataCache.dataStatuses.map(_.getLen).sum - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - val conf = { - // scalastyle:off jobcontext - ContextUtil.getConfiguration(job) - // scalastyle:on jobcontext - } + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { + val conf = ContextUtil.getConfiguration(job) // SPARK-9849 DirectParquetOutputCommitter qualified name should be backward compatible val committerClassName = conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) @@ -280,10 +289,13 @@ private[sql] class ParquetRelation( sqlContext.conf.parquetCompressionCodec.toUpperCase, CompressionCodecName.UNCOMPRESSED).name()) - new OutputWriterFactory { + new BucketedOutputWriterFactory { override def newInstance( - path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new ParquetOutputWriter(path, context) + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = { + new ParquetOutputWriter(path, bucketId, context) } } } @@ -340,7 +352,7 @@ private[sql] class ParquetRelation( // URI of the path to create a new Path. val pathWithEscapedAuthority = escapePathUserInfo(f.getPath) new FileStatus( - f.getLen, f.isDir, f.getReplication, f.getBlockSize, f.getModificationTime, + f.getLen, f.isDirectory, f.getReplication, f.getBlockSize, f.getModificationTime, f.getAccessTime, f.getPermission, f.getOwner, f.getGroup, pathWithEscapedAuthority) }.toSeq @@ -359,7 +371,7 @@ private[sql] class ParquetRelation( } } - val jobContext = newJobContext(getConf(isDriverSide = true), jobId) + val jobContext = new JobContextImpl(getConf(isDriverSide = true), jobId) val rawSplits = inputFormat.getSplits(jobContext) Array.tabulate[SparkPartition](rawSplits.size) { i => @@ -564,7 +576,7 @@ private[sql] object ParquetRelation extends Logging { parquetFilterPushDown: Boolean, assumeBinaryIsString: Boolean, assumeInt96IsTimestamp: Boolean)(job: Job): Unit = { - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val conf = job.getConfiguration conf.set(ParquetInputFormat.READ_SUPPORT_CLASS, classOf[CatalystReadSupport].getName) // Try to push down filters when filter push-down is enabled. @@ -607,7 +619,7 @@ private[sql] object ParquetRelation extends Logging { FileInputFormat.setInputPaths(job, inputFiles.map(_.getPath): _*) } - overrideMinSplitSize(parquetBlockSize, SparkHadoopUtil.get.getConfigurationFromJobContext(job)) + overrideMinSplitSize(parquetBlockSize, job.getConfiguration) } private[parquet] def readSchema( @@ -642,7 +654,7 @@ private[sql] object ParquetRelation extends Logging { logInfo( s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(serializedSchema.get) + LegacyTypeStringParser.parse(serializedSchema.get) } .recover { case cause: Throwable => logWarning( @@ -825,7 +837,7 @@ private[sql] object ParquetRelation extends Logging { logInfo( s"Serialized Spark schema in Parquet key-value metadata is not in JSON format, " + "falling back to the deprecated DataType.fromCaseClassString parser.") - DataType.fromCaseClassString(schemaString).asInstanceOf[StructType] + LegacyTypeStringParser.parse(schemaString).asInstanceOf[StructType] }.recoverWith { case cause: Throwable => logWarning( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 1a8e7ab202dc..d484403d1c64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.execution.datasources +import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Cast} +import org.apache.spark.sql.catalyst.expressions.{RowOrdering, Alias, Attribute, Cast} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation} -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} /** * Try to replaces [[UnresolvedRelation]]s with [[ResolvedDataSource]]. @@ -165,22 +165,22 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => // OK } - case CreateTableUsingAsSelect(tableIdent, _, _, partitionColumns, mode, _, query) => + case c: CreateTableUsingAsSelect => // When the SaveMode is Overwrite, we need to check if the table is an input table of // the query. If so, we will throw an AnalysisException to let users know it is not allowed. - if (mode == SaveMode.Overwrite && catalog.tableExists(tableIdent)) { + if (c.mode == SaveMode.Overwrite && catalog.tableExists(c.tableIdent)) { // Need to remove SubQuery operator. - EliminateSubQueries(catalog.lookupRelation(tableIdent)) match { + EliminateSubQueries(catalog.lookupRelation(c.tableIdent)) match { // Only do the check if the table is a data source table // (the relation is a BaseRelation). case l @ LogicalRelation(dest: BaseRelation, _) => // Get all input data source relations of the query. - val srcRelations = query.collect { + val srcRelations = c.child.collect { case LogicalRelation(src: BaseRelation, _) => src } if (srcRelations.contains(dest)) { failAnalysis( - s"Cannot overwrite table $tableIdent that is also being read from.") + s"Cannot overwrite table ${c.tableIdent} that is also being read from.") } else { // OK } @@ -192,7 +192,17 @@ private[sql] case class PreWriteCheck(catalog: Catalog) extends (LogicalPlan => } PartitioningUtils.validatePartitionColumnDataTypes( - query.schema, partitionColumns, catalog.conf.caseSensitiveAnalysis) + c.child.schema, c.partitionColumns, catalog.conf.caseSensitiveAnalysis) + + for { + spec <- c.bucketSpec + sortColumnName <- spec.sortColumnNames + sortColumn <- c.child.schema.find(_.name == sortColumnName) + } { + if (!RowOrdering.isOrderable(sortColumn.dataType)) { + failAnalysis(s"Cannot use ${sortColumn.dataType.simpleString} for sorting column.") + } + } case _ => // OK } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index fbd387bc2ef4..bd2d17c0189e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -18,21 +18,19 @@ package org.apache.spark.sql.execution.datasources.text import com.google.common.base.Objects -import org.apache.hadoop.fs.{Path, FileStatus} -import org.apache.hadoop.io.{NullWritable, Text, LongWritable} -import org.apache.hadoop.mapred.{TextInputFormat, JobConf} -import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat -import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, Job} +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.io.{LongWritable, NullWritable, Text} +import org.apache.hadoop.mapred.{JobConf, TextInputFormat} +import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} import org.apache.hadoop.mapreduce.lib.input.FileInputFormat +import org.apache.hadoop.mapreduce.lib.output.TextOutputFormat import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow -import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, BufferHolder} -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} import org.apache.spark.sql.execution.datasources.PartitionSpec import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} @@ -50,7 +48,7 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation = { dataSchema.foreach(verifySchema) - new TextRelation(None, partitionColumns, paths)(sqlContext) + new TextRelation(None, dataSchema, partitionColumns, paths)(sqlContext) } override def shortName(): String = "text" @@ -70,15 +68,16 @@ class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { private[sql] class TextRelation( val maybePartitionSpec: Option[PartitionSpec], + val textSchema: Option[StructType], override val userDefinedPartitionColumns: Option[StructType], override val paths: Array[String] = Array.empty[String], parameters: Map[String, String] = Map.empty[String, String]) (@transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) { - /** Data schema is always a single column, named "text". */ - override def dataSchema: StructType = new StructType().add("value", StringType) - + /** Data schema is always a single column, named "value" if original Data source has no schema. */ + override def dataSchema: StructType = + textSchema.getOrElse(new StructType().add("value", StringType)) /** This is an internal data source that outputs internal row format. */ override val needConversion: Boolean = false @@ -88,8 +87,8 @@ private[sql] class TextRelation( filters: Array[Filter], inputPaths: Array[FileStatus], broadcastedConf: Broadcast[SerializableConfiguration]): RDD[InternalRow] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration val paths = inputPaths.map(_.getPath).sortBy(_.toUri) if (paths.nonEmpty) { @@ -101,14 +100,14 @@ private[sql] class TextRelation( .mapPartitions { iter => val bufferHolder = new BufferHolder val unsafeRowWriter = new UnsafeRowWriter - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(1) iter.map { case (_, line) => // Writes to an UnsafeRow directly bufferHolder.reset() unsafeRowWriter.initialize(bufferHolder, 1) unsafeRowWriter.write(0, line.getBytes, 0, line.getLength) - unsafeRow.pointTo(bufferHolder.buffer, 1, bufferHolder.totalSize()) + unsafeRow.pointTo(bufferHolder.buffer, bufferHolder.totalSize()) unsafeRow } } @@ -138,17 +137,16 @@ private[sql] class TextRelation( } class TextOutputWriter(path: String, dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter - with SparkHadoopMapRedUtil { + extends OutputWriter { private[this] val buffer = new Text() private val recordWriter: RecordWriter[NullWritable, Text] = { new TextOutputFormat[NullWritable, Text]() { override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 74892e4e13fa..dbb6b654b1a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql.execution import scala.collection.mutable.HashSet +import org.apache.spark.{Accumulator, AccumulatorParam, Logging} import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.trees.TreeNodeRef -import org.apache.spark.{Accumulator, AccumulatorParam, Logging} /** * Contains methods for debugging query execution. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 1d381e2eaef3..0a818cc2c2a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} -import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.ThreadUtils -import org.apache.spark.{InternalAccumulator, TaskContext} /** * Performs an inner hash join of two child relations. When the output RDD of this operator is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index ab81bd7b3fc0..6c7fa2eee5bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.execution.joins import scala.concurrent._ import scala.concurrent.duration._ +import org.apache.spark.{InternalAccumulator, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.{BinaryNode, SQLExecution, SparkPlan} +import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution} +import org.apache.spark.sql.execution.{BinaryNode, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.{InternalAccumulator, TaskContext} /** * Performs a outer hash join for two child relations. When the output RDD of this operator is diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index aab177b2e842..e55f8694781a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.collection.{BitSet, CompactBuffer} @@ -46,15 +46,8 @@ case class BroadcastNestedLoopJoin( case BuildLeft => (right, left) } - override def outputsUnsafeRows: Boolean = left.outputsUnsafeRows || right.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - private[this] def genResultProjection: InternalRow => InternalRow = { - if (outputsUnsafeRows) { UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } } override def outputPartitioning: Partitioning = streamed.outputPartitioning diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index fa2bc7672131..93d32e1fb93a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -20,10 +20,10 @@ package org.apache.spark.sql.execution.joins import org.apache.spark._ import org.apache.spark.rdd.{CartesianPartition, CartesianRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} +import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter @@ -56,15 +56,14 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField // Create an iterator from sorter and wrapper it as Iterator[UnsafeRow] def createIter(): Iterator[UnsafeRow] = { val iter = sorter.getIterator - val unsafeRow = new UnsafeRow + val unsafeRow = new UnsafeRow(numFieldsOfRight) new Iterator[UnsafeRow] { override def hasNext: Boolean = { iter.hasNext } override def next(): UnsafeRow = { iter.loadNext() - unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, numFieldsOfRight, - iter.getRecordLength) + unsafeRow.pointTo(iter.getBaseObject, iter.getBaseOffset, iter.getRecordLength) unsafeRow } } @@ -82,10 +81,6 @@ class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numField case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output ++ right.output - override def canProcessSafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def outputsUnsafeRows: Boolean = true - override private[sql] lazy val metrics = Map( "numLeftRows" -> SQLMetrics.createLongMetric(sparkContext, "number of left rows"), "numRightRows" -> SQLMetrics.createLongMetric(sparkContext, "number of right rows"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index fb961d97c3c3..7f9d9daa5ab2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -44,10 +44,6 @@ trait HashJoin { override def output: Seq[Attribute] = left.output ++ right.output - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def buildSideKeyGenerator: Projection = UnsafeProjection.create(buildKeys, buildPlan.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index ed626fef56af..6d464d6946b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -64,10 +64,6 @@ trait HashOuterJoin { s"HashOuterJoin should not take $x as the JoinType") } - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def buildKeyGenerator: Projection = UnsafeProjection.create(buildKeys, buildPlan.output) @@ -75,7 +71,7 @@ trait HashOuterJoin { UnsafeProjection.create(streamedKeys, streamedPlan.output) protected[this] def resultProjection: InternalRow => InternalRow = - UnsafeProjection.create(self.schema) + UnsafeProjection.create(output, output) @transient private[this] lazy val DUMMY_LIST = CompactBuffer[InternalRow](null) @transient protected[this] lazy val EMPTY_LIST = CompactBuffer[InternalRow]() @@ -151,82 +147,4 @@ trait HashOuterJoin { } ret.iterator } - - protected[this] def fullOuterIterator( - key: InternalRow, - leftIter: Iterable[InternalRow], - rightIter: Iterable[InternalRow], - joinedRow: JoinedRow, - resultProjection: InternalRow => InternalRow, - numOutputRows: LongSQLMetric): Iterator[InternalRow] = { - if (!key.anyNull) { - // Store the positions of records in right, if one of its associated row satisfy - // the join condition. - val rightMatchedSet = scala.collection.mutable.Set[Int]() - leftIter.iterator.flatMap[InternalRow] { l => - joinedRow.withLeft(l) - var matched = false - rightIter.zipWithIndex.collect { - // 1. For those matched (satisfy the join condition) records with both sides filled, - // append them directly - - case (r, idx) if boundCondition(joinedRow.withRight(r)) => - numOutputRows += 1 - matched = true - // if the row satisfy the join condition, add its index into the matched set - rightMatchedSet.add(idx) - resultProjection(joinedRow) - - } ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // 2. For those unmatched records in left, append additional records with empty right. - - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all - // of the records in right side. - // If we didn't get any proper row, then append a single row with empty right. - numOutputRows += 1 - resultProjection(joinedRow.withRight(rightNullRow)) - }) - } ++ rightIter.zipWithIndex.collect { - // 3. For those unmatched records in right, append additional records with empty left. - - // Re-visiting the records in right, and append additional row with empty left, if its not - // in the matched set. - case (r, idx) if !rightMatchedSet.contains(idx) => - numOutputRows += 1 - resultProjection(joinedRow(leftNullRow, r)) - } - } else { - leftIter.iterator.map[InternalRow] { l => - numOutputRows += 1 - resultProjection(joinedRow(l, rightNullRow)) - } ++ rightIter.iterator.map[InternalRow] { r => - numOutputRows += 1 - resultProjection(joinedRow(leftNullRow, r)) - } - } - } - - // This is only used by FullOuter - protected[this] def buildHashTable( - iter: Iterator[InternalRow], - numIterRows: LongSQLMetric, - keyGenerator: Projection): java.util.HashMap[InternalRow, CompactBuffer[InternalRow]] = { - val hashTable = new java.util.HashMap[InternalRow, CompactBuffer[InternalRow]]() - while (iter.hasNext) { - val currentRow = iter.next() - numIterRows += 1 - val rowKey = keyGenerator(currentRow) - - var existingMatchList = hashTable.get(rowKey) - if (existingMatchList == null) { - existingMatchList = new CompactBuffer[InternalRow]() - hashTable.put(rowKey.copy(), existingMatchList) - } - - existingMatchList += currentRow.copy() - } - - hashTable - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala index f23a1830e91c..3e0f74cd98c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala @@ -33,10 +33,6 @@ trait HashSemiJoin { override def output: Seq[Attribute] = left.output - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - protected def leftKeyGenerator: Projection = UnsafeProjection.create(leftKeys, left.output) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index 8c7099ab5a34..ee7a1bdc343c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -21,7 +21,8 @@ import java.io.{Externalizable, IOException, ObjectInput, ObjectOutput} import java.nio.ByteOrder import java.util.{HashMap => JavaHashMap} -import org.apache.spark.memory.{TaskMemoryManager, StaticMemoryManager} +import org.apache.spark.{SparkConf, SparkEnv} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkSqlSerializer @@ -30,10 +31,8 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.unsafe.memory.MemoryLocation -import org.apache.spark.util.{SizeEstimator, KnownSizeEstimation, Utils} +import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils} import org.apache.spark.util.collection.CompactBuffer -import org.apache.spark.{SparkConf, SparkEnv} - /** * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete @@ -245,8 +244,8 @@ private[joins] final class UnsafeHashedRelation( val sizeInBytes = Platform.getInt(base, offset + 4) offset += 8 - val row = new UnsafeRow - row.pointTo(base, offset, numFields, sizeInBytes) + val row = new UnsafeRow(numFields) + row.pointTo(base, offset, sizeInBytes) buffer += row offset += sizeInBytes } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala index efa7b49410ed..82498ee39564 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala @@ -42,9 +42,6 @@ case class LeftSemiJoinBNL( override def output: Seq[Attribute] = left.output - override def outputsUnsafeRows: Boolean = streamed.outputsUnsafeRows - override def canProcessUnsafeRows: Boolean = true - /** The Streamed Relation */ override def left: SparkPlan = streamed diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index bf3b05be981f..25b3b5ca2377 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 4bf7b521c77d..812f881d06fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -53,10 +53,6 @@ case class SortMergeJoin( override def requiredChildOrdering: Seq[Seq[SortOrder]] = requiredOrders(leftKeys) :: requiredOrders(rightKeys) :: Nil - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - private def requiredOrders(keys: Seq[Expression]): Seq[SortOrder] = { // This must be ascending in order to agree with the `keyOrdering` defined in `doExecute()`. keys.map(SortOrder(_, Ascending)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala index efaa69c1d322..ed41ad2a005e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeOuterJoin.scala @@ -22,10 +22,10 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} -import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, RowIterator, SparkPlan} +import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics} import org.apache.spark.util.collection.BitSet /** @@ -89,10 +89,6 @@ case class SortMergeOuterJoin( keys.map(SortOrder(_, Ascending)) } - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - private def createLeftKeyGenerator(): Projection = UnsafeProjection.create(leftKeys, left.output) @@ -114,7 +110,7 @@ case class SortMergeOuterJoin( (r: InternalRow) => true } } - val resultProj: InternalRow => InternalRow = UnsafeProjection.create(schema) + val resultProj: InternalRow => InternalRow = UnsafeProjection.create(output, output) joinType match { case LeftOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala index 3dcef9409564..59345046da49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} /** * A [[HashJoinNode]] that builds the [[HashedRelation]] according to the value of diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 6a882c9234df..a0dfe996ccd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.local import scala.util.control.NonFatal import org.apache.spark.Logging -import org.apache.spark.sql.{SQLConf, Row} +import org.apache.spark.sql.{Row, SQLConf} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ @@ -69,18 +69,6 @@ abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Loggin */ def close(): Unit - /** Specifies whether this operator outputs UnsafeRows */ - def outputsUnsafeRows: Boolean = false - - /** Specifies whether this operator is capable of processing UnsafeRows */ - def canProcessUnsafeRows: Boolean = false - - /** - * Specifies whether this operator is capable of processing Java-object-based Rows (i.e. rows - * that are not UnsafeRows). - */ - def canProcessSafeRows: Boolean = true - /** * Returns the content through the [[Iterator]] interface. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala index 7321fc66b4dd..b93bde58a55e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNode.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{FullOuter, RightOuter, LeftOuter, JoinType} +import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.util.collection.{BitSet, CompactBuffer} @@ -47,11 +47,7 @@ case class NestedLoopJoinNode( } private[this] def genResultProjection: InternalRow => InternalRow = { - if (outputsUnsafeRows) { - UnsafeProjection.create(schema) - } else { - identity[InternalRow] - } + UnsafeProjection.create(schema) } private[this] var currentRow: InternalRow = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index 11529d6dd9b8..bd73b08263f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.local import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, UnsafeProjection} case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala index 6c0f6f8a52dc..52735c9d7f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/metric/SQLMetrics.scala @@ -17,8 +17,8 @@ package org.apache.spark.sql.execution.metric -import org.apache.spark.util.Utils import org.apache.spark.{Accumulable, AccumulableParam, SparkContext} +import org.apache.spark.util.Utils /** * Create a layer for specialized metric. We cannot add `@specialized` to diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala index defcec95fb55..41e35fd724cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala @@ -24,8 +24,8 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle._ -import org.apache.spark.{Logging => SparkLogging, TaskContext, Accumulator} -import org.apache.spark.api.python.{PythonRunner, PythonBroadcast, PythonRDD, SerDeUtil} +import org.apache.spark.{Accumulator, Logging => SparkLogging, TaskContext} +import org.apache.spark.api.python.{PythonBroadcast, PythonRDD, PythonRunner, SerDeUtil} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayBasedMapData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -351,10 +351,6 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: def children: Seq[SparkPlan] = child :: Nil - override def outputsUnsafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = true - protected override def doExecute(): RDD[InternalRow] = { val inputRDD = child.execute().map(_.copy()) val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) @@ -400,13 +396,14 @@ case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: val unpickle = new Unpickler val row = new GenericMutableRow(1) val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) outputIterator.flatMap { pickedResult => val unpickledBatch = unpickle.loads(pickedResult) unpickledBatch.asInstanceOf[java.util.ArrayList[Any]].asScala }.map { result => row(0) = EvaluatePython.fromJava(result, udf.dataType) - joined(queue.poll(), row) + resultProj(joined(queue.poll(), row)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala deleted file mode 100644 index 5f8fc2de8b46..000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/rowFormatConverters.scala +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.Partitioning -import org.apache.spark.sql.catalyst.rules.Rule - -/** - * Converts Java-object-based rows into [[UnsafeRow]]s. - */ -case class ConvertToUnsafe(child: SparkPlan) extends UnaryNode { - - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputsUnsafeRows: Boolean = true - override def canProcessUnsafeRows: Boolean = false - override def canProcessSafeRows: Boolean = true - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val convertToUnsafe = UnsafeProjection.create(child.schema) - iter.map(convertToUnsafe) - } - } -} - -/** - * Converts [[UnsafeRow]]s back into Java-object-based rows. - */ -case class ConvertToSafe(child: SparkPlan) extends UnaryNode { - override def output: Seq[Attribute] = child.output - override def outputPartitioning: Partitioning = child.outputPartitioning - override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def outputsUnsafeRows: Boolean = false - override def canProcessUnsafeRows: Boolean = true - override def canProcessSafeRows: Boolean = false - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - val convertToSafe = FromUnsafeProjection(child.output.map(_.dataType)) - iter.map(convertToSafe) - } - } -} - -private[sql] object EnsureRowFormats extends Rule[SparkPlan] { - - private def onlyHandlesSafeRows(operator: SparkPlan): Boolean = - operator.canProcessSafeRows && !operator.canProcessUnsafeRows - - private def onlyHandlesUnsafeRows(operator: SparkPlan): Boolean = - operator.canProcessUnsafeRows && !operator.canProcessSafeRows - - private def handlesBothSafeAndUnsafeRows(operator: SparkPlan): Boolean = - operator.canProcessSafeRows && operator.canProcessUnsafeRows - - override def apply(operator: SparkPlan): SparkPlan = operator.transformUp { - case operator: SparkPlan if onlyHandlesSafeRows(operator) => - if (operator.children.exists(_.outputsUnsafeRows)) { - operator.withNewChildren { - operator.children.map { - c => if (c.outputsUnsafeRows) ConvertToSafe(c) else c - } - } - } else { - operator - } - case operator: SparkPlan if onlyHandlesUnsafeRows(operator) => - if (operator.children.exists(!_.outputsUnsafeRows)) { - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator - } - case operator: SparkPlan if handlesBothSafeAndUnsafeRows(operator) => - if (operator.children.map(_.outputsUnsafeRows).toSet.size != 1) { - // If this operator's children produce both unsafe and safe rows, - // convert everything unsafe rows. - operator.withNewChildren { - operator.children.map { - c => if (!c.outputsUnsafeRows) ConvertToUnsafe(c) else c - } - } - } else { - operator - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala index db463029aedf..a191759813de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/FrequentItems.scala @@ -20,9 +20,9 @@ package org.apache.spark.sql.execution.stat import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.Logging +import org.apache.spark.sql.{Column, DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Row, Column, DataFrame} private[sql] object FrequentItems extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index 00231d65a7d5..7d701949afcf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.execution.stat import org.apache.spark.Logging -import org.apache.spark.sql.{Row, Column, DataFrame} -import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Cast} +import org.apache.spark.sql.{Column, DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{Cast, GenericMutableRow} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -29,7 +29,7 @@ private[sql] object StatFunctions extends Logging { /** Calculate the Pearson Correlation Coefficient for the given columns */ private[sql] def pearsonCorrelation(df: DataFrame, cols: Seq[String]): Double = { - val counts = collectStatisticalData(df, cols) + val counts = collectStatisticalData(df, cols, "correlation") counts.Ck / math.sqrt(counts.MkX * counts.MkY) } @@ -73,13 +73,14 @@ private[sql] object StatFunctions extends Logging { def cov: Double = Ck / (count - 1) } - private def collectStatisticalData(df: DataFrame, cols: Seq[String]): CovarianceCounter = { - require(cols.length == 2, "Currently cov supports calculating the covariance " + + private def collectStatisticalData(df: DataFrame, cols: Seq[String], + functionName: String): CovarianceCounter = { + require(cols.length == 2, s"Currently $functionName calculation is supported " + "between two columns.") cols.map(name => (name, df.schema.fields.find(_.name == name))).foreach { case (name, data) => require(data.nonEmpty, s"Couldn't find column with name $name") - require(data.get.dataType.isInstanceOf[NumericType], "Covariance calculation for columns " + - s"with dataType ${data.get.dataType} not supported.") + require(data.get.dataType.isInstanceOf[NumericType], s"Currently $functionName calculation " + + s"for columns with dataType ${data.get.dataType} not supported.") } val columns = cols.map(n => Column(Cast(Column(n).expr, DoubleType))) df.select(columns: _*).queryExecution.toRdd.aggregate(new CovarianceCounter)( @@ -98,7 +99,7 @@ private[sql] object StatFunctions extends Logging { * @return the covariance of the two columns. */ private[sql] def calculateCov(df: DataFrame, cols: Seq[String]): Double = { - val counts = collectStatisticalData(df, cols) + val counts = collectStatisticalData(df, cols, "covariance") counts.cov } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala index e19a1e3e5851..cd5613692708 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.execution.ui import scala.collection.mutable +import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler._ -import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.SparkPlanInfo -import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricValue, SQLMetricParam} -import org.apache.spark.{JobExecutionStatus, Logging, SparkConf} +import org.apache.spark.sql.execution.{SparkPlanInfo, SQLExecution} +import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetricParam, SQLMetricValue} import org.apache.spark.ui.SparkUI @DeveloperApi @@ -160,12 +159,14 @@ private[sql] class SQLListener(conf: SparkConf) extends SparkListener with Loggi } override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized { - updateTaskAccumulatorValues( - taskEnd.taskInfo.taskId, - taskEnd.stageId, - taskEnd.stageAttemptId, - taskEnd.taskMetrics.accumulatorUpdates(), - finishTask = true) + if (taskEnd.taskMetrics != null) { + updateTaskAccumulatorValues( + taskEnd.taskInfo.taskId, + taskEnd.stageId, + taskEnd.stageAttemptId, + taskEnd.taskMetrics.accumulatorUpdates(), + finishTask = true) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 65117d582475..6eea92451734 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.expressions +import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} /** * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala index 893e800a6143..3921147857a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/WindowSpec.scala @@ -18,11 +18,10 @@ package org.apache.spark.sql.expressions import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.{Column, catalyst} +import org.apache.spark.sql.{catalyst, Column} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ - /** * :: Experimental :: * A window specification that defines the partitioning, ordering, and frame boundaries. @@ -140,57 +139,7 @@ class WindowSpec private[sql]( * Converts this [[WindowSpec]] into a [[Column]] with an aggregate expression. */ private[sql] def withAggregate(aggregate: Column): Column = { - val windowExpr = aggregate.expr match { - // First, we check if we get an aggregate function without the DISTINCT keyword. - // Right now, we do not support using a DISTINCT aggregate function as a - // window function. - case AggregateExpression(aggregateFunction, _, isDistinct) if !isDistinct => - aggregateFunction match { - case Average(child) => WindowExpression( - UnresolvedWindowFunction("avg", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Sum(child) => WindowExpression( - UnresolvedWindowFunction("sum", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Count(children) => WindowExpression( - UnresolvedWindowFunction("count", children), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case First(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF first_value - UnresolvedWindowFunction( - "first_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Last(child, ignoreNulls) => WindowExpression( - // TODO this is a hack for Hive UDAF last_value - UnresolvedWindowFunction( - "last_value", - child :: ignoreNulls :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Min(child) => WindowExpression( - UnresolvedWindowFunction("min", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case Max(child) => WindowExpression( - UnresolvedWindowFunction("max", child :: Nil), - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - case x => - throw new UnsupportedOperationException(s"$x is not supported in a window operation.") - } - - case AggregateExpression(aggregateFunction, _, isDistinct) if isDistinct => - throw new UnsupportedOperationException( - s"Distinct aggregate function ${aggregateFunction} is not supported " + - s"in window operation.") - - case wf: WindowFunction => - WindowExpression( - wf, - WindowSpecDefinition(partitionSpec, orderSpec, frame)) - - case x => - throw new UnsupportedOperationException(s"$x is not supported in a window operation.") - } - - new Column(windowExpr) + val spec = WindowSpecDefinition(partitionSpec, orderSpec, frame) + new Column(WindowExpression(aggregate.expr, spec)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala index 11dbf391cff9..8b355befc34a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/udaf.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, AggregateExpression} -import org.apache.spark.sql.execution.aggregate.ScalaUDAF +import org.apache.spark.annotation.Experimental import org.apache.spark.sql.{Column, Row} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} +import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.types._ -import org.apache.spark.annotation.Experimental /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index e79defbbbdee..592d79df3109 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql import scala.language.implicitConversions -import scala.reflect.runtime.universe.{TypeTag, typeTag} +import scala.reflect.runtime.universe.{typeTag, TypeTag} import scala.util.Try import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} +import org.apache.spark.sql.catalyst.{ScalaReflection, SqlParser} +import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -558,13 +558,6 @@ object functions extends LegacyFunctions { // Window functions ////////////////////////////////////////////////////////////////////////////////////////////// - /** - * @group window_funcs - * @deprecated As of 1.6.0, replaced by `cume_dist`. This will be removed in Spark 2.0. - */ - @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") - def cumeDist(): Column = cume_dist() - /** * Window function: returns the cumulative distribution of values within a window partition, * i.e. the fraction of rows that are below the current row. @@ -577,14 +570,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def cume_dist(): Column = withExpr { UnresolvedWindowFunction("cume_dist", Nil) } - - /** - * @group window_funcs - * @deprecated As of 1.6.0, replaced by `dense_rank`. This will be removed in Spark 2.0. - */ - @deprecated("Use dense_rank. This will be removed in Spark 2.0.", "1.6.0") - def denseRank(): Column = dense_rank() + def cume_dist(): Column = withExpr { new CumeDist } /** * Window function: returns the rank of rows within a window partition, without any gaps. @@ -597,7 +583,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def dense_rank(): Column = withExpr { UnresolvedWindowFunction("dense_rank", Nil) } + def dense_rank(): Column = withExpr { new DenseRank } /** * Window function: returns the value that is `offset` rows before the current row, and @@ -648,7 +634,7 @@ object functions extends LegacyFunctions { * @since 1.4.0 */ def lag(e: Column, offset: Int, defaultValue: Any): Column = withExpr { - UnresolvedWindowFunction("lag", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + Lag(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -700,7 +686,7 @@ object functions extends LegacyFunctions { * @since 1.4.0 */ def lead(e: Column, offset: Int, defaultValue: Any): Column = withExpr { - UnresolvedWindowFunction("lead", e.expr :: Literal(offset) :: Literal(defaultValue) :: Nil) + Lead(e.expr, Literal(offset), Literal(defaultValue)) } /** @@ -713,14 +699,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.4.0 */ - def ntile(n: Int): Column = withExpr { UnresolvedWindowFunction("ntile", lit(n).expr :: Nil) } - - /** - * @group window_funcs - * @deprecated As of 1.6.0, replaced by `percent_rank`. This will be removed in Spark 2.0. - */ - @deprecated("Use percent_rank. This will be removed in Spark 2.0.", "1.6.0") - def percentRank(): Column = percent_rank() + def ntile(n: Int): Column = withExpr { new NTile(Literal(n)) } /** * Window function: returns the relative rank (i.e. percentile) of rows within a window partition. @@ -735,7 +714,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def percent_rank(): Column = withExpr { UnresolvedWindowFunction("percent_rank", Nil) } + def percent_rank(): Column = withExpr { new PercentRank } /** * Window function: returns the rank of rows within a window partition. @@ -750,14 +729,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.4.0 */ - def rank(): Column = withExpr { UnresolvedWindowFunction("rank", Nil) } - - /** - * @group window_funcs - * @deprecated As of 1.6.0, replaced by `row_number`. This will be removed in Spark 2.0. - */ - @deprecated("Use row_number. This will be removed in Spark 2.0.", "1.6.0") - def rowNumber(): Column = row_number() + def rank(): Column = withExpr { new Rank } /** * Window function: returns a sequential number starting at 1 within a window partition. @@ -765,7 +737,7 @@ object functions extends LegacyFunctions { * @group window_funcs * @since 1.6.0 */ - def row_number(): Column = withExpr { UnresolvedWindowFunction("row_number", Nil) } + def row_number(): Column = withExpr { RowNumber() } ////////////////////////////////////////////////////////////////////////////////////////////// // Non-aggregate functions @@ -827,13 +799,6 @@ object functions extends LegacyFunctions { @scala.annotation.varargs def coalesce(e: Column*): Column = withExpr { Coalesce(e.map(_.expr)) } - /** - * @group normal_funcs - * @deprecated As of 1.6.0, replaced by `input_file_name`. This will be removed in Spark 2.0. - */ - @deprecated("Use input_file_name. This will be removed in Spark 2.0.", "1.6.0") - def inputFileName(): Column = input_file_name() - /** * Creates a string column for the file name of the current Spark task. * @@ -842,13 +807,6 @@ object functions extends LegacyFunctions { */ def input_file_name(): Column = withExpr { InputFileName() } - /** - * @group normal_funcs - * @deprecated As of 1.6.0, replaced by `isnan`. This will be removed in Spark 2.0. - */ - @deprecated("Use isnan. This will be removed in Spark 2.0.", "1.6.0") - def isNaN(e: Column): Column = isnan(e) - /** * Return true iff the column is NaN. * @@ -972,14 +930,6 @@ object functions extends LegacyFunctions { */ def randn(): Column = randn(Utils.random.nextLong) - /** - * @group normal_funcs - * @since 1.4.0 - * @deprecated As of 1.6.0, replaced by `spark_partition_id`. This will be removed in Spark 2.0. - */ - @deprecated("Use cume_dist. This will be removed in Spark 2.0.", "1.6.0") - def sparkPartitionId(): Column = spark_partition_id() - /** * Partition ID of the Spark task. * @@ -1863,6 +1813,17 @@ object functions extends LegacyFunctions { */ def crc32(e: Column): Column = withExpr { Crc32(e.expr) } + /** + * Calculates the hash code of given columns, and returns the result as a int column. + * + * @group misc_funcs + * @since 2.0 + */ + @scala.annotation.varargs + def hash(cols: Column*): Column = withExpr { + new Murmur3Hash(cols.map(_.expr)) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2512,7 +2473,8 @@ object functions extends LegacyFunctions { ////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off + // scalastyle:off line.size.limit + // scalastyle:off parameter.number /* Use the following code to generate: (0 to 10).map { x => @@ -2528,29 +2490,11 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val inputTypes = Try($inputTypes).getOrElse(Nil) + val inputTypes = Try($inputTypes).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) }""") } - (0 to 10).map { x => - val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") - val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") - println(s""" - /** - * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - */ - @deprecated("Use udf", "1.5.0") - def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = withExpr { - ScalaUDF(f, returnType, Seq($argsInUDF)) - }""") - } */ /** * Defines a user-defined function of 0 arguments as user-defined function (UDF). @@ -2560,7 +2504,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag](f: Function0[RT]): UserDefinedFunction = { - val inputTypes = Try(Nil).getOrElse(Nil) + val inputTypes = Try(Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2572,7 +2516,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag](f: Function1[A1, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2584,7 +2528,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag](f: Function2[A1, A2, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2596,7 +2540,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag](f: Function3[A1, A2, A3, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2608,7 +2552,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag](f: Function4[A1, A2, A3, A4, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2620,7 +2564,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag](f: Function5[A1, A2, A3, A4, A5, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2632,7 +2576,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag](f: Function6[A1, A2, A3, A4, A5, A6, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2644,7 +2588,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag](f: Function7[A1, A2, A3, A4, A5, A6, A7, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2656,7 +2600,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag](f: Function8[A1, A2, A3, A4, A5, A6, A7, A8, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2668,7 +2612,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag](f: Function9[A1, A2, A3, A4, A5, A6, A7, A8, A9, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } @@ -2680,185 +2624,25 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def udf[RT: TypeTag, A1: TypeTag, A2: TypeTag, A3: TypeTag, A4: TypeTag, A5: TypeTag, A6: TypeTag, A7: TypeTag, A8: TypeTag, A9: TypeTag, A10: TypeTag](f: Function10[A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, RT]): UserDefinedFunction = { - val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).getOrElse(Nil) + val inputTypes = Try(ScalaReflection.schemaFor(typeTag[A1]).dataType :: ScalaReflection.schemaFor(typeTag[A2]).dataType :: ScalaReflection.schemaFor(typeTag[A3]).dataType :: ScalaReflection.schemaFor(typeTag[A4]).dataType :: ScalaReflection.schemaFor(typeTag[A5]).dataType :: ScalaReflection.schemaFor(typeTag[A6]).dataType :: ScalaReflection.schemaFor(typeTag[A7]).dataType :: ScalaReflection.schemaFor(typeTag[A8]).dataType :: ScalaReflection.schemaFor(typeTag[A9]).dataType :: ScalaReflection.schemaFor(typeTag[A10]).dataType :: Nil).toOption UserDefinedFunction(f, ScalaReflection.schemaFor(typeTag[RT]).dataType, inputTypes) } - ////////////////////////////////////////////////////////////////////////////////////////////////// - /** - * Call a Scala function of 0 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function0[_], returnType: DataType): Column = withExpr { - ScalaUDF(f, returnType, Seq()) - } - - /** - * Call a Scala function of 1 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr)) - } - - /** - * Call a Scala function of 2 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) - } - - /** - * Call a Scala function of 3 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) - } - - /** - * Call a Scala function of 4 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) - } - - /** - * Call a Scala function of 5 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) - } - - /** - * Call a Scala function of 6 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) - } - - /** - * Call a Scala function of 7 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) - } - - /** - * Call a Scala function of 8 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf() - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) - } - - /** - * Call a Scala function of 9 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf(). - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) - } - - /** - * Call a Scala function of 10 arguments as user-defined function (UDF). This requires - * you to specify the return data type. - * - * @group udf_funcs - * @since 1.3.0 - * @deprecated As of 1.5.0, since it's redundant with udf(). - * This will be removed in Spark 2.0. - */ - @deprecated("Use udf. This will be removed in Spark 2.0.", "1.5.0") - def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = withExpr { - ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) - } - - // scalastyle:on + // scalastyle:on parameter.number + // scalastyle:on line.size.limit /** - * Call an user-defined function. - * Example: - * {{{ - * import org.apache.spark.sql._ + * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must + * specifcy the output data type, and there is no automatic input type coercion. * - * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") - * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) - * df.select($"id", callUDF("simpleUDF", $"value")) - * }}} + * @param f A closure in Scala + * @param dataType The output data type of the UDF * * @group udf_funcs - * @since 1.5.0 + * @since 2.0.0 */ - @scala.annotation.varargs - def callUDF(udfName: String, cols: Column*): Column = withExpr { - UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) + def udf(f: AnyRef, dataType: DataType): UserDefinedFunction = { + UserDefinedFunction(f, dataType, None) } /** @@ -2870,24 +2654,15 @@ object functions extends LegacyFunctions { * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) - * df.select($"id", callUdf("simpleUDF", $"value")) + * df.select($"id", callUDF("simpleUDF", $"value")) * }}} * * @group udf_funcs - * @since 1.4.0 - * @deprecated As of 1.5.0, since it was not coherent to have two functions callUdf and callUDF. - * This will be removed in Spark 2.0. - */ - @deprecated("Use callUDF. This will be removed in Spark 2.0.", "1.5.0") - def callUdf(udfName: String, cols: Column*): Column = withExpr { - // Note: we avoid using closures here because on file systems that are case-insensitive, the - // compiled class file for the closure here will conflict with the one in callUDF (upper case). - val exprs = new Array[Expression](cols.size) - var i = 0 - while (i < cols.size) { - exprs(i) = cols(i).expr - i += 1 - } - UnresolvedFunction(udfName, exprs, isDistinct = false) + * @since 1.5.0 + */ + @scala.annotation.varargs + def callUDF(udfName: String, cols: Column*): Column = withExpr { + UnresolvedFunction(udfName, cols.map(_.expr), isDistinct = false) } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala index b1cb0e55026b..f12b6ca9d6ad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/DB2Dialect.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.jdbc -import org.apache.spark.sql.types.{BooleanType, StringType, DataType} - +import org.apache.spark.sql.types.{BooleanType, DataType, StringType} private object DB2Dialect extends JdbcDialect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index 13db141f27db..ca2d909e2ccc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.jdbc import java.sql.Connection -import org.apache.spark.sql.types._ import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.types._ /** * :: DeveloperApi :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index da413ed1f08b..e1717049f383 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.jdbc import java.sql.Types -import org.apache.spark.sql.types.{BooleanType, LongType, DataType, MetadataBuilder} - +import org.apache.spark.sql.types.{BooleanType, DataType, LongType, MetadataBuilder} private case object MySQLDialect extends JdbcDialect { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index 3cf80f576e92..ad9e31690b2d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -64,6 +64,7 @@ private object PostgresDialect extends JdbcDialect { getJDBCType(et).map(_.databaseTypeDefinition) .orElse(JdbcUtils.getCommonJDBCType(et).map(_.databaseTypeDefinition)) .map(typeName => JdbcType(s"$typeName[]", java.sql.Types.ARRAY)) + case ByteType => throw new IllegalArgumentException(s"Unsupported type in postgresql: $dt"); case _ => None } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index a9c600b139b1..bd73a36fd40b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -42,10 +42,4 @@ package object sql { @DeveloperApi type Strategy = org.apache.spark.sql.catalyst.planning.GenericStrategy[SparkPlan] - /** - * Type alias for [[DataFrame]]. Kept here for backward source compatibility for Scala. - * @deprecated As of 1.3.0, replaced by `DataFrame`. - */ - @deprecated("use DataFrame", "1.3.0") - type SchemaRDD = DataFrame } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index fc8ce6901dfc..c35f33132f60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -21,21 +21,21 @@ import scala.collection.mutable import scala.util.Try import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{PathFilter, FileStatus, FileSystem, Path} -import org.apache.hadoop.mapred.{JobConf, FileInputFormat} +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection import org.apache.spark.sql.execution.{FileRelation, RDDConversions} -import org.apache.spark.sql.execution.datasources.{PartitioningUtils, PartitionSpec, Partition} +import org.apache.spark.sql.execution.datasources.{BucketSpec, Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql._ import org.apache.spark.util.SerializableConfiguration /** @@ -161,6 +161,20 @@ trait HadoopFsRelationProvider { dataSchema: Option[StructType], partitionColumns: Option[StructType], parameters: Map[String, String]): HadoopFsRelation + + // TODO: expose bucket API to users. + private[sql] def createRelation( + sqlContext: SQLContext, + paths: Array[String], + dataSchema: Option[StructType], + partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], + parameters: Map[String, String]): HadoopFsRelation = { + if (bucketSpec.isDefined) { + throw new AnalysisException("Currently we don't support bucketing for this data source.") + } + createRelation(sqlContext, paths, dataSchema, partitionColumns, parameters) + } } /** @@ -351,7 +365,18 @@ abstract class OutputWriterFactory extends Serializable { * * @since 1.4.0 */ - def newInstance(path: String, dataSchema: StructType, context: TaskAttemptContext): OutputWriter + def newInstance( + path: String, + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter + + // TODO: expose bucket API to users. + private[sql] def newInstance( + path: String, + bucketId: Option[Int], + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter = + newInstance(path, dataSchema, context) } /** @@ -435,6 +460,9 @@ abstract class HadoopFsRelation private[sql]( private var _partitionSpec: PartitionSpec = _ + // TODO: expose bucket API to users. + private[sql] def bucketSpec: Option[BucketSpec] = None + private class FileStatusCache { var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] @@ -462,7 +490,7 @@ abstract class HadoopFsRelation private[sql]( name.toLowerCase == "_temporary" || name.startsWith(".") } - val (dirs, files) = statuses.partition(_.isDir) + val (dirs, files) = statuses.partition(_.isDirectory) // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) if (dirs.isEmpty) { @@ -858,10 +886,10 @@ private[sql] object HadoopFsRelation extends Logging { val jobConf = new JobConf(fs.getConf, this.getClass()) val pathFilter = FileInputFormat.getInputPathFilter(jobConf) if (pathFilter != null) { - val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDir) + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDirectory) files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDir) + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) } } @@ -896,7 +924,7 @@ private[sql] object HadoopFsRelation extends Logging { FakeFileStatus( status.getPath.toString, status.getLen, - status.isDir, + status.isDirectory, status.getReplication, status.getBlockSize, status.getModificationTime, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala index 8d4854b698ed..20a17ba82be9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/test/ExamplePointUDT.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.test -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.types._ /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index ac432e2baa3c..e6f8779929d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.util import java.util.concurrent.locks.ReentrantReadWriteLock + import scala.collection.mutable.ListBuffer import scala.util.control.NonFatal @@ -25,7 +26,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.sql.execution.QueryExecution - /** * :: Experimental :: * The interface of query execution listener that can be used to analyze execution metrics. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java index 7b50aad4ad49..640efcc737ea 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java @@ -107,7 +107,7 @@ public Row call(Person person) throws Exception { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.applySchema(rowRDD, schema); + DataFrame df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); Row[] actual = sqlContext.sql("SELECT * FROM people").collect(); @@ -143,7 +143,7 @@ public Row call(Person person) { fields.add(DataTypes.createStructField("age", DataTypes.IntegerType, false)); StructType schema = DataTypes.createStructType(fields); - DataFrame df = sqlContext.applySchema(rowRDD, schema); + DataFrame df = sqlContext.createDataFrame(rowRDD, schema); df.registerTempTable("people"); List actual = sqlContext.sql("SELECT * FROM people").toJavaRDD().map(new Function() { @Override diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 383a2d0badb5..9f8db39e33d7 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -23,6 +23,8 @@ import java.sql.Timestamp; import java.util.*; +import com.google.common.base.Objects; +import org.junit.rules.ExpectedException; import scala.Tuple2; import scala.Tuple3; import scala.Tuple4; @@ -608,6 +610,44 @@ public int hashCode() { } } + public class SimpleJavaBean2 implements Serializable { + private Timestamp a; + private Date b; + private java.math.BigDecimal c; + + public Timestamp getA() { return a; } + + public void setA(Timestamp a) { this.a = a; } + + public Date getB() { return b; } + + public void setB(Date b) { this.b = b; } + + public java.math.BigDecimal getC() { return c; } + + public void setC(java.math.BigDecimal c) { this.c = c; } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + SimpleJavaBean that = (SimpleJavaBean) o; + + if (!a.equals(that.a)) return false; + if (!b.equals(that.b)) return false; + return c.equals(that.c); + } + + @Override + public int hashCode() { + int result = a.hashCode(); + result = 31 * result + b.hashCode(); + result = 31 * result + c.hashCode(); + return result; + } + } + public class NestedJavaBean implements Serializable { private SimpleJavaBean a; @@ -689,4 +729,140 @@ public void testJavaBeanEncoder() { .as(Encoders.bean(SimpleJavaBean.class)); Assert.assertEquals(data, ds3.collectAsList()); } + + @Test + public void testJavaBeanEncoder2() { + // This is a regression test of SPARK-12404 + OuterScopes.addOuterScope(this); + SimpleJavaBean2 obj = new SimpleJavaBean2(); + obj.setA(new Timestamp(0)); + obj.setB(new Date(0)); + obj.setC(java.math.BigDecimal.valueOf(1)); + Dataset ds = + context.createDataset(Arrays.asList(obj), Encoders.bean(SimpleJavaBean2.class)); + ds.collect(); + } + + public class SmallBean implements Serializable { + private String a; + + private int b; + + public int getB() { + return b; + } + + public void setB(int b) { + this.b = b; + } + + public String getA() { + return a; + } + + public void setA(String a) { + this.a = a; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + SmallBean smallBean = (SmallBean) o; + return b == smallBean.b && com.google.common.base.Objects.equal(a, smallBean.a); + } + + @Override + public int hashCode() { + return Objects.hashCode(a, b); + } + } + + public class NestedSmallBean implements Serializable { + private SmallBean f; + + public SmallBean getF() { + return f; + } + + public void setF(SmallBean f) { + this.f = f; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + NestedSmallBean that = (NestedSmallBean) o; + return Objects.equal(f, that.f); + } + + @Override + public int hashCode() { + return Objects.hashCode(f); + } + } + + @Rule + public transient ExpectedException nullabilityCheck = ExpectedException.none(); + + @Test + public void testRuntimeNullabilityCheck() { + OuterScopes.addOuterScope(this); + + StructType schema = new StructType() + .add("f", new StructType() + .add("a", StringType, true) + .add("b", IntegerType, true), true); + + // Shouldn't throw runtime exception since it passes nullability check. + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", 1 + }) + }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + SmallBean smallBean = new SmallBean(); + smallBean.setA("hello"); + smallBean.setB(1); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + nestedSmallBean.setF(smallBean); + + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + { + Row row = new GenericRow(new Object[] { null }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + NestedSmallBean nestedSmallBean = new NestedSmallBean(); + Assert.assertEquals(ds.collectAsList(), Collections.singletonList(nestedSmallBean)); + } + + nullabilityCheck.expect(RuntimeException.class); + nullabilityCheck.expectMessage( + "Null value appeared in non-nullable field " + + "test.org.apache.spark.sql.JavaDatasetSuite$SmallBean.b of type int."); + + { + Row row = new GenericRow(new Object[] { + new GenericRow(new Object[] { + "hello", null + }) + }); + + DataFrame df = context.createDataFrame(Collections.singletonList(row), schema); + Dataset ds = df.as(Encoders.bean(NestedSmallBean.class)); + + ds.collect(); + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index d86df4cfb9b4..89b9a687682d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -17,20 +17,18 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.analysis.NoSuchTableException -import org.apache.spark.sql.execution.Exchange -import org.apache.spark.sql.execution.PhysicalRDD - import scala.concurrent.duration._ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators +import org.apache.spark.sql.execution.Exchange +import org.apache.spark.sql.execution.PhysicalRDD import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext} -import org.apache.spark.storage.{StorageLevel, RDDBlockId} +import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} +import org.apache.spark.storage.{RDDBlockId, StorageLevel} private case class BigData(s: String) @@ -289,7 +287,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext testData.select('key).registerTempTable("t1") sqlContext.table("t1") sqlContext.dropTempTable("t1") - intercept[NoSuchTableException](sqlContext.table("t1")) + intercept[AnalysisException](sqlContext.table("t1")) } test("Drops cached temporary table") { @@ -301,7 +299,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assert(sqlContext.isCached("t2")) sqlContext.dropTempTable("t1") - intercept[NoSuchTableException](sqlContext.table("t1")) + intercept[AnalysisException](sqlContext.table("t1")) assert(!sqlContext.isCached("t2")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 38c0eb589f96..eb4efcd1d4e4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.scalatest.Matchers._ +import org.apache.spark.sql.catalyst.expressions.NamedExpression import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext @@ -298,7 +298,7 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( - testData.select(isNaN($"a"), isNaN($"b")), + testData.select(isnan($"a"), isnan($"b")), Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil) checkAnswer( @@ -580,26 +580,26 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { ) } - test("sparkPartitionId") { + test("spark_partition_id") { // Make sure we have 2 partitions, each with 2 records. val df = sparkContext.parallelize(Seq[Int](), 2).mapPartitions { _ => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( - df.select(sparkPartitionId()), + df.select(spark_partition_id()), Row(0) :: Row(0) :: Row(1) :: Row(1) :: Nil ) } - test("InputFileName") { + test("input_file_name") { withTempPath { dir => val data = sparkContext.parallelize(0 to 10).toDF("id") data.write.parquet(dir.getCanonicalPath) - val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(inputFileName()) + val answer = sqlContext.read.parquet(dir.getCanonicalPath).select(input_file_name()) .head.getString(0) assert(answer.contains(dir.getCanonicalPath)) - checkAnswer(data.select(inputFileName()).limit(1), Row("")) + checkAnswer(data.select(input_file_name()).limit(1), Row("")) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala index 09f7b507670c..b76fc73b7fa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameComplexTypeSuite.scala @@ -43,4 +43,13 @@ class DataFrameComplexTypeSuite extends QueryTest with SharedSQLContext { val df = sparkContext.parallelize(Seq((1, 1))).toDF("a", "b") df.select(array($"a").as("s")).select(f(expr("s[0]"))).collect() } + + test("SPARK-12477 accessing null element in array field") { + val df = sparkContext.parallelize(Seq((Seq("val1", null, "val2"), + Seq(Some(1), None, Some(2))))).toDF("s", "i") + val nullStringRow = df.selectExpr("s[1]").collect()(0) + assert(nullStringRow == org.apache.spark.sql.Row(null)) + val nullIntRow = df.selectExpr("i[1]").collect()(0) + assert(nullIntRow == org.apache.spark.sql.Row(null)) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index b15af42caa3a..63ad6c439a87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -62,6 +62,28 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { } } + test("randomSplit on reordered partitions") { + // This test ensures that randomSplit does not create overlapping splits even when the + // underlying dataframe (such as the one below) doesn't guarantee a deterministic ordering of + // rows in each partition. + val data = + sparkContext.parallelize(1 to 600, 2).mapPartitions(scala.util.Random.shuffle(_)).toDF("id") + val splits = data.randomSplit(Array[Double](2, 3), seed = 1) + + assert(splits.length == 2, "wrong number of splits") + + // Verify that the splits span the entire dataset + assert(splits.flatMap(_.collect()).toSet == data.collect().toSet) + + // Verify that the splits don't overalap + assert(splits(0).intersect(splits(1)).collect().isEmpty) + + // Verify that the results are deterministic across multiple runs + val firstRun = splits.toSeq.map(_.collect().toSeq) + val secondRun = data.randomSplit(Array[Double](2, 3), seed = 1).toSeq.map(_.collect().toSeq) + assert(firstRun == secondRun) + } + test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") val corr1 = df.stat.corr("a", "b", "pearson") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 0644bdaaa35c..983dfbdedeef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -29,8 +29,8 @@ import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestData.TestData2 import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} +import org.apache.spark.sql.test.SQLTestData.TestData2 import org.apache.spark.sql.types._ class DataFrameSuite extends QueryTest with SharedSQLContext { @@ -176,6 +176,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData.select("key").collect().toSeq) } + test("selectExpr with udtf") { + val df = Seq((Map("1" -> 1), 1)).toDF("a", "b") + checkAnswer( + df.selectExpr("explode(a)"), + Row("1", 1) :: Nil) + } + test("filterExpr") { val res = testData.collect().filter(_.getInt(0) > 90).toSeq checkAnswer(testData.filter("key > 90"), res) @@ -301,6 +308,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( mapData.toDF().limit(1), mapData.take(1).map(r => Row.fromSeq(r.productIterator.toSeq))) + + // SPARK-12340: overstep the bounds of Int in SparkPlan.executeTake + checkAnswer( + sqlContext.range(2).limit(2147483638), + Row(0) :: Row(1) :: Nil + ) } test("except") { @@ -334,15 +347,6 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { ) } - test("deprecated callUdf in SQLContext") { - val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") - val sqlctx = df.sqlContext - sqlctx.udf.register("simpleUdf", (v: Int) => v * v) - checkAnswer( - df.select($"id", callUdf("simpleUdf", $"value")), - Row("id1", 1) :: Row("id2", 16) :: Row("id3", 25) :: Nil) - } - test("callUDF in SQLContext") { val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") val sqlctx = df.sqlContext @@ -762,6 +766,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val res11 = sqlContext.range(-1).select("id") assert(res11.count == 0) + + // using the default slice number + val res12 = sqlContext.range(3, 15, 3).select("id") + assert(res12.count == 4) + assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) } test("SPARK-8621: support empty string column name") { @@ -1170,4 +1179,50 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val primitiveUDF = udf((i: Int) => i * 2) checkAnswer(df.select(primitiveUDF($"age")), Row(44) :: Row(null) :: Nil) } + + test("SPARK-12398 truncated toString") { + val df1 = Seq((1L, "row1")).toDF("id", "name") + assert(df1.toString() === "[id: bigint, name: string]") + + val df2 = Seq((1L, "c2", false)).toDF("c1", "c2", "c3") + assert(df2.toString === "[c1: bigint, c2: string ... 1 more field]") + + val df3 = Seq((1L, "c2", false, 10)).toDF("c1", "c2", "c3", "c4") + assert(df3.toString === "[c1: bigint, c2: string ... 2 more fields]") + + val df4 = Seq((1L, Tuple2(1L, "val"))).toDF("c1", "c2") + assert(df4.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string>]") + + val df5 = Seq((1L, Tuple2(1L, "val"), 20.0)).toDF("c1", "c2", "c3") + assert(df5.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string> ... 1 more field]") + + val df6 = Seq((1L, Tuple2(1L, "val"), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert(df6.toString === "[c1: bigint, c2: struct<_1: bigint, _2: string> ... 2 more fields]") + + val df7 = Seq((1L, Tuple3(1L, "val", 2), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert( + df7.toString === + "[c1: bigint, c2: struct<_1: bigint, _2: string ... 1 more field> ... 2 more fields]") + + val df8 = Seq((1L, Tuple7(1L, "val", 2, 3, 4, 5, 6), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert( + df8.toString === + "[c1: bigint, c2: struct<_1: bigint, _2: string ... 5 more fields> ... 2 more fields]") + + val df9 = + Seq((1L, Tuple4(1L, Tuple4(1L, 2L, 3L, 4L), 2L, 3L), 20.0, 1)).toDF("c1", "c2", "c3", "c4") + assert( + df9.toString === + "[c1: bigint, c2: struct<_1: bigint," + + " _2: struct<_1: bigint," + + " _2: bigint ... 2 more fields> ... 2 more fields> ... 2 more fields]") + + } + + test("SPARK-12512: support `.` in column name for withColumn()") { + val df = Seq("a" -> "b").toDF("col.a", "col.b") + checkAnswer(df.select(df("*")), Row("a", "b")) + checkAnswer(df.withColumn("col.a", lit("c")), Row("c", "b")) + checkAnswer(df.withColumn("col.c", lit("c")), Row("a", "b", "c")) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala similarity index 54% rename from sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala rename to sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala index 2c98f1c3cc49..09a56f6f3ae2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameWindowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowSuite.scala @@ -15,16 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.hive +package org.apache.spark.sql -import org.apache.spark.sql.{Row, QueryTest} -import org.apache.spark.sql.expressions.Window +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction, Window} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DataType, LongType, StructType} -class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { - import hiveContext.implicits._ - import hiveContext.sql +class DataFrameWindowSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("reuse window partitionBy") { val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") @@ -55,10 +54,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lead("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lead(value) OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Row("1") :: Row(null) :: Row("2") :: Row(null) :: Nil) } test("lag") { @@ -68,10 +64,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lag("value", 1).over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lag(value) OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Row(null) :: Row("1") :: Row(null) :: Row("2") :: Nil) } test("lead with default value") { @@ -81,10 +74,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lead("value", 2, "n/a").over(Window.partitionBy("key").orderBy("value"))), - sql( - """SELECT - | lead(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Seq(Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"), Row("n/a"), Row("n/a"))) } test("lag with default value") { @@ -94,10 +84,7 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { checkAnswer( df.select( lag("value", 2, "n/a").over(Window.partitionBy($"key").orderBy($"value"))), - sql( - """SELECT - | lag(value, 2, "n/a") OVER (PARTITION BY key ORDER BY value) - | FROM window_table""".stripMargin).collect()) + Seq(Row("n/a"), Row("n/a"), Row("1"), Row("1"), Row("n/a"), Row("n/a"), Row("2"))) } test("rank functions in unspecific window") { @@ -112,78 +99,52 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { count("key").over(Window.partitionBy("value").orderBy("key")), sum("key").over(Window.partitionBy("value").orderBy("key")), ntile(2).over(Window.partitionBy("value").orderBy("key")), - rowNumber().over(Window.partitionBy("value").orderBy("key")), - denseRank().over(Window.partitionBy("value").orderBy("key")), + row_number().over(Window.partitionBy("value").orderBy("key")), + dense_rank().over(Window.partitionBy("value").orderBy("key")), rank().over(Window.partitionBy("value").orderBy("key")), - cumeDist().over(Window.partitionBy("value").orderBy("key")), - percentRank().over(Window.partitionBy("value").orderBy("key"))), - sql( - s"""SELECT - |key, - |max(key) over (partition by value order by key), - |min(key) over (partition by value order by key), - |avg(key) over (partition by value order by key), - |count(key) over (partition by value order by key), - |sum(key) over (partition by value order by key), - |ntile(2) over (partition by value order by key), - |row_number() over (partition by value order by key), - |dense_rank() over (partition by value order by key), - |rank() over (partition by value order by key), - |cume_dist() over (partition by value order by key), - |percent_rank() over (partition by value order by key) - |FROM window_table""".stripMargin).collect()) + cume_dist().over(Window.partitionBy("value").orderBy("key")), + percent_rank().over(Window.partitionBy("value").orderBy("key"))), + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d, 0.0d) :: + Row(1, 1, 1, 1.0d, 1, 1, 1, 1, 1, 1, 1.0d / 3.0d, 0.0d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 1, 2, 2, 2, 1.0d, 0.5d) :: + Row(2, 2, 1, 5.0d / 3.0d, 3, 5, 2, 3, 2, 2, 1.0d, 0.5d) :: Nil) } test("aggregation and rows between") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + val df = Seq((1, "1"), (2, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 2))), - sql( - """SELECT - | avg(key) OVER - | (PARTITION BY value ORDER BY key ROWS BETWEEN 1 preceding and 2 following) - | FROM window_table""".stripMargin).collect()) + Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(3.0d / 2.0d), Row(2.0d), Row(2.0d))) } - test("aggregation and range betweens") { - val df = Seq((1, "1"), (2, "2"), (1, "1"), (2, "2")).toDF("key", "value") + test("aggregation and range between") { + val df = Seq((1, "1"), (1, "1"), (3, "1"), (2, "2"), (2, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( avg("key").over(Window.partitionBy($"value").orderBy($"key").rangeBetween(-1, 1))), - sql( - """SELECT - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and 1 following) - | FROM window_table""".stripMargin).collect()) + Seq(Row(4.0d / 3.0d), Row(4.0d / 3.0d), Row(7.0d / 4.0d), Row(5.0d / 2.0d), + Row(2.0d), Row(2.0d))) } - test("aggregation and rows betweens with unbounded") { + test("aggregation and rows between with unbounded") { val df = Seq((1, "1"), (2, "2"), (2, "3"), (1, "3"), (3, "2"), (4, "3")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( df.select( $"key", - last("value").over( + last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(0, Long.MaxValue)), - last("value").over( + last("key").over( Window.partitionBy($"value").orderBy($"key").rowsBetween(Long.MinValue, 0)), - last("value").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 3))), - sql( - """SELECT - | key, - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between current row and unbounded following), - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between unbounded preceding and current row), - | last_value(value) OVER - | (PARTITION BY value ORDER BY key ROWS between 1 preceding and 3 following) - | FROM window_table""".stripMargin).collect()) + last("key").over(Window.partitionBy($"value").orderBy($"key").rowsBetween(-1, 1))), + Seq(Row(1, 1, 1, 1), Row(2, 3, 2, 3), Row(3, 3, 3, 3), Row(1, 4, 1, 2), Row(2, 4, 2, 4), + Row(4, 4, 4, 4))) } - test("aggregation and range betweens with unbounded") { + test("aggregation and range between with unbounded") { val df = Seq((5, "1"), (5, "2"), (4, "2"), (6, "2"), (3, "1"), (2, "2")).toDF("key", "value") df.registerTempTable("window_table") checkAnswer( @@ -200,18 +161,12 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { avg("key").over(Window.partitionBy("value").orderBy("key").rangeBetween(-1, 0)) .as("avg_key3") ), - sql( - """SELECT - | key, - | last_value(value) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 2 preceding and 1 preceding) == "2", - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN unbounded preceding and 1 following), - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN current row and unbounded following), - | avg(key) OVER - | (PARTITION BY value ORDER BY key RANGE BETWEEN 1 preceding and current row) - | FROM window_table""".stripMargin).collect()) + Seq(Row(3, null, 3.0d, 4.0d, 3.0d), + Row(5, false, 4.0d, 5.0d, 5.0d), + Row(2, null, 2.0d, 17.0d / 4.0d, 2.0d), + Row(4, true, 11.0d / 3.0d, 5.0d, 4.0d), + Row(5, true, 17.0d / 4.0d, 11.0d / 2.0d, 4.5d), + Row(6, true, 17.0d / 4.0d, 6.0d, 11.0d / 2.0d))) } test("reverse sliding range frame") { @@ -254,6 +209,107 @@ class HiveDataFrameWindowSuite extends QueryTest with TestHiveSingleton { sum($"value").over(window.rangeBetween(1, Long.MaxValue))), Row(1, 13, null) :: Row(2, 13, 2) :: Row(4, 7, 9) :: Row(3, 11, 6) :: Row(2, 13, 2) :: Row(1, 13, null) :: Nil) + } + + test("statistical functions") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). + toDF("key", "value") + val window = Window.partitionBy($"key") + checkAnswer( + df.select( + $"key", + var_pop($"value").over(window), + var_samp($"value").over(window), + approxCountDistinct($"value").over(window)), + Seq.fill(4)(Row("a", 1.0d / 4.0d, 1.0d / 3.0d, 2)) + ++ Seq.fill(3)(Row("b", 2.0d / 3.0d, 1.0d, 3))) + } + + test("window function with aggregates") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)). + toDF("key", "value") + val window = Window.orderBy() + checkAnswer( + df.groupBy($"key") + .agg( + sum($"value"), + sum(sum($"value")).over(window) - sum($"value")), + Seq(Row("a", 6, 9), Row("b", 9, 6))) + } + + test("window function with udaf") { + val udaf = new UserDefinedAggregateFunction { + def inputSchema: StructType = new StructType() + .add("a", LongType) + .add("b", LongType) + + def bufferSchema: StructType = new StructType() + .add("product", LongType) + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer(0) = 0L + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + if (!(input.isNullAt(0) || input.isNullAt(1))) { + buffer(0) = buffer.getLong(0) + input.getLong(0) * input.getLong(1) + } + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0) + } + + def evaluate(buffer: Row): Any = + buffer.getLong(0) + } + val df = Seq( + ("a", 1, 1), + ("a", 1, 5), + ("a", 2, 10), + ("a", 2, -1), + ("b", 4, 7), + ("b", 3, 8), + ("b", 2, 4)) + .toDF("key", "a", "b") + val window = Window.partitionBy($"key").orderBy($"a").rangeBetween(Long.MinValue, 0L) + checkAnswer( + df.select( + $"key", + $"a", + $"b", + udaf($"a", $"b").over(window)), + Seq( + Row("a", 1, 1, 6), + Row("a", 1, 5, 6), + Row("a", 2, 10, 24), + Row("a", 2, -1, 24), + Row("b", 4, 7, 60), + Row("b", 3, 8, 32), + Row("b", 2, 4, 8))) + } + + test("null inputs") { + val df = Seq(("a", 1), ("a", 1), ("a", 2), ("a", 2), ("b", 4), ("b", 3), ("b", 2)) + .toDF("key", "value") + val window = Window.orderBy() + checkAnswer( + df.select( + $"key", + $"value", + avg(lit(null)).over(window), + sum(lit(null)).over(window)), + Seq( + Row("a", 1, null, null), + Row("a", 1, null, null), + Row("a", 2, null, null), + Row("a", 2, null, null), + Row("b", 4, null, null), + Row("b", 3, null, null), + Row("b", 2, null, null))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index c6d2bf07b280..3258f3782d8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql - import scala.language.postfixOps -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.functions._ import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext /** An `Aggregator` that adds up any numeric type returned by the given function. */ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index f1b6b98dc160..53b5f45c2d4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -17,13 +17,14 @@ package org.apache.spark.sql -import java.io.{ObjectInput, ObjectOutput, Externalizable} +import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.sql.{Date, Timestamp} import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext - +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} class DatasetSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -42,6 +43,17 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 1, 1, 1) } + + test("SPARK-12404: Datatype Helper Serializablity") { + val ds = sparkContext.parallelize(( + new Timestamp(0), + new Date(0), + java.math.BigDecimal.valueOf(1), + scala.math.BigDecimal(1)) :: Nil).toDS() + + ds.collect() + } + test("collect, first, and take should use encoders for serialization") { val item = NonSerializableCaseClass("abcd") val ds = Seq(item).toDS() @@ -501,7 +513,47 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { ds.as[ClassData2].collect() } - assert(e.getMessage.contains("cannot resolve 'c' given input columns a, b"), e.getMessage) + assert(e.getMessage.contains("cannot resolve 'c' given input columns: [a, b]"), e.getMessage) + } + + test("runtime nullability check") { + val schema = StructType(Seq( + StructField("f", StructType(Seq( + StructField("a", StringType, nullable = true), + StructField("b", IntegerType, nullable = false) + )), nullable = true) + )) + + def buildDataset(rows: Row*): Dataset[NestedStruct] = { + val rowRDD = sqlContext.sparkContext.parallelize(rows) + sqlContext.createDataFrame(rowRDD, schema).as[NestedStruct] + } + + checkAnswer( + buildDataset(Row(Row("hello", 1))), + NestedStruct(ClassData("hello", 1)) + ) + + // Shouldn't throw runtime exception when parent object (`ClassData`) is null + assert(buildDataset(Row(null)).collect() === Array(NestedStruct(null))) + + val message = intercept[RuntimeException] { + buildDataset(Row(Row("hello", null))).collect() + }.getMessage + + assert(message.contains( + "Null value appeared in non-nullable field org.apache.spark.sql.ClassData.b of type Int." + )) + } + + test("SPARK-12478: top level null field") { + val ds0 = Seq(NestedStruct(null)).toDS() + checkAnswer(ds0, NestedStruct(null)) + checkAnswer(ds0.toDF(), Row(null)) + + val ds1 = Seq(DeepNestedStruct(NestedStruct(null))).toDS() + checkAnswer(ds1, DeepNestedStruct(NestedStruct(null))) + checkAnswer(ds1.toDF(), Row(Row(null))) } } @@ -509,6 +561,9 @@ case class ClassData(a: String, b: Int) case class ClassData2(c: String, d: Int) case class ClassNullableData(a: String, b: Integer) +case class NestedStruct(f: ClassData) +case class DeepNestedStruct(f: NestedStruct) + /** * A class used to test serialization using encoders. This class throws exceptions when using * Java serialization -- so the only way it can be "serialized" is through our encoders. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala index a61c3aa48a73..f7aa3b747ae5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.sql.{Timestamp, Date} +import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat import org.apache.spark.sql.catalyst.util.DateTimeUtils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala index 78a98798eff6..2c4b4f80ff9e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala @@ -15,16 +15,14 @@ * limitations under the License. */ -package test.org.apache.spark.sql +package org.apache.spark.sql import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute} -import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.{Row, Strategy, QueryTest} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.unsafe.types.UTF8String case class FastOperator(output: Seq[Attribute]) extends SparkPlan { @@ -34,6 +32,7 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan { sparkContext.parallelize(Seq(row)) } + override def producedAttributes: AttributeSet = outputSet override def children: Seq[SparkPlan] = Nil } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 1f384edf321b..1391c9d57ff7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -73,6 +73,10 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { checkAnswer( df.select($"key", functions.json_tuple($"jstring", "f1", "f2", "f3", "f4", "f5")), expected) + + checkAnswer( + df.selectExpr("key", "json_tuple(jstring, 'f1', 'f2', 'f3', 'f4', 'f5')"), + expected) } test("json_tuple filter and group") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 5688f46e5e3d..3d7c576965fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,9 +19,9 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} -import org.apache.spark.sql.catalyst.TableIdentifier class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext { import testImplicits._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala index 162c0b56c6e1..6a375a33bfcf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MultiSQLContextsSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql -import org.apache.spark._ import org.scalatest.BeforeAndAfterAll +import org.apache.spark._ + class MultiSQLContextsSuite extends SparkFunSuite with BeforeAndAfterAll { private var originalActiveSQLContext: Option[SQLContext] = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index bc22fb8b7bdb..fac26bd0c026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -20,11 +20,17 @@ package org.apache.spark.sql import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ +import scala.util.control.NonFatal +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.{LogicalRDD, Queryable} import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.Queryable +import org.apache.spark.sql.execution.datasources.LogicalRelation abstract class QueryTest extends PlanTest { @@ -123,6 +129,10 @@ abstract class QueryTest extends PlanTest { |""".stripMargin) } + checkJsonFormat(analyzedDF) + + assertEmptyMissingInput(df) + QueryTest.checkAnswer(analyzedDF, expectedAnswer) match { case Some(errorMessage) => fail(errorMessage) case None => @@ -177,6 +187,109 @@ abstract class QueryTest extends PlanTest { s"Expected query to contain $numCachedTables, but it actually had ${cachedData.size}\n" + planWithCaching) } + + private def checkJsonFormat(df: DataFrame): Unit = { + val logicalPlan = df.queryExecution.analyzed + // bypass some cases that we can't handle currently. + logicalPlan.transform { + case _: MapPartitions[_, _] => return + case _: MapGroups[_, _, _] => return + case _: AppendColumns[_, _] => return + case _: CoGroup[_, _, _, _] => return + case _: LogicalRelation => return + }.transformAllExpressions { + case a: ImperativeAggregate => return + } + + // bypass hive tests before we fix all corner cases in hive module. + if (this.getClass.getName.startsWith("org.apache.spark.sql.hive")) return + + val jsonString = try { + logicalPlan.toJSON + } catch { + case NonFatal(e) => + fail( + s""" + |Failed to parse logical plan to JSON: + |${logicalPlan.treeString} + """.stripMargin, e) + } + + // scala function is not serializable to JSON, use null to replace them so that we can compare + // the plans later. + val normalized1 = logicalPlan.transformAllExpressions { + case udf: ScalaUDF => udf.copy(function = null) + case gen: UserDefinedGenerator => gen.copy(function = null) + } + + // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains + // these non-serializable stuff, and use these original ones to replace the null-placeholders + // in the logical plans parsed from JSON. + var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l } + var localRelations = logicalPlan.collect { case l: LocalRelation => l } + var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i } + + val jsonBackPlan = try { + TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) + } catch { + case NonFatal(e) => + fail( + s""" + |Failed to rebuild the logical plan from JSON: + |${logicalPlan.treeString} + | + |${logicalPlan.prettyJson} + """.stripMargin, e) + } + + val normalized2 = jsonBackPlan transformDown { + case l: LogicalRDD => + val origin = logicalRDDs.head + logicalRDDs = logicalRDDs.drop(1) + LogicalRDD(l.output, origin.rdd)(sqlContext) + case l: LocalRelation => + val origin = localRelations.head + localRelations = localRelations.drop(1) + l.copy(data = origin.data) + case l: InMemoryRelation => + val origin = inMemoryRelations.head + inMemoryRelations = inMemoryRelations.drop(1) + InMemoryRelation( + l.output, + l.useCompression, + l.batchSize, + l.storageLevel, + origin.child, + l.tableName)( + origin.cachedColumnBuffers, + l._statistics, + origin._batchStats) + } + + assert(logicalRDDs.isEmpty) + assert(localRelations.isEmpty) + assert(inMemoryRelations.isEmpty) + + if (normalized1 != normalized2) { + fail( + s""" + |== FAIL: the logical plan parsed from json does not match the original one === + |${sideBySide(logicalPlan.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** + * Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans. + */ + def assertEmptyMissingInput(query: Queryable): Unit = { + assert(query.queryExecution.analyzed.missingInput.isEmpty, + s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}") + assert(query.queryExecution.optimizedPlan.missingInput.isEmpty, + s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}") + assert(query.queryExecution.executedPlan.missingInput.isEmpty, + s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}") + } } object QueryTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 3ba14d7602a6..4552eb6ce00a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} +import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 3d2bd236ceea..cf0701eca29e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.{TestSQLContext, SharedSQLContext} - +import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} class SQLConfSuite extends QueryTest with SharedSQLContext { private val testKey = "test.key.0" @@ -93,4 +92,41 @@ class SQLConfSuite extends QueryTest with SharedSQLContext { } assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10") } + + test("Test SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE's method") { + sqlContext.conf.clear() + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "100") + assert(sqlContext.conf.targetPostShuffleInputSize === 100) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1k") + assert(sqlContext.conf.targetPostShuffleInputSize === 1024) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1M") + assert(sqlContext.conf.targetPostShuffleInputSize === 1048576) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "1g") + assert(sqlContext.conf.targetPostShuffleInputSize === 1073741824) + + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1") + assert(sqlContext.conf.targetPostShuffleInputSize === -1) + + // Test overflow exception + intercept[IllegalArgumentException] { + // This value exceeds Long.MaxValue + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "90000000000g") + } + + intercept[IllegalArgumentException] { + // This value less than Int.MinValue + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-90000000000g") + } + + // Test invalid input + intercept[IllegalArgumentException] { + // This value exceeds Long.MaxValue + sqlContext.setConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, "-1g") + } + sqlContext.conf.clear() + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index bb82b562aaaa..5de0979606b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} +import org.apache.spark.sql.test.SQLTestData._ import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ @@ -2028,4 +2028,43 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { Row(false) :: Row(true) :: Nil) } + test("rollup") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by rollup(course, year)" + + " order by course, year"), + Row(null, null, 113000.0) :: + Row("Java", null, 50000.0) :: + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("dotNET", null, 63000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: Nil + ) + } + + test("cube") { + checkAnswer( + sql("select course, year, sum(earnings) from courseSales group by cube(course, year)"), + Row("Java", 2012, 20000.0) :: + Row("Java", 2013, 30000.0) :: + Row("Java", null, 50000.0) :: + Row("dotNET", 2012, 15000.0) :: + Row("dotNET", 2013, 48000.0) :: + Row("dotNET", null, 63000.0) :: + Row(null, 2012, 35000.0) :: + Row(null, 2013, 78000.0) :: + Row(null, null, 113000.0) :: Nil + ) + } + + test("hash function") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + withTempTable("tbl") { + df.registerTempTable("tbl") + checkAnswer( + df.select(hash($"i", $"j")), + sql("SELECT hash(i, j) from tbl") + ) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala index 00f1526576cc..a32763db054f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UnsafeRowSuite.scala @@ -34,8 +34,8 @@ class UnsafeRowSuite extends SparkFunSuite { test("UnsafeRow Java serialization") { // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data val data = new Array[Byte](1024) - val row = new UnsafeRow - row.pointTo(data, 1, 16) + val row = new UnsafeRow(1) + row.pointTo(data, 16) row.setLong(0, 19285) val ser = new JavaSerializer(new SparkConf).newInstance() @@ -47,8 +47,8 @@ class UnsafeRowSuite extends SparkFunSuite { test("UnsafeRow Kryo serialization") { // serializing an UnsafeRow pointing to a large buffer should only serialize the relevant data val data = new Array[Byte](1024) - val row = new UnsafeRow - row.pointTo(data, 1, 16) + val row = new UnsafeRow(1) + row.pointTo(data, 16) row.setLong(0, 19285) val ser = new KryoSerializer(new SparkConf).newInstance() @@ -86,11 +86,10 @@ class UnsafeRowSuite extends SparkFunSuite { offheapRowPage.getBaseOffset, arrayBackedUnsafeRow.getSizeInBytes ) - val offheapUnsafeRow: UnsafeRow = new UnsafeRow() + val offheapUnsafeRow: UnsafeRow = new UnsafeRow(3) offheapUnsafeRow.pointTo( offheapRowPage.getBaseObject, offheapRowPage.getBaseOffset, - 3, // num fields arrayBackedUnsafeRow.getSizeInBytes ) assert(offheapUnsafeRow.getBaseObject === null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index f602f2fb89ca..6800a8ddf6e3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,19 +17,17 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayData} - import scala.beans.{BeanInfo, BeanProperty} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.collection.OpenHashSet - @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -65,6 +63,11 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] private[spark] override def asNullable: MyDenseVectorUDT = this + + override def equals(other: Any): Boolean = other match { + case _: MyDenseVectorUDT => true + case _ => false + } } class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 180050bdac00..35ff1c40fe6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql._ -import org.apache.spark.{SparkFunSuite, SparkContext, SparkConf, MapOutputStatistics} class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { @@ -260,6 +260,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .set("spark.driver.allowMultipleContexts", "true") .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") .set( SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, targetNumPostShufflePartitions.toString) @@ -318,7 +319,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSQLContext(test, 1536, minNumPostShufflePartitions) + withSQLContext(test, 2000, minNumPostShufflePartitions) } test(s"determining the number of reducers: join operator$testNameNote") { @@ -421,7 +422,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { } } - withSQLContext(test, 6144, minNumPostShufflePartitions) + withSQLContext(test, 6644, minNumPostShufflePartitions) } test(s"determining the number of reducers: complex query 2$testNameNote") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 911d12e93e50..87bff3295f5b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -28,7 +28,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => ConvertToSafe(Exchange(SinglePartition, ConvertToUnsafe(plan))), + plan => Exchange(SinglePartition, plan), input.map(Row.fromTuple) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala deleted file mode 100644 index faef76d52ae7..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExpandSuite.scala +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, BoundReference, Alias, Literal} -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.IntegerType - -class ExpandSuite extends SparkPlanTest with SharedSQLContext { - import testImplicits.localSeqToDataFrameHolder - - private def testExpand(f: SparkPlan => SparkPlan): Unit = { - val input = (1 to 1000).map(Tuple1.apply) - val projections = Seq.tabulate(2) { i => - Alias(BoundReference(0, IntegerType, false), "id")() :: Alias(Literal(i), "gid")() :: Nil - } - val attributes = projections.head.map(_.toAttribute) - checkAnswer( - input.toDF(), - plan => Expand(projections, attributes, f(plan)), - input.flatMap(i => Seq.tabulate(2)(j => Row(i._1, j))) - ) - } - - test("inheriting child row type") { - val exprs = AttributeReference("a", IntegerType, false)() :: Nil - val plan = Expand(Seq(exprs), exprs, ConvertToUnsafe(LocalTableScan(exprs, Seq.empty))) - assert(plan.outputsUnsafeRows, "Expand should inherits the created row type from its child.") - } - - test("expanding UnsafeRows") { - testExpand(ConvertToUnsafe) - } - - test("expanding SafeRows") { - testExpand(identity) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index e7a08481cfa8..6f10e4b80577 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.encoders.RowEncoder -import org.apache.spark.sql.types.{LongType, StringType, IntegerType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} class GroupedIteratorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 2fb439f50117..03a1b8e11d45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.joins.{SortMergeJoin, BroadcastHashJoin} +import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -417,6 +417,45 @@ class PlannerSuite extends SharedSQLContext { } } + test("EnsureRequirements eliminates Exchange if child has Exchange with same partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = Exchange(finalPartitioning, + DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)), + None) + + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case e: Exchange => true }.size == 2) { + fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") + } + } + + test("EnsureRequirements does not eliminate Exchange with different partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil) + // Number of partitions differ + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) + val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = Exchange(finalPartitioning, + DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)), + None) + + val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan) + assertDistributionRequirementsAreSatisfied(outputPlan) + if (outputPlan.collect { case e: Exchange => true }.size == 1) { + fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") + } + } + // --------------------------------------------------------------------------------------------- } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala deleted file mode 100644 index 2328899bb2f8..000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala +++ /dev/null @@ -1,164 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{SQLContext, Row} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull} -import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{ArrayType, StringType} -import org.apache.spark.unsafe.types.UTF8String - -class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext { - - private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect { - case c: ConvertToUnsafe => c - case c: ConvertToSafe => c - } - - private val outputsSafe = ReferenceSort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) - assert(!outputsSafe.outputsUnsafeRows) - private val outputsUnsafe = Sort(Nil, false, PhysicalRDD(Seq.empty, null, "name")) - assert(outputsUnsafe.outputsUnsafeRows) - - test("planner should insert unsafe->safe conversions when required") { - val plan = Limit(10, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe]) - } - - test("filter can process unsafe rows") { - val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 1) - assert(preparedPlan.outputsUnsafeRows) - } - - test("filter can process safe rows") { - val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).isEmpty) - assert(!preparedPlan.outputsUnsafeRows) - } - - test("coalesce can process unsafe rows") { - val plan = Coalesce(1, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 1) - assert(preparedPlan.outputsUnsafeRows) - } - - test("except can process unsafe rows") { - val plan = Except(outputsUnsafe, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 2) - assert(preparedPlan.outputsUnsafeRows) - } - - test("except requires all of its input rows' formats to agree") { - val plan = Except(outputsSafe, outputsUnsafe) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("intersect can process unsafe rows") { - val plan = Intersect(outputsUnsafe, outputsUnsafe) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(getConverters(preparedPlan).size === 2) - assert(preparedPlan.outputsUnsafeRows) - } - - test("intersect requires all of its input rows' formats to agree") { - val plan = Intersect(outputsSafe, outputsUnsafe) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("execute() fails an assertion if inputs rows are of different formats") { - val e = intercept[AssertionError] { - Union(Seq(outputsSafe, outputsUnsafe)).execute() - } - assert(e.getMessage.contains("format")) - } - - test("union requires all of its input rows' formats to agree") { - val plan = Union(Seq(outputsSafe, outputsUnsafe)) - assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("union can process safe rows") { - val plan = Union(Seq(outputsSafe, outputsSafe)) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(!preparedPlan.outputsUnsafeRows) - } - - test("union can process unsafe rows") { - val plan = Union(Seq(outputsUnsafe, outputsUnsafe)) - val preparedPlan = sqlContext.prepareForExecution.execute(plan) - assert(preparedPlan.outputsUnsafeRows) - } - - test("round trip with ConvertToUnsafe and ConvertToSafe") { - val input = Seq(("hello", 1), ("world", 2)) - checkAnswer( - sqlContext.createDataFrame(input), - plan => ConvertToSafe(ConvertToUnsafe(plan)), - input.map(Row.fromTuple) - ) - } - - test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") { - SQLContext.setActive(sqlContext) - val schema = ArrayType(StringType) - val rows = (1 to 100).map { i => - InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString)))) - } - val relation = LocalTableScan(Seq(AttributeReference("t", schema)()), rows) - - val plan = - DummyPlan( - ConvertToSafe( - ConvertToUnsafe(relation))) - assert(plan.execute().collect().map(_.getUTF8String(0).toString) === (1 to 100).map(_.toString)) - } -} - -case class DummyPlan(child: SparkPlan) extends UnaryNode { - - override protected def doExecute(): RDD[InternalRow] = { - child.execute().mapPartitions { iter => - // This `DummyPlan` is in safe mode, so we don't need to do copy even we hold some - // values gotten from the incoming rows. - // we cache all strings here to make sure we have deep copied UTF8String inside incoming - // safe InternalRow. - val strings = new scala.collection.mutable.ArrayBuffer[UTF8String] - iter.foreach { row => - strings += row.getArray(0).getUTF8String(0) - } - strings.map(InternalRow(_)).iterator - } - } - - override def output: Seq[Attribute] = Seq(AttributeReference("a", StringType)()) -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala index e5d34be4c65e..6259453da26a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution import scala.util.Random import org.apache.spark.AccumulatorSuite +import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -import org.apache.spark.sql.{RandomDataGenerator, Row} - /** * Test sorting. Many of the test cases generate random data and compares the sorted result with one @@ -99,7 +98,7 @@ class SortSuite extends SparkPlanTest with SharedSQLContext { ) checkThatPlansAgree( inputDf, - p => ConvertToSafe(Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23)), + p => Sort(sortOrder, global = true, p: SparkPlan, testSpillFrequency = 23), ReferenceSort(sortOrder, global = true, _: SparkPlan), sortAnswers = false ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 5a8406789ab8..9c258cb31f46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -17,16 +17,16 @@ package org.apache.spark.sql.execution -import scala.util.control.NonFatal import scala.collection.mutable -import scala.util.{Try, Random} +import scala.util.{Random, Try} +import scala.util.control.NonFatal import org.scalatest.Matchers -import org.apache.spark.{SparkConf, TaskContextImpl, TaskContext, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext, TaskContextImpl} import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 29027a664b4b..95c9550aebb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark._ import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 09e258299de5..9f09eb4429c1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -17,20 +17,19 @@ package org.apache.spark.sql.execution -import java.io.{File, ByteArrayInputStream, ByteArrayOutputStream} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} +import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD -import org.apache.spark.storage.ShuffleBlockId -import org.apache.spark.util.collection.ExternalSorter -import org.apache.spark.util.Utils import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow} import org.apache.spark.sql.types._ -import org.apache.spark._ - +import org.apache.spark.storage.ShuffleBlockId +import org.apache.spark.util.collection.ExternalSorter +import org.apache.spark.util.Utils /** * used to test close InputStream in UnsafeRowSerializer diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala index 706ff1f99850..9ca8c4d2ed2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnTypeSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql.execution.columnar -import java.nio.{ByteOrder, ByteBuffer} +import java.nio.{ByteBuffer, ByteOrder} +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.types._ -import org.apache.spark.{Logging, SparkFunSuite} - class ColumnTypeSuite extends SparkFunSuite with Logging { private val DEFAULT_BUFFER_SIZE = 512 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala index 9cae65ef6f5d..1529313dfbd5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ColumnarTestUtils.scala @@ -22,7 +22,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, GenericMutableRow} -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types.{AtomicType, Decimal} import org.apache.spark.unsafe.types.UTF8String @@ -60,6 +60,7 @@ object ColumnarTestUtils { case MAP(_) => ArrayBasedMapData( Map(Random.nextInt() -> UTF8String.fromString(Random.nextString(Random.nextInt(32))))) + case _ => throw new IllegalArgumentException(s"Unknown column type $columnType") }).asInstanceOf[JvmType] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala index 35dc9a276cef..dc22d3e8e4d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnAccessorSuite.scala @@ -21,7 +21,7 @@ import java.nio.ByteBuffer import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} import org.apache.spark.sql.types._ class TestNullableColumnAccessor[JvmType]( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala index 93be3e16a5ed..cdd4551d64b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/NullableColumnBuilderSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.columnar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.CatalystTypeConverters -import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, GenericMutableRow} +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeProjection} import org.apache.spark.sql.types._ class TestNullableColumnBuilder[JvmType](columnType: ColumnType[JvmType]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index ccbddef0fad3..f67e9c7dae27 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala index 4cc0a3a9585d..1742df31bba9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonParsingOptionsSuite.scala @@ -111,4 +111,23 @@ class JsonParsingOptionsSuite extends QueryTest with SharedSQLContext { assert(df.schema.head.name == "age") assert(df.first().getDouble(0).isNaN) } + + test("allowBackslashEscapingAnyCharacter off") { + val str = """{"name": "Cazen Lee", "price": "\$10"}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "false").json(rdd) + + assert(df.schema.head.name == "_corrupt_record") + } + + test("allowBackslashEscapingAnyCharacter on") { + val str = """{"name": "Cazen Lee", "price": "\$10"}""" + val rdd = sqlContext.sparkContext.parallelize(Seq(str)) + val df = sqlContext.read.option("allowBackslashEscapingAnyCharacter", "true").json(rdd) + + assert(df.schema.head.name == "name") + assert(df.schema.last.name == "price") + assert(df.first().getString(0) == "Cazen Lee") + assert(df.first().getString(1) == "$10") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index baa258ad2615..b3b6b7df0c1d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.json import java.io.{File, StringWriter} import java.sql.{Date, Timestamp} + import scala.collection.JavaConverters._ import com.fasterxml.jackson.core.JsonFactory diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala index 0835bd123049..4217c81ff3e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala @@ -21,9 +21,9 @@ import scala.collection.JavaConverters.{collectionAsScalaIterableConverter, mapA import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{Path, PathFilter} +import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter} import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext -import org.apache.parquet.hadoop.{ParquetFileReader, ParquetWriter} import org.apache.parquet.io.api.RecordConsumer import org.apache.parquet.schema.{MessageType, MessageTypeParser} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 6178e37d2a58..587aa5fd30d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -17,15 +17,17 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.parquet.filter2.predicate.Operators._ import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators} +import org.apache.parquet.filter2.predicate.FilterApi._ +import org.apache.parquet.filter2.predicate.Operators.{Column => _, _} -import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf} +import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types._ /** * A test suite that tests Parquet filter2 API based filter pushdown optimization. @@ -323,6 +325,47 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } + test("SPARK-12231: test the filter and empty project in partitioned DataSource scan") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}" + (1 to 3).map(i => (i, i + 1, i + 2, i + 3)).toDF("a", "b", "c", "d"). + write.partitionBy("a").parquet(path) + + // The filter "a > 1 or b < 2" will not get pushed down, and the projection is empty, + // this query will throw an exception since the project from combinedFilter expect + // two projection while the + val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + + assert(df1.filter("a > 1 or b < 2").count() == 2) + } + } + } + + test("SPARK-12231: test the new projection in partitioned DataSource scan") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}" + (1 to 3).map(i => (i, i + 1, i + 2, i + 3)).toDF("a", "b", "c", "d"). + write.partitionBy("a").parquet(path) + + // test the generate new projection case + // when projects != partitionAndNormalColumnProjs + + val df1 = sqlContext.read.parquet(dir.getCanonicalPath) + + checkAnswer( + df1.filter("a > 1 or b > 2").orderBy("a").selectExpr("a", "b", "c", "d"), + (2 to 3).map(i => Row(i, i + 1, i + 2, i + 3))) + } + } + } + + test("SPARK-11103: Filter applied on merged Parquet schema with new column fails") { import testImplicits._ @@ -362,4 +405,89 @@ class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContex } } } + + test("SPARK-12218: 'Not' is included in Parquet filter pushdown") { + import testImplicits._ + + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.parquet(path) + + checkAnswer( + sqlContext.read.parquet(path).where("not (a = 2) or not(b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + + checkAnswer( + sqlContext.read.parquet(path).where("not (a = 2 and b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + } + } + } + + test("SPARK-12218 Converting conjunctions into Parquet filter predicates") { + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = false), + StructField("b", StringType, nullable = true), + StructField("c", DoubleType, nullable = true) + )) + + assertResult(Some(and( + lt(intColumn("a"), 10: Integer), + gt(doubleColumn("c"), 1.5: java.lang.Double))) + ) { + ParquetFilters.createFilter( + schema, + sources.And( + sources.LessThan("a", 10), + sources.GreaterThan("c", 1.5D))) + } + + assertResult(None) { + ParquetFilters.createFilter( + schema, + sources.And( + sources.LessThan("a", 10), + sources.StringContains("b", "prefix"))) + } + + assertResult(None) { + ParquetFilters.createFilter( + schema, + sources.Not( + sources.And( + sources.GreaterThan("a", 1), + sources.StringContains("b", "prefix")))) + } + } + + test("SPARK-11164: test the parquet filter in") { + import testImplicits._ + withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i.toFloat, i%3)).toDF("a", "b").write.parquet(path) + + // When a filter is pushed to Parquet, Parquet can apply it to every row. + // So, we can check the number of rows returned from the Parquet + // to make sure our filter pushdown work. + val df = sqlContext.read.parquet(path).where("b in (0,2)") + assert(stripSparkFilter(df).count == 3) + + val df1 = sqlContext.read.parquet(path).where("not (b in (1))") + assert(stripSparkFilter(df1).count == 3) + + val df2 = sqlContext.read.parquet(path).where("not (b in (1,3) or a <= 2)") + assert(stripSparkFilter(df2).count == 2) + + val df3 = sqlContext.read.parquet(path).where("not (b in (1,3) and a <= 2)") + assert(stripSparkFilter(df3).count == 4) + + val df4 = sqlContext.read.parquet(path).where("not (a <= 2)") + assert(stripSparkFilter(df4).count == 3) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 0c5d4887ed79..ab48e971b507 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -17,17 +17,17 @@ package org.apache.spark.sql.execution.datasources.parquet -import org.apache.parquet.column.{Encoding, ParquetProperties} - import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} -import org.apache.parquet.example.data.simple.SimpleGroup +import org.apache.parquet.column.{Encoding, ParquetProperties} import org.apache.parquet.example.data.{Group, GroupWriter} +import org.apache.parquet.example.data.simple.SimpleGroup import org.apache.parquet.hadoop._ import org.apache.parquet.hadoop.api.WriteSupport import org.apache.parquet.hadoop.api.WriteSupport.WriteContext @@ -38,6 +38,7 @@ import org.apache.parquet.schema.{MessageType, MessageTypeParser} import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -618,6 +619,100 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { readResourceParquetFile("dec-in-fixed-len.parquet"), sqlContext.range(1 << 4).select('id % 10 cast DecimalType(10, 2) as 'fixed_len_dec)) } + + test("SPARK-12589 copy() on rows returned from reader works for strings") { + withTempPath { dir => + val data = (1, "abc") ::(2, "helloabcde") :: Nil + data.toDF().write.parquet(dir.getCanonicalPath) + var hash1: Int = 0 + var hash2: Int = 0 + (false :: true :: Nil).foreach { v => + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> v.toString) { + val df = sqlContext.read.parquet(dir.getCanonicalPath) + val rows = df.queryExecution.toRdd.map(_.copy()).collect() + val unsafeRows = rows.map(_.asInstanceOf[UnsafeRow]) + if (!v) { + hash1 = unsafeRows(0).hashCode() + hash2 = unsafeRows(1).hashCode() + } else { + assert(hash1 == unsafeRows(0).hashCode()) + assert(hash2 == unsafeRows(1).hashCode()) + } + } + } + } + } + + test("UnsafeRowParquetRecordReader - direct path read") { + val data = (0 to 10).map(i => (i, ((i + 'a').toChar.toString))) + withTempPath { dir => + sqlContext.createDataFrame(data).repartition(1).write.parquet(dir.getCanonicalPath) + val file = SpecificParquetRecordReaderBase.listDirectory(dir).get(0); + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, null) + val result = mutable.ArrayBuffer.empty[(Int, String)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue + val v = (row.getInt(0), row.getString(1)) + result += v + } + assert(data == result) + } finally { + reader.close() + } + } + + // Project just one column + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, ("_2" :: Nil).asJava) + val result = mutable.ArrayBuffer.empty[(String)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue + result += row.getString(0) + } + assert(data.map(_._2) == result) + } finally { + reader.close() + } + } + + // Project columns in opposite order + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, ("_2" :: "_1" :: Nil).asJava) + val result = mutable.ArrayBuffer.empty[(String, Int)] + while (reader.nextKeyValue()) { + val row = reader.getCurrentValue + val v = (row.getString(0), row.getInt(1)) + result += v + } + assert(data.map { x => (x._2, x._1) } == result) + } finally { + reader.close() + } + } + + // Empty projection + { + val reader = new UnsafeRowParquetRecordReader + try { + reader.initialize(file, List[String]().asJava) + var result = 0 + while (reader.nextKeyValue()) { + result += 1 + } + assert(result == data.length) + } finally { + reader.close() + } + } + } + } } class JobCommitFailureParquetOutputCommitter(outputPath: Path, context: TaskAttemptContext) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 71e9034d9779..0feb945fbb79 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils} +import org.apache.spark.sql.execution.datasources.{LogicalRelation, Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index f777e973052d..0bc64404f164 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -22,8 +22,8 @@ import java.io.File import org.apache.hadoop.fs.Path import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow import org.apache.spark.sql.execution.datasources.parquet.TestingUDT.{NestedStruct, NestedStructUDT} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala new file mode 100644 index 000000000000..cab6abde6da2 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadBenchmark.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ +import scala.util.Try + +import org.apache.spark.sql.{SQLConf, SQLContext} +import org.apache.spark.util.{Benchmark, Utils} +import org.apache.spark.{SparkConf, SparkContext} + +/** + * Benchmark to measure parquet read performance. + * To run this: + * spark-submit --class --jars + */ +object ParquetReadBenchmark { + val conf = new SparkConf() + conf.set("spark.sql.parquet.compression.codec", "snappy") + val sc = new SparkContext("local[1]", "test-sql-context", conf) + val sqlContext = new SQLContext(sc) + + def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) finally Utils.deleteRecursively(path) + } + + def withTempTable(tableNames: String*)(f: => Unit): Unit = { + try f finally tableNames.foreach(sqlContext.dropTempTable) + } + + def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { + val (keys, values) = pairs.unzip + val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption) + (keys, values).zipped.foreach(sqlContext.conf.setConfString) + try f finally { + keys.zip(currentValues).foreach { + case (key, Some(value)) => sqlContext.conf.setConfString(key, value) + case (key, None) => sqlContext.conf.unsetConf(key) + } + } + } + + def intScanBenchmark(values: Int): Unit = { + withTempPath { dir => + sqlContext.range(values).write.parquet(dir.getCanonicalPath) + withTempTable("tempTable") { + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + val benchmark = new Benchmark("Single Int Column Scan", values) + + benchmark.addCase("SQL Parquet Reader") { iter => + sqlContext.sql("select sum(id) from tempTable").collect() + } + + benchmark.addCase("SQL Parquet MR") { iter => + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + sqlContext.sql("select sum(id) from tempTable").collect() + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray + benchmark.addCase("ParquetReader") { num => + var sum = 0L + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new UnsafeRowParquetRecordReader + reader.initialize(p, ("id" :: Nil).asJava) + + while (reader.nextKeyValue()) { + val record = reader.getCurrentValue + if (!record.isNullAt(0)) sum += record.getInt(0) + } + reader.close() + }} + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------- + SQL Parquet Reader 1910.0 13.72 1.00 X + SQL Parquet MR 2330.0 11.25 0.82 X + ParquetReader 1252.6 20.93 1.52 X + */ + benchmark.run() + } + } + } + + def intStringScanBenchmark(values: Int): Unit = { + withTempPath { dir => + withTempTable("t1", "tempTable") { + sqlContext.range(values).registerTempTable("t1") + sqlContext.sql("select id as c1, cast(id as STRING) as c2 from t1") + .write.parquet(dir.getCanonicalPath) + sqlContext.read.parquet(dir.getCanonicalPath).registerTempTable("tempTable") + + val benchmark = new Benchmark("Int and String Scan", values) + + benchmark.addCase("SQL Parquet Reader") { iter => + sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + } + + benchmark.addCase("SQL Parquet MR") { iter => + withSQLConf(SQLConf.PARQUET_UNSAFE_ROW_RECORD_READER_ENABLED.key -> "false") { + sqlContext.sql("select sum(c1), sum(length(c2)) from tempTable").collect + } + } + + val files = SpecificParquetRecordReaderBase.listDirectory(dir).toArray + benchmark.addCase("ParquetReader") { num => + var sum1 = 0L + var sum2 = 0L + files.map(_.asInstanceOf[String]).foreach { p => + val reader = new UnsafeRowParquetRecordReader + reader.initialize(p, null) + while (reader.nextKeyValue()) { + val record = reader.getCurrentValue + if (!record.isNullAt(0)) sum1 += record.getInt(0) + if (!record.isNullAt(1)) sum2 += record.getUTF8String(1).numBytes() + } + reader.close() + } + } + + /* + Intel(R) Core(TM) i7-4870HQ CPU @ 2.50GHz + Int and String Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------- + SQL Parquet Reader 2245.6 7.00 1.00 X + SQL Parquet MR 2914.2 5.40 0.77 X + ParquetReader 1544.6 10.18 1.45 X + */ + benchmark.run() + } + } + } + + def main(args: Array[String]): Unit = { + intScanBenchmark(1024 * 1024 * 15) + intStringScanBenchmark(1024 * 1024 * 10) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index fdd7697c91f5..449fcc860fac 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File -import org.apache.parquet.schema.MessageType - import scala.collection.JavaConverters._ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag @@ -28,12 +26,13 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.parquet.format.converter.ParquetMetadataConverter -import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} import org.apache.parquet.hadoop.{Footer, ParquetFileReader, ParquetFileWriter} +import org.apache.parquet.hadoop.metadata.{BlockMetaData, FileMetaData, ParquetMetadata} +import org.apache.parquet.schema.MessageType +import org.apache.spark.sql.{DataFrame, SaveMode, SQLConf} import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{DataFrame, SQLConf, SaveMode} /** * A helper trait that provides convenient facilities for Parquet testing. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala index 914e516613f9..f95272530d58 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/text/TextSuite.scala @@ -17,12 +17,11 @@ package org.apache.spark.sql.execution.datasources.text +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.util.Utils - class TextSuite extends QueryTest with SharedSQLContext { test("reading text file") { @@ -33,8 +32,8 @@ class TextSuite extends QueryTest with SharedSQLContext { verifyFrame(sqlContext.read.text(testFile)) } - test("writing") { - val df = sqlContext.read.text(testFile) + test("SPARK-12562 verify write.text() can handle column name beyond `value`") { + val df = sqlContext.read.text(testFile).withColumnRenamed("value", "adwrasdf") val tempFile = Utils.createTempDir() tempFile.delete() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 5b2998c3c76d..58581d71e1bc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,8 +22,8 @@ import scala.reflect.ClassTag import org.scalatest.BeforeAndAfterAll import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.sql.{QueryTest, SQLConf, SQLContext} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} /** * Test various broadcast join operators. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala index 2ec17146476f..42fadaa8e221 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf} +import org.apache.spark.sql.{execution, DataFrame, Row, SQLConf} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala index 9c80714a9af4..3d3e9a7b9092 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution.joins import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType} +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} class OuterJoinSuite extends SparkPlanTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala index 3afd762942bc..9c86084f9b8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.{SQLConf, DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SQLConf} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression} import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala index c30327185e16..eb70747926fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala @@ -22,8 +22,8 @@ import org.mockito.Mockito.{mock, when} import org.apache.spark.broadcast.TorrentBroadcast import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, UnsafeProjection, Expression} -import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.expressions.{Expression, InterpretedMutableProjection, UnsafeProjection} +import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation} class HashJoinNodeSuite extends LocalNodeTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala index 615c41709361..1a485f967dd3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -20,10 +20,9 @@ package org.apache.spark.sql.execution.local import org.apache.spark.SparkFunSuite import org.apache.spark.sql.SQLConf import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.types.{IntegerType, StringType} - class LocalNodeTest extends SparkFunSuite { protected val conf: SQLConf = new SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala index 12a4e1356fed..eef3c1f3e34d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.ui import java.util.Properties -import org.apache.spark.{SparkException, SparkContext, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.sql.{DataFrame, SQLContext} @@ -336,39 +336,45 @@ class SQLListenerSuite extends SparkFunSuite with SharedSQLContext { class SQLListenerMemoryLeakSuite extends SparkFunSuite { test("no memory leak") { - val conf = new SparkConf() - .setMaster("local") - .setAppName("test") - .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly - .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly - val sc = new SparkContext(conf) + val oldLogLevel = org.apache.log4j.Logger.getRootLogger().getLevel() try { - SQLContext.clearSqlListener() - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - // Run 100 successful executions and 100 failed executions. - // Each execution only has one job and one stage. - for (i <- 0 until 100) { - val df = Seq( - (1, 1), - (2, 2) - ).toDF() - df.collect() - try { - df.foreach(_ => throw new RuntimeException("Oops")) - } catch { - case e: SparkException => // This is expected for a failed job + org.apache.log4j.Logger.getRootLogger().setLevel(org.apache.log4j.Level.FATAL) + val conf = new SparkConf() + .setMaster("local") + .setAppName("test") + .set("spark.task.maxFailures", "1") // Don't retry the tasks to run this test quickly + .set("spark.sql.ui.retainedExecutions", "50") // Set it to 50 to run this test quickly + val sc = new SparkContext(conf) + try { + SQLContext.clearSqlListener() + val sqlContext = new SQLContext(sc) + import sqlContext.implicits._ + // Run 100 successful executions and 100 failed executions. + // Each execution only has one job and one stage. + for (i <- 0 until 100) { + val df = Seq( + (1, 1), + (2, 2) + ).toDF() + df.collect() + try { + df.foreach(_ => throw new RuntimeException("Oops")) + } catch { + case e: SparkException => // This is expected for a failed job + } } + sc.listenerBus.waitUntilEmpty(10000) + assert(sqlContext.listener.getCompletedExecutions.size <= 50) + assert(sqlContext.listener.getFailedExecutions.size <= 50) + // 50 for successful executions and 50 for failed executions + assert(sqlContext.listener.executionIdToData.size <= 100) + assert(sqlContext.listener.jobIdToExecutionId.size <= 100) + assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) + } finally { + sc.stop() } - sc.listenerBus.waitUntilEmpty(10000) - assert(sqlContext.listener.getCompletedExecutions.size <= 50) - assert(sqlContext.listener.getFailedExecutions.size <= 50) - // 50 for successful executions and 50 for failed executions - assert(sqlContext.listener.executionIdToData.size <= 100) - assert(sqlContext.listener.jobIdToExecutionId.size <= 100) - assert(sqlContext.listener.stageIdToStageMetrics.size <= 100) } finally { - sc.stop() + org.apache.log4j.Logger.getRootLogger().setLevel(oldLogLevel) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 2b91f62c2fa2..1fa22e293331 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -18,18 +18,25 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal -import java.sql.DriverManager +import java.sql.{Date, DriverManager, Timestamp} import java.util.{Calendar, GregorianCalendar, Properties} import org.h2.jdbc.JdbcSQLException import org.scalatest.BeforeAndAfter +import org.scalatest.PrivateMethodTester import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.Row +import org.apache.spark.sql.execution.ExplainCommand +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.jdbc.JDBCRDD +import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext { +class JDBCSuite extends SparkFunSuite + with BeforeAndAfter with PrivateMethodTester with SharedSQLContext { import testImplicits._ val url = "jdbc:h2:mem:testdb0" @@ -180,10 +187,29 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID != 2")).collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME = 'fred'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME <=> 'fred'")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME > 'fred'")).collect().size == 2) assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME != 'fred'")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME IN ('mary', 'fred')")) + .collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME NOT IN ('fred')")) + .collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary'")) + .collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE THEID = 1 OR NAME = 'mary' " + + "AND THEID = 2")).collect().size == 2) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE 'fr%'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE '%ed'")).collect().size == 1) + assert(stripSparkFilter(sql("SELECT * FROM foobar WHERE NAME LIKE '%re%'")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NULL")).collect().size == 1) assert(stripSparkFilter(sql("SELECT * FROM nulltypes WHERE A IS NOT NULL")).collect().size == 0) + + // This is a test to reflect discussion in SPARK-12218. + // The older versions of spark have this kind of bugs in parquet data source. + val df1 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2 AND NAME != 'mary')") + val df2 = sql("SELECT * FROM foobar WHERE NOT (THEID != 2) OR NOT (NAME != 'mary')") + assert(df1.collect.toSet === Set(Row("mary", 2))) + assert(df2.collect.toSet === Set(Row("mary", 2))) } test("SELECT * WHERE (quoted strings)") { @@ -429,6 +455,32 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(DerbyColumns === Seq(""""abc"""", """"key"""")) } + test("compile filters") { + val compileFilter = PrivateMethod[Option[String]]('compileFilter) + def doCompileFilter(f: Filter): String = JDBCRDD invokePrivate compileFilter(f) getOrElse("") + assert(doCompileFilter(EqualTo("col0", 3)) === "col0 = 3") + assert(doCompileFilter(Not(EqualTo("col1", "abc"))) === "(NOT (col1 = 'abc'))") + assert(doCompileFilter(And(EqualTo("col0", 0), EqualTo("col1", "def"))) + === "(col0 = 0) AND (col1 = 'def')") + assert(doCompileFilter(Or(EqualTo("col0", 2), EqualTo("col1", "ghi"))) + === "(col0 = 2) OR (col1 = 'ghi')") + assert(doCompileFilter(LessThan("col0", 5)) === "col0 < 5") + assert(doCompileFilter(LessThan("col3", + Timestamp.valueOf("1995-11-21 00:00:00.0"))) === "col3 < '1995-11-21 00:00:00.0'") + assert(doCompileFilter(LessThan("col4", Date.valueOf("1983-08-04"))) === "col4 < '1983-08-04'") + assert(doCompileFilter(LessThanOrEqual("col0", 5)) === "col0 <= 5") + assert(doCompileFilter(GreaterThan("col0", 3)) === "col0 > 3") + assert(doCompileFilter(GreaterThanOrEqual("col0", 3)) === "col0 >= 3") + assert(doCompileFilter(In("col1", Array("jkl"))) === "col1 IN ('jkl')") + assert(doCompileFilter(Not(In("col1", Array("mno", "pqr")))) + === "(NOT (col1 IN ('mno', 'pqr')))") + assert(doCompileFilter(IsNull("col1")) === "col1 IS NULL") + assert(doCompileFilter(IsNotNull("col1")) === "col1 IS NOT NULL") + assert(doCompileFilter(And(EqualNullSafe("col0", "abc"), EqualTo("col1", "def"))) + === "((NOT (col0 != 'abc' OR col0 IS NULL OR 'abc' IS NULL) " + + "OR (col0 IS NULL AND 'abc' IS NULL))) AND (col1 = 'def')") + } + test("Dialect unregister") { JdbcDialects.registerDialect(testH2Dialect) JdbcDialects.unregisterDialect(testH2Dialect) @@ -462,6 +514,10 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext val Postgres = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") assert(Postgres.getCatalystType(java.sql.Types.OTHER, "json", 1, null) === Some(StringType)) assert(Postgres.getCatalystType(java.sql.Types.OTHER, "jsonb", 1, null) === Some(StringType)) + val errMsg = intercept[IllegalArgumentException] { + Postgres.getJDBCType(ByteType) + } + assert(errMsg.getMessage contains "Unsupported type in postgresql: ByteType") } test("DerbyDialect jdbc type mapping") { @@ -497,4 +553,24 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext assert(rows(0).getAs[java.sql.Timestamp](2) === java.sql.Timestamp.valueOf("2002-02-20 11:22:33.543543")) } + + test("test credentials in the properties are not in plan output") { + val df = sql("SELECT * FROM parts") + val explain = ExplainCommand(df.queryExecution.logical, extended = true) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + r => assert(!List("testPass", "testUser").exists(r.toString.contains)) + } + // test the JdbcRelation toString output + df.queryExecution.analyzed.collect { + case r: LogicalRelation => assert(r.relation.toString == "JDBCRelation(TEST.PEOPLE)") + } + } + + test("test credentials in the connection url are not in the plan output") { + val df = sqlContext.read.jdbc(urlWithUserAndPass, "TEST.PEOPLE", new Properties) + val explain = ExplainCommand(df.queryExecution.logical, extended = true) + sqlContext.executePlan(explain).executedPlan.executeCollect().foreach { + r => assert(!List("testPass", "testUser").exists(r.toString.contains)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index 3eaa817f9c0b..27b02d6e1ab3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala index 27d1cd92fca1..cb6e5179b31f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/ResolvedDataSourceSuite.scala @@ -57,4 +57,21 @@ class ResolvedDataSourceSuite extends SparkFunSuite { ResolvedDataSource.lookupDataSource("org.apache.spark.sql.parquet") === classOf[org.apache.spark.sql.execution.datasources.parquet.DefaultSource]) } + + test("error message for unknown data sources") { + val error1 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("avro") + } + assert(error1.getMessage.contains("spark-packages")) + + val error2 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("com.databricks.spark.avro") + } + assert(error2.getMessage.contains("spark-packages")) + + val error3 = intercept[ClassNotFoundException] { + ResolvedDataSource.lookupDataSource("asfdwefasdfasdf") + } + assert(error3.getMessage.contains("spark-packages")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala index 10d261368993..e055da9e8a39 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame} +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLConf} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala index 152c9c8459de..df530d8587ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/ProcessTestUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.test -import java.io.{IOException, InputStream} +import java.io.{InputStream, IOException} import scala.sys.process.BasicIO diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index e87da1527c4d..7df344edb4ed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.test import java.io.File import java.util.UUID -import scala.util.Try import scala.language.implicitConversions +import scala.util.Try import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfterAll diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index b5b2143292a6..435e565f6345 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala index 2228f651e238..60bb4dc5e77b 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/hive/service/server/HiveServerServerOptionsProcessor.scala @@ -16,7 +16,7 @@ */ package org.apache.hive.service.server -import org.apache.hive.service.server.HiveServer2.{StartOptionExecutor, ServerOptionsProcessor} +import org.apache.hive.service.server.HiveServer2.{ServerOptionsProcessor, StartOptionExecutor} /** * Class to upgrade a package-private class to public, and diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index a4fd0c3ce970..66eaa3ebcd73 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -27,8 +27,9 @@ import org.apache.commons.logging.LogFactory import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hive.service.cli.thrift.{ThriftBinaryCLIService, ThriftHttpCLIService} -import org.apache.hive.service.server.{HiveServerServerOptionsProcessor, HiveServer2} +import org.apache.hive.service.server.{HiveServer2, HiveServerServerOptionsProcessor} +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd, SparkListenerJobStart} import org.apache.spark.sql.SQLConf @@ -36,8 +37,6 @@ import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab import org.apache.spark.util.{ShutdownHookManager, Utils} -import org.apache.spark.{Logging, SparkContext} - /** * The main entry point for the Spark SQL port of HiveServer2. Starts up a `SparkSQLContext` and a @@ -67,6 +66,7 @@ object HiveThriftServer2 extends Logging { } def main(args: Array[String]) { + Utils.initDaemon(log) val optionsProcessor = new HiveServerServerOptionsProcessor("HiveThriftServer2") if (!optionsProcessor.process(args)) { System.exit(-1) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index e022ee86a763..cd2167c4ecb1 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.hive.thriftserver import java.security.PrivilegedExceptionAction import java.sql.{Date, Timestamp} +import java.util.{Arrays, Map => JMap, UUID} import java.util.concurrent.RejectedExecutionException -import java.util.{Arrays, UUID, Map => JMap} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, Map => SMap} @@ -33,11 +33,10 @@ import org.apache.hive.service.cli.operation.ExecuteStatementOperation import org.apache.hive.service.cli.session.HiveSession import org.apache.spark.Logging +import org.apache.spark.sql.{DataFrame, Row => SparkRow, SQLConf} import org.apache.spark.sql.execution.SetCommand import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.types._ -import org.apache.spark.sql.{DataFrame, SQLConf, Row => SparkRow} - private[hive] class SparkExecuteStatementOperation( parentSession: HiveSession, diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala index 8e7aa75bc3b2..03bc830df203 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIDriver.scala @@ -20,13 +20,10 @@ package org.apache.spark.sql.hive.thriftserver import java.io._ import java.util.{ArrayList => JArrayList, Locale} -import org.apache.spark.sql.AnalysisException - import scala.collection.JavaConverters._ import jline.console.ConsoleReader import jline.console.history.FileHistory - import org.apache.commons.lang3.StringUtils import org.apache.commons.logging.LogFactory import org.apache.hadoop.conf.Configuration @@ -35,11 +32,12 @@ import org.apache.hadoop.hive.common.{HiveInterruptCallback, HiveInterruptUtils} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, SetProcessor, CommandProcessor, CommandProcessorFactory} +import org.apache.hadoop.hive.ql.processors.{AddResourceProcessor, CommandProcessor, CommandProcessorFactory, SetProcessor} import org.apache.hadoop.hive.ql.session.SessionState import org.apache.thrift.transport.TSocket import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.HiveContext import org.apache.spark.util.{ShutdownHookManager, Utils} diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala index 5ad8c54f296d..6fe57554cf58 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLCLIService.scala @@ -27,11 +27,11 @@ import org.apache.commons.logging.Log import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.shims.Utils import org.apache.hadoop.security.UserGroupInformation +import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.hive.service.Service.STATE import org.apache.hive.service.auth.HiveAuthFactory import org.apache.hive.service.cli._ import org.apache.hive.service.server.HiveServer2 -import org.apache.hive.service.{AbstractService, Service, ServiceException} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.hive.thriftserver.ReflectionUtils._ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index f1ec7238520a..4278aa30fbbd 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.hive.thriftserver -import java.util.{Arrays, ArrayList => JArrayList, List => JList} -import org.apache.log4j.LogManager -import org.apache.spark.sql.AnalysisException +import java.util.{ArrayList => JArrayList, Arrays, List => JList} import scala.collection.JavaConverters._ @@ -27,8 +25,10 @@ import org.apache.commons.lang3.exception.ExceptionUtils import org.apache.hadoop.hive.metastore.api.{FieldSchema, Schema} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse +import org.apache.log4j.LogManager import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} private[hive] class SparkSQLDriver( diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index bacf6cc458fd..ca25d23c3e37 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -21,9 +21,9 @@ import java.io.PrintStream import scala.collection.JavaConverters._ +import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.scheduler.StatsReportListener import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.{Logging, SparkConf, SparkContext} import org.apache.spark.util.Utils /** A singleton object for the master program. The slaves should not access this. */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 476651a559d2..9954d3436d37 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.hive.thriftserver.server import java.util.{Map => JMap} + import scala.collection.mutable.Map import org.apache.hive.service.cli._ import org.apache.hive.service.cli.operation.{ExecuteStatementOperation, Operation, OperationManager} import org.apache.hive.service.cli.session.HiveSession + import org.apache.spark.Logging import org.apache.spark.sql.hive.HiveContext -import org.apache.spark.sql.hive.thriftserver.{SparkExecuteStatementOperation, ReflectionUtils} +import org.apache.spark.sql.hive.thriftserver.{ReflectionUtils, SparkExecuteStatementOperation} /** * Executes queries using Spark SQL, and maintains a list of handles to active queries. diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala index e990bd06011f..3719da4925cc 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerPage.scala @@ -23,10 +23,11 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.Logging -import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{SessionInfo, ExecutionState, ExecutionInfo} -import org.apache.spark.ui.UIUtils._ +import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState, SessionInfo} import org.apache.spark.ui._ +import org.apache.spark.ui.UIUtils._ /** Page for Spark Web UI that shows statistics of a thrift server */ diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala index af16cb31df18..27d1c8bab4d9 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerSessionPage.scala @@ -23,10 +23,11 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node import org.apache.commons.lang3.StringEscapeUtils + import org.apache.spark.Logging import org.apache.spark.sql.hive.thriftserver.HiveThriftServer2.{ExecutionInfo, ExecutionState} -import org.apache.spark.ui.UIUtils._ import org.apache.spark.ui._ +import org.apache.spark.ui.UIUtils._ /** Page for Spark Web UI that shows statistics of a streaming job */ private[ui] class ThriftServerSessionPage(parent: ThriftServerTab) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala index 4eabeaa6735e..1dc7d79436d7 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/ui/ThriftServerTab.scala @@ -17,10 +17,10 @@ package org.apache.spark.sql.hive.thriftserver.ui +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.sql.hive.thriftserver.{HiveThriftServer2, SparkSQLEnv} import org.apache.spark.sql.hive.thriftserver.ui.ThriftServerTab._ import org.apache.spark.ui.{SparkUI, SparkUITab} -import org.apache.spark.{SparkContext, Logging, SparkException} /** * Spark Web UI tab that shows statistics of a streaming job. diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index fcf039916913..ab31d45a79a2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -22,15 +22,15 @@ import java.sql.Timestamp import java.util.Date import scala.collection.mutable.ArrayBuffer -import scala.concurrent.duration._ import scala.concurrent.{Await, Promise} -import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import scala.concurrent.duration._ import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfterAll -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer +import org.apache.spark.util.Utils /** * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 139d8e897ba1..e598284ab22f 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -23,9 +23,8 @@ import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.{future, Await, ExecutionContext, Promise} import scala.concurrent.duration._ -import scala.concurrent.{Await, Promise, future} import scala.io.Source import scala.util.{Random, Try} @@ -41,10 +40,10 @@ import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkFunSuite} +import org.apache.spark.util.{ThreadUtils, Utils} object TestData { def getTestDataFilePath(name: String): URL = { @@ -356,31 +355,54 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { s"LOAD DATA LOCAL INPATH '${TestData.smallKv}' OVERWRITE INTO TABLE test_map") queries.foreach(statement.execute) - - val largeJoin = "SELECT COUNT(*) FROM test_map " + - List.fill(10)("join test_map").mkString(" ") - val f = future { Thread.sleep(100); statement.cancel(); } - val e = intercept[SQLException] { - statement.executeQuery(largeJoin) + implicit val ec = ExecutionContext.fromExecutorService( + ThreadUtils.newDaemonSingleThreadExecutor("test-jdbc-cancel")) + try { + // Start a very-long-running query that will take hours to finish, then cancel it in order + // to demonstrate that cancellation works. + val f = future { + statement.executeQuery( + "SELECT COUNT(*) FROM test_map " + + List.fill(10)("join test_map").mkString(" ")) + } + // Note that this is slightly race-prone: if the cancel is issued before the statement + // begins executing then we'll fail with a timeout. As a result, this fixed delay is set + // slightly more conservatively than may be strictly necessary. + Thread.sleep(1000) + statement.cancel() + val e = intercept[SQLException] { + Await.result(f, 3.minute) + } + assert(e.getMessage.contains("cancelled")) + + // Cancellation is a no-op if spark.sql.hive.thriftServer.async=false + statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") + try { + val sf = future { + statement.executeQuery( + "SELECT COUNT(*) FROM test_map " + + List.fill(4)("join test_map").mkString(" ") + ) + } + // Similarly, this is also slightly race-prone on fast machines where the query above + // might race and complete before we issue the cancel. + Thread.sleep(1000) + statement.cancel() + val rs1 = Await.result(sf, 3.minute) + rs1.next() + assert(rs1.getInt(1) === math.pow(5, 5)) + rs1.close() + + val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") + rs2.next() + assert(rs2.getInt(1) === 5) + rs2.close() + } finally { + statement.executeQuery("SET spark.sql.hive.thriftServer.async=true") + } + } finally { + ec.shutdownNow() } - assert(e.getMessage contains "cancelled") - Await.result(f, 3.minute) - - // cancel is a noop - statement.executeQuery("SET spark.sql.hive.thriftServer.async=false") - val sf = future { Thread.sleep(100); statement.cancel(); } - val smallJoin = "SELECT COUNT(*) FROM test_map " + - List.fill(4)("join test_map").mkString(" ") - val rs1 = statement.executeQuery(smallJoin) - Await.result(sf, 3.minute) - rs1.next() - assert(rs1.getInt(1) === math.pow(5, 5)) - rs1.close() - - val rs2 = statement.executeQuery("SELECT COUNT(*) FROM test_map") - rs2.next() - assert(rs2.getInt(1) === 5) - rs2.close() } } @@ -817,6 +839,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl } override protected def beforeAll(): Unit = { + super.beforeAll() // Chooses a random port between 10000 and 19999 listeningPort = 10000 + Random.nextInt(10000) diagnosisBuffer.clear() @@ -838,7 +861,11 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl } override protected def afterAll(): Unit = { - stopThriftServer() - logInfo("HiveThriftServer2 stopped") + try { + stopThriftServer() + logInfo("HiveThriftServer2 stopped") + } finally { + super.afterAll() + } } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2d0d7b8af358..afd2f611580f 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -41,9 +41,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - def testCases = hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + def testCases: Seq[(String, File)] = { + hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) + } override def beforeAll() { + super.beforeAll() TestHive.cacheTables = true // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) @@ -53,6 +56,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + // Use Hive hash expression instead of the native one + TestHive.functionRegistry.unregisterFunction("hash") RuleExecutor.resetTime() } @@ -62,13 +67,15 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + TestHive.functionRegistry.restore() // For debugging dump some statistics about how much time was spent in various optimizer rules. logWarning(RuleExecutor.dumpTimeSpent()) + super.afterAll() } /** A list of tests deemed out of scope currently and thus completely disregarded. */ - override def blackList = Seq( + override def blackList: Seq[String] = Seq( // These tests use hooks that are not on the classpath and thus break all subsequent execution. "hook_order", "hook_context_cs", @@ -103,7 +110,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_merge", "alter_concatenate_indexed_table", "protectmode2", - //"describe_table", + // "describe_table", "describe_comment_nonascii", "create_merge_compressed", @@ -308,14 +315,19 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // The difference between the double numbers generated by Hive and Spark // can be ignored (e.g., 0.6633880657639323 and 0.6633880657639322) - "udaf_corr" + "udaf_corr", + + // Feature removed in HIVE-11145 + "alter_partition_protect_mode", + "drop_partitions_ignore_protection", + "protectmode" ) /** * The set of tests that are believed to be working in catalyst. Tests not on whiteList or * blacklist are implicitly marked as ignored. */ - override def whiteList = Seq( + override def whiteList: Seq[String] = Seq( "add_part_exist", "add_part_multiple", "add_partition_no_whitelist", @@ -328,7 +340,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter_index", "alter_merge_2", "alter_partition_format_loc", - "alter_partition_protect_mode", "alter_partition_with_whitelist", "alter_rename_partition", "alter_table_serde", @@ -460,7 +471,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "drop_partitions_filter", "drop_partitions_filter2", "drop_partitions_filter3", - "drop_partitions_ignore_protection", "drop_table", "drop_table2", "drop_table_removes_partition_dirs", @@ -778,7 +788,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "ppr_pushdown2", "ppr_pushdown3", "progress_1", - "protectmode", "push_or", "query_with_semi", "quote1", diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index 92bb9e6d73af..bad3ca6da231 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -104,6 +104,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) TestHive.reset() + super.afterAll() } ///////////////////////////////////////////////////////////////////////////// @@ -454,6 +455,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |window w1 as (distribute by p_mfgr sort by p_name rows between 2 preceding and 2 following) """.stripMargin, reset = false) + /* Disabled because: + - Spark uses a different default stddev. + - Tiny numerical differences in stddev results. createQueryTest("windowing.q -- 15. testExpressions", s""" |select p_mfgr,p_name, p_size, @@ -472,7 +476,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name | rows between 2 preceding and 2 following) """.stripMargin, reset = false) - + */ createQueryTest("windowing.q -- 16. testMultipleWindows", s""" |select p_mfgr,p_name, p_size, @@ -530,6 +534,9 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte // when running this test suite under Java 7 and 8. // We change the original sql query a little bit for making the test suite passed // under different JDK + /* Disabled because: + - Spark uses a different default stddev. + - Tiny numerical differences in stddev results. createQueryTest("windowing.q -- 20. testSTATs", """ |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp @@ -547,7 +554,7 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte |) t lateral view explode(uniq_size) d as uniq_data |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp """.stripMargin, reset = false) - + */ createQueryTest("windowing.q -- 21. testDISTs", """ |select p_mfgr,p_name, p_size, diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index d96f3e2b9f62..cd0c2aeb93a9 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../../pom.xml @@ -232,6 +232,7 @@ v${hive.version.short}/src/main/scala + ${project.build.directory/generated-sources/antlr @@ -260,6 +261,7 @@ + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala index 7f8449cdc282..395c8bff53f4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/ExtendedHiveQlParser.scala @@ -21,7 +21,7 @@ import scala.language.implicitConversions import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.AbstractSparkSQLParser -import org.apache.spark.sql.hive.execution.{AddJar, AddFile, HiveNativeCommand} +import org.apache.spark.sql.hive.execution.{AddFile, AddJar, HiveNativeCommand} /** * A parser that recognizes all HiveQL constructs together with Spark SQL specific extensions. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 5958777b0d06..cbaf00603e18 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -37,25 +37,24 @@ import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.hadoop.util.VersionInfo +import org.apache.spark.{Logging, SparkContext} import org.apache.spark.api.java.JavaSparkContext +import org.apache.spark.sql._ import org.apache.spark.sql.SQLConf.SQLConfEntry import org.apache.spark.sql.SQLConf.SQLConfEntry._ -import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, SqlParser} -import org.apache.spark.sql.execution.datasources.{ResolveDataSource, DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, PreInsertCastAndRename, PreWriteCheck, ResolveDataSource} import org.apache.spark.sql.execution.ui.SQLListener -import org.apache.spark.sql.execution.{CacheManager, ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkContext} - /** * This is the HiveQL Dialect, this dialect is strongly bind with HiveContext @@ -380,7 +379,7 @@ class HiveContext private[hive]( def calculateTableSize(fs: FileSystem, path: Path): Long = { val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDir) { + val size = if (fileStatus.isDirectory) { fs.listStatus(path) .map { status => if (!status.getPath().getName().startsWith(stagingDir)) { @@ -476,7 +475,6 @@ class HiveContext private[hive]( catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUDFs :: - ResolveHiveWindowFunction :: PreInsertCastAndRename :: (if (conf.runSQLOnFile) new ResolveDataSource(self) :: Nil else Nil) @@ -568,7 +566,7 @@ class HiveContext private[hive]( } @transient - private val hivePlanner = new SparkPlanner with HiveStrategies { + private val hivePlanner = new SparkPlanner(this) with HiveStrategies { val hiveContext = self override def strategies: Seq[Strategy] = experimental.extraStrategies ++ Seq( @@ -693,11 +691,14 @@ private[hive] object HiveContext { val CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING = booleanConf( "spark.sql.hive.convertMetastoreParquet.mergeSchema", defaultValue = Some(false), - doc = "TODO") + doc = "When true, also tries to merge possibly different but compatible Parquet schemas in " + + "different Parquet data files. This configuration is only effective " + + "when \"spark.sql.hive.convertMetastoreParquet\" is true.") val CONVERT_CTAS = booleanConf("spark.sql.hive.convertCTAS", defaultValue = Some(false), - doc = "TODO") + doc = "When true, a table created by a Hive CTAS statement (no USING clause) will be " + + "converted to a data source table, using the data source set by spark.sql.sources.default.") val HIVE_METASTORE_SHARED_PREFIXES = stringSeqConf("spark.sql.hive.metastore.sharedPrefixes", defaultValue = Some(jdbcPrefixes), @@ -718,7 +719,7 @@ private[hive] object HiveContext { val HIVE_THRIFT_SERVER_ASYNC = booleanConf("spark.sql.hive.thriftServer.async", defaultValue = Some(true), - doc = "TODO") + doc = "When set to true, Hive Thrift server executes SQL queries in an asynchronous way.") /** Constructs a configuration for hive, where the metastore is located in a temp directory. */ def newTemporaryConfiguration(useInMemoryDerby: Boolean): Map[String, String] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 95b57d6ad124..7a260e72eb45 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -19,18 +19,19 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ +import org.apache.hadoop.{io => hadoopIo} import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ +import org.apache.hadoop.hive.serde2.{io => hiveIo} import org.apache.hadoop.hive.serde2.objectinspector.{StructField => HiveStructField, _} +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.typeinfo.{DecimalTypeInfo, TypeInfoFactory} -import org.apache.hadoop.hive.serde2.{io => hiveIo} -import org.apache.hadoop.{io => hadoopIo} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, types} import org.apache.spark.unsafe.types.UTF8String /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 08b291e08823..43d84d507b20 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -31,21 +31,21 @@ import org.apache.hadoop.hive.ql.metadata._ import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.Logging -import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} +import org.apache.spark.sql.{AnalysisException, SaveMode, SQLContext} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{Catalog, MultiInstanceRelation, OverrideCatalog} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.util.DataTypeParser +import org.apache.spark.sql.execution.{datasources, FileRelation} +import org.apache.spark.sql.execution.datasources.{Partition => ParquetPartition, _} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation -import org.apache.spark.sql.execution.datasources.{CreateTableUsingAsSelect, LogicalRelation, Partition => ParquetPartition, PartitionSpec, ResolvedDataSource} -import org.apache.spark.sql.execution.{FileRelation, datasources} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{AnalysisException, SQLContext, SaveMode} private[hive] case class HiveSerDe( inputFormat: Option[String] = None, @@ -211,6 +211,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], provider: String, options: Map[String, String], isExternal: Boolean): Unit = { @@ -240,6 +241,25 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive } } + if (userSpecifiedSchema.isDefined && bucketSpec.isDefined) { + val BucketSpec(numBuckets, bucketColumnNames, sortColumnNames) = bucketSpec.get + + tableProperties.put("spark.sql.sources.schema.numBuckets", numBuckets.toString) + tableProperties.put("spark.sql.sources.schema.numBucketCols", + bucketColumnNames.length.toString) + bucketColumnNames.zipWithIndex.foreach { case (bucketCol, index) => + tableProperties.put(s"spark.sql.sources.schema.bucketCol.$index", bucketCol) + } + + if (sortColumnNames.nonEmpty) { + tableProperties.put("spark.sql.sources.schema.numSortCols", + sortColumnNames.length.toString) + sortColumnNames.zipWithIndex.foreach { case (sortCol, index) => + tableProperties.put(s"spark.sql.sources.schema.sortCol.$index", sortCol) + } + } + } + if (userSpecifiedSchema.isEmpty && partitionColumns.length > 0) { // The table does not have a specified schema, which means that the schema will be inferred // when we load the table. So, we are not expecting partition columns and we will discover @@ -596,6 +616,7 @@ private[hive] class HiveMetastoreCatalog(val client: ClientInterface, hive: Hive conf.defaultDataSourceName, temporary = false, Array.empty[String], + bucketSpec = None, mode, options = Map.empty[String, String], child @@ -728,6 +749,8 @@ private[hive] case class MetastoreRelation Objects.hashCode(databaseName, tableName, alias, output) } + override protected def otherCopyArgs: Seq[AnyRef] = table :: sqlContext :: Nil + @transient val hiveQlTable: Table = { // We start by constructing an API table as Hive performs several important transformations // internally when converting an API table to a QL table. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 091caab921fe..5b13dbe47370 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -17,40 +17,30 @@ package org.apache.spark.sql.hive -import java.sql.Date import java.util.Locale import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} -import org.apache.hadoop.hive.ql.lib.Node -import org.apache.hadoop.hive.ql.parse._ -import org.apache.hadoop.hive.ql.plan.PlanUtils +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry, FunctionInfo} +import org.apache.hadoop.hive.ql.parse.EximUtil import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.ql.{Context, ErrorMsg} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe - import org.apache.spark.Logging -import org.apache.spark.sql.{AnalysisException, catalyst} -import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.{logical, _} +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.parser.ParseUtils._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.execution.ExplainCommand -import org.apache.spark.sql.execution.datasources.DescribeCommand -import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.execution.SparkQl +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{AnalyzeTable, DropTable, HiveNativeCommand, HiveScriptIOSchema} +import org.apache.spark.sql.hive.execution.{HiveNativeCommand, AnalyzeTable, DropTable, HiveScriptIOSchema} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.random.RandomSampler +import org.apache.spark.sql.AnalysisException /** * Used when we need to start parsing the AST before deciding that we are going to pass the command @@ -70,7 +60,7 @@ private[hive] case class CreateTableAsSelect( override def output: Seq[Attribute] = Seq.empty[Attribute] override lazy val resolved: Boolean = tableDesc.specifiedDatabase.isDefined && - tableDesc.schema.size > 0 && + tableDesc.schema.nonEmpty && tableDesc.serde.isDefined && tableDesc.inputFormat.isDefined && tableDesc.outputFormat.isDefined && @@ -88,7 +78,7 @@ private[hive] case class CreateViewAsSelect( } /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] object HiveQl extends Logging { +private[hive] object HiveQl extends SparkQl with Logging { protected val nativeCommands = Seq( "TOK_ALTERDATABASE_OWNER", "TOK_ALTERDATABASE_PROPERTIES", @@ -179,102 +169,6 @@ private[hive] object HiveQl extends Logging { protected val hqlParser = new ExtendedHiveQlParser - /** - * A set of implicit transformations that allow Hive ASTNodes to be rewritten by transformations - * similar to [[catalyst.trees.TreeNode]]. - * - * Note that this should be considered very experimental and is not indented as a replacement - * for TreeNode. Primarily it should be noted ASTNodes are not immutable and do not appear to - * have clean copy semantics. Therefore, users of this class should take care when - * copying/modifying trees that might be used elsewhere. - */ - implicit class TransformableNode(n: ASTNode) { - /** - * Returns a copy of this node where `rule` has been recursively applied to it and all of its - * children. When `rule` does not apply to a given node it is left unchanged. - * @param rule the function use to transform this nodes children - */ - def transform(rule: PartialFunction[ASTNode, ASTNode]): ASTNode = { - try { - val afterRule = rule.applyOrElse(n, identity[ASTNode]) - afterRule.withChildren( - nilIfEmpty(afterRule.getChildren) - .asInstanceOf[Seq[ASTNode]] - .map(ast => Option(ast).map(_.transform(rule)).orNull)) - } catch { - case e: Exception => - logError(dumpTree(n).toString) - throw e - } - } - - /** - * Returns a scala.Seq equivalent to [s] or Nil if [s] is null. - */ - private def nilIfEmpty[A](s: java.util.List[A]): Seq[A] = - Option(s).map(_.asScala).getOrElse(Nil) - - /** - * Returns this ASTNode with the text changed to `newText`. - */ - def withText(newText: String): ASTNode = { - n.token.asInstanceOf[org.antlr.runtime.CommonToken].setText(newText) - n - } - - /** - * Returns this ASTNode with the children changed to `newChildren`. - */ - def withChildren(newChildren: Seq[ASTNode]): ASTNode = { - (1 to n.getChildCount).foreach(_ => n.deleteChild(0)) - n.addChildren(newChildren.asJava) - n - } - - /** - * Throws an error if this is not equal to other. - * - * Right now this function only checks the name, type, text and children of the node - * for equality. - */ - def checkEquals(other: ASTNode): Unit = { - def check(field: String, f: ASTNode => Any): Unit = if (f(n) != f(other)) { - sys.error(s"$field does not match for trees. " + - s"'${f(n)}' != '${f(other)}' left: ${dumpTree(n)}, right: ${dumpTree(other)}") - } - check("name", _.getName) - check("type", _.getType) - check("text", _.getText) - check("numChildren", n => nilIfEmpty(n.getChildren).size) - - val leftChildren = nilIfEmpty(n.getChildren).asInstanceOf[Seq[ASTNode]] - val rightChildren = nilIfEmpty(other.getChildren).asInstanceOf[Seq[ASTNode]] - leftChildren zip rightChildren foreach { - case (l, r) => l checkEquals r - } - } - } - - /** - * Returns the AST for the given SQL string. - */ - def getAst(sql: String): ASTNode = { - /* - * Context has to be passed in hive0.13.1. - * Otherwise, there will be Null pointer exception, - * when retrieving properties form HiveConf. - */ - val hContext = createContext() - val node = getAst(sql, hContext) - hContext.clear() - node - } - - private def createContext(): Context = new Context(hiveConf) - - private def getAst(sql: String, context: Context) = - ParseUtils.findRootNonNullToken((new ParseDriver).parse(sql, context)) - /** * Returns the HiveConf */ @@ -294,216 +188,16 @@ private[hive] object HiveQl extends Logging { /** Returns a LogicalPlan for a given HiveQL string. */ def parseSql(sql: String): LogicalPlan = hqlParser.parse(sql) - val errorRegEx = "line (\\d+):(\\d+) (.*)".r - - /** Creates LogicalPlan for a given HiveQL string. */ - def createPlan(sql: String): LogicalPlan = { - try { - val context = createContext() - val tree = getAst(sql, context) - val plan = if (nativeCommands contains tree.getText) { - HiveNativeCommand(sql) - } else { - nodeToPlan(tree, context) match { - case NativePlaceholder => HiveNativeCommand(sql) - case other => other - } - } - context.clear() - plan - } catch { - case pe: org.apache.hadoop.hive.ql.parse.ParseException => - pe.getMessage match { - case errorRegEx(line, start, message) => - throw new AnalysisException(message, Some(line.toInt), Some(start.toInt)) - case otherMessage => - throw new AnalysisException(otherMessage) - } - case e: MatchError => throw e - case e: Exception => - throw new AnalysisException(e.getMessage) - case e: NotImplementedError => - throw new AnalysisException( - s""" - |Unsupported language features in query: $sql - |${dumpTree(getAst(sql))} - |$e - |${e.getStackTrace.head} - """.stripMargin) - } - } - - def parseDdl(ddl: String): Seq[Attribute] = { - val tree = - try { - ParseUtils.findRootNonNullToken( - (new ParseDriver).parse(ddl, null /* no context required for parsing alone */)) - } catch { - case pe: org.apache.hadoop.hive.ql.parse.ParseException => - throw new RuntimeException(s"Failed to parse ddl: '$ddl'", pe) - } - assert(tree.asInstanceOf[ASTNode].getText == "TOK_CREATETABLE", "Only CREATE TABLE supported.") - val tableOps = tree.getChildren - val colList = - tableOps.asScala - .find(_.asInstanceOf[ASTNode].getText == "TOK_TABCOLLIST") - .getOrElse(sys.error("No columnList!")).getChildren - - colList.asScala.map(nodeToAttribute) - } - - /** Extractor for matching Hive's AST Tokens. */ - object Token { - /** @return matches of the form (tokenName, children). */ - def unapply(t: Any): Option[(String, Seq[ASTNode])] = t match { - case t: ASTNode => - CurrentOrigin.setPosition(t.getLine, t.getCharPositionInLine) - Some((t.getText, - Option(t.getChildren).map(_.asScala.toList).getOrElse(Nil).asInstanceOf[Seq[ASTNode]])) - case _ => None - } - } - - protected def getClauses( - clauseNames: Seq[String], - nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { - var remainingNodes = nodeList - val clauses = clauseNames.map { clauseName => - val (matches, nonMatches) = remainingNodes.partition(_.getText.toUpperCase == clauseName) - remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) - matches.headOption - } - - if (remainingNodes.nonEmpty) { - sys.error( - s"""Unhandled clauses: ${remainingNodes.map(dumpTree(_)).mkString("\n")}. - |You are likely trying to use an unsupported Hive feature."""".stripMargin) - } - clauses - } - - def getClause(clauseName: String, nodeList: Seq[Node]): Node = - getClauseOption(clauseName, nodeList).getOrElse(sys.error( - s"Expected clause $clauseName missing from ${nodeList.map(dumpTree(_)).mkString("\n")}")) - - def getClauseOption(clauseName: String, nodeList: Seq[Node]): Option[Node] = { - nodeList.filter { case ast: ASTNode => ast.getText == clauseName } match { - case Seq(oneMatch) => Some(oneMatch) - case Seq() => None - case _ => sys.error(s"Found multiple instances of clause $clauseName") - } - } - - protected def nodeToAttribute(node: Node): Attribute = node match { - case Token("TOK_TABCOL", Token(colName, Nil) :: dataType :: Nil) => - AttributeReference(colName, nodeToDataType(dataType), true)() - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } - - protected def nodeToDataType(node: Node): DataType = node match { - case Token("TOK_DECIMAL", precision :: scale :: Nil) => - DecimalType(precision.getText.toInt, scale.getText.toInt) - case Token("TOK_DECIMAL", precision :: Nil) => - DecimalType(precision.getText.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT - case Token("TOK_BIGINT", Nil) => LongType - case Token("TOK_INT", Nil) => IntegerType - case Token("TOK_TINYINT", Nil) => ByteType - case Token("TOK_SMALLINT", Nil) => ShortType - case Token("TOK_BOOLEAN", Nil) => BooleanType - case Token("TOK_STRING", Nil) => StringType - case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_FLOAT", Nil) => FloatType - case Token("TOK_DOUBLE", Nil) => DoubleType - case Token("TOK_DATE", Nil) => DateType - case Token("TOK_TIMESTAMP", Nil) => TimestampType - case Token("TOK_BINARY", Nil) => BinaryType - case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) - case Token("TOK_STRUCT", - Token("TOK_TABCOLLIST", fields) :: Nil) => - StructType(fields.map(nodeToStructField)) - case Token("TOK_MAP", - keyType :: - valueType :: Nil) => - MapType(nodeToDataType(keyType), nodeToDataType(valueType)) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for DataType:\n ${dumpTree(a).toString} ") - } - - protected def nodeToStructField(node: Node): StructField = node match { - case Token("TOK_TABCOL", - Token(fieldName, Nil) :: - dataType :: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", - Token(fieldName, Nil) :: - dataType :: - _ /* comment */:: Nil) => - StructField(fieldName, nodeToDataType(dataType), nullable = true) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for StructField:\n ${dumpTree(a).toString} ") - } - - protected def extractTableIdent(tableNameParts: Node): TableIdentifier = { - tableNameParts.getChildren.asScala.map { - case Token(part, Nil) => cleanIdentifier(part) - } match { - case Seq(tableOnly) => TableIdentifier(tableOnly) - case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) - case other => sys.error("Hive only supports tables names like 'tableName' " + - s"or 'databaseName.tableName', found '$other'") - } - } - - /** - * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) - * is equivalent to - * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 - * Check the following link for details. - * -https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup - * - * The bitmask denotes the grouping expressions validity for a grouping set, - * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 0: k1, bit 1: k2, and bit 2: k3), the grouping id of - * GROUPING SETS (k1, k2) and (k2) should be 3 and 2 respectively. - */ - protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { - val (keyASTs, setASTs) = children.partition( n => n match { - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => false // grouping sets - case _ => true // grouping keys - }) - - val keys = keyASTs.map(nodeToExpr).toSeq - val keyMap = keyASTs.map(_.toStringTree).zipWithIndex.toMap - - val bitmasks: Seq[Int] = setASTs.map(set => set match { - case Token("TOK_GROUPING_SETS_EXPRESSION", null) => 0 - case Token("TOK_GROUPING_SETS_EXPRESSION", children) => - children.foldLeft(0)((bitmap, col) => { - val colString = col.asInstanceOf[ASTNode].toStringTree() - require(keyMap.contains(colString), s"$colString doens't show up in the GROUP BY list") - bitmap | 1 << keyMap(colString) - }) - case _ => sys.error("Expect GROUPING SETS clause") - }) - - (keys, bitmasks) - } - - protected def getProperties(node: Node): Seq[(String, String)] = node match { + protected def getProperties(node: ASTNode): Seq[(String, String)] = node match { case Token("TOK_TABLEPROPLIST", list) => list.map { case Token("TOK_TABLEPROPERTY", Token(key, Nil) :: Token(value, Nil) :: Nil) => - (unquoteString(key) -> unquoteString(value)) + unquoteString(key) -> unquoteString(value) } } private def createView( view: ASTNode, - context: Context, viewNameParts: ASTNode, query: ASTNode, schema: Seq[HiveColumn], @@ -512,8 +206,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C replace: Boolean): CreateViewAsSelect = { val TableIdentifier(viewName, dbName) = extractTableIdent(viewNameParts) - val originalText = context.getTokenRewriteStream - .toString(query.getTokenStartIndex, query.getTokenStopIndex) + val originalText = query.source val tableDesc = HiveTable( specifiedDatabase = dbName, @@ -532,104 +225,67 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C // We need to keep the original SQL string so that if `spark.sql.nativeView` is // false, we can fall back to use hive native command later. // We can remove this when parser is configurable(can access SQLConf) in the future. - val sql = context.getTokenRewriteStream - .toString(view.getTokenStartIndex, view.getTokenStopIndex) - CreateViewAsSelect(tableDesc, nodeToPlan(query, context), allowExist, replace, sql) + val sql = view.source + CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql) } - protected def nodeToPlan(node: ASTNode, context: Context): LogicalPlan = node match { - // Special drop table that also uncaches. - case Token("TOK_DROPTABLE", - Token("TOK_TABNAME", tableNameParts) :: - ifExists) => - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - DropTable(tableName, ifExists.nonEmpty) - // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" - case Token("TOK_ANALYZE", - Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: - isNoscan) => - // Reference: - // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables - if (partitionSpec.nonEmpty) { - // Analyze partitions will be treated as a Hive native command. - NativePlaceholder - } else if (isNoscan.isEmpty) { - // If users do not specify "noscan", it will be treated as a Hive native command. - NativePlaceholder - } else { - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - AnalyzeTable(tableName) + protected override def createPlan( + sql: String, + node: ASTNode): LogicalPlan = { + if (nativeCommands.contains(node.text)) { + HiveNativeCommand(sql) + } else { + nodeToPlan(node) match { + case NativePlaceholder => HiveNativeCommand(sql) + case plan => plan } - // Just fake explain for any of the native commands. - case Token("TOK_EXPLAIN", explainArgs) - if noExplainCommands.contains(explainArgs.head.getText) => - ExplainCommand(OneRowRelation) - case Token("TOK_EXPLAIN", explainArgs) - if "TOK_CREATETABLE" == explainArgs.head.getText => - val Some(crtTbl) :: _ :: extended :: Nil = - getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand( - nodeToPlan(crtTbl, context), - extended = extended.isDefined) - case Token("TOK_EXPLAIN", explainArgs) => - // Ignore FORMATTED if present. - val Some(query) :: _ :: extended :: Nil = - getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand( - nodeToPlan(query, context), - extended = extended.isDefined) - - case Token("TOK_DESCTABLE", describeArgs) => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val Some(tableType) :: formatted :: extended :: pretty :: Nil = - getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) - if (formatted.isDefined || pretty.isDefined) { - // FORMATTED and PRETTY are not supported and this statement will be treated as - // a Hive native command. - NativePlaceholder - } else { - tableType match { - case Token("TOK_TABTYPE", nameParts) if nameParts.size == 1 => { - nameParts.head match { - case Token(".", dbName :: tableName :: Nil) => - // It is describing a table with the format like "describe db.table". - // TODO: Actually, a user may mean tableName.columnName. Need to resolve this issue. - val tableIdent = extractTableIdent(nameParts.head) - DescribeCommand( - UnresolvedRelation(tableIdent, None), isExtended = extended.isDefined) - case Token(".", dbName :: tableName :: colName :: Nil) => - // It is describing a column with the format like "describe db.table column". - NativePlaceholder - case tableName => - // It is describing a table with the format like "describe table". - DescribeCommand( - UnresolvedRelation(TableIdentifier(tableName.getText), None), - isExtended = extended.isDefined) - } - } - // All other cases. - case _ => NativePlaceholder + } + } + + protected override def isNoExplainCommand(command: String): Boolean = + noExplainCommands.contains(command) + + protected override def nodeToPlan(node: ASTNode): LogicalPlan = { + node match { + // Special drop table that also uncaches. + case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: ifExists) => + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + DropTable(tableName, ifExists.nonEmpty) + + // Support "ANALYZE TABLE tableNmae COMPUTE STATISTICS noscan" + case Token("TOK_ANALYZE", + Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: isNoscan) => + // Reference: + // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables + if (partitionSpec.nonEmpty) { + // Analyze partitions will be treated as a Hive native command. + NativePlaceholder + } else if (isNoscan.isEmpty) { + // If users do not specify "noscan", it will be treated as a Hive native command. + NativePlaceholder + } else { + val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") + AnalyzeTable(tableName) } - } - case view @ Token("TOK_ALTERVIEW", children) => - val Some(viewNameParts) :: maybeQuery :: ignores = - getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_ALTERVIEW_ADDPARTS", - "TOK_ALTERVIEW_DROPPARTS", - "TOK_ALTERVIEW_PROPERTIES", - "TOK_ALTERVIEW_RENAME"), children) + case view @ Token("TOK_ALTERVIEW", children) => + val Some(nameParts) :: maybeQuery :: _ = + getClauses(Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_ALTERVIEW_ADDPARTS", + "TOK_ALTERVIEW_DROPPARTS", + "TOK_ALTERVIEW_PROPERTIES", + "TOK_ALTERVIEW_RENAME"), children) - // if ALTER VIEW doesn't have query part, let hive to handle it. - maybeQuery.map { query => - createView(view, context, viewNameParts, query, Nil, Map(), false, true) - }.getOrElse(NativePlaceholder) + // if ALTER VIEW doesn't have query part, let hive to handle it. + maybeQuery.map { query => + createView(view, nameParts, query, Nil, Map(), allowExist = false, replace = true) + }.getOrElse(NativePlaceholder) - case view @ Token("TOK_CREATEVIEW", children) + case view @ Token("TOK_CREATEVIEW", children) if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - val Seq( + val Seq( Some(viewNameParts), Some(query), maybeComment, @@ -638,1236 +294,467 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C maybeProperties, maybeColumns, maybePartCols - ) = getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_TABLECOMMENT", - "TOK_ORREPLACE", - "TOK_IFNOTEXISTS", - "TOK_TABLEPROPERTIES", - "TOK_TABCOLNAME", - "TOK_VIEWPARTCOLS"), children) - - // If the view is partitioned, we let hive handle it. - if (maybePartCols.isDefined) { - NativePlaceholder - } else { - val schema = maybeColumns.map { cols => - BaseSemanticAnalyzer.getColumns(cols, true).asScala.map { field => + ) = getClauses(Seq( + "TOK_TABNAME", + "TOK_QUERY", + "TOK_TABLECOMMENT", + "TOK_ORREPLACE", + "TOK_IFNOTEXISTS", + "TOK_TABLEPROPERTIES", + "TOK_TABCOLNAME", + "TOK_VIEWPARTCOLS"), children) + + // If the view is partitioned, we let hive handle it. + if (maybePartCols.isDefined) { + NativePlaceholder + } else { + val schema = maybeColumns.map { cols => // We can't specify column types when create view, so fill it with null first, and // update it after the schema has been resolved later. - HiveColumn(field.getName, null, field.getComment) - } - }.getOrElse(Seq.empty[HiveColumn]) - - val properties = scala.collection.mutable.Map.empty[String, String] - - maybeProperties.foreach { - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - properties ++= getProperties(list) - } - - maybeComment.foreach { - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) - if (comment ne null) { - properties += ("comment" -> comment) - } - } + nodeToColumns(cols, lowerCase = true).map(_.copy(hiveType = null)) + }.getOrElse(Seq.empty[HiveColumn]) - createView(view, context, viewNameParts, query, schema, properties.toMap, - allowExisting.isDefined, replace.isDefined) - } - - case Token("TOK_CREATETABLE", children) - if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val ( - Some(tableNameParts) :: - _ /* likeTable */ :: - externalTable :: - Some(query) :: - allowExisting +: - ignores) = - getClauses( - Seq( - "TOK_TABNAME", - "TOK_LIKETABLE", - "EXTERNAL", - "TOK_QUERY", - "TOK_IFNOTEXISTS", - "TOK_TABLECOMMENT", - "TOK_TABCOLLIST", - "TOK_TABLEPARTCOLS", // Partitioned by - "TOK_TABLEBUCKETS", // Clustered by - "TOK_TABLESKEWED", // Skewed by - "TOK_TABLEROWFORMAT", - "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", - "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat - "TOK_STORAGEHANDLER", // Storage handler - "TOK_TABLELOCATION", - "TOK_TABLEPROPERTIES"), - children) - val TableIdentifier(tblName, dbName) = extractTableIdent(tableNameParts) - - // TODO add bucket support - var tableDesc: HiveTable = HiveTable( - specifiedDatabase = dbName, - name = tblName, - schema = Seq.empty[HiveColumn], - partitionColumns = Seq.empty[HiveColumn], - properties = Map[String, String](), - serdeProperties = Map[String, String](), - tableType = if (externalTable.isDefined) ExternalTable else ManagedTable, - location = None, - inputFormat = None, - outputFormat = None, - serde = None, - viewText = None) - - // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) - val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - // handle the default format for the storage type abbreviation - val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } - - hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) - hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) - hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) + val properties = scala.collection.mutable.Map.empty[String, String] - children.collect { - case list @ Token("TOK_TABCOLLIST", _) => - val cols = BaseSemanticAnalyzer.getColumns(list, true) - if (cols != null) { - tableDesc = tableDesc.copy( - schema = cols.asScala.map { field => - HiveColumn(field.getName, field.getType, field.getComment) - }) - } - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = BaseSemanticAnalyzer.unescapeSQLString(child.getText) - // TODO support the sql text - tableDesc = tableDesc.copy(viewText = Option(comment)) - case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => - val cols = BaseSemanticAnalyzer.getColumns(list(0), false) - if (cols != null) { - tableDesc = tableDesc.copy( - partitionColumns = cols.asScala.map { field => - HiveColumn(field.getName, field.getType, field.getComment) - }) - } - case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => - val serdeParams = new java.util.HashMap[String, String]() - child match { - case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => - val fieldDelim = BaseSemanticAnalyzer.unescapeSQLString (rowChild1.getText()) - serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) - serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) - if (rowChild2.length > 1) { - val fieldEscape = BaseSemanticAnalyzer.unescapeSQLString (rowChild2(0).getText) - serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) - } - case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => - val collItemDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) - serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) - case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => - val mapKeyDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) - serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) - case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => - val lineDelim = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) - if (!(lineDelim == "\n") && !(lineDelim == "10")) { - throw new AnalysisException( - SemanticAnalyzer.generateErrorMessage( - rowChild, - ErrorMsg.LINES_TERMINATED_BY_NON_NEWLINE.getMsg)) - } - serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) - case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => - val nullFormat = BaseSemanticAnalyzer.unescapeSQLString(rowChild.getText) - // TODO support the nullFormat - case _ => assert(false) + maybeProperties.foreach { + case Token("TOK_TABLEPROPERTIES", list :: Nil) => + properties ++= getProperties(list) } - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) - case Token("TOK_TABLELOCATION", child :: Nil) => - var location = BaseSemanticAnalyzer.unescapeSQLString(child.getText) - location = EximUtil.relativeToAbsolutePath(hiveConf, location) - tableDesc = tableDesc.copy(location = Option(location)) - case Token("TOK_TABLESERIALIZER", child :: Nil) => - tableDesc = tableDesc.copy( - serde = Option(BaseSemanticAnalyzer.unescapeSQLString(child.getChild(0).getText))) - if (child.getChildCount == 2) { - val serdeParams = new java.util.HashMap[String, String]() - BaseSemanticAnalyzer.readProps( - (child.getChild(1).getChild(0)).asInstanceOf[ASTNode], serdeParams) - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) - } - case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => - child.getText().toLowerCase(Locale.ENGLISH) match { - case "orc" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - case "parquet" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + maybeComment.foreach { + case Token("TOK_TABLECOMMENT", child :: Nil) => + val comment = unescapeSQLString(child.text) + if (comment ne null) { + properties += ("comment" -> comment) } - - case "rcfile" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - } - - case "textfile" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - - case "sequencefile" => - tableDesc = tableDesc.copy( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - - case "avro" => - tableDesc = tableDesc.copy( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) - if (tableDesc.serde.isEmpty) { - tableDesc = tableDesc.copy( - serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) - } - - case _ => - throw new SemanticException( - s"Unrecognized file format in STORED AS clause: ${child.getText}") } - case Token("TOK_TABLESERIALIZER", - Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => - tableDesc = tableDesc.copy(serde = Option(unquoteString(serdeName))) - - otherProps match { - case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => - tableDesc = tableDesc.copy( - serdeProperties = tableDesc.serdeProperties ++ getProperties(list)) - case Nil => - } - - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) - case list @ Token("TOK_TABLEFILEFORMAT", children) => - tableDesc = tableDesc.copy( - inputFormat = - Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(0).getText)), - outputFormat = - Option(BaseSemanticAnalyzer.unescapeSQLString(list.getChild(1).getText))) - case Token("TOK_STORAGEHANDLER", _) => - throw new AnalysisException(ErrorMsg.CREATE_NON_NATIVE_AS.getMsg()) - case _ => // Unsupport features - } - - CreateTableAsSelect(tableDesc, nodeToPlan(query, context), allowExisting != None) - - // If its not a "CTAS" like above then take it as a native command - case Token("TOK_CREATETABLE", _) => NativePlaceholder - - // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" - case Token("TOK_TRUNCATETABLE", - Token("TOK_TABLE_PARTITION", table) :: Nil) => NativePlaceholder - - case Token("TOK_QUERY", queryArgs) - if Seq("TOK_FROM", "TOK_INSERT").contains(queryArgs.head.getText) => - - val (fromClause: Option[ASTNode], insertClauses, cteRelations) = - queryArgs match { - case Token("TOK_FROM", args: Seq[ASTNode]) :: insertClauses => - // check if has CTE - insertClauses.last match { - case Token("TOK_CTE", cteClauses) => - val cteRelations = cteClauses.map(node => { - val relation = nodeToRelation(node, context).asInstanceOf[Subquery] - (relation.alias, relation) - }).toMap - (Some(args.head), insertClauses.init, Some(cteRelations)) - - case _ => (Some(args.head), insertClauses, None) - } - - case Token("TOK_INSERT", _) :: Nil => (None, queryArgs, None) + createView(view, viewNameParts, query, schema, properties.toMap, + allowExisting.isDefined, replace.isDefined) } - // Return one query for each insert clause. - val queries = insertClauses.map { case Token("TOK_INSERT", singleInsert) => + case Token("TOK_CREATETABLE", children) + if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => + // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL val ( - intoClause :: - destClause :: - selectClause :: - selectDistinctClause :: - whereClause :: - groupByClause :: - rollupGroupByClause :: - cubeGroupByClause :: - groupingSetsClause :: - orderByClause :: - havingClause :: - sortByClause :: - clusterByClause :: - distributeByClause :: - limitClause :: - lateralViewClause :: - windowClause :: Nil) = { + Some(tableNameParts) :: + _ /* likeTable */ :: + externalTable :: + Some(query) :: + allowExisting +: + _) = getClauses( Seq( - "TOK_INSERT_INTO", - "TOK_DESTINATION", - "TOK_SELECT", - "TOK_SELECTDI", - "TOK_WHERE", - "TOK_GROUPBY", - "TOK_ROLLUP_GROUPBY", - "TOK_CUBE_GROUPBY", - "TOK_GROUPING_SETS", - "TOK_ORDERBY", - "TOK_HAVING", - "TOK_SORTBY", - "TOK_CLUSTERBY", - "TOK_DISTRIBUTEBY", - "TOK_LIMIT", - "TOK_LATERAL_VIEW", - "WINDOW"), - singleInsert) + "TOK_TABNAME", + "TOK_LIKETABLE", + "EXTERNAL", + "TOK_QUERY", + "TOK_IFNOTEXISTS", + "TOK_TABLECOMMENT", + "TOK_TABCOLLIST", + "TOK_TABLEPARTCOLS", // Partitioned by + "TOK_TABLEBUCKETS", // Clustered by + "TOK_TABLESKEWED", // Skewed by + "TOK_TABLEROWFORMAT", + "TOK_TABLESERIALIZER", + "TOK_FILEFORMAT_GENERIC", + "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat + "TOK_STORAGEHANDLER", // Storage handler + "TOK_TABLELOCATION", + "TOK_TABLEPROPERTIES"), + children) + val TableIdentifier(tblName, dbName) = extractTableIdent(tableNameParts) + + // TODO add bucket support + var tableDesc: HiveTable = HiveTable( + specifiedDatabase = dbName, + name = tblName, + schema = Seq.empty[HiveColumn], + partitionColumns = Seq.empty[HiveColumn], + properties = Map[String, String](), + serdeProperties = Map[String, String](), + tableType = if (externalTable.isDefined) ExternalTable else ManagedTable, + location = None, + inputFormat = None, + outputFormat = None, + serde = None, + viewText = None) + + // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) + val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) + // handle the default format for the storage type abbreviation + val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) } - val relations = fromClause match { - case Some(f) => nodeToRelation(f, context) - case None => OneRowRelation - } + hiveSerDe.inputFormat.foreach(f => tableDesc = tableDesc.copy(inputFormat = Some(f))) + hiveSerDe.outputFormat.foreach(f => tableDesc = tableDesc.copy(outputFormat = Some(f))) + hiveSerDe.serde.foreach(f => tableDesc = tableDesc.copy(serde = Some(f))) - val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.getChildren.asScala - Filter(nodeToExpr(whereExpr), relations) - }.getOrElse(relations) - - val select = - (selectClause orElse selectDistinctClause).getOrElse(sys.error("No select clause.")) - - // Script transformations are expressed as a select clause with a single expression of type - // TOK_TRANSFORM - val transformation = select.getChildren.iterator().next() match { - case Token("TOK_SELEXPR", - Token("TOK_TRANSFORM", - Token("TOK_EXPLIST", inputExprs) :: - Token("TOK_SERDE", inputSerdeClause) :: - Token("TOK_RECORDWRITER", writerClause) :: - // TODO: Need to support other types of (in/out)put - Token(script, Nil) :: - Token("TOK_SERDE", outputSerdeClause) :: - Token("TOK_RECORDREADER", readerClause) :: - outputClause) :: Nil) => - - val (output, schemaLess) = outputClause match { - case Token("TOK_ALIASLIST", aliases) :: Nil => - (aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }, - false) - case Token("TOK_TABCOLLIST", attributes) :: Nil => - (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => - AttributeReference(name, nodeToDataType(dataType))() }, false) - case Nil => - (List(AttributeReference("key", StringType)(), - AttributeReference("value", StringType)()), true) + children.collect { + case list @ Token("TOK_TABCOLLIST", _) => + val cols = nodeToColumns(list, lowerCase = true) + if (cols != null) { + tableDesc = tableDesc.copy(schema = cols) } - - type SerDeInfo = ( - Seq[(String, String)], // Input row format information - Option[String], // Optional input SerDe class - Seq[(String, String)], // Input SerDe properties - Boolean // Whether to use default record reader/writer - ) - - def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { - case Token("TOK_SERDEPROPS", propsClause) :: Nil => - val rowFormat = propsClause.map { - case Token(name, Token(value, Nil) :: Nil) => (name, value) + case Token("TOK_TABLECOMMENT", child :: Nil) => + val comment = unescapeSQLString(child.text) + // TODO support the sql text + tableDesc = tableDesc.copy(viewText = Option(comment)) + case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => + val cols = nodeToColumns(list.head, lowerCase = false) + if (cols != null) { + tableDesc = tableDesc.copy(partitionColumns = cols) + } + case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => + val serdeParams = new java.util.HashMap[String, String]() + child match { + case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => + val fieldDelim = unescapeSQLString (rowChild1.text) + serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) + serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) + if (rowChild2.length > 1) { + val fieldEscape = unescapeSQLString (rowChild2.head.text) + serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) } - (rowFormat, None, Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(BaseSemanticAnalyzer.unescapeSQLString(serdeClass)), Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: - Token("TOK_TABLEPROPERTIES", - Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => - val serdeProps = propsClause.map { - case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (BaseSemanticAnalyzer.unescapeSQLString(name), - BaseSemanticAnalyzer.unescapeSQLString(value)) + case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => + val collItemDelim = unescapeSQLString(rowChild.text) + serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) + case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => + val mapKeyDelim = unescapeSQLString(rowChild.text) + serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) + case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => + val lineDelim = unescapeSQLString(rowChild.text) + if (!(lineDelim == "\n") && !(lineDelim == "10")) { + throw new AnalysisException( + s"LINES TERMINATED BY only supports newline '\\n' right now: $rowChild") } - - // SPARK-10310: Special cases LazySimpleSerDe - // TODO Fully supports user-defined record reader/writer classes - val unescapedSerDeClass = BaseSemanticAnalyzer.unescapeSQLString(serdeClass) - val useDefaultRecordReaderWriter = - unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName - (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) - - case Nil => - // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here - val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") - (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) - } - - val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = - matchSerDe(inputSerdeClause) - - val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = - matchSerDe(outputSerdeClause) - - val unescapedScript = BaseSemanticAnalyzer.unescapeSQLString(script) - - // TODO Adds support for user-defined record reader/writer classes - val recordReaderClass = if (useDefaultRecordReader) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) - } else { - None + serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) + case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => + val nullFormat = unescapeSQLString(rowChild.text) + // TODO support the nullFormat + case _ => assert(false) } - - val recordWriterClass = if (useDefaultRecordWriter) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) - } else { - None + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ serdeParams.asScala) + case Token("TOK_TABLELOCATION", child :: Nil) => + val location = EximUtil.relativeToAbsolutePath(hiveConf, unescapeSQLString(child.text)) + tableDesc = tableDesc.copy(location = Option(location)) + case Token("TOK_TABLESERIALIZER", child :: Nil) => + tableDesc = tableDesc.copy( + serde = Option(unescapeSQLString(child.children.head.text))) + if (child.numChildren == 2) { + // This is based on the readProps(..) method in + // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java: + val serdeParams = child.children(1).children.head.children.map { + case Token(_, Token(prop, Nil) :: valueNode) => + val value = valueNode.headOption + .map(_.text) + .map(unescapeSQLString) + .orNull + (unescapeSQLString(prop), value) + }.toMap + tableDesc = tableDesc.copy(serdeProperties = tableDesc.serdeProperties ++ serdeParams) } + case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => + child.text.toLowerCase(Locale.ENGLISH) match { + case "orc" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } - val schema = HiveScriptIOSchema( - inRowFormat, outRowFormat, - inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, - recordReaderClass, recordWriterClass, - schemaLess) - - Some( - logical.ScriptTransformation( - inputExprs.map(nodeToExpr), - unescapedScript, - output, - withWhere, schema)) - case _ => None - } - - val withLateralView = lateralViewClause.map { lv => - val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = lv.getChildren.iterator().next() - - val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() - .asInstanceOf[ASTNode].getText - - val (generator, attributes) = nodesToGenerator(clauses) - Generate( - generator, - join = true, - outer = false, - Some(alias.toLowerCase), - attributes.map(UnresolvedAttribute(_)), - withWhere) - }.getOrElse(withWhere) - - // The projection of the query can either be a normal projection, an aggregation - // (if there is a group by) or a script transformation. - val withProject: LogicalPlan = transformation.getOrElse { - val selectExpressions = - select.getChildren.asScala.flatMap(selExprNodeToExpr).map(UnresolvedAlias) - Seq( - groupByClause.map(e => e match { - case Token("TOK_GROUPBY", children) => - // Not a transformation so must be either project or aggregation. - Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView) - case _ => sys.error("Expect GROUP BY") - }), - groupingSetsClause.map(e => e match { - case Token("TOK_GROUPING_SETS", children) => - val(groupByExprs, masks) = extractGroupingSet(children) - GroupingSets(masks, groupByExprs, withLateralView, selectExpressions) - case _ => sys.error("Expect GROUPING SETS") - }), - rollupGroupByClause.map(e => e match { - case Token("TOK_ROLLUP_GROUPBY", children) => - Rollup(children.map(nodeToExpr), withLateralView, selectExpressions) - case _ => sys.error("Expect WITH ROLLUP") - }), - cubeGroupByClause.map(e => e match { - case Token("TOK_CUBE_GROUPBY", children) => - Cube(children.map(nodeToExpr), withLateralView, selectExpressions) - case _ => sys.error("Expect WITH CUBE") - }), - Some(Project(selectExpressions, withLateralView))).flatten.head - } + case "parquet" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) + } - // Handle HAVING clause. - val withHaving = havingClause.map { h => - val havingExpr = h.getChildren.asScala match { case Seq(hexpr) => nodeToExpr(hexpr) } - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(havingExpr, BooleanType), withProject) - }.getOrElse(withProject) - - // Handle SELECT DISTINCT - val withDistinct = - if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - - // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. - val withSort = - (orderByClause, sortByClause, distributeByClause, clusterByClause) match { - case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.getChildren.asScala.map(nodeToSortOrder), true, withDistinct) - case (None, Some(perPartitionOrdering), None, None) => - Sort( - perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), - false, withDistinct) - case (None, None, Some(partitionExprs), None) => - RepartitionByExpression( - partitionExprs.getChildren.asScala.map(nodeToExpr), withDistinct) - case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort( - perPartitionOrdering.getChildren.asScala.map(nodeToSortOrder), false, - RepartitionByExpression( - partitionExprs.getChildren.asScala.map(nodeToExpr), - withDistinct)) - case (None, None, None, Some(clusterExprs)) => - Sort( - clusterExprs.getChildren.asScala.map(nodeToExpr).map(SortOrder(_, Ascending)), - false, - RepartitionByExpression( - clusterExprs.getChildren.asScala.map(nodeToExpr), - withDistinct)) - case (None, None, None, None) => withDistinct - case _ => sys.error("Unsupported set of ordering / distribution clauses.") - } + case "rcfile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy(serde = + Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + } - val withLimit = - limitClause.map(l => nodeToExpr(l.getChildren.iterator().next())) - .map(Limit(_, withSort)) - .getOrElse(withSort) - - // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.getChildren.asScala.collect { - case Token("TOK_WINDOWDEF", - Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - windowName -> nodesToWindowSpecification(spec) - }.toMap) - // Handle cases like - // window w1 as (partition by p_mfgr order by p_name - // range between 2 preceding and 2 following), - // w2 as w1 - val resolvedCrossReference = windowDefinitions.map { - windowDefMap => windowDefMap.map { - case (windowName, WindowSpecReference(other)) => - (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) - case o => o.asInstanceOf[(String, WindowSpecDefinition)] - } - } + case "textfile" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - val withWindowDefinitions = - resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) - - // TOK_INSERT_INTO means to add files to the table. - // TOK_DESTINATION means to overwrite the table. - val resultDestination = - (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = intoClause.isEmpty - nodeToDest( - resultDestination, - withWindowDefinitions, - overwrite) - } + case "sequencefile" => + tableDesc = tableDesc.copy( + inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), + outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - // If there are multiple INSERTS just UNION them together into on query. - val query = queries.reduceLeft(Union) + case "avro" => + tableDesc = tableDesc.copy( + inputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), + outputFormat = + Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) + if (tableDesc.serde.isEmpty) { + tableDesc = tableDesc.copy( + serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) + } - // return With plan if there is CTE - cteRelations.map(With(query, _)).getOrElse(query) + case _ => + throw new AnalysisException( + s"Unrecognized file format in STORED AS clause: ${child.text}") + } - // HIVE-9039 renamed TOK_UNION => TOK_UNIONALL while adding TOK_UNIONDISTINCT - case Token("TOK_UNIONALL", left :: right :: Nil) => - Union(nodeToPlan(left, context), nodeToPlan(right, context)) + case Token("TOK_TABLESERIALIZER", + Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => + tableDesc = tableDesc.copy(serde = Option(unquoteString(serdeName))) - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for $node:\n ${dumpTree(a).toString} ") - } + otherProps match { + case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => + tableDesc = tableDesc.copy( + serdeProperties = tableDesc.serdeProperties ++ getProperties(list)) + case _ => + } - val allJoinTokens = "(TOK_.*JOIN)".r - val laterViewToken = "TOK_LATERAL_VIEW(.*)".r - def nodeToRelation(node: Node, context: Context): LogicalPlan = node match { - case Token("TOK_SUBQUERY", - query :: Token(alias, Nil) :: Nil) => - Subquery(cleanIdentifier(alias), nodeToPlan(query, context)) - - case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => - val Token("TOK_SELECT", - Token("TOK_SELEXPR", clauses) :: Nil) = selectClause - - val alias = getClause("TOK_TABALIAS", clauses).getChildren.iterator().next() - .asInstanceOf[ASTNode].getText - - val (generator, attributes) = nodesToGenerator(clauses) - Generate( - generator, - join = true, - outer = isOuter.nonEmpty, - Some(alias.toLowerCase), - attributes.map(UnresolvedAttribute(_)), - nodeToRelation(relationClause, context)) - - /* All relations, possibly with aliases or sampling clauses. */ - case Token("TOK_TABREF", clauses) => - // If the last clause is not a token then it's the alias of the table. - val (nonAliasClauses, aliasClause) = - if (clauses.last.getText.startsWith("TOK")) { - (clauses, None) - } else { - (clauses.dropRight(1), Some(clauses.last)) + case Token("TOK_TABLEPROPERTIES", list :: Nil) => + tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) + case list @ Token("TOK_TABLEFILEFORMAT", _) => + tableDesc = tableDesc.copy( + inputFormat = + Option(unescapeSQLString(list.children.head.text)), + outputFormat = + Option(unescapeSQLString(list.children(1).text))) + case Token("TOK_STORAGEHANDLER", _) => + throw new AnalysisException( + "CREATE TABLE AS SELECT cannot be used for a non-native table") + case _ => // Unsupport features } - val (Some(tableNameParts) :: - splitSampleClause :: - bucketSampleClause :: Nil) = { - getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), - nonAliasClauses) - } - - val tableIdent = extractTableIdent(tableNameParts) - val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } - val relation = UnresolvedRelation(tableIdent, alias) - - // Apply sampling if requested. - (bucketSampleClause orElse splitSampleClause).map { - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_ROWCOUNT", Nil) :: - Token(count, Nil) :: Nil) => - Limit(Literal(count.toInt), relation) - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_PERCENT", Nil) :: - Token(fraction, Nil) :: Nil) => - // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling - // function takes X PERCENT as the input and the range of X is [0, 100], we need to - // adjust the fraction. - require( - fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) - && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), - s"Sampling fraction ($fraction) must be on interval [0, 100]") - Sample(0.0, fraction.toDouble / 100, withReplacement = false, (math.random * 1000).toInt, - relation) - case Token("TOK_TABLEBUCKETSAMPLE", - Token(numerator, Nil) :: - Token(denominator, Nil) :: Nil) => - val fraction = numerator.toDouble / denominator.toDouble - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation) - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for sampling clause: ${a.getType}, text: ${a.getText} : - |${dumpTree(a).toString}" + - """.stripMargin) - }.getOrElse(relation) - - case Token("TOK_UNIQUEJOIN", joinArgs) => - val tableOrdinals = - joinArgs.zipWithIndex.filter { - case (arg, i) => arg.getText == "TOK_TABREF" - }.map(_._2) - - val isPreserved = tableOrdinals.map(i => (i - 1 < 0) || joinArgs(i - 1).getText == "PRESERVE") - val tables = tableOrdinals.map(i => nodeToRelation(joinArgs(i), context)) - val joinExpressions = - tableOrdinals.map(i => joinArgs(i + 1).getChildren.asScala.map(nodeToExpr)) - - val joinConditions = joinExpressions.sliding(2).map { - case Seq(c1, c2) => - val predicates = (c1, c2).zipped.map { case (e1, e2) => EqualTo(e1, e2): Expression } - predicates.reduceLeft(And) - }.toBuffer - - val joinType = isPreserved.sliding(2).map { - case Seq(true, true) => FullOuter - case Seq(true, false) => LeftOuter - case Seq(false, true) => RightOuter - case Seq(false, false) => Inner - }.toBuffer - - val joinedTables = tables.reduceLeft(Join(_, _, Inner, None)) - - // Must be transform down. - val joinedResult = joinedTables transform { - case j: Join => - j.copy( - condition = Some(joinConditions.remove(joinConditions.length - 1)), - joinType = joinType.remove(joinType.length - 1)) - } - - val groups = joinExpressions.head.indices.map(i => Coalesce(joinExpressions.map(_(i)))) - - // Unique join is not really the same as an outer join so we must group together results where - // the joinExpressions are the same, taking the First of each value is only okay because the - // user of a unique join is implicitly promising that there is only one result. - // TODO: This doesn't actually work since [[Star]] is not a valid aggregate expression. - // instead we should figure out how important supporting this feature is and whether it is - // worth the number of hacks that will be required to implement it. Namely, we need to add - // some sort of mapped star expansion that would expand all child output row to be similarly - // named output expressions where some aggregate expression has been applied (i.e. First). - // Aggregate(groups, Star(None, First(_)) :: Nil, joinedResult) - throw new UnsupportedOperationException - - case Token(allJoinTokens(joinToken), - relation1 :: - relation2 :: other) => - if (!(other.size <= 1)) { - sys.error(s"Unsupported join operation: $other") - } - - val joinType = joinToken match { - case "TOK_JOIN" => Inner - case "TOK_CROSSJOIN" => Inner - case "TOK_RIGHTOUTERJOIN" => RightOuter - case "TOK_LEFTOUTERJOIN" => LeftOuter - case "TOK_FULLOUTERJOIN" => FullOuter - case "TOK_LEFTSEMIJOIN" => LeftSemi - } - Join(nodeToRelation(relation1, context), - nodeToRelation(relation2, context), - joinType, - other.headOption.map(nodeToExpr)) - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } + CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting.isDefined) - def nodeToSortOrder(node: Node): SortOrder = node match { - case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Ascending) - case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Descending) + // If its not a "CTAS" like above then take it as a native command + case Token("TOK_CREATETABLE", _) => + NativePlaceholder - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for:\n ${dumpTree(a).toString} ") - } + // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" + case Token("TOK_TRUNCATETABLE", Token("TOK_TABLE_PARTITION", table) :: Nil) => + NativePlaceholder - val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r - protected def nodeToDest( - node: Node, - query: LogicalPlan, - overwrite: Boolean): LogicalPlan = node match { - case Token(destinationToken(), - Token("TOK_DIR", - Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => - query - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.getChildren.asScala.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, false) - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: - Token("TOK_IFNOTEXISTS", - ifNotExists) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.getChildren.asScala.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(PlanUtils.stripQuotes(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable(UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, true) - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for ${a.getName}:" + - s"\n ${dumpTree(a).toString} ") + case _ => + super.nodeToPlan(node) + } } - protected def selExprNodeToExpr(node: Node): Option[Expression] = node match { - case Token("TOK_SELEXPR", e :: Nil) => - Some(nodeToExpr(e)) - - case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => - Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) - - case Token("TOK_SELEXPR", e :: aliasChildren) => - var aliasNames = ArrayBuffer[String]() - aliasChildren.foreach { _ match { - case Token(name, Nil) => aliasNames += cleanIdentifier(name) + protected override def nodeToDescribeFallback(node: ASTNode): LogicalPlan = NativePlaceholder + + protected override def nodeToTransformation( + node: ASTNode, + child: LogicalPlan): Option[ScriptTransformation] = node match { + case Token("TOK_SELEXPR", + Token("TOK_TRANSFORM", + Token("TOK_EXPLIST", inputExprs) :: + Token("TOK_SERDE", inputSerdeClause) :: + Token("TOK_RECORDWRITER", writerClause) :: + // TODO: Need to support other types of (in/out)put + Token(script, Nil) :: + Token("TOK_SERDE", outputSerdeClause) :: + Token("TOK_RECORDREADER", readerClause) :: + outputClause) :: Nil) => + + val (output, schemaLess) = outputClause match { + case Token("TOK_ALIASLIST", aliases) :: Nil => + (aliases.map { case Token(name, Nil) => AttributeReference(name, StringType)() }, + false) + case Token("TOK_TABCOLLIST", attributes) :: Nil => + (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => + AttributeReference(name, nodeToDataType(dataType))() }, false) + case Nil => + (List(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) case _ => - } + noParseRule("Transform", node) } - Some(MultiAlias(nodeToExpr(e), aliasNames)) - - /* Hints are ignored */ - case Token("TOK_HINTLIST", _) => None - - case a: ASTNode => - throw new NotImplementedError(s"No parse rules for ${a.getName }:" + - s"\n ${dumpTree(a).toString } ") - } - protected val escapedIdentifier = "`([^`]+)`".r - protected val doubleQuotedString = "\"([^\"]+)\"".r - protected val singleQuotedString = "'([^']+)'".r + type SerDeInfo = ( + Seq[(String, String)], // Input row format information + Option[String], // Optional input SerDe class + Seq[(String, String)], // Input SerDe properties + Boolean // Whether to use default record reader/writer + ) + + def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { + case Token("TOK_SERDEPROPS", propsClause) :: Nil => + val rowFormat = propsClause.map { + case Token(name, Token(value, Nil) :: Nil) => (name, value) + } + (rowFormat, None, Nil, false) - protected def unquoteString(str: String) = str match { - case singleQuotedString(s) => s - case doubleQuotedString(s) => s - case other => other - } + case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => + (Nil, Some(unescapeSQLString(serdeClass)), Nil, false) - /** Strips backticks from ident if present */ - protected def cleanIdentifier(ident: String): String = ident match { - case escapedIdentifier(i) => i - case plainIdent => plainIdent - } + case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: + Token("TOK_TABLEPROPERTIES", + Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => + val serdeProps = propsClause.map { + case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => + (unescapeSQLString(name), unescapeSQLString(value)) + } - val numericAstTypes = Seq( - HiveParser.Number, - HiveParser.TinyintLiteral, - HiveParser.SmallintLiteral, - HiveParser.BigintLiteral, - HiveParser.DecimalLiteral) - - /* Case insensitive matches */ - val COUNT = "(?i)COUNT".r - val SUM = "(?i)SUM".r - val AND = "(?i)AND".r - val OR = "(?i)OR".r - val NOT = "(?i)NOT".r - val TRUE = "(?i)TRUE".r - val FALSE = "(?i)FALSE".r - val LIKE = "(?i)LIKE".r - val RLIKE = "(?i)RLIKE".r - val REGEXP = "(?i)REGEXP".r - val IN = "(?i)IN".r - val DIV = "(?i)DIV".r - val BETWEEN = "(?i)BETWEEN".r - val WHEN = "(?i)WHEN".r - val CASE = "(?i)CASE".r - - protected def nodeToExpr(node: Node): Expression = node match { - /* Attribute References */ - case Token("TOK_TABLE_OR_COL", - Token(name, Nil) :: Nil) => - UnresolvedAttribute.quoted(cleanIdentifier(name)) - case Token(".", qualifier :: Token(attr, Nil) :: Nil) => - nodeToExpr(qualifier) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) - case other => UnresolvedExtractValue(other, Literal(attr)) + // SPARK-10310: Special cases LazySimpleSerDe + // TODO Fully supports user-defined record reader/writer classes + val unescapedSerDeClass = unescapeSQLString(serdeClass) + val useDefaultRecordReaderWriter = + unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName + (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) + + case Nil => + // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here + val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") + (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) } - /* Stars (*) */ - case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) - // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only - // has a single child which is tableName. - case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", Token(name, Nil) :: Nil) :: Nil) => - UnresolvedStar(Some(UnresolvedAttribute.parseAttributeName(name))) - - /* Aggregate Functions */ - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => - Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => - Count(Literal(1)).toAggregateExpression() - - /* Casts */ - case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), IntegerType) - case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), LongType) - case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), FloatType) - case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DoubleType) - case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ShortType) - case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ByteType) - case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BinaryType) - case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BooleanType) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, scale.getText.toInt)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.getText.toInt, 0)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) - case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), TimestampType) - case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DateType) - - /* Arithmetic */ - case Token("+", child :: Nil) => nodeToExpr(child) - case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) - case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) - case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) - case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) - case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) - case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) - case Token(DIV(), left :: right:: Nil) => - Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) - case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) - case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) - case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) - case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) - - /* Comparisons */ - case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) - case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) - case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) - case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) - case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => - IsNotNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => - IsNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => - In(nodeToExpr(value), list.map(nodeToExpr)) - case Token("TOK_FUNCTION", - Token(BETWEEN(), Nil) :: - kw :: - target :: - minValue :: - maxValue :: Nil) => - - val targetExpression = nodeToExpr(target) - val betweenExpr = - And( - GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), - LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) - kw match { - case Token("KW_FALSE", Nil) => betweenExpr - case Token("KW_TRUE", Nil) => Not(betweenExpr) - } + val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = + matchSerDe(inputSerdeClause) - /* Boolean Logic */ - case Token(AND(), left :: right:: Nil) => And(nodeToExpr(left), nodeToExpr(right)) - case Token(OR(), left :: right:: Nil) => Or(nodeToExpr(left), nodeToExpr(right)) - case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) - case Token("!", child :: Nil) => Not(nodeToExpr(child)) - - /* Case statements */ - case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => - CaseWhen(branches.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => - val keyExpr = nodeToExpr(branches.head) - CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) - - /* Complex datatype manipulation */ - case Token("[", child :: ordinal :: Nil) => - UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) - - /* Window Functions */ - case Token("TOK_FUNCTION", Token(name, Nil) +: args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = UnresolvedWindowFunction(name, args.map(nodeToExpr)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - // Safe to use Literal(1)? - val function = UnresolvedWindowFunction(name, Literal(1) :: Nil) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } + val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = + matchSerDe(outputSerdeClause) - /* UDFs - Must be last otherwise will preempt built in functions */ - case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) - // Aggregate function with DISTINCT keyword. - case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) - - /* Literals */ - case Token("TOK_NULL", Nil) => Literal.create(null, NullType) - case Token(TRUE(), Nil) => Literal.create(true, BooleanType) - case Token(FALSE(), Nil) => Literal.create(false, BooleanType) - case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => BaseSemanticAnalyzer.unescapeSQLString(s.getText)).mkString) - - // This code is adapted from - // /ql/src/java/org/apache/hadoop/hive/ql/parse/TypeCheckProcFactory.java#L223 - case ast: ASTNode if numericAstTypes contains ast.getType => - var v: Literal = null - try { - if (ast.getText.endsWith("L")) { - // Literal bigint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toLong, LongType) - } else if (ast.getText.endsWith("S")) { - // Literal smallint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toShort, ShortType) - } else if (ast.getText.endsWith("Y")) { - // Literal tinyint. - v = Literal.create(ast.getText.substring(0, ast.getText.length() - 1).toByte, ByteType) - } else if (ast.getText.endsWith("BD") || ast.getText.endsWith("D")) { - // Literal decimal - val strVal = ast.getText.stripSuffix("D").stripSuffix("B") - v = Literal(Decimal(strVal)) - } else { - v = Literal.create(ast.getText.toDouble, DoubleType) - v = Literal.create(ast.getText.toLong, LongType) - v = Literal.create(ast.getText.toInt, IntegerType) - } - } catch { - case nfe: NumberFormatException => // Do nothing - } + val unescapedScript = unescapeSQLString(script) - if (v == null) { - sys.error(s"Failed to parse number '${ast.getText}'.") + // TODO Adds support for user-defined record reader/writer classes + val recordReaderClass = if (useDefaultRecordReader) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) } else { - v + None } - case ast: ASTNode if ast.getType == HiveParser.StringLiteral => - Literal(BaseSemanticAnalyzer.unescapeSQLString(ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_DATELITERAL => - Literal(Date.valueOf(ast.getText.substring(1, ast.getText.length - 1))) - - case ast: ASTNode if ast.getType == HiveParser.TOK_CHARSETLITERAL => - Literal(BaseSemanticAnalyzer.charSetString(ast.getChild(0).getText, ast.getChild(1).getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => - Literal(CalendarInterval.fromYearMonthString(ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_TIME_LITERAL => - Literal(CalendarInterval.fromDayTimeString(ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_YEAR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("year", ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MONTH_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("month", ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_DAY_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("day", ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_HOUR_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("hour", ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_MINUTE_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("minute", ast.getText)) - - case ast: ASTNode if ast.getType == HiveParser.TOK_INTERVAL_SECOND_LITERAL => - Literal(CalendarInterval.fromSingleUnitString("second", ast.getText)) - - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText} : - |${dumpTree(a).toString}" + - """.stripMargin) - } - - /* Case insensitive matches for Window Specification */ - val PRECEDING = "(?i)preceding".r - val FOLLOWING = "(?i)following".r - val CURRENT = "(?i)current".r - def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { - case Token(windowName, Nil) :: Nil => - // Refer to a window spec defined in the window clause. - WindowSpecReference(windowName) - case Nil => - // OVER() - WindowSpecDefinition( - partitionSpec = Nil, - orderSpec = Nil, - frameSpecification = UnspecifiedFrame) - case spec => - val (partitionClause :: rowFrame :: rangeFrame :: Nil) = - getClauses( - Seq( - "TOK_PARTITIONINGSPEC", - "TOK_WINDOWRANGE", - "TOK_WINDOWVALUES"), - spec) - - // Handle Partition By and Order By. - val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => - val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = - getClauses( - Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.getChildren.asScala.asInstanceOf[Seq[ASTNode]]) - - (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { - case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.getChildren.asScala.map(nodeToExpr), - orderByExpr.getChildren.asScala.map(nodeToSortOrder)) - case (Some(partitionByExpr), None, None) => - (partitionByExpr.getChildren.asScala.map(nodeToExpr), Nil) - case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.getChildren.asScala.map(nodeToSortOrder)) - case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.getChildren.asScala.map(nodeToExpr) - (expressions, expressions.map(SortOrder(_, Ascending))) - case _ => - throw new NotImplementedError( - s"""No parse rules for Node ${partitionAndOrdering.getName} - """.stripMargin) - } - }.getOrElse { - (Nil, Nil) + val recordWriterClass = if (useDefaultRecordWriter) { + Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) + } else { + None } - // Handle Window Frame - val windowFrame = - if (rowFrame.isEmpty && rangeFrame.isEmpty) { - UnspecifiedFrame - } else { - val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) - def nodeToBoundary(node: Node): FrameBoundary = node match { - case Token(PRECEDING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedPreceding - } else { - ValuePreceding(count.toInt) - } - case Token(FOLLOWING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedFollowing - } else { - ValueFollowing(count.toInt) - } - case Token(CURRENT(), Nil) => CurrentRow - case _ => - throw new NotImplementedError( - s"""No parse rules for the Window Frame Boundary based on Node ${node.getName} - """.stripMargin) - } - - rowFrame.orElse(rangeFrame).map { frame => - frame.getChildren.asScala.toList match { - case precedingNode :: followingNode :: Nil => - SpecifiedWindowFrame( - frameType, - nodeToBoundary(precedingNode), - nodeToBoundary(followingNode)) - case precedingNode :: Nil => - SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) - case _ => - throw new NotImplementedError( - s"""No parse rules for the Window Frame based on Node ${frame.getName} - """.stripMargin) - } - }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) - } - - WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) + val schema = HiveScriptIOSchema( + inRowFormat, outRowFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + recordReaderClass, recordWriterClass, + schemaLess) + + Some( + ScriptTransformation( + inputExprs.map(nodeToExpr), + unescapedScript, + output, + child, schema)) + case _ => None } - val explode = "(?i)explode".r - val jsonTuple = "(?i)json_tuple".r - def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { - val function = nodes.head - - val attributes = nodes.flatMap { - case Token(a, Nil) => a.toLowerCase :: Nil - case _ => Nil - } - - function match { - case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => - (Explode(nodeToExpr(child)), attributes) - - case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => - (JsonTuple(children.map(nodeToExpr)), attributes) - - case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $functionName")) - val functionClassName = functionInfo.getFunctionClass.getName - - (HiveGenericUDTF( - new HiveFunctionWrapper(functionClassName), - children.map(nodeToExpr)), attributes) - - case a: ASTNode => - throw new NotImplementedError( - s"""No parse rules for ASTNode type: ${a.getType}, text: ${a.getText}, tree: - |${dumpTree(a).toString} - """.stripMargin) - } + protected override def nodeToGenerator(node: ASTNode): Generator = node match { + case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( + sys.error(s"Couldn't find function $functionName")) + val functionClassName = functionInfo.getFunctionClass.getName + HiveGenericUDTF( + functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) + case other => super.nodeToGenerator(node) } - def dumpTree(node: Node, builder: StringBuilder = new StringBuilder, indent: Int = 0) - : StringBuilder = { - node match { - case a: ASTNode => builder.append( - (" " * indent) + a.getText + " " + - a.getLine + ", " + - a.getTokenStartIndex + "," + - a.getTokenStopIndex + ", " + - a.getCharPositionInLine + "\n") - case other => sys.error(s"Non ASTNode encountered: $other") + // This is based the getColumns methods in + // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java + protected def nodeToColumns(node: ASTNode, lowerCase: Boolean): Seq[HiveColumn] = { + node.children.map(_.children).collect { + case Token(rawColName, Nil) :: colTypeNode :: comment => + val colName = if (!lowerCase) rawColName + else rawColName.toLowerCase + HiveColumn( + cleanIdentifier(colName), + nodeToTypeString(colTypeNode), + comment.headOption.map(n => unescapeSQLString(n.text)).orNull) } + } - Option(node.getChildren).map(_.asScala).getOrElse(Nil).foreach(dumpTree(_, builder, indent + 1)) - builder + // This is based on the following methods in + // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java: + // getTypeStringFromAST + // getStructTypeStringFromAST + // getUnionTypeStringFromAST + protected def nodeToTypeString(node: ASTNode): String = node.tokenType match { + case SparkSqlParser.TOK_LIST => + val listType :: Nil = node.children + val listTypeString = nodeToTypeString(listType) + s"${serdeConstants.LIST_TYPE_NAME}<$listTypeString>" + + case SparkSqlParser.TOK_MAP => + val keyType :: valueType :: Nil = node.children + val keyTypeString = nodeToTypeString(keyType) + val valueTypeString = nodeToTypeString(valueType) + s"${serdeConstants.MAP_TYPE_NAME}<$keyTypeString,$valueTypeString>" + + case SparkSqlParser.TOK_STRUCT => + val typeNode = node.children.head + require(typeNode.children.nonEmpty, "Struct must have one or more columns.") + val structColStrings = typeNode.children.map { columnNode => + val Token(colName, Nil) :: colTypeNode :: Nil = columnNode.children + cleanIdentifier(colName) + ":" + nodeToTypeString(colTypeNode) + } + s"${serdeConstants.STRUCT_TYPE_NAME}<${structColStrings.mkString(",")}>" + + case SparkSqlParser.TOK_UNIONTYPE => + val typeNode = node.children.head + val unionTypesString = typeNode.children.map(nodeToTypeString).mkString(",") + s"${serdeConstants.UNION_TYPE_NAME}<$unionTypesString>" + + case SparkSqlParser.TOK_CHAR => + val Token(size, Nil) :: Nil = node.children + s"${serdeConstants.CHAR_TYPE_NAME}($size)" + + case SparkSqlParser.TOK_VARCHAR => + val Token(size, Nil) :: Nil = node.children + s"${serdeConstants.VARCHAR_TYPE_NAME}($size)" + + case SparkSqlParser.TOK_DECIMAL => + val precisionAndScale = node.children match { + case Token(precision, Nil) :: Token(scale, Nil) :: Nil => + precision + "," + scale + case Token(precision, Nil) :: Nil => + precision + "," + HiveDecimal.USER_DEFAULT_SCALE + case Nil => + HiveDecimal.USER_DEFAULT_PRECISION + "," + HiveDecimal.USER_DEFAULT_SCALE + case _ => + noParseRule("Decimal", node) + } + s"${serdeConstants.DECIMAL_TYPE_NAME}($precisionAndScale)" + + // Simple data types. + case SparkSqlParser.TOK_BOOLEAN => serdeConstants.BOOLEAN_TYPE_NAME + case SparkSqlParser.TOK_TINYINT => serdeConstants.TINYINT_TYPE_NAME + case SparkSqlParser.TOK_SMALLINT => serdeConstants.SMALLINT_TYPE_NAME + case SparkSqlParser.TOK_INT => serdeConstants.INT_TYPE_NAME + case SparkSqlParser.TOK_BIGINT => serdeConstants.BIGINT_TYPE_NAME + case SparkSqlParser.TOK_FLOAT => serdeConstants.FLOAT_TYPE_NAME + case SparkSqlParser.TOK_DOUBLE => serdeConstants.DOUBLE_TYPE_NAME + case SparkSqlParser.TOK_STRING => serdeConstants.STRING_TYPE_NAME + case SparkSqlParser.TOK_BINARY => serdeConstants.BINARY_TYPE_NAME + case SparkSqlParser.TOK_DATE => serdeConstants.DATE_TYPE_NAME + case SparkSqlParser.TOK_TIMESTAMP => serdeConstants.TIMESTAMP_TYPE_NAME + case SparkSqlParser.TOK_INTERVAL_YEAR_MONTH => serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME + case SparkSqlParser.TOK_INTERVAL_DAY_TIME => serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME + case SparkSqlParser.TOK_DATETIME => serdeConstants.DATETIME_TYPE_NAME + case _ => null } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala index f0697613cff3..b8cced0b8096 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveShim.scala @@ -20,15 +20,13 @@ package org.apache.spark.sql.hive import java.io.{InputStream, OutputStream} import java.rmi.server.UID -import org.apache.avro.Schema - import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag import com.esotericsoftware.kryo.Kryo import com.esotericsoftware.kryo.io.{Input, Output} - +import org.apache.avro.Schema import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.exec.{UDF, Utilities} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index d38ad9127327..3687dd6f5a7a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -22,14 +22,13 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand, _} +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTableUsingAsSelect, DescribeCommand} import org.apache.spark.sql.hive.execution._ - private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. - self: SQLContext#SparkPlanner => + self: SparkPlanner => val hiveContext: HiveContext @@ -89,10 +88,9 @@ private[hive] trait HiveStrategies { tableIdent, userSpecifiedSchema, provider, opts, allowExisting, managedIfNoPath) ExecutedCommand(cmd) :: Nil - case CreateTableUsingAsSelect( - tableIdent, provider, false, partitionCols, mode, opts, query) => - val cmd = - CreateMetastoreDataSourceAsSelect(tableIdent, provider, partitionCols, mode, opts, query) + case c: CreateTableUsingAsSelect => + val cmd = CreateMetastoreDataSourceAsSelect(c.tableIdent, c.provider, c.partitionColumns, + c.bucketSpec, c.mode, c.options, c.child) ExecutedCommand(cmd) :: Nil case _ => Nil diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala new file mode 100644 index 000000000000..1c910051facc --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.util.concurrent.atomic.AtomicLong + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.optimizer.ProjectCollapsing +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} +import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.{DataFrame, SQLContext} + +/** + * A builder class used to convert a resolved logical plan into a SQL query string. Note that this + * all resolved logical plan are convertible. They either don't have corresponding SQL + * representations (e.g. logical plans that operate on local Scala collections), or are simply not + * supported by this builder (yet). + */ +class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { + def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) + + def toSQL: Option[String] = { + val canonicalizedPlan = Canonicalizer.execute(logicalPlan) + val maybeSQL = try { + toSQL(canonicalizedPlan) + } catch { case cause: UnsupportedOperationException => + logInfo(s"Failed to build SQL query string because: ${cause.getMessage}") + None + } + + if (maybeSQL.isDefined) { + logDebug( + s"""Built SQL query string successfully from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + |# Built SQL query string: + |${maybeSQL.get} + """.stripMargin) + } else { + logDebug( + s"""Failed to build SQL query string from given logical plan: + | + |# Original logical plan: + |${logicalPlan.treeString} + |# Canonicalized logical plan: + |${canonicalizedPlan.treeString} + """.stripMargin) + } + + maybeSQL + } + + private def projectToSQL( + projectList: Seq[NamedExpression], + child: LogicalPlan, + isDistinct: Boolean): Option[String] = { + for { + childSQL <- toSQL(child) + listSQL = projectList.map(_.sql).mkString(", ") + maybeFrom = child match { + case OneRowRelation => " " + case _ => " FROM " + } + distinct = if (isDistinct) " DISTINCT " else " " + } yield s"SELECT$distinct$listSQL$maybeFrom$childSQL" + } + + private def aggregateToSQL( + groupingExprs: Seq[Expression], + aggExprs: Seq[Expression], + child: LogicalPlan): Option[String] = { + val aggSQL = aggExprs.map(_.sql).mkString(", ") + val groupingSQL = groupingExprs.map(_.sql).mkString(", ") + val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " + val maybeFrom = child match { + case OneRowRelation => " " + case _ => " FROM " + } + + toSQL(child).map { childSQL => + s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL" + } + } + + private def toSQL(node: LogicalPlan): Option[String] = node match { + case Distinct(Project(list, child)) => + projectToSQL(list, child, isDistinct = true) + + case Project(list, child) => + projectToSQL(list, child, isDistinct = false) + + case Aggregate(groupingExprs, aggExprs, child) => + aggregateToSQL(groupingExprs, aggExprs, child) + + case Limit(limit, child) => + for { + childSQL <- toSQL(child) + limitSQL = limit.sql + } yield s"$childSQL LIMIT $limitSQL" + + case Filter(condition, child) => + for { + childSQL <- toSQL(child) + whereOrHaving = child match { + case _: Aggregate => "HAVING" + case _ => "WHERE" + } + conditionSQL = condition.sql + } yield s"$childSQL $whereOrHaving $conditionSQL" + + case Union(left, right) => + for { + leftSQL <- toSQL(left) + rightSQL <- toSQL(right) + } yield s"$leftSQL UNION ALL $rightSQL" + + // ParquetRelation converted from Hive metastore table + case Subquery(alias, LogicalRelation(r: ParquetRelation, _)) => + // There seems to be a bug related to `ParquetConversions` analysis rule. The problem is + // that, the metastore database name and table name are not always propagated to converted + // `ParquetRelation` instances via data source options. Here we use subquery alias as a + // workaround. + Some(s"`$alias`") + + case Subquery(alias, child) => + toSQL(child).map(childSQL => s"($childSQL) AS $alias") + + case Join(left, right, joinType, condition) => + for { + leftSQL <- toSQL(left) + rightSQL <- toSQL(right) + joinTypeSQL = joinType.sql + conditionSQL = condition.map(" ON " + _.sql).getOrElse("") + } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL" + + case MetastoreRelation(database, table, alias) => + val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("") + Some(s"`$database`.`$table`$aliasSQL") + + case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) + if orders.map(_.child) == partitionExprs => + for { + childSQL <- toSQL(child) + partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") + } yield s"$childSQL CLUSTER BY $partitionExprsSQL" + + case Sort(orders, global, child) => + for { + childSQL <- toSQL(child) + ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") + orderOrSort = if (global) "ORDER" else "SORT" + } yield s"$childSQL $orderOrSort BY $ordersSQL" + + case RepartitionByExpression(partitionExprs, child, _) => + for { + childSQL <- toSQL(child) + partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") + } yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL" + + case OneRowRelation => + Some("") + + case _ => None + } + + object Canonicalizer extends RuleExecutor[LogicalPlan] { + override protected def batches: Seq[Batch] = Seq( + Batch("Canonicalizer", FixedPoint(100), + // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over + // `Aggregate`s to perform type casting. This rule merges these `Project`s into + // `Aggregate`s. + ProjectCollapsing, + + // Used to handle other auxiliary `Project`s added by analyzer (e.g. + // `ResolveAggregateFunctions` rule) + RecoverScopingInfo + ) + ) + + object RecoverScopingInfo extends Rule[LogicalPlan] { + override def apply(tree: LogicalPlan): LogicalPlan = tree transform { + // This branch handles aggregate functions within HAVING clauses. For example: + // + // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255" + // + // This kind of query results in query plans of the following form because of analysis rule + // `ResolveAggregateFunctions`: + // + // Project ... + // +- Filter ... + // +- Aggregate ... + // +- MetastoreRelation default, src, None + case plan @ Project(_, Filter(_, _: Aggregate)) => + wrapChildWithSubquery(plan) + + case plan @ Project(_, + _: Subquery | _: Filter | _: Join | _: MetastoreRelation | OneRowRelation | _: Limit + ) => plan + + case plan: Project => + wrapChildWithSubquery(plan) + } + + def wrapChildWithSubquery(project: Project): Project = project match { + case Project(projectList, child) => + val alias = SQLBuilder.newSubqueryName + val childAttributes = child.outputSet + val aliasedProjectList = projectList.map(_.transform { + case a: Attribute if childAttributes.contains(a) => + a.withQualifiers(alias :: Nil) + }.asInstanceOf[NamedExpression]) + + Project(aliasedProjectList, Subquery(alias, child)) + } + } + } +} + +object SQLBuilder { + private val nextSubqueryId = new AtomicLong(0) + + private def newSubqueryName: String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}" +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 70ee02823eeb..fd465e80a87e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -23,11 +23,11 @@ import org.apache.hadoop.fs.{Path, PathFilter} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants._ import org.apache.hadoop.hive.ql.exec.Utilities -import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable, Hive, HiveUtils, HiveStorageHandler} +import org.apache.hadoop.hive.ql.metadata.{Hive, HiveStorageHandler, HiveUtils, Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Deserializer -import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspectorConverters, StructObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.primitive._ import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala index 598ccdeee4ad..ce7a305d437a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/ClientWrapper.scala @@ -25,17 +25,16 @@ import scala.language.reflectiveCalls import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} import org.apache.hadoop.hive.metastore.{TableType => HTableType} +import org.apache.hadoop.hive.metastore.api.{Database, FieldSchema} +import org.apache.hadoop.hive.ql.{metadata, Driver} import org.apache.hadoop.hive.ql.metadata.Hive import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.ql.{Driver, metadata} import org.apache.hadoop.hive.shims.{HadoopShims, ShimLoader} import org.apache.hadoop.security.UserGroupInformation -import org.apache.hadoop.util.VersionInfo -import org.apache.spark.{SparkConf, SparkException, Logging} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.util.{CircularBuffer, Utils} @@ -65,74 +64,6 @@ private[hive] class ClientWrapper( extends ClientInterface with Logging { - overrideHadoopShims() - - // !! HACK ALERT !! - // - // Internally, Hive `ShimLoader` tries to load different versions of Hadoop shims by checking - // major version number gathered from Hadoop jar files: - // - // - For major version number 1, load `Hadoop20SShims`, where "20S" stands for Hadoop 0.20 with - // security. - // - For major version number 2, load `Hadoop23Shims`, where "23" stands for Hadoop 0.23. - // - // However, APIs in Hadoop 2.0.x and 2.1.x versions were in flux due to historical reasons. It - // turns out that Hadoop 2.0.x versions should also be used together with `Hadoop20SShims`, but - // `Hadoop23Shims` is chosen because the major version number here is 2. - // - // To fix this issue, we try to inspect Hadoop version via `org.apache.hadoop.utils.VersionInfo` - // and load `Hadoop20SShims` for Hadoop 1.x and 2.0.x versions. If Hadoop version information is - // not available, we decide whether to override the shims or not by checking for existence of a - // probe method which doesn't exist in Hadoop 1.x or 2.0.x versions. - private def overrideHadoopShims(): Unit = { - val hadoopVersion = VersionInfo.getVersion - val VersionPattern = """(\d+)\.(\d+).*""".r - - hadoopVersion match { - case null => - logError("Failed to inspect Hadoop version") - - // Using "Path.getPathWithoutSchemeAndAuthority" as the probe method. - val probeMethod = "getPathWithoutSchemeAndAuthority" - if (!classOf[Path].getDeclaredMethods.exists(_.getName == probeMethod)) { - logInfo( - s"Method ${classOf[Path].getCanonicalName}.$probeMethod not found, " + - s"we are probably using Hadoop 1.x or 2.0.x") - loadHadoop20SShims() - } - - case VersionPattern(majorVersion, minorVersion) => - logInfo(s"Inspected Hadoop version: $hadoopVersion") - - // Loads Hadoop20SShims for 1.x and 2.0.x versions - val (major, minor) = (majorVersion.toInt, minorVersion.toInt) - if (major < 2 || (major == 2 && minor == 0)) { - loadHadoop20SShims() - } - } - - // Logs the actual loaded Hadoop shims class - val loadedShimsClassName = ShimLoader.getHadoopShims.getClass.getCanonicalName - logInfo(s"Loaded $loadedShimsClassName for Hadoop version $hadoopVersion") - } - - private def loadHadoop20SShims(): Unit = { - val hadoop20SShimsClassName = "org.apache.hadoop.hive.shims.Hadoop20SShims" - logInfo(s"Loading Hadoop shims $hadoop20SShimsClassName") - - try { - val shimsField = classOf[ShimLoader].getDeclaredField("hadoopShims") - // scalastyle:off classforname - val shimsClass = Class.forName(hadoop20SShimsClassName) - // scalastyle:on classforname - val shims = classOf[HadoopShims].cast(shimsClass.newInstance()) - shimsField.setAccessible(true) - shimsField.set(null, shims) - } catch { case cause: Throwable => - throw new RuntimeException(s"Failed to load $hadoop20SShimsClassName", cause) - } - } - // Circular buffer to hold what hive prints to STDOUT and ERR. Only printed when failures occur. private val outputBuffer = new CircularBuffer() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 346840079b85..ca636b0265d4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.hive.serde.serdeConstants import org.apache.spark.Logging import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{StringType, IntegralType} +import org.apache.spark.sql.types.{IntegralType, StringType} /** * A shim that defines the interface between ClientWrapper and the underlying Hive library used to diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index e72a60b42e65..4c0aae6c04bd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, MetastoreRelation} -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} /** * Create table and insert the query result into it. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 2c81115ee4fe..6e288afbb4d2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveContext} -import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.execution.RunnableCommand +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} import org.apache.spark.sql.hive.client.{HiveColumn, HiveTable} /** diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala index 441b6b6033e1..dfa5a982b158 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/DescribeHiveTableCommand.scala @@ -21,10 +21,10 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.metastore.api.FieldSchema +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.MetastoreRelation -import org.apache.spark.sql.{Row, SQLContext} /** * Implementation for "describe [extended] table". diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 41b645b2c9c9..381fb61160ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.hive.execution +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.RunnableCommand import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.StringType -import org.apache.spark.sql.{Row, SQLContext} private[hive] case class HiveNativeCommand(sql: String) extends RunnableCommand { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 806d2b9b0b7d..1588728bdbaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -51,6 +51,9 @@ case class HiveTableScan( require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, "Partition pruning predicates only supported for partitioned tables.") + override def producedAttributes: AttributeSet = outputSet ++ + AttributeSet(partitionPruningPred.flatMap(_.references)) + // Retrieve the original attributes based on expression ID so that capitalization matches. val attributes = requestedAttributes.map(relation.attributeMap) @@ -129,11 +132,17 @@ case class HiveTableScan( } } - protected override def doExecute(): RDD[InternalRow] = if (!relation.hiveQlTable.isPartitioned) { - hadoopReader.makeRDDForTable(relation.hiveQlTable) - } else { - hadoopReader.makeRDDForPartitionedTable( - prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + protected override def doExecute(): RDD[InternalRow] = { + val rdd = if (!relation.hiveQlTable.isPartitioned) { + hadoopReader.makeRDDForTable(relation.hiveQlTable) + } else { + hadoopReader.makeRDDForPartitionedTable( + prunePartitions(relation.getHiveQlPartitions(partitionPruningPred))) + } + rdd.mapPartitionsInternal { iter => + val proj = UnsafeProjection.create(schema) + iter.map(proj) + } } override def output: Seq[Attribute] = attributes diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f936cf565b2b..b02ace786c66 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -23,22 +23,21 @@ import scala.collection.JavaConverters._ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.ql.{Context, ErrorMsg} +import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.hive.serde2.Serializer -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption import org.apache.hadoop.hive.serde2.objectinspector._ -import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.ObjectInspectorCopyOption +import org.apache.hadoop.mapred.{FileOutputFormat, JobConf} +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan} -import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} +import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection} +import org.apache.spark.sql.execution.{SparkPlan, UnaryNode} import org.apache.spark.sql.hive._ +import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.types.DataType -import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.util.SerializableJobConf private[hive] @@ -101,15 +100,17 @@ case class InsertIntoHiveTable( writerContainer.executorSideSetup(context.stageId, context.partitionId, context.attemptNumber) + val proj = FromUnsafeProjection(child.schema) iterator.foreach { row => var i = 0 + val safeRow = proj(row) while (i < fieldOIs.length) { - outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(row.get(i, dataTypes(i))) + outputData(i) = if (row.isNullAt(i)) null else wrappers(i)(safeRow.get(i, dataTypes(i))) i += 1 } writerContainer - .getLocalFileWriter(row, table.schema) + .getLocalFileWriter(safeRow, table.schema) .write(serializer.serialize(outputData, standardOI)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index b30117f0de99..5e6641693798 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -31,16 +31,16 @@ import org.apache.hadoop.hive.serde2.AbstractSerDe import org.apache.hadoop.hive.serde2.objectinspector._ import org.apache.hadoop.io.Writable +import org.apache.spark.{Logging, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types.DataType import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} -import org.apache.spark.{Logging, TaskContext} /** * Transforms the input by forking and running the specified script. @@ -58,7 +58,9 @@ case class ScriptTransformation( ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) extends UnaryNode { - override def otherCopyArgs: Seq[HiveContext] = sc :: Nil + override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil + + override def producedAttributes: AttributeSet = outputSet -- inputSet private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) @@ -211,7 +213,8 @@ case class ScriptTransformation( child.execute().mapPartitions { iter => if (iter.hasNext) { - processIterator(iter) + val proj = UnsafeProjection.create(schema) + processIterator(iter).map(proj) } else { // If the input iterator has no rows then do not launch the external script. Iterator.empty diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 94210a5394f9..612f01cda88b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.RunnableCommand -import org.apache.spark.sql.execution.datasources.{LogicalRelation, ResolvedDataSource} +import org.apache.spark.sql.execution.datasources.{BucketSpec, LogicalRelation, ResolvedDataSource} import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -151,6 +151,7 @@ case class CreateMetastoreDataSource( tableIdent, userSpecifiedSchema, Array.empty[String], + bucketSpec = None, provider, optionsWithPath, isExternal) @@ -164,6 +165,7 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent: TableIdentifier, provider: String, partitionColumns: Array[String], + bucketSpec: Option[BucketSpec], mode: SaveMode, options: Map[String, String], query: LogicalPlan) extends RunnableCommand { @@ -254,8 +256,14 @@ case class CreateMetastoreDataSourceAsSelect( } // Create the relation based on the data of df. - val resolved = - ResolvedDataSource(sqlContext, provider, partitionColumns, mode, optionsWithPath, df) + val resolved = ResolvedDataSource( + sqlContext, + provider, + partitionColumns, + bucketSpec, + mode, + optionsWithPath, + df) if (createMetastoreTable) { // We will use the schema of resolved.relation as the schema of the table (instead of @@ -265,6 +273,7 @@ case class CreateMetastoreDataSourceAsSelect( tableIdent, Some(resolved.relation.schema), partitionColumns, + bucketSpec, provider, optionsWithPath, isExternal) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 2e8c026259ef..e76c18fa528f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -17,31 +17,26 @@ package org.apache.spark.sql.hive -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer import scala.util.Try -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory import org.apache.hadoop.hive.ql.exec._ -import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} -import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper +import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions +import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, ObjectInspector, ObjectInspectorFactory} import org.apache.spark.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.sequenceOption +import org.apache.spark.sql.catalyst.{InternalRow, analysis} import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.sql.types._ @@ -76,19 +71,19 @@ private[hive] class HiveFunctionRegistry( try { if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { HiveGenericUDF( - new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) + name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) } else if ( classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUDAFFunction(new HiveFunctionWrapper(functionClassName), children) + HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveUDAFFunction( - new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) + name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udtf = HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children) udtf.elementTypes // Force it to check input data types. udtf } else { @@ -138,7 +133,8 @@ private[hive] class HiveFunctionRegistry( } } -private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveSimpleUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def deterministic: Boolean = isUDFDeterministic @@ -192,6 +188,8 @@ private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, childre override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } // Adapter from Catalyst ExpressionResult to Hive DeferredObject @@ -206,7 +204,8 @@ private[hive] class DeferredObjectAdapter(oi: ObjectInspector, dataType: DataTyp override def get(): AnyRef = wrap(func(), oi, dataType) } -private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) +private[hive] case class HiveGenericUDF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { override def nullable: Boolean = true @@ -258,230 +257,8 @@ private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, childr override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } -} - -/** - * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]]. - */ -private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { - private def shouldResolveFunction( - unresolvedWindowFunction: UnresolvedWindowFunction, - windowSpec: WindowSpecDefinition): Boolean = { - unresolvedWindowFunction.childrenResolved && windowSpec.childrenResolved - } - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p: LogicalPlan if !p.childrenResolved => p - - // We are resolving WindowExpressions at here. When we get here, we have already - // replaced those WindowSpecReferences. - case p: LogicalPlan => - p transformExpressions { - // We will not start to resolve the function unless all arguments are resolved - // and all expressions in window spec are fixed. - case WindowExpression( - u @ UnresolvedWindowFunction(name, children), - windowSpec: WindowSpecDefinition) if shouldResolveFunction(u, windowSpec) => - // First, let's find the window function info. - val windowFunctionInfo: WindowFunctionInfo = - Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse( - throw new AnalysisException(s"Couldn't find window function $name")) - - // Get the class of this function. - // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use - // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. - val functionClass = windowFunctionInfo.getFunctionClass() - val newChildren = - // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit - // input parameters and requires implicit parameters, which - // are expressions in Order By clause. - if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) { - if (children.nonEmpty) { - throw new AnalysisException(s"$name does not take input parameters.") - } - windowSpec.orderSpec.map(_.child) - } else { - children - } - - // If the class is UDAF, we need to use UDAFBridge. - val isUDAFBridgeRequired = - if (classOf[UDAF].isAssignableFrom(functionClass)) { - true - } else { - false - } - - // Create the HiveWindowFunction. For the meaning of isPivotResult, see the doc of - // HiveWindowFunction. - val windowFunction = - HiveWindowFunction( - new HiveFunctionWrapper(functionClass.getName), - windowFunctionInfo.isPivotResult, - isUDAFBridgeRequired, - newChildren) - - // Second, check if the specified window function can accept window definition. - windowSpec.frameSpecification match { - case frame: SpecifiedWindowFrame if !windowFunctionInfo.isSupportsWindow => - // This Hive window function does not support user-speficied window frame. - throw new AnalysisException( - s"Window function $name does not take a frame specification.") - case frame: SpecifiedWindowFrame if windowFunctionInfo.isSupportsWindow && - windowFunctionInfo.isPivotResult => - // These two should not be true at the same time when a window frame is defined. - // If so, throw an exception. - throw new AnalysisException(s"Could not handle Hive window function $name because " + - s"it supports both a user specified window frame and pivot result.") - case _ => // OK - } - // Resolve those UnspecifiedWindowFrame because the physical Window operator still needs - // a window frame specification to work. - val newWindowSpec = windowSpec.frameSpecification match { - case UnspecifiedFrame => - val newWindowFrame = - SpecifiedWindowFrame.defaultWindowFrame( - windowSpec.orderSpec.nonEmpty, - windowFunctionInfo.isSupportsWindow) - WindowSpecDefinition(windowSpec.partitionSpec, windowSpec.orderSpec, newWindowFrame) - case _ => windowSpec - } - - // Finally, we create a WindowExpression with the resolved window function and - // specified window spec. - WindowExpression(windowFunction, newWindowSpec) - } - } -} - -/** - * A [[WindowFunction]] implementation wrapping Hive's window function. - * @param funcWrapper The wrapper for the Hive Window Function. - * @param pivotResult If it is true, the Hive function will return a list of values representing - * the values of the added columns. Otherwise, a single value is returned for - * current row. - * @param isUDAFBridgeRequired If it is true, the function returned by functionWrapper's - * createFunction is UDAF, we need to use GenericUDAFBridge to wrap - * it as a GenericUDAFResolver2. - * @param children Input parameters. - */ -private[hive] case class HiveWindowFunction( - funcWrapper: HiveFunctionWrapper, - pivotResult: Boolean, - isUDAFBridgeRequired: Boolean, - children: Seq[Expression]) extends WindowFunction - with HiveInspectors with Unevaluable { - - // Hive window functions are based on GenericUDAFResolver2. - type UDFType = GenericUDAFResolver2 - - @transient - protected lazy val resolver: GenericUDAFResolver2 = - if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) - } else { - funcWrapper.createFunction[GenericUDAFResolver2]() - } - - @transient - protected lazy val inputInspectors = children.map(toInspector).toArray - - // The GenericUDAFEvaluator used to evaluate the window function. - @transient - protected lazy val evaluator: GenericUDAFEvaluator = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) - resolver.getEvaluator(parameterInfo) - } - - // The object inspector of values returned from the Hive window function. - @transient - protected lazy val returnInspector = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - override val dataType: DataType = - if (!pivotResult) { - inspectorToDataType(returnInspector) - } else { - // If pivotResult is true, we should take the element type out as the data type of this - // function. - inspectorToDataType(returnInspector) match { - case ArrayType(dt, _) => dt - case _ => - sys.error( - s"error resolve the data type of window function ${funcWrapper.functionClassName}") - } - } - - override def nullable: Boolean = true - - @transient - lazy val inputProjection = new InterpretedProjection(children) - - @transient - private var hiveEvaluatorBuffer: AggregationBuffer = _ - // Output buffer. - private var outputBuffer: Any = _ - - @transient - private lazy val inputDataTypes: Array[DataType] = children.map(_.dataType).toArray - - override def init(): Unit = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - // Reset the hiveEvaluatorBuffer and outputPosition - override def reset(): Unit = { - // We create a new aggregation buffer to workaround the bug in GenericUDAFRowNumber. - // Basically, GenericUDAFRowNumberEvaluator.reset calls RowNumberBuffer.init. - // However, RowNumberBuffer.init does not really reset this buffer. - hiveEvaluatorBuffer = evaluator.getNewAggregationBuffer - evaluator.reset(hiveEvaluatorBuffer) - } - - override def prepareInputParameters(input: InternalRow): AnyRef = { - wrap( - inputProjection(input), - inputInspectors, - new Array[AnyRef](children.length), - inputDataTypes) - } - - // Add input parameters for a single row. - override def update(input: AnyRef): Unit = { - evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) - } - - override def batchUpdate(inputs: Array[AnyRef]): Unit = { - var i = 0 - while (i < inputs.length) { - evaluator.iterate(hiveEvaluatorBuffer, inputs(i).asInstanceOf[Array[AnyRef]]) - i += 1 - } - } - - override def evaluate(): Unit = { - outputBuffer = unwrap(evaluator.evaluate(hiveEvaluatorBuffer), returnInspector) - } - - override def get(index: Int): Any = { - if (!pivotResult) { - // if pivotResult is false, we will get a single value for all rows in the frame. - outputBuffer - } else { - // if pivotResult is true, we will get a ArrayData having the same size with the size - // of the window frame. At here, we will return the result at the position of - // index in the output buffer. - outputBuffer.asInstanceOf[ArrayData].get(index, dataType) - } - } - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - override def newInstance(): WindowFunction = - new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -496,6 +273,7 @@ private[hive] case class HiveWindowFunction( * user defined aggregations, which have clean semantics even in a partitioned execution. */ private[hive] case class HiveGenericUDTF( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Generator with HiveInspectors with CodegenFallback { @@ -561,6 +339,8 @@ private[hive] case class HiveGenericUDTF( override def toString: String = { s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" } + + override def sql: String = s"$name(${children.map(_.sql).mkString(", ")})" } /** @@ -568,6 +348,7 @@ private[hive] case class HiveGenericUDTF( * performance a lot. */ private[hive] case class HiveUDAFFunction( + name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression], isUDAFBridgeRequired: Boolean = false, @@ -652,5 +433,9 @@ private[hive] case class HiveUDAFFunction( override def supportsPartial: Boolean = false override val dataType: DataType = inspectorToDataType(returnInspector) -} + override def sql(isDistinct: Boolean): String = { + val distinct = if (isDistinct) "DISTINCT " else " " + s"$name($distinct${children.map(_.sql).mkString(", ")})" + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala index 93c016b6c6c7..22182ba00986 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveWriterContainers.scala @@ -23,16 +23,17 @@ import java.util.Date import scala.collection.mutable import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.common.FileUtils import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.hadoop.hive.ql.exec.{FileSinkOperator, Utilities} import org.apache.hadoop.hive.ql.io.{HiveFileFormatUtils, HiveOutputFormat} import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.hadoop.io.Writable import org.apache.hadoop.mapred._ -import org.apache.hadoop.hive.common.FileUtils +import org.apache.hadoop.mapreduce.TaskType -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.{Logging, SerializableWritable, SparkHadoopWriter} +import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} @@ -46,9 +47,7 @@ import org.apache.spark.util.SerializableJobConf private[hive] class SparkHiveWriterContainer( jobConf: JobConf, fileSinkConf: FileSinkDesc) - extends Logging - with SparkHadoopMapRedUtil - with Serializable { + extends Logging with Serializable { private val now = new Date() private val tableDesc: TableDesc = fileSinkConf.getTableInfo @@ -68,8 +67,8 @@ private[hive] class SparkHiveWriterContainer( @transient private var writer: FileSinkOperator.RecordWriter = null @transient protected lazy val committer = conf.value.getOutputCommitter - @transient protected lazy val jobContext = newJobContext(conf.value, jID.value) - @transient private lazy val taskContext = newTaskAttemptContext(conf.value, taID.value) + @transient protected lazy val jobContext = new JobContextImpl(conf.value, jID.value) + @transient private lazy val taskContext = new TaskAttemptContextImpl(conf.value, taID.value) @transient private lazy val outputFormat = conf.value.getOutputFormat.asInstanceOf[HiveOutputFormat[AnyRef, Writable]] @@ -131,7 +130,7 @@ private[hive] class SparkHiveWriterContainer( jID = new SerializableWritable[JobID](SparkHadoopWriter.createJobID(now, jobId)) taID = new SerializableWritable[TaskAttemptID]( - new TaskAttemptID(new TaskID(jID.value, true, splitID), attemptID)) + new TaskAttemptID(new TaskID(jID.value, TaskType.MAP, splitID), attemptID)) } private def setConfParams() { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala index 0f9a1a6ef3b2..b91a14bdbcc4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileOperator.scala @@ -95,7 +95,7 @@ private[orc] object OrcFileOperator extends Logging { val fs = origPath.getFileSystem(conf) val path = origPath.makeQualified(fs.getUri, fs.getWorkingDirectory) val paths = SparkHadoopUtil.get.listLeafStatuses(fs, origPath) - .filterNot(_.isDir) + .filterNot(_.isDirectory) .map(_.getPath) .filterNot(_.getName.startsWith("_")) .filterNot(_.getName.startsWith(".")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala index 27193f54d3a9..99a232f74fac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFilters.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.hive.common.`type`.{HiveChar, HiveDecimal, HiveVarchar} -import org.apache.hadoop.hive.ql.io.sarg.{SearchArgumentFactory, SearchArgument} +import org.apache.hadoop.hive.ql.io.sarg.{SearchArgument, SearchArgumentFactory} import org.apache.hadoop.hive.ql.io.sarg.SearchArgument.Builder import org.apache.hadoop.hive.serde2.io.DateWritable @@ -26,15 +26,47 @@ import org.apache.spark.Logging import org.apache.spark.sql.sources._ /** - * It may be optimized by push down partial filters. But we are conservative here. - * Because if some filters fail to be parsed, the tree may be corrupted, - * and cannot be used anymore. + * Helper object for building ORC `SearchArgument`s, which are used for ORC predicate push-down. + * + * Due to limitation of ORC `SearchArgument` builder, we had to end up with a pretty weird double- + * checking pattern when converting `And`/`Or`/`Not` filters. + * + * An ORC `SearchArgument` must be built in one pass using a single builder. For example, you can't + * build `a = 1` and `b = 2` first, and then combine them into `a = 1 AND b = 2`. This is quite + * different from the cases in Spark SQL or Parquet, where complex filters can be easily built using + * existing simpler ones. + * + * The annoying part is that, `SearchArgument` builder methods like `startAnd()`, `startOr()`, and + * `startNot()` mutate internal state of the builder instance. This forces us to translate all + * convertible filters with a single builder instance. However, before actually converting a filter, + * we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible filter is + * found, we may already end up with a builder whose internal state is inconsistent. + * + * For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and then + * try to convert its children. Say we convert `left` child successfully, but find that `right` + * child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is inconsistent + * now. + * + * The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their + * children with brand new builders, and only do the actual conversion with the right builder + * instance when the children are proven to be convertible. + * + * P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. Usage of + * builder methods mentioned above can only be found in test code, where all tested filters are + * known to be convertible. */ private[orc] object OrcFilters extends Logging { def createFilter(filters: Array[Filter]): Option[SearchArgument] = { + // First, tries to convert each filter individually to see whether it's convertible, and then + // collect all convertible ones to build the final `SearchArgument`. + val convertibleFilters = for { + filter <- filters + _ <- buildSearchArgument(filter, SearchArgumentFactory.newBuilder()) + } yield filter + for { - // Combines all filters with `And`s to produce a single conjunction predicate - conjunction <- filters.reduceOption(And) + // Combines all convertible filters using `And` to produce a single conjunction + conjunction <- convertibleFilters.reduceOption(And) // Then tries to build a single ORC `SearchArgument` for the conjunction predicate builder <- buildSearchArgument(conjunction, SearchArgumentFactory.newBuilder()) } yield builder.build() @@ -50,46 +82,22 @@ private[orc] object OrcFilters extends Logging { case _ => false } - // lian: I probably missed something here, and had to end up with a pretty weird double-checking - // pattern when converting `And`/`Or`/`Not` filters. - // - // The annoying part is that, `SearchArgument` builder methods like `startAnd()` `startOr()`, - // and `startNot()` mutate internal state of the builder instance. This forces us to translate - // all convertible filters with a single builder instance. However, before actually converting a - // filter, we've no idea whether it can be recognized by ORC or not. Thus, when an inconvertible - // filter is found, we may already end up with a builder whose internal state is inconsistent. - // - // For example, to convert an `And` filter with builder `b`, we call `b.startAnd()` first, and - // then try to convert its children. Say we convert `left` child successfully, but find that - // `right` child is inconvertible. Alas, `b.startAnd()` call can't be rolled back, and `b` is - // inconsistent now. - // - // The workaround employed here is that, for `And`/`Or`/`Not`, we first try to convert their - // children with brand new builders, and only do the actual conversion with the right builder - // instance when the children are proven to be convertible. - // - // P.S.: Hive seems to use `SearchArgument` together with `ExprNodeGenericFuncDesc` only. - // Usage of builder methods mentioned above can only be found in test code, where all tested - // filters are known to be convertible. - expression match { case And(left, right) => - val tryLeft = buildSearchArgument(left, newBuilder) - val tryRight = buildSearchArgument(right, newBuilder) - - val conjunction = for { - _ <- tryLeft - _ <- tryRight + // At here, it is not safe to just convert one side if we do not understand the + // other side. Here is an example used to explain the reason. + // Let's say we have NOT(a = 2 AND b in ('1')) and we do not understand how to + // convert b in ('1'). If we only convert a = 2, we will end up with a filter + // NOT(a = 2), which will generate wrong results. + // Pushing one side of AND down is only safe to do at the top level. + // You can see ParquetRelation's initializeLocalJobFunc method as an example. + for { + _ <- buildSearchArgument(left, newBuilder) + _ <- buildSearchArgument(right, newBuilder) lhs <- buildSearchArgument(left, builder.startAnd()) rhs <- buildSearchArgument(right, lhs) } yield rhs.end() - // For filter `left AND right`, we can still push down `left` even if `right` is not - // convertible, and vice versa. - conjunction - .orElse(tryLeft.flatMap(_ => buildSearchArgument(left, builder))) - .orElse(tryRight.flatMap(_ => buildSearchArgument(right, builder))) - case Or(left, right) => for { _ <- buildSearchArgument(left, newBuilder) @@ -104,6 +112,10 @@ private[orc] object OrcFilters extends Logging { negate <- buildSearchArgument(child, builder.startNot()) } yield negate.end() + // NOTE: For all case branches dealing with leaf predicates below, the additional `startAnd()` + // call is mandatory. ORC `SearchArgument` builder requires that all leaf predicates must be + // wrapped by a "parent" predicate (`And`, `Or`, or `Not`). + case EqualTo(attribute, value) if isSearchableLiteral(value) => Some(builder.startAnd().equals(attribute, value).end()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala index 1136670b7a0e..14fa152c2331 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcRelation.scala @@ -28,24 +28,22 @@ import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspect import org.apache.hadoop.hive.serde2.typeinfo.{StructTypeInfo, TypeInfoUtils} import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{InputFormat => MapRedInputFormat, JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat import org.apache.spark.Logging import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.mapred.SparkHadoopMapRedUtil import org.apache.spark.rdd.{HadoopRDD, RDD} +import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.datasources.PartitionSpec +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.hive.{HiveContext, HiveInspectors, HiveMetastoreTypes, HiveShim} import org.apache.spark.sql.sources.{Filter, _} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.util.SerializableConfiguration -private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourceRegister { +private[sql] class DefaultSource extends BucketedHadoopFsRelationProvider with DataSourceRegister { override def shortName(): String = "orc" @@ -54,20 +52,22 @@ private[sql] class DefaultSource extends HadoopFsRelationProvider with DataSourc paths: Array[String], dataSchema: Option[StructType], partitionColumns: Option[StructType], + bucketSpec: Option[BucketSpec], parameters: Map[String, String]): HadoopFsRelation = { assert( sqlContext.isInstanceOf[HiveContext], "The ORC data source can only be used with HiveContext.") - new OrcRelation(paths, dataSchema, None, partitionColumns, parameters)(sqlContext) + new OrcRelation(paths, dataSchema, None, partitionColumns, bucketSpec, parameters)(sqlContext) } } private[orc] class OrcOutputWriter( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext) - extends OutputWriter with SparkHadoopMapRedUtil with HiveInspectors { + extends OutputWriter with HiveInspectors { private val serializer = { val table = new Properties() @@ -77,7 +77,7 @@ private[orc] class OrcOutputWriter( }.mkString(":")) val serde = new OrcSerde - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration serde.initialize(configuration, table) serde } @@ -99,11 +99,12 @@ private[orc] class OrcOutputWriter( private lazy val recordWriter: RecordWriter[NullWritable, Writable] = { recordWriterInstantiated = true - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val conf = context.getConfiguration val uniqueWriteJobId = conf.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val partition = taskAttemptId.getTaskID.getId - val filename = f"part-r-$partition%05d-$uniqueWriteJobId.orc" + val bucketString = bucketId.map(id => f"-$id%05d").getOrElse("") + val filename = f"part-r-$partition%05d-$uniqueWriteJobId$bucketString.orc" new OrcOutputFormat().getRecordWriter( new Path(path, filename).getFileSystem(conf), @@ -155,6 +156,7 @@ private[sql] class OrcRelation( maybeDataSchema: Option[StructType], maybePartitionSpec: Option[PartitionSpec], override val userDefinedPartitionColumns: Option[StructType], + override val bucketSpec: Option[BucketSpec], parameters: Map[String, String])( @transient val sqlContext: SQLContext) extends HadoopFsRelation(maybePartitionSpec, parameters) @@ -171,6 +173,7 @@ private[sql] class OrcRelation( maybeDataSchema, maybePartitionSpec, maybePartitionSpec.map(_.partitionColumns), + None, parameters)(sqlContext) } @@ -207,8 +210,8 @@ private[sql] class OrcRelation( OrcTableScan(output, this, filters, inputPaths).execute() } - override def prepareJobForWrite(job: Job): OutputWriterFactory = { - SparkHadoopUtil.get.getConfigurationFromJobContext(job) match { + override def prepareJobForWrite(job: Job): BucketedOutputWriterFactory = { + job.getConfiguration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) case conf => @@ -218,12 +221,13 @@ private[sql] class OrcRelation( classOf[MapRedOutputFormat[_, _]]) } - new OutputWriterFactory { + new BucketedOutputWriterFactory { override def newInstance( path: String, + bucketId: Option[Int], dataSchema: StructType, context: TaskAttemptContext): OutputWriter = { - new OrcOutputWriter(path, dataSchema, context) + new OrcOutputWriter(path, bucketId, dataSchema, context) } } } @@ -289,8 +293,8 @@ private[orc] case class OrcTableScan( } def execute(): RDD[InternalRow] = { - val job = new Job(sqlContext.sparkContext.hadoopConfiguration) - val conf = SparkHadoopUtil.get.getConfigurationFromJobContext(job) + val job = Job.getInstance(sqlContext.sparkContext.hadoopConfiguration) + val conf = job.getConfiguration // Tries to push down filters if ORC filter push-down is enabled if (sqlContext.conf.orcFilterPushDown) { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 97792549bb7a..d26cb4847906 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -29,14 +29,17 @@ import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.spark.sql.{SQLContext, SQLConf} +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.sql.{SQLConf, SQLContext} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.util.{ShutdownHookManager, Utils} -import org.apache.spark.{SparkConf, SparkContext} // SPARK-3729: Test key required to check for initialization errors with config. object TestHive @@ -410,7 +413,10 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { try { // HACK: Hive is too noisy by default. org.apache.log4j.LogManager.getCurrentLoggers.asScala.foreach { log => - log.asInstanceOf[org.apache.log4j.Logger].setLevel(org.apache.log4j.Level.WARN) + val logger = log.asInstanceOf[org.apache.log4j.Logger] + if (!logger.getName.contains("org.apache.spark")) { + logger.setLevel(org.apache.log4j.Level.WARN) + } } cacheManager.clearCache() @@ -448,6 +454,27 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { logError("FATAL ERROR: Failed to reset TestDB state.", e) } } + + @transient + override protected[sql] lazy val functionRegistry = new TestHiveFunctionRegistry( + org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), this.executionHive) +} + +private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: ClientWrapper) + extends HiveFunctionRegistry(fr, client) { + + private val removedFunctions = + collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))] + + def unregisterFunction(name: String): Unit = { + fr.functionBuilders.remove(name).foreach(f => removedFunctions += name -> f) + } + + def restore(): Unit = { + removedFunctions.foreach { + case (name, (info, builder)) => fr.registerFunction(name, info, builder) + } + } } private[hive] object TestHiveContext { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 99478e82d419..9b37dd110376 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive import java.io.File +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.sql.execution.columnar.InMemoryColumnarTableScan import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.storage.RDDBlockId import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index cf737836939f..a2d283622ca5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -21,10 +21,10 @@ import scala.util.Try import org.scalatest.BeforeAndAfter +import org.apache.spark.sql.catalyst.parser.ParseDriver +import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.{AnalysisException, QueryTest} - class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter { import hiveContext.implicits._ @@ -117,8 +117,9 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { + def ast = ParseDriver.parse(query, hiveContext.conf) def parseTree = - Try(quietly(HiveQl.dumpTree(HiveQl.getAst(query)))).getOrElse("") + Try(quietly(ast.treeString)).getOrElse("") test(name) { val error = intercept[AnalysisException] { @@ -140,10 +141,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd val expectedStart = line.indexOf(token) val actualStart = error.startPosition.getOrElse { - fail( - s"start not returned for error on token $token\n" + - HiveQl.dumpTree(HiveQl.getAst(query)) - ) + fail(s"start not returned for error on token $token\n${ast.treeString}") } assert(expectedStart === actualStart, s"""Incorrect start position. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala new file mode 100644 index 000000000000..3a6eb57add4e --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import java.sql.Timestamp + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.{If, Literal} + +class ExpressionSQLBuilderSuite extends SQLBuilderTest { + test("literal") { + checkSQL(Literal("foo"), "\"foo\"") + checkSQL(Literal("\"foo\""), "\"\\\"foo\\\"\"") + checkSQL(Literal(1: Byte), "CAST(1 AS TINYINT)") + checkSQL(Literal(2: Short), "CAST(2 AS SMALLINT)") + checkSQL(Literal(4: Int), "4") + checkSQL(Literal(8: Long), "CAST(8 AS BIGINT)") + checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") + checkSQL(Literal(2.5D), "2.5") + checkSQL( + Literal(Timestamp.valueOf("2016-01-01 00:00:00")), + "TIMESTAMP('2016-01-01 00:00:00.0')") + // TODO tests for decimals + } + + test("binary comparisons") { + checkSQL('a.int === 'b.int, "(`a` = `b`)") + checkSQL('a.int <=> 'b.int, "(`a` <=> `b`)") + checkSQL('a.int !== 'b.int, "(NOT (`a` = `b`))") + + checkSQL('a.int < 'b.int, "(`a` < `b`)") + checkSQL('a.int <= 'b.int, "(`a` <= `b`)") + checkSQL('a.int > 'b.int, "(`a` > `b`)") + checkSQL('a.int >= 'b.int, "(`a` >= `b`)") + + checkSQL('a.int in ('b.int, 'c.int), "(`a` IN (`b`, `c`))") + checkSQL('a.int in (1, 2), "(`a` IN (1, 2))") + + checkSQL('a.int.isNull, "(`a` IS NULL)") + checkSQL('a.int.isNotNull, "(`a` IS NOT NULL)") + } + + test("logical operators") { + checkSQL('a.boolean && 'b.boolean, "(`a` AND `b`)") + checkSQL('a.boolean || 'b.boolean, "(`a` OR `b`)") + checkSQL(!'a.boolean, "(NOT `a`)") + checkSQL(If('a.boolean, 'b.int, 'c.int), "(IF(`a`, `b`, `c`))") + } + + test("arithmetic expressions") { + checkSQL('a.int + 'b.int, "(`a` + `b`)") + checkSQL('a.int - 'b.int, "(`a` - `b`)") + checkSQL('a.int * 'b.int, "(`a` * `b`)") + checkSQL('a.int / 'b.int, "(`a` / `b`)") + checkSQL('a.int % 'b.int, "(`a` % `b`)") + + checkSQL(-'a.int, "(-`a`)") + checkSQL(-('a.int + 'b.int), "(-(`a` + `b`))") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala index 9864acf76526..35e433964da9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameAnalyticsSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.hive +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.sql.{DataFrame, QueryTest, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.scalatest.BeforeAndAfterAll // TODO ideally we should put the test suite into the package `sql`, as // `hive` package is optional in compiling, however, `SQLContext.sql` doesn't diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala index f621367eb553..63cf5030ab8b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDataFrameJoinSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton class HiveDataFrameJoinSuite extends QueryTest with TestHiveSingleton { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala index 8bb9058cd74e..3b867bbfa181 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveInspectorSuite.scala @@ -30,7 +30,7 @@ import org.apache.hadoop.io.LongWritable import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.util.{MapData, GenericArrayData, ArrayBasedMapData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ import org.apache.spark.sql.Row diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index d63f3d399652..14a83d53904a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{QueryTest, Row, SaveMode, SQLConf} import org.apache.spark.sql.hive.client.{ExternalTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils} import org.apache.spark.sql.types.{DecimalType, StringType, StructType} -import org.apache.spark.sql.{SQLConf, QueryTest, Row, SaveMode} class HiveMetastoreCatalogSuite extends SparkFunSuite with TestHiveSingleton { import hiveContext.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala index 5596ec6882ea..7841ffe5e03d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.execution.datasources.parquet.ParquetTest import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.{QueryTest, Row} case class Cases(lower: String, UPPER: String) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index a330362b4e1d..f4a1a1742248 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -18,15 +18,14 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.plans.logical.Generate import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.JsonTuple +import org.apache.spark.sql.catalyst.plans.logical.Generate import org.apache.spark.sql.hive.client.{ExternalTable, HiveColumn, HiveTable, ManagedTable} - class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { private def extractTableDesc(sql: String): (HiveTable, Boolean) = { HiveQl.createPlan(sql).collect { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 53185fd7751e..8932ce9503a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -29,7 +29,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.SpanSugar._ import org.apache.spark._ -import org.apache.spark.sql.{SQLContext, QueryTest} +import org.apache.spark.sql.{QueryTest, SQLContext} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer @@ -87,7 +87,7 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } - test("SPARK-8489: MissingRequirementError during reflection") { + ignore("SPARK-8489: MissingRequirementError during reflection") { // This test uses a pre-built jar to test SPARK-8489. In a nutshell, this test creates // a HiveContext and uses it to create a data frame from an RDD using reflection. // Before the fix in SPARK-8470, this results in a MissingRequirementError because @@ -103,7 +103,7 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } - ignore("SPARK-9757 Persist Parquet relation with decimal column") { + test("SPARK-9757 Persist Parquet relation with decimal column") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( "--class", SPARK_9757.getClass.getName.stripSuffix("$"), @@ -363,7 +363,7 @@ object SPARK_11009 extends QueryTest { val df = sqlContext.range(1 << 20) val df2 = df.select((df("id") % 1000).alias("A"), (df("id") / 1000).alias("B")) val ws = Window.partitionBy(df2("A")).orderBy(df2("B")) - val df3 = df2.select(df2("A"), df2("B"), rowNumber().over(ws).alias("rn")).filter("rn < 0") + val df3 = df2.select(df2("A"), df2("B"), row_number().over(ws).alias("rn")).filter("rn < 0") if (df3.rdd.count() != 0) { throw new Exception("df3 should have 0 output row.") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 81ee9ba71beb..da7303c79106 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -22,8 +22,8 @@ import java.io.File import org.apache.hadoop.hive.conf.HiveConf import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql.execution.QueryExecutionException import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 183aca29cf98..a94f7053c39f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive import org.scalatest.BeforeAndAfterAll -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.QueryTest import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.hive.test.TestHiveSingleton class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterAll { import hiveContext._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala new file mode 100644 index 000000000000..0e81acf532a0 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.functions._ + +class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sqlContext.range(10).write.saveAsTable("t0") + + sqlContext + .range(10) + .select('id as 'key, concat(lit("val_"), 'id) as 'value) + .write + .saveAsTable("t1") + + sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + } + + override protected def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS t0") + sql("DROP TABLE IF EXISTS t1") + sql("DROP TABLE IF EXISTS t2") + } + + private def checkHiveQl(hiveQl: String): Unit = { + val df = sql(hiveQl) + val convertedSQL = new SQLBuilder(df).toSQL + + if (convertedSQL.isEmpty) { + fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin) + } + + val sqlString = convertedSQL.get + try { + checkAnswer(sql(sqlString), df) + } catch { case cause: Throwable => + fail( + s"""Failed to execute converted SQL string or got wrong answer: + | + |# Converted SQL query string: + |$sqlString + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin, + cause) + } + } + + test("in") { + checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)") + } + + test("aggregate function in having clause") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0") + } + + test("aggregate function in order by clause") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY MAX(key)") + } + + // TODO Fix name collision introduced by ResolveAggregateFunction analysis rule + // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into + // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query + // execution since these aliases have different expression ID. But this introduces name collision + // when converting resolved plans back to SQL query strings as expression IDs are stripped. + ignore("aggregate function in order by clause with multiple order keys") { + checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key ORDER BY key, MAX(key)") + } + + test("type widening in union") { + checkHiveQl("SELECT id FROM t0 UNION ALL SELECT CAST(id AS INT) AS id FROM t0") + } + + test("case") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM t0") + } + + test("case with else") { + checkHiveQl("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM t0") + } + + test("case with key") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM t0") + } + + test("case with key and else") { + checkHiveQl("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM t0") + } + + test("select distinct without aggregate functions") { + checkHiveQl("SELECT DISTINCT id FROM t0") + } + + test("cluster by") { + checkHiveQl("SELECT id FROM t0 CLUSTER BY id") + } + + test("distribute by") { + checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id") + } + + test("distribute by with sort by") { + checkHiveQl("SELECT id FROM t0 DISTRIBUTE BY id SORT BY id") + } + + test("distinct aggregation") { + checkHiveQl("SELECT COUNT(DISTINCT id) FROM t0") + } + + // TODO Enable this + // Query plans transformed by DistinctAggregationRewriter are not recognized yet + ignore("distinct and non-distinct aggregation") { + checkHiveQl("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM t2 GROUP BY a") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index f74eb1500b98..202851ae1366 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -17,20 +17,20 @@ package org.apache.spark.sql.hive -import java.io.{IOException, File} +import java.io.{File, IOException} import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.fs.Path import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.datasources.LogicalRelation +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.client.{HiveTable, ManagedTable} import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.util.Utils /** @@ -707,6 +707,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv tableIdent = TableIdentifier("wide_schema"), userSpecifiedSchema = Some(schema), partitionColumns = Array.empty[String], + bucketSpec = None, provider = "json", options = Map("path" -> "just a dummy path"), isExternal = false) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index f16c257ab5ab..c2c896e5f61b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.hive +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils -import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { private lazy val df = sqlContext.range(10).coalesce(1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala index 49aab85cf1aa..4a73153a80ee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala @@ -21,8 +21,8 @@ import java.sql.Timestamp import org.apache.hadoop.hive.conf.HiveConf -import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest import org.apache.spark.sql.{Row, SQLConf} +import org.apache.spark.sql.execution.datasources.parquet.ParquetCompatibilityTest import org.apache.spark.sql.hive.test.TestHiveSingleton class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest with TestHiveSingleton { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala index f542a5a02508..f49ee690ac04 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/QueryPartitionSuite.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.hive import com.google.common.io.Files -import org.apache.spark.util.Utils -import org.apache.spark.sql.{QueryTest, _} +import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils class QueryPartitionSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext.implicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala new file mode 100644 index 000000000000..cf4a3fdd8880 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.{DataFrame, QueryTest} + +abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { + protected def checkSQL(e: Expression, expectedSQL: String): Unit = { + val actualSQL = e.sql + try { + assert(actualSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following expression: + | + |${e.prettyName} + | + |$cause + """.stripMargin) + } + } + + protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = { + val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL + + if (maybeSQL.isEmpty) { + fail( + s"""Cannot convert the following logical query plan to SQL: + | + |${plan.treeString} + """.stripMargin) + } + + val actualSQL = maybeSQL.get + + try { + assert(actualSQL === expectedSQL) + } catch { + case cause: Throwable => + fail( + s"""Wrong SQL generated for the following logical query plan: + | + |${plan.treeString} + | + |$cause + """.stripMargin) + } + + checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan)) + } + + protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { + checkSQL(df.queryExecution.analyzed, expectedSQL) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index f775f1e95587..78f74cdc19dd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive import scala.reflect.ClassTag -import org.apache.spark.sql.{Row, SQLConf, QueryTest} +import org.apache.spark.sql.{QueryTest, Row, SQLConf} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 502b240f3650..ff10a251f3b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -21,10 +21,10 @@ import java.io.File import org.apache.hadoop.util.VersionInfo -import org.apache.spark.sql.hive.HiveContext import org.apache.spark.{Logging, SparkFunSuite} -import org.apache.spark.sql.catalyst.expressions.{NamedExpression, Literal, AttributeReference, EqualTo} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal, NamedExpression} import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.types.IntegerType import org.apache.spark.tags.ExtendedHiveTest import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala index e38d1eb5779f..f5cd73d45ed7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ConcurrentHiveSuite.scala @@ -17,9 +17,10 @@ package org.apache.spark.sql.hive.execution +import org.scalatest.BeforeAndAfterAll + import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.sql.hive.test.TestHiveContext -import org.scalatest.BeforeAndAfterAll class ConcurrentHiveSuite extends SparkFunSuite with BeforeAndAfterAll { ignore("multiple instances not supported") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 4455430aa727..57358a07840e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -27,9 +27,10 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution.{SetCommand, ExplainCommand} import org.apache.spark.sql.execution.datasources.DescribeCommand +import org.apache.spark.sql.execution.{ExplainCommand, SetCommand} import org.apache.spark.sql.hive.test.TestHive +import org.apache.spark.sql.hive.{InsertIntoHiveTable => LogicalInsertIntoHiveTable, SQLBuilder} /** * Allows the creations of tests that execute the same query against both hive @@ -130,6 +131,28 @@ abstract class HiveComparisonTest new java.math.BigInteger(1, digest.digest).toString(16) } + /** Used for testing [[SQLBuilder]] */ + private var numConvertibleQueries: Int = 0 + private var numTotalQueries: Int = 0 + + override protected def afterAll(): Unit = { + logInfo({ + val percentage = if (numTotalQueries > 0) { + numConvertibleQueries.toDouble / numTotalQueries * 100 + } else { + 0D + } + + s"""SQLBuiler statistics: + |- Total query number: $numTotalQueries + |- Number of convertible queries: $numConvertibleQueries + |- Percentage of convertible queries: $percentage% + """.stripMargin + }) + + super.afterAll() + } + protected def prepareAnswer( hiveQuery: TestHive.type#QueryExecution, answer: Seq[String]): Seq[String] = { @@ -372,8 +395,49 @@ abstract class HiveComparisonTest // Run w/ catalyst val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) => - val query = new TestHive.QueryExecution(queryString) - try { (query, prepareAnswer(query, query.stringResult())) } catch { + var query: TestHive.QueryExecution = null + try { + query = { + val originalQuery = new TestHive.QueryExecution(queryString) + val containsCommands = originalQuery.analyzed.collectFirst { + case _: Command => () + case _: LogicalInsertIntoHiveTable => () + }.nonEmpty + + if (containsCommands) { + originalQuery + } else { + numTotalQueries += 1 + new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => + numConvertibleQueries += 1 + logInfo( + s""" + |### Running SQL generation round-trip test {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + | + |Generated SQL: + |$sql + |}}} + """.stripMargin.trim) + new TestHive.QueryExecution(sql) + }.getOrElse { + logInfo( + s""" + |### Cannot convert the following logical plan back to SQL {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + |}}} + """.stripMargin.trim) + originalQuery + } + } + } + + (query, prepareAnswer(query, query.stringResult())) + } catch { case e: Throwable => val errorMessage = s""" diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index a7b7ad009391..b7ef5d1db729 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -18,8 +18,8 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.QueryTest -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils /** * A set of tests that validates support for Hive Explain command. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala index 0d4c7f86b315..9bdc24162b73 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveOperatorQueryableSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.hive.execution -import org.apache.spark.sql.{Row, QueryTest} +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 8a5acaf3e10b..4659d745fe78 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -26,14 +26,14 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.scalatest.BeforeAndAfter +import org.apache.spark.{SparkException, SparkFiles} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} -import org.apache.spark.sql.{AnalysisException, DataFrame, Row} -import org.apache.spark.{SparkException, SparkFiles} +import org.apache.spark.sql.hive.test.TestHive._ case class TestData(a: Int, b: String) @@ -60,6 +60,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) sql("DROP TEMPORARY FUNCTION udtf_count2") + super.afterAll() } test("SPARK-4908: concurrent hive native commands") { @@ -387,9 +388,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("partitioned table scan", "SELECT ds, hr, key, value FROM srcpart") - createQueryTest("hash", - "SELECT hash('test') FROM src LIMIT 1") - createQueryTest("create table as", """ |CREATE TABLE createdtable AS SELECT * FROM src; @@ -790,6 +788,24 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assert(sql("select key from src having key > 490").collect().size < 100) } + test("union/except/intersect") { + assertResult(Array(Row(1), Row(1))) { + sql("select 1 as a union all select 1 as a").collect() + } + assertResult(Array(Row(1))) { + sql("select 1 as a union distinct select 1 as a").collect() + } + assertResult(Array(Row(1))) { + sql("select 1 as a union select 1 as a").collect() + } + assertResult(Array()) { + sql("select 1 as a except select 1 as a").collect() + } + assertResult(Array(Row(1))) { + sql("select 1 as a intersect select 1 as a").collect() + } + } + test("SPARK-5383 alias for udfs with multi output columns") { assert( sql("select stack(2, key, value, key, value) as (a, b) from src limit 5") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala index b08db6de2d2f..dd13b8392880 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveResolutionSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, jsonRDD, sql} +import org.apache.spark.sql.hive.test.TestHive.{read, sparkContext, sql} import org.apache.spark.sql.hive.test.TestHive.implicits._ case class Nested(a: Int, B: Int) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala index 2209fc2f30a3..b0c0dcbe5c25 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveTableScanSuite.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHive import org.apache.spark.sql.hive.test.TestHive._ import org.apache.spark.sql.hive.test.TestHive.implicits._ - import org.apache.spark.util.Utils class HiveTableScanSuite extends HiveComparisonTest { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 9deb1a6db15a..c5ff8825abd7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -17,23 +17,23 @@ package org.apache.spark.sql.hive.execution -import java.io.{PrintWriter, File, DataInput, DataOutput} +import java.io.{DataInput, DataOutput, File, PrintWriter} import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.udf.UDAFPercentile -import org.apache.hadoop.hive.ql.udf.generic.{GenericUDFOPAnd, GenericUDTFExplode, GenericUDAFAverage, GenericUDF} +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF, GenericUDFOPAnd, GenericUDTFExplode} import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory import org.apache.hadoop.io.Writable -import org.apache.spark.sql.test.SQLTestUtils + import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils - case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) // Case classes for the custom UDF's. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 3427152b2da0..593fac2c3281 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -22,13 +22,13 @@ import java.sql.{Date, Timestamp} import scala.collection.JavaConverters._ import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{TableIdentifier, DefaultParserDialect} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, EliminateSubQueries} +import org.apache.spark.sql.catalyst.{DefaultParserDialect, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{EliminateSubQueries, FunctionRegistry} import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation +import org.apache.spark.sql.hive.{HiveContext, HiveQLDialect, MetastoreRelation} +import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -915,6 +915,27 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { ).map(i => Row(i._1, i._2, i._3, i._4))) } + test("window function: distinct should not be silently ignored") { + val data = Seq( + WindowData(1, "a", 5), + WindowData(2, "a", 6), + WindowData(3, "b", 7), + WindowData(4, "b", 8), + WindowData(5, "c", 9), + WindowData(6, "c", 10) + ) + sparkContext.parallelize(data).toDF().registerTempTable("windowData") + + val e = intercept[AnalysisException] { + sql( + """ + |select month, area, product, sum(distinct product + 1) over (partition by 1 order by 2) + |from windowData + """.stripMargin) + } + assert(e.getMessage.contains("Distinct window functions are not supported")) + } + test("window function: expressions in arguments of a window functions") { val data = Seq( WindowData(1, "a", 5), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala index 7cfdb886b585..8f163f27c94c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.execution.{UnaryNode, SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryNode} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.types.StringType diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala new file mode 100644 index 000000000000..ea82b8c45969 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/WindowQuerySuite.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql._ +import org.apache.spark.sql.hive.test.{TestHive, TestHiveSingleton} +import org.apache.spark.sql.test.SQLTestUtils + +/** + * This suite contains a couple of Hive window tests which fail in the typical setup due to tiny + * numerical differences or due semantic differences between Hive and Spark. + */ +class WindowQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + + override def beforeAll(): Unit = { + sql("DROP TABLE IF EXISTS part") + sql( + """ + |CREATE TABLE part( + | p_partkey INT, + | p_name STRING, + | p_mfgr STRING, + | p_brand STRING, + | p_type STRING, + | p_size INT, + | p_container STRING, + | p_retailprice DOUBLE, + | p_comment STRING) + """.stripMargin) + val testData1 = TestHive.getHiveFile("data/files/part_tiny.txt").getCanonicalPath + sql( + s""" + |LOAD DATA LOCAL INPATH '$testData1' overwrite into table part + """.stripMargin) + } + + override def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS part") + } + + test("windowing.q -- 15. testExpressions") { + // Moved because: + // - Spark uses a different default stddev (sample instead of pop) + // - Tiny numerical differences in stddev results. + // - Different StdDev behavior when n=1 (NaN instead of 0) + checkAnswer(sql(s""" + |select p_mfgr,p_name, p_size, + |rank() over(distribute by p_mfgr sort by p_name) as r, + |dense_rank() over(distribute by p_mfgr sort by p_name) as dr, + |cume_dist() over(distribute by p_mfgr sort by p_name) as cud, + |percent_rank() over(distribute by p_mfgr sort by p_name) as pr, + |ntile(3) over(distribute by p_mfgr sort by p_name) as nt, + |count(p_size) over(distribute by p_mfgr sort by p_name) as ca, + |avg(p_size) over(distribute by p_mfgr sort by p_name) as avg, + |stddev(p_size) over(distribute by p_mfgr sort by p_name) as st, + |first_value(p_size % 5) over(distribute by p_mfgr sort by p_name) as fv, + |last_value(p_size) over(distribute by p_mfgr sort by p_name) as lv, + |first_value(p_size) over w1 as fvW1 + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + """.stripMargin), + // scalastyle:off + Seq( + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 1, 1, 0.3333333333333333, 0.0, 1, 2, 2.0, 0.0, 2, 2, 2), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 3, 2, 0.5, 0.4, 2, 3, 12.666666666666666, 18.475208614068027, 2, 34, 2), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 4, 3, 0.6666666666666666, 0.6, 2, 4, 11.0, 15.448840301675292, 2, 6, 2), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 5, 4, 0.8333333333333334, 0.8, 3, 5, 14.4, 15.388307249337076, 2, 28, 34), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 6, 5, 1.0, 1.0, 3, 6, 19.0, 17.787636155487327, 2, 42, 6), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 1, 1, 0.2, 0.0, 1, 1, 14.0, Double.NaN, 4, 14, 14), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 2, 2, 0.4, 0.25, 1, 2, 27.0, 18.384776310850235, 4, 40, 14), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 3, 3, 0.6, 0.5, 2, 3, 18.666666666666668, 19.42506971244462, 4, 2, 14), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 4, 4, 0.8, 0.75, 2, 4, 20.25, 16.17353805861084, 4, 25, 40), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 5, 5, 1.0, 1.0, 3, 5, 19.8, 14.042791745233567, 4, 18, 2), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 1, 1, 0.2, 0.0, 1, 1, 17.0,Double.NaN, 2, 17, 17), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 2, 2, 0.4, 0.25, 1, 2, 15.5, 2.1213203435596424, 2, 14, 17), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 3, 3, 0.6, 0.5, 2, 3, 16.666666666666668, 2.516611478423583, 2, 19, 17), + Row("Manufacturer#3", "almond antique misty red olive", 1, 4, 4, 0.8, 0.75, 2, 4, 12.75, 8.098353742170895, 2, 1, 14), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 5, 5, 1.0, 1.0, 3, 5, 19.2, 16.037456157383566, 2, 45, 19), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 1, 1, 0.2, 0.0, 1, 1, 10.0, Double.NaN, 0, 10, 10), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 2, 2, 0.4, 0.25, 1, 2, 24.5, 20.506096654409877, 0, 39, 10), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 3, 3, 0.6, 0.5, 2, 3, 25.333333333333332, 14.571661996262929, 0, 27, 10), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 4, 4, 0.8, 0.75, 2, 4, 20.75, 15.01943185787443, 0, 7, 39), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 5, 5, 1.0, 1.0, 3, 5, 19.0, 13.583077707206124, 0, 12, 27), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 1, 1, 0.2, 0.0, 1, 1, 31.0, Double.NaN, 1, 31, 31), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 2, 2, 0.4, 0.25, 1, 2, 18.5, 17.67766952966369, 1, 6, 31), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 3, 3, 0.6, 0.5, 2, 3, 13.0, 15.716233645501712, 1, 2, 31), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 4, 4, 0.8, 0.75, 2, 4, 21.25, 20.902551678363736, 1, 46, 6), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 5, 5, 1.0, 1.0, 3, 5, 21.6, 18.1190507477627, 1, 23, 2))) + // scalastyle:on + } + + test("windowing.q -- 20. testSTATs") { + // Moved because: + // - Spark uses a different default stddev/variance (sample instead of pop) + // - Tiny numerical differences in aggregation results. + checkAnswer(sql(""" + |select p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + |from ( + |select p_mfgr,p_name, p_size, + |stddev_pop(p_retailprice) over w1 as sdev, + |stddev_pop(p_retailprice) over w1 as sdev_pop, + |collect_set(p_size) over w1 as uniq_size, + |var_pop(p_retailprice) over w1 as var, + |corr(p_size, p_retailprice) over w1 as cor, + |covar_pop(p_size, p_retailprice) over w1 as covarp + |from part + |window w1 as (distribute by p_mfgr sort by p_mfgr, p_name + | rows between 2 preceding and 2 following) + |) t lateral view explode(uniq_size) d as uniq_data + |order by p_mfgr,p_name, p_size, sdev, sdev_pop, uniq_data, var, cor, covarp + """.stripMargin), + // scalastyle:off + Seq( + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 2, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 6, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 258.10677784349247, 258.10677784349247, 34, 66619.10876874997, 0.811328754177887, 2801.7074999999995), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 273.70217881648085, 273.70217881648085, 2, 74912.88268888886, 1.0, 4128.782222222221), + Row("Manufacturer#1", "almond antique burnished rose metallic", 2, 273.70217881648085, 273.70217881648085, 34, 74912.88268888886, 1.0, 4128.782222222221), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 2, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 6, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 28, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique chartreuse lavender yellow", 34, 230.9015158547037, 230.9015158547037, 34, 53315.510023999974, 0.6956393773976641, 2210.7864), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 2, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 6, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 28, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 34, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond antique salmon chartreuse burlywood", 6, 202.73109328368943, 202.73109328368943, 42, 41099.89618399999, 0.6307859771012139, 2009.9536000000007), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 6, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 28, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 34, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine burnished black steel", 28, 121.60645179738611, 121.60645179738611, 42, 14788.129118749992, 0.2036684720435979, 331.1337500000004), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 6, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 28, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#1", "almond aquamarine pink moccasin thistle", 42, 96.57515864168516, 96.57515864168516, 42, 9326.761266666656, -1.4442181184933883E-4, -0.20666666666708502), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 2, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 14, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet chocolate turquoise", 14, 142.23631697518977, 142.23631697518977, 40, 20231.16986666666, -0.4936952655452319, -1113.7466666666658), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 2, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 14, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 25, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond antique violet turquoise frosted", 40, 137.7630649884068, 137.7630649884068, 40, 18978.662074999997, -0.5205630897335946, -1004.4812499999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 2, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 14, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 18, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 25, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine midnight light salmon", 2, 130.03972279269132, 130.03972279269132, 40, 16910.329504000005, -0.46908967495720255, -766.1791999999995), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 2, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 18, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 25, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine rose maroon antique", 25, 135.55100986344593, 135.55100986344593, 40, 18374.076275000018, -0.6091405874714462, -1128.1787499999987), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 2, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 18, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#2", "almond aquamarine sandy cyan gainsboro", 18, 156.44019460768035, 156.44019460768035, 25, 24473.534488888898, -0.9571686373491605, -1441.4466666666676), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 14, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 17, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique chartreuse khaki white", 17, 196.77422668858057, 196.77422668858057, 19, 38720.0962888889, 0.5557168646224995, 224.6944444444446), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 1, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 14, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 17, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique forest lavender goldenrod", 14, 275.1414418985261, 275.1414418985261, 19, 75702.81305000003, -0.6720833036576083, -1296.9000000000003), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 1, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 14, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 17, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 19, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique metallic orange dim", 19, 260.23473614412046, 260.23473614412046, 45, 67722.11789600001, -0.5703526513979519, -2129.0664), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 1, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 14, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 19, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique misty red olive", 1, 275.913996235693, 275.913996235693, 45, 76128.53331875002, -0.5774768996448021, -2547.7868749999993), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 1, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 19, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#3", "almond antique olive coral navajo", 45, 260.58159187137954, 260.58159187137954, 45, 67902.7660222222, -0.8710736366736884, -4099.731111111111), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 10, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 27, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique gainsboro frosted violet", 10, 170.1301188959661, 170.1301188959661, 39, 28944.25735555556, -0.6656975320098423, -1347.4777777777779), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 7, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 10, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 27, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond antique violet mint lemon", 39, 242.26834609323197, 242.26834609323197, 39, 58693.95151875002, -0.8051852719193339, -2537.328125), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 7, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 10, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 12, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 27, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine floral ivory bisque", 27, 234.10001662537323, 234.10001662537323, 39, 54802.81778400003, -0.6046935574240581, -1719.8079999999995), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 7, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 12, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 27, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond aquamarine yellow dodger mint", 7, 247.33427141977316, 247.33427141977316, 39, 61174.241818750015, -0.5508665654707869, -1719.0368749999975), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 7, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 12, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#4", "almond azure aquamarine papaya violet", 12, 283.33443305668936, 283.33443305668936, 27, 80278.4009555556, -0.7755740084632333, -1867.4888888888881), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 2, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 6, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique blue firebrick mint", 31, 83.69879024746344, 83.69879024746344, 31, 7005.487488888881, 0.3900430308728505, 418.9233333333353), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 2, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 6, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 31, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique medium spring khaki", 6, 316.68049612345885, 316.68049612345885, 46, 100286.53662500005, -0.7136129117761831, -4090.853749999999), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 2, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 6, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 23, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 31, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond antique sky peru orange", 2, 285.4050629824216, 285.4050629824216, 46, 81456.04997600004, -0.712858514567818, -3297.2011999999986), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 2, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 6, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 23, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond aquamarine dodger light gainsboro", 46, 285.43749038756283, 285.43749038756283, 46, 81474.56091875004, -0.9841287871533909, -4871.028125000002), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 2, 99807.08486666666, -0.9978877469246935, -5664.856666666666), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 23, 99807.08486666666, -0.9978877469246935, -5664.856666666666), + Row("Manufacturer#5", "almond azure blanched chiffon midnight", 23, 315.9225931564038, 315.9225931564038, 46, 99807.08486666666, -0.9978877469246935, -5664.856666666666))) + // scalastyle:on + } + + test("null arguments") { + checkAnswer(sql(""" + |select p_mfgr, p_name, p_size, + |sum(null) over(distribute by p_mfgr sort by p_name) as sum, + |avg(null) over(distribute by p_mfgr sort by p_name) as avg + |from part + """.stripMargin), + sql(""" + |select p_mfgr, p_name, p_size, + |null as sum, + |null as avg + |from part + """.stripMargin)) + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala new file mode 100644 index 000000000000..5afc7e77ab77 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcFilterSuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.orc + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.io.sarg.{PredicateLeaf, SearchArgument} + +import org.apache.spark.sql.{Column, DataFrame, QueryTest} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.planning.PhysicalOperation +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} + +/** + * A test suite that tests ORC filter API based filter pushdown optimization. + */ +class OrcFilterSuite extends QueryTest with OrcTest { + private def checkFilterPredicate( + df: DataFrame, + predicate: Predicate, + checker: (SearchArgument) => Unit): Unit = { + val output = predicate.collect { case a: Attribute => a }.distinct + val query = df + .select(output.map(e => Column(e)): _*) + .where(Column(predicate)) + + var maybeRelation: Option[OrcRelation] = None + val maybeAnalyzedPredicate = query.queryExecution.optimizedPlan.collect { + case PhysicalOperation(_, filters, LogicalRelation(orcRelation: OrcRelation, _)) => + maybeRelation = Some(orcRelation) + filters + }.flatten.reduceLeftOption(_ && _) + assert(maybeAnalyzedPredicate.isDefined, "No filter is analyzed from the given query") + + val (_, selectedFilters) = + DataSourceStrategy.selectFilters(maybeRelation.get, maybeAnalyzedPredicate.toSeq) + assert(selectedFilters.nonEmpty, "No filter is pushed down") + + val maybeFilter = OrcFilters.createFilter(selectedFilters.toArray) + assert(maybeFilter.isDefined, s"Couldn't generate filter predicate for $selectedFilters") + checker(maybeFilter.get) + } + + private def checkFilterPredicate + (predicate: Predicate, filterOperator: PredicateLeaf.Operator) + (implicit df: DataFrame): Unit = { + def checkComparisonOperator(filter: SearchArgument) = { + val operator = filter.getLeaves.asScala.head.getOperator + assert(operator === filterOperator) + } + checkFilterPredicate(df, predicate, checkComparisonOperator) + } + + private def checkFilterPredicate + (predicate: Predicate, stringExpr: String) + (implicit df: DataFrame): Unit = { + def checkLogicalOperator(filter: SearchArgument) = { + assert(filter.toString == stringExpr) + } + checkFilterPredicate(df, predicate, checkLogicalOperator) + } + + test("filter pushdown - boolean") { + withOrcDataFrame((true :: false :: Nil).map(b => Tuple1.apply(Option(b)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + } + } + + test("filter pushdown - integer") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - long") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toLong)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - float") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toFloat)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - double") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i.toDouble)))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === 1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> 1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < 2, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > 3, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= 1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= 4, PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal(1) === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal(1) <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal(2) > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal(3) < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(1) >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal(4) <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - string") { + withOrcDataFrame((1 to 4).map(i => Tuple1(i.toString))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + + checkFilterPredicate('_1 === "1", PredicateLeaf.Operator.EQUALS) + checkFilterPredicate('_1 <=> "1", PredicateLeaf.Operator.NULL_SAFE_EQUALS) + + checkFilterPredicate('_1 < "2", PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate('_1 > "3", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 <= "1", PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate('_1 >= "4", PredicateLeaf.Operator.LESS_THAN) + + checkFilterPredicate(Literal("1") === '_1, PredicateLeaf.Operator.EQUALS) + checkFilterPredicate(Literal("1") <=> '_1, PredicateLeaf.Operator.NULL_SAFE_EQUALS) + checkFilterPredicate(Literal("2") > '_1, PredicateLeaf.Operator.LESS_THAN) + checkFilterPredicate(Literal("3") < '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("1") >= '_1, PredicateLeaf.Operator.LESS_THAN_EQUALS) + checkFilterPredicate(Literal("4") <= '_1, PredicateLeaf.Operator.LESS_THAN) + } + } + + test("filter pushdown - binary") { + implicit class IntToBinary(int: Int) { + def b: Array[Byte] = int.toString.getBytes("UTF-8") + } + + withOrcDataFrame((1 to 4).map(i => Tuple1(i.b))) { implicit df => + checkFilterPredicate('_1.isNull, PredicateLeaf.Operator.IS_NULL) + } + } + + test("filter pushdown - combinations with logical operators") { + withOrcDataFrame((1 to 4).map(i => Tuple1(Option(i)))) { implicit df => + // Because `ExpressionTree` is not accessible at Hive 1.2.x, this should be checked + // in string form in order to check filter creation including logical operators + // such as `and`, `or` or `not`. So, this function uses `SearchArgument.toString()` + // to produce string expression and then compare it to given string expression below. + // This might have to be changed after Hive version is upgraded. + checkFilterPredicate( + '_1.isNotNull, + """leaf-0 = (IS_NULL _1) + |expr = (not leaf-0)""".stripMargin.trim + ) + checkFilterPredicate( + '_1 !== 1, + """leaf-0 = (EQUALS _1 1) + |expr = (not leaf-0)""".stripMargin.trim + ) + checkFilterPredicate( + !('_1 < 4), + """leaf-0 = (LESS_THAN _1 4) + |expr = (not leaf-0)""".stripMargin.trim + ) + checkFilterPredicate( + '_1 < 2 || '_1 > 3, + """leaf-0 = (LESS_THAN _1 2) + |leaf-1 = (LESS_THAN_EQUALS _1 3) + |expr = (or leaf-0 (not leaf-1))""".stripMargin.trim + ) + checkFilterPredicate( + '_1 < 2 && '_1 > 3, + """leaf-0 = (LESS_THAN _1 2) + |leaf-1 = (LESS_THAN_EQUALS _1 3) + |expr = (and leaf-0 (not leaf-1))""".stripMargin.trim + ) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala index 92043d66c914..e8a61123d18b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcHadoopFsRelationSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.hive.orc import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{Row, SQLConf} import org.apache.spark.sql.sources.HadoopFsRelationTest import org.apache.spark.sql.types._ @@ -60,4 +61,23 @@ class OrcHadoopFsRelationSuite extends HadoopFsRelationTest { "dataSchema" -> dataSchemaWithPartition.json)).format(dataSourceName).load()) } } + + test("SPARK-12218: 'Not' is included in ORC filter pushdown") { + import testImplicits._ + + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + withTempPath { dir => + val path = s"${dir.getCanonicalPath}/table1" + (1 to 5).map(i => (i, (i % 2).toString)).toDF("a", "b").write.orc(path) + + checkAnswer( + sqlContext.read.orc(path).where("not (a = 2) or not(b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + + checkAnswer( + sqlContext.read.orc(path).where("not (a = 2 and b in ('1'))"), + (1 to 5).map(i => Row(i, (i % 2).toString))) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala index 52e09f9496f0..6161412a4977 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcPartitionDiscoverySuite.scala @@ -22,8 +22,8 @@ import java.io.File import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag -import org.scalatest.BeforeAndAfterAll import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ import org.apache.spark.sql.hive.test.TestHiveSingleton diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 7a34cf731b4c..27ea3e804165 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -23,6 +23,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.sources._ case class OrcData(intField: Int, stringField: String) @@ -174,4 +175,33 @@ class OrcSourceSuite extends OrcSuite { |) """.stripMargin) } + + test("SPARK-12218 Converting conjunctions into ORC SearchArguments") { + // The `LessThan` should be converted while the `StringContains` shouldn't + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(Array( + LessThan("a", 10), + StringContains("b", "prefix") + )).get.toString + } + + // The `LessThan` should be converted while the whole inner `And` shouldn't + assertResult( + """leaf-0 = (LESS_THAN a 10) + |expr = leaf-0 + """.stripMargin.trim + ) { + OrcFilters.createFilter(Array( + LessThan("a", 10), + Not(And( + GreaterThan("a", 1), + StringContains("b", "prefix") + )) + )).get.toString + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala index 88a0ed511749..637c10611afc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala @@ -23,8 +23,8 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql._ -import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils private[sql] trait OrcTest extends SQLTestUtils with TestHiveSingleton { import testImplicits._ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 905eb7a3925b..2ceb83668190 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.hive import java.io.File import org.apache.spark.sql._ -import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} import org.apache.spark.sql.execution.{ExecutedCommand, PhysicalRDD} +import org.apache.spark.sql.execution.datasources.{InsertIntoDataSource, InsertIntoHadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.hive.execution.HiveTableScan import org.apache.spark.sql.hive.test.TestHiveSingleton -import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala new file mode 100644 index 000000000000..579da0291f29 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources + +import java.io.File + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.{AnalysisException, QueryTest} + +class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import testImplicits._ + + test("bucketed by non-existing column") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[AnalysisException](df.write.bucketBy(2, "k").saveAsTable("tt")) + } + + test("numBuckets not greater than 0 or less than 100000") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.bucketBy(0, "i").saveAsTable("tt")) + intercept[IllegalArgumentException](df.write.bucketBy(100000, "i").saveAsTable("tt")) + } + + test("specify sorting columns without bucketing columns") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.sortBy("j").saveAsTable("tt")) + } + + test("sorting by non-orderable column") { + val df = Seq("a" -> Map(1 -> 1), "b" -> Map(2 -> 2)).toDF("i", "j") + intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt")) + } + + test("write bucketed data to unsupported data source") { + val df = Seq(Tuple1("a"), Tuple1("b")).toDF("i") + intercept[AnalysisException](df.write.bucketBy(3, "i").format("text").saveAsTable("tt")) + } + + test("write bucketed data to non-hive-table or existing hive table") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").parquet("/tmp/path")) + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").json("/tmp/path")) + intercept[IllegalArgumentException](df.write.bucketBy(2, "i").insertInto("tt")) + } + + private val testFileName = """.*-(\d+)$""".r + private val otherFileName = """.*-(\d+)\..*""".r + private def getBucketId(fileName: String): Int = { + fileName match { + case testFileName(bucketId) => bucketId.toInt + case otherFileName(bucketId) => bucketId.toInt + } + } + + private def testBucketing( + dataDir: File, + source: String, + bucketCols: Seq[String], + sortCols: Seq[String] = Nil): Unit = { + val allBucketFiles = dataDir.listFiles().filterNot(f => + f.getName.startsWith(".") || f.getName.startsWith("_") + ) + val groupedBucketFiles = allBucketFiles.groupBy(f => getBucketId(f.getName)) + assert(groupedBucketFiles.size <= 8) + + for ((bucketId, bucketFiles) <- groupedBucketFiles) { + for (bucketFile <- bucketFiles) { + val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath) + .select((bucketCols ++ sortCols).map(col): _*) + + if (sortCols.nonEmpty) { + checkAnswer(df.sort(sortCols.map(col): _*), df.collect()) + } + + val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect() + + for (row <- rows) { + assert(row.isInstanceOf[UnsafeRow]) + val actualBucketId = (row.hashCode() % 8 + 8) % 8 + assert(actualBucketId == bucketId) + } + } + } + } + + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + + test("write bucketed data") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j", "k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, Seq("j", "k")) + } + } + } + } + + test("write bucketed data with sortBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .partitionBy("i") + .bucketBy(8, "j") + .sortBy("k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + for (i <- 0 until 5) { + testBucketing(new File(tableDir, s"i=$i"), source, Seq("j"), Seq("k")) + } + } + } + } + + test("write bucketed data without partitionBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .bucketBy(8, "i", "j") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + testBucketing(tableDir, source, Seq("i", "j")) + } + } + } + + test("write bucketed data without partitionBy with sortBy") { + for (source <- Seq("parquet", "json", "orc")) { + withTable("bucketed_table") { + df.write + .format(source) + .bucketBy(8, "i", "j") + .sortBy("k") + .saveAsTable("bucketed_table") + + val tableDir = new File(hiveContext.warehousePath, "bucketed_table") + testBucketing(tableDir, source, Seq("i", "j"), Seq("k")) + } + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala index dc0531a6d4bc..64c61a509254 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala @@ -18,12 +18,12 @@ package org.apache.spark.sql.sources import org.apache.hadoop.fs.Path + import org.apache.spark.SparkException import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils - class CommitFailureTestRelationSuite extends SQLTestUtils with TestHiveSingleton { // When committing a task, `CommitFailureTestSource` throws an exception for testing purpose. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala index b554d135e4b5..058c101eebb0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextHadoopFsRelationSuite.scala @@ -21,14 +21,13 @@ import java.io.File import org.apache.hadoop.fs.Path -import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.sql.{execution, Column, DataFrame, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, PredicateHelper} import org.apache.spark.sql.execution.{LogicalRDD, PhysicalRDD} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Column, DataFrame, Row, execution} import org.apache.spark.util.Utils class SimpleTextHadoopFsRelationSuite extends HadoopFsRelationTest with PredicateHelper { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala index 01960fd2901b..9fc437bf8815 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/SimpleTextRelation.scala @@ -22,15 +22,14 @@ import java.text.NumberFormat import com.google.common.base.Objects import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.io.{NullWritable, Text} -import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} import org.apache.hadoop.mapreduce.{Job, RecordWriter, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat, TextOutputFormat} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{sources, Row, SQLContext} +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, expressions} import org.apache.spark.sql.types.{DataType, StructType} -import org.apache.spark.sql.{Row, SQLContext, sources} /** * A simple example [[HadoopFsRelationProvider]]. @@ -53,9 +52,9 @@ class AppendingTextOutputFormat(outputFile: Path) extends TextOutputFormat[NullW numberFormat.setGroupingUsed(false) override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = { - val configuration = SparkHadoopUtil.get.getConfigurationFromJobContext(context) + val configuration = context.getConfiguration val uniqueWriteJobId = configuration.get("spark.sql.sources.writeJobUUID") - val taskAttemptId = SparkHadoopUtil.get.getTaskAttemptIDFromTaskAttemptContext(context) + val taskAttemptId = context.getTaskAttemptID val split = taskAttemptId.getTaskID.getId val name = FileOutputFormat.getOutputName(context) new Path(outputFile, s"$name-${numberFormat.format(split)}-$uniqueWriteJobId") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala index 665e87e3e335..efbf9988ddc1 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala @@ -27,7 +27,6 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql._ -import org.apache.spark.sql.execution.ConvertToUnsafe import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -689,36 +688,6 @@ abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils with Tes sqlContext.sparkContext.conf.set("spark.speculation", speculationEnabled.toString) } } - - test("HadoopFsRelation produces UnsafeRow") { - withTempTable("test_unsafe") { - withTempPath { dir => - val path = dir.getCanonicalPath - sqlContext.range(3).write.format(dataSourceName).save(path) - sqlContext.read - .format(dataSourceName) - .option("dataSchema", new StructType().add("id", LongType, nullable = false).json) - .load(path) - .registerTempTable("test_unsafe") - - val df = sqlContext.sql( - """SELECT COUNT(*) - |FROM test_unsafe a JOIN test_unsafe b - |WHERE a.id = b.id - """.stripMargin) - - val plan = df.queryExecution.executedPlan - - assert( - plan.collect { case plan: ConvertToUnsafe => plan }.isEmpty, - s"""Query plan shouldn't have ${classOf[ConvertToUnsafe].getSimpleName} node(s): - |$plan - """.stripMargin) - - checkAnswer(df, Row(3)) - } - } - } } // This class is used to test SPARK-8578. We should not use any custom output committer when diff --git a/streaming/pom.xml b/streaming/pom.xml index 435e16db13ab..39cbd0d00f95 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index d0046afdeb44..b186d297610e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -21,16 +21,15 @@ import java.io._ import java.util.concurrent.Executors import java.util.concurrent.RejectedExecutionException -import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.spark.{SparkException, SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.io.CompressionCodec -import org.apache.spark.util.{MetadataCleaner, Utils} +import org.apache.spark.util.Utils import org.apache.spark.streaming.scheduler.JobGenerator - private[streaming] class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) extends Logging with Serializable { @@ -41,7 +40,6 @@ class Checkpoint(ssc: StreamingContext, val checkpointTime: Time) val checkpointDir = ssc.checkpointDir val checkpointDuration = ssc.checkpointDuration val pendingTimes = ssc.scheduler.getPendingTimes().toArray - val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConfPairs = ssc.conf.getAll def createSparkConf(): SparkConf = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala index 7829f5e88799..eedb42c0611c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStreamGraph.scala @@ -17,11 +17,13 @@ package org.apache.spark.streaming +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import scala.collection.mutable.ArrayBuffer -import java.io.{ObjectInputStream, IOException, ObjectOutputStream} + import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.{DStream, InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.scheduler.Job -import org.apache.spark.streaming.dstream.{DStream, ReceiverInputDStream, InputDStream} import org.apache.spark.util.Utils final private[streaming] class DStreamGraph extends Serializable with Logging { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala index 9f6f95223f61..f1114c1e5ac6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StateSpec.scala @@ -17,13 +17,12 @@ package org.apache.spark.streaming -import com.google.common.base.Optional +import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.{JavaPairRDD, JavaUtils} +import org.apache.spark.api.java.{JavaPairRDD, JavaUtils, Optional} import org.apache.spark.api.java.function.{Function3 => JFunction3, Function4 => JFunction4} import org.apache.spark.rdd.RDD import org.apache.spark.util.ClosureCleaner -import org.apache.spark.{HashPartitioner, Partitioner} /** * :: Experimental :: @@ -199,7 +198,11 @@ object StateSpec { StateSpec[KeyType, ValueType, StateType, MappedType] = { val wrappedFunc = (time: Time, k: KeyType, v: Option[ValueType], s: State[StateType]) => { val t = mappingFunction.call(time, k, JavaUtils.optionToOptional(v), s) - Option(t.orNull) + if (t.isPresent) { + Some(t.get) + } else { + None + } } StateSpec.function(wrappedFunc) } @@ -219,7 +222,7 @@ object StateSpec { mappingFunction: JFunction3[KeyType, Optional[ValueType], State[StateType], MappedType]): StateSpec[KeyType, ValueType, StateType, MappedType] = { val wrappedFunc = (k: KeyType, v: Option[ValueType], s: State[StateType]) => { - mappingFunction.call(k, Optional.fromNullable(v.get), s) + mappingFunction.call(k, Optional.ofNullable(v.get), s) } StateSpec.function(wrappedFunc) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index b24c0d067bb0..ba509a1030af 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -29,8 +29,8 @@ import akka.actor.{Props, SupervisorStrategy} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.io.{BytesWritable, LongWritable, Text} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.spark._ import org.apache.spark.annotation.{DeveloperApi, Experimental} @@ -41,7 +41,7 @@ import org.apache.spark.serializer.SerializationDebugger import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.StreamingContextState._ import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receiver.{ActorReceiver, ActorSupervisorStrategy, Receiver} +import org.apache.spark.streaming.receiver.{ActorReceiverSupervisor, ActorSupervisorStrategy, Receiver} import org.apache.spark.streaming.scheduler.{JobScheduler, StreamingListener} import org.apache.spark.streaming.ui.{StreamingJobProgressListener, StreamingTab} import org.apache.spark.util.{AsynchronousListenerBus, CallSite, ShutdownHookManager, ThreadUtils, Utils} @@ -226,7 +226,7 @@ class StreamingContext private[streaming] ( * Set the context to periodically checkpoint the DStream operations for driver * fault-tolerance. * @param directory HDFS-compatible directory where the checkpoint data will be reliably stored. - * Note that this must be a fault-tolerant file system like HDFS for + * Note that this must be a fault-tolerant file system like HDFS. */ def checkpoint(directory: String) { if (directory != null) { @@ -274,7 +274,7 @@ class StreamingContext private[streaming] ( * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver * - * @deprecated As of 1.0.0", replaced by `receiverStream`. + * @deprecated As of 1.0.0 replaced by `receiverStream`. */ @deprecated("Use receiverStream", "1.0.0") def networkStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { @@ -285,7 +285,7 @@ class StreamingContext private[streaming] ( /** * Create an input stream with any arbitrary user implemented receiver. - * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * Find more details at http://spark.apache.org/docs/latest/streaming-custom-receivers.html * @param receiver Custom implementation of Receiver */ def receiverStream[T: ClassTag](receiver: Receiver[T]): ReceiverInputDStream[T] = { @@ -312,7 +312,7 @@ class StreamingContext private[streaming] ( storageLevel: StorageLevel = StorageLevel.MEMORY_AND_DISK_SER_2, supervisorStrategy: SupervisorStrategy = ActorSupervisorStrategy.defaultStrategy ): ReceiverInputDStream[T] = withNamedScope("actor stream") { - receiverStream(new ActorReceiver[T](props, name, storageLevel, supervisorStrategy)) + receiverStream(new ActorReceiverSupervisor[T](props, name, storageLevel, supervisorStrategy)) } /** @@ -549,7 +549,7 @@ class StreamingContext private[streaming] ( // Verify whether the DStream checkpoint is serializable if (isCheckpointingEnabled) { - val checkpoint = new Checkpoint(this, Time.apply(0)) + val checkpoint = new Checkpoint(this, Time(0)) try { Checkpoint.serialize(checkpoint, conf) } catch { @@ -575,9 +575,9 @@ class StreamingContext private[streaming] ( * * Return the current state of the context. The context can be in three possible states - * - * - StreamingContextState.INTIALIZED - The context has been created, but not been started yet. + * - StreamingContextState.INITIALIZED - The context has been created, but not started yet. * Input DStreams, transformations and output operations can be created on the context. - * - StreamingContextState.ACTIVE - The context has been started, and been not stopped. + * - StreamingContextState.ACTIVE - The context has been started, and not stopped. * Input DStreams, transformations and output operations cannot be created on the context. * - StreamingContextState.STOPPED - The context has been stopped and cannot be used any more. */ @@ -902,3 +902,15 @@ object StreamingContext extends Logging { result } } + +private class StreamingContextPythonHelper { + + /** + * This is a private method only for Python to implement `getOrCreate`. + */ + def tryRecoverFromCheckpoint(checkpointPath: String): Option[StreamingContext] = { + val checkpointOption = CheckpointReader.read( + checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, false) + checkpointOption.map(new StreamingContext(null, _, null)) + } +} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala index 01cdcb057404..a59f4efccb57 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStream.scala @@ -17,14 +17,14 @@ package org.apache.spark.streaming.api.java -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.api.java.function.{Function => JFunction} -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.storage.StorageLevel -import org.apache.spark.rdd.RDD - import scala.language.implicitConversions import scala.reflect.ClassTag + +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.api.java.function.{Function => JFunction} +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.dstream.DStream /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala index 84acec7d8e33..733147f63ea2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.api.java -import java.lang.{Long => JLong} +import java.{lang => jl} import java.util.{List => JList} import scala.collection.JavaConverters._ @@ -50,8 +50,8 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T def wrapRDD(in: RDD[T]): R - implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[JLong] = { - in.map(new JLong(_)) + implicit def scalaIntToJavaLong(in: DStream[Long]): JavaDStream[jl.Long] = { + in.map(jl.Long.valueOf) } /** @@ -74,14 +74,14 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * Return a new DStream in which each RDD has a single element generated by counting each RDD * of this DStream. */ - def count(): JavaDStream[JLong] = dstream.count() + def count(): JavaDStream[jl.Long] = dstream.count() /** * Return a new DStream in which each RDD contains the counts of each distinct value in * each RDD of this DStream. Hash partitioning is used to generate the RDDs with * Spark's default number of partitions. */ - def countByValue(): JavaPairDStream[T, JLong] = { + def countByValue(): JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong(dstream.countByValue()) } @@ -91,7 +91,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * partitions. * @param numPartitions number of partitions of each RDD in the new DStream. */ - def countByValue(numPartitions: Int): JavaPairDStream[T, JLong] = { + def countByValue(numPartitions: Int): JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong(dstream.countByValue(numPartitions)) } @@ -101,7 +101,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * of elements in a window over this DStream. windowDuration and slideDuration are as defined in * the window() operation. This is equivalent to window(windowDuration, slideDuration).count() */ - def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[JLong] = { + def countByWindow(windowDuration: Duration, slideDuration: Duration) : JavaDStream[jl.Long] = { dstream.countByWindow(windowDuration, slideDuration) } @@ -116,7 +116,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * DStream's batching interval */ def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration) - : JavaPairDStream[T, JLong] = { + : JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong( dstream.countByValueAndWindow(windowDuration, slideDuration)) } @@ -133,7 +133,7 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T * @param numPartitions number of partitions of each RDD in the new DStream. */ def countByValueAndWindow(windowDuration: Duration, slideDuration: Duration, numPartitions: Int) - : JavaPairDStream[T, JLong] = { + : JavaPairDStream[T, jl.Long] = { JavaPairDStream.scalaToJavaLong( dstream.countByValueAndWindow(windowDuration, slideDuration, numPartitions)) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala index 42ddd63f0f06..d718f1d6fc43 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala @@ -17,21 +17,21 @@ package org.apache.spark.streaming.api.java -import java.lang.{Long => JLong, Iterable => JIterable} +import java.{lang => jl} +import java.lang.{Iterable => JIterable} import java.util.{List => JList} import scala.collection.JavaConverters._ import scala.language.implicitConversions import scala.reflect.ClassTag -import com.google.common.base.Optional import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} import org.apache.spark.Partitioner import org.apache.spark.annotation.Experimental -import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils} +import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaUtils, Optional} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} @@ -847,7 +847,7 @@ object JavaPairDStream { } def scalaToJavaLong[K: ClassTag](dstream: JavaPairDStream[K, Long]) - : JavaPairDStream[K, JLong] = { - DStream.toPairDStreamFunctions(dstream.dstream).mapValues(new JLong(_)) + : JavaPairDStream[K, jl.Long] = { + DStream.toPairDStreamFunctions(dstream.dstream).mapValues(jl.Long.valueOf) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala index e6ff8a0cb545..da0db02236a1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairInputDStream.scala @@ -17,11 +17,11 @@ package org.apache.spark.streaming.api.java -import org.apache.spark.streaming.dstream.InputDStream - import scala.language.implicitConversions import scala.reflect.ClassTag +import org.apache.spark.streaming.dstream.InputDStream + /** * A Java-friendly interface to [[org.apache.spark.streaming.dstream.InputDStream]] of * key-value pairs. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 8f21c79a760c..00f9d8a9e881 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -17,14 +17,15 @@ package org.apache.spark.streaming.api.java -import java.lang.{Boolean => JBoolean} import java.io.{Closeable, InputStream} +import java.lang.{Boolean => JBoolean} import java.util.{List => JList, Map => JMap} import scala.collection.JavaConverters._ import scala.reflect.ClassTag import akka.actor.{Props, SupervisorStrategy} +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} @@ -37,10 +38,9 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ -import org.apache.spark.streaming.scheduler.StreamingListener import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver -import org.apache.hadoop.conf.Configuration +import org.apache.spark.streaming.scheduler.StreamingListener /** * A Java-friendly version of [[org.apache.spark.streaming.StreamingContext]] which is the main @@ -695,9 +695,9 @@ object JavaStreamingContext { * * @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactory. */ - @deprecated("use getOrCreate without JavaStreamingContextFactor", "1.4.0") + @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( checkpointPath: String, factory: JavaStreamingContextFactory @@ -718,7 +718,7 @@ object JavaStreamingContext { * @param factory JavaStreamingContextFactory object to create a new JavaStreamingContext * @param hadoopConf Hadoop configuration if necessary for reading from any HDFS compatible * file system - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactory. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( @@ -744,7 +744,7 @@ object JavaStreamingContext { * file system * @param createOnError Whether to create a new JavaStreamingContext if there is an * error in reading checkpoint data. - * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactor. + * @deprecated As of 1.4.0, replaced by `getOrCreate` without JavaStreamingContextFactory. */ @deprecated("use getOrCreate without JavaStreamingContextFactory", "1.4.0") def getOrCreate( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala index 056248ccc7bc..953fe95177f0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala @@ -30,12 +30,11 @@ import org.apache.spark.SparkException import org.apache.spark.api.java._ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.{Interval, Duration, Time} -import org.apache.spark.streaming.dstream._ +import org.apache.spark.streaming.{Duration, Interval, Time} import org.apache.spark.streaming.api.java._ +import org.apache.spark.streaming.dstream._ import org.apache.spark.util.Utils - /** * Interface for Python callback function which is used to transform RDDs */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala index 4eb92dd8b105..695384deb32d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.dstream import scala.reflect.ClassTag import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} /** * An input stream that always returns the same RDD on each timestep. Useful for testing. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index 1a6edf9473d8..1dfb4e7abc0e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -32,7 +32,7 @@ import org.apache.spark.streaming._ import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.streaming.scheduler.Job import org.apache.spark.streaming.ui.UIUtils -import org.apache.spark.util.{CallSite, MetadataCleaner, Utils} +import org.apache.spark.util.{CallSite, Utils} /** * A Discretized Stream (DStream), the basic abstraction in Spark Streaming, is a continuous @@ -97,11 +97,13 @@ abstract class DStream[T: ClassTag] ( private[streaming] val mustCheckpoint = false private[streaming] var checkpointDuration: Duration = null private[streaming] val checkpointData = new DStreamCheckpointData(this) + @transient + private var restoredFromCheckpointData = false // Reference to whole DStream graph private[streaming] var graph: DStreamGraph = null - private[streaming] def isInitialized = (zeroTime != null) + private[streaming] def isInitialized = zeroTime != null // Duration for which the DStream requires its parent DStream to remember each RDD created private[streaming] def parentRememberDuration = rememberDuration @@ -187,15 +189,15 @@ abstract class DStream[T: ClassTag] ( */ private[streaming] def initialize(time: Time) { if (zeroTime != null && zeroTime != time) { - throw new SparkException("ZeroTime is already initialized to " + zeroTime - + ", cannot initialize it again to " + time) + throw new SparkException(s"ZeroTime is already initialized to $zeroTime" + + s", cannot initialize it again to $time") } zeroTime = time // Set the checkpoint interval to be slideDuration or 10 seconds, which ever is larger if (mustCheckpoint && checkpointDuration == null) { checkpointDuration = slideDuration * math.ceil(Seconds(10) / slideDuration).toInt - logInfo("Checkpoint interval automatically set to " + checkpointDuration) + logInfo(s"Checkpoint interval automatically set to $checkpointDuration") } // Set the minimum value of the rememberDuration if not already set @@ -232,7 +234,7 @@ abstract class DStream[T: ClassTag] ( require( !mustCheckpoint || checkpointDuration != null, - "The checkpoint interval for " + this.getClass.getSimpleName + " has not been set." + + s"The checkpoint interval for ${this.getClass.getSimpleName} has not been set." + " Please use DStream.checkpoint() to set the interval." ) @@ -243,65 +245,53 @@ abstract class DStream[T: ClassTag] ( require( checkpointDuration == null || checkpointDuration >= slideDuration, - "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + - checkpointDuration + " which is lower than its slide time (" + slideDuration + "). " + - "Please set it to at least " + slideDuration + "." + s"The checkpoint interval for ${this.getClass.getSimpleName} has been set to " + + s"$checkpointDuration which is lower than its slide time ($slideDuration). " + + s"Please set it to at least $slideDuration." ) require( checkpointDuration == null || checkpointDuration.isMultipleOf(slideDuration), - "The checkpoint interval for " + this.getClass.getSimpleName + " has been set to " + - checkpointDuration + " which not a multiple of its slide time (" + slideDuration + "). " + - "Please set it to a multiple of " + slideDuration + "." + s"The checkpoint interval for ${this.getClass.getSimpleName} has been set to " + + s" $checkpointDuration which not a multiple of its slide time ($slideDuration). " + + s"Please set it to a multiple of $slideDuration." ) require( checkpointDuration == null || storageLevel != StorageLevel.NONE, - "" + this.getClass.getSimpleName + " has been marked for checkpointing but the storage " + + s"${this.getClass.getSimpleName} has been marked for checkpointing but the storage " + "level has not been set to enable persisting. Please use DStream.persist() to set the " + "storage level to use memory for better checkpointing performance." ) require( checkpointDuration == null || rememberDuration > checkpointDuration, - "The remember duration for " + this.getClass.getSimpleName + " has been set to " + - rememberDuration + " which is not more than the checkpoint interval (" + - checkpointDuration + "). Please set it to higher than " + checkpointDuration + "." - ) - - val metadataCleanerDelay = MetadataCleaner.getDelaySeconds(ssc.conf) - logInfo("metadataCleanupDelay = " + metadataCleanerDelay) - require( - metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, - "It seems you are doing some DStream window operation or setting a checkpoint interval " + - "which requires " + this.getClass.getSimpleName + " to remember generated RDDs for more " + - "than " + rememberDuration.milliseconds / 1000 + " seconds. But Spark's metadata cleanup" + - "delay is set to " + metadataCleanerDelay + " seconds, which is not sufficient. Please " + - "set the Java cleaner delay to more than " + - math.ceil(rememberDuration.milliseconds / 1000.0).toInt + " seconds." + s"The remember duration for ${this.getClass.getSimpleName} has been set to " + + s" $rememberDuration which is not more than the checkpoint interval" + + s" ($checkpointDuration). Please set it to higher than $checkpointDuration." ) dependencies.foreach(_.validateAtStart()) - logInfo("Slide time = " + slideDuration) - logInfo("Storage level = " + storageLevel) - logInfo("Checkpoint interval = " + checkpointDuration) - logInfo("Remember duration = " + rememberDuration) - logInfo("Initialized and validated " + this) + logInfo(s"Slide time = $slideDuration") + logInfo(s"Storage level = ${storageLevel.description}") + logInfo(s"Checkpoint interval = $checkpointDuration") + logInfo(s"Remember duration = $rememberDuration") + logInfo(s"Initialized and validated $this") } private[streaming] def setContext(s: StreamingContext) { if (ssc != null && ssc != s) { - throw new SparkException("Context is already set in " + this + ", cannot set it again") + throw new SparkException(s"Context must not be set again for $this") } ssc = s - logInfo("Set context for " + this) + logInfo(s"Set context for $this") dependencies.foreach(_.setContext(ssc)) } private[streaming] def setGraph(g: DStreamGraph) { if (graph != null && graph != g) { - throw new SparkException("Graph is already set in " + this + ", cannot set it again") + throw new SparkException(s"Graph must not be set again for $this") } graph = g dependencies.foreach(_.setGraph(graph)) @@ -310,7 +300,7 @@ abstract class DStream[T: ClassTag] ( private[streaming] def remember(duration: Duration) { if (duration != null && (rememberDuration == null || duration > rememberDuration)) { rememberDuration = duration - logInfo("Duration for remembering RDDs set to " + rememberDuration + " for " + this) + logInfo(s"Duration for remembering RDDs set to $rememberDuration for $this") } dependencies.foreach(_.remember(parentRememberDuration)) } @@ -320,11 +310,11 @@ abstract class DStream[T: ClassTag] ( if (!isInitialized) { throw new SparkException (this + " has not been initialized") } else if (time <= zeroTime || ! (time - zeroTime).isMultipleOf(slideDuration)) { - logInfo("Time " + time + " is invalid as zeroTime is " + zeroTime + - " and slideDuration is " + slideDuration + " and difference is " + (time - zeroTime)) + logInfo(s"Time $time is invalid as zeroTime is $zeroTime" + + s" , slideDuration is $slideDuration and difference is ${time - zeroTime}") false } else { - logDebug("Time " + time + " is valid") + logDebug(s"Time $time is valid") true } } @@ -462,20 +452,20 @@ abstract class DStream[T: ClassTag] ( oldRDDs.map(x => s"${x._1} -> ${x._2.id}").mkString(", ") + "]") generatedRDDs --= oldRDDs.keys if (unpersistData) { - logDebug("Unpersisting old RDDs: " + oldRDDs.values.map(_.id).mkString(", ")) + logDebug(s"Unpersisting old RDDs: ${oldRDDs.values.map(_.id).mkString(", ")}") oldRDDs.values.foreach { rdd => rdd.unpersist(false) // Explicitly remove blocks of BlockRDD rdd match { case b: BlockRDD[_] => - logInfo("Removing blocks of RDD " + b + " of time " + time) + logInfo(s"Removing blocks of RDD $b of time $time") b.removeBlocks() case _ => } } } - logDebug("Cleared " + oldRDDs.size + " RDDs that were older than " + - (time - rememberDuration) + ": " + oldRDDs.keys.mkString(", ")) + logDebug(s"Cleared ${oldRDDs.size} RDDs that were older than " + + s"${time - rememberDuration}: ${oldRDDs.keys.mkString(", ")}") dependencies.foreach(_.clearMetadata(time)) } @@ -487,10 +477,10 @@ abstract class DStream[T: ClassTag] ( * this method to save custom checkpoint data. */ private[streaming] def updateCheckpointData(currentTime: Time) { - logDebug("Updating checkpoint data for time " + currentTime) + logDebug(s"Updating checkpoint data for time $currentTime") checkpointData.update(currentTime) dependencies.foreach(_.updateCheckpointData(currentTime)) - logDebug("Updated checkpoint data for time " + currentTime + ": " + checkpointData) + logDebug(s"Updated checkpoint data for time $currentTime: $checkpointData") } private[streaming] def clearCheckpointData(time: Time) { @@ -507,22 +497,25 @@ abstract class DStream[T: ClassTag] ( * override the updateCheckpointData() method would also need to override this method. */ private[streaming] def restoreCheckpointData() { - // Create RDDs from the checkpoint data - logInfo("Restoring checkpoint data") - checkpointData.restore() - dependencies.foreach(_.restoreCheckpointData()) - logInfo("Restored checkpoint data") + if (!restoredFromCheckpointData) { + // Create RDDs from the checkpoint data + logInfo("Restoring checkpoint data") + checkpointData.restore() + dependencies.foreach(_.restoreCheckpointData()) + restoredFromCheckpointData = true + logInfo("Restored checkpoint data") + } } @throws(classOf[IOException]) private def writeObject(oos: ObjectOutputStream): Unit = Utils.tryOrIOException { - logDebug(this.getClass().getSimpleName + ".writeObject used") + logDebug(s"${this.getClass().getSimpleName}.writeObject used") if (graph != null) { graph.synchronized { if (graph.checkpointInProgress) { oos.defaultWriteObject() } else { - val msg = "Object of " + this.getClass.getName + " is being serialized " + + val msg = s"Object of ${this.getClass.getName} is being serialized " + " possibly as a part of closure of an RDD operation. This is because " + " the DStream object is being referred to from within the closure. " + " Please rewrite the RDD operation inside this DStream to avoid this. " + @@ -539,7 +532,7 @@ abstract class DStream[T: ClassTag] ( @throws(classOf[IOException]) private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { - logDebug(this.getClass().getSimpleName + ".readObject used") + logDebug(s"${this.getClass().getSimpleName}.readObject used") ois.defaultReadObject() generatedRDDs = new HashMap[Time, RDD[T]] () } @@ -763,7 +756,7 @@ abstract class DStream[T: ClassTag] ( val firstNum = rdd.take(num + 1) // scalastyle:off println println("-------------------------------------------") - println("Time: " + time) + println(s"Time: $time") println("-------------------------------------------") firstNum.take(num).foreach(println) if (firstNum.length > num) println("...") @@ -910,21 +903,19 @@ abstract class DStream[T: ClassTag] ( val alignedToTime = if ((toTime - zeroTime).isMultipleOf(slideDuration)) { toTime } else { - logWarning("toTime (" + toTime + ") is not a multiple of slideDuration (" - + slideDuration + ")") - toTime.floor(slideDuration, zeroTime) + logWarning(s"toTime ($toTime) is not a multiple of slideDuration ($slideDuration)") + toTime.floor(slideDuration, zeroTime) } val alignedFromTime = if ((fromTime - zeroTime).isMultipleOf(slideDuration)) { fromTime } else { - logWarning("fromTime (" + fromTime + ") is not a multiple of slideDuration (" - + slideDuration + ")") + logWarning(s"fromTime ($fromTime) is not a multiple of slideDuration ($slideDuration)") fromTime.floor(slideDuration, zeroTime) } - logInfo("Slicing from " + fromTime + " to " + toTime + - " (aligned to " + alignedFromTime + " and " + alignedToTime + ")") + logInfo(s"Slicing from $fromTime to $toTime" + + s" (aligned to $alignedFromTime and $alignedToTime)") alignedFromTime.to(alignedToTime, slideDuration).flatMap(time => { if (time >= zeroTime) getOrCompute(time) else None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 39fd21342813..3eff174c2b66 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -17,11 +17,13 @@ package org.apache.spark.streaming.dstream +import java.io.{IOException, ObjectInputStream, ObjectOutputStream} + import scala.collection.mutable.HashMap import scala.reflect.ClassTag -import java.io.{ObjectOutputStream, ObjectInputStream, IOException} -import org.apache.hadoop.fs.Path -import org.apache.hadoop.fs.FileSystem + +import org.apache.hadoop.fs.{FileSystem, Path} + import org.apache.spark.Logging import org.apache.spark.streaming.Time import org.apache.spark.util.Utils diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala index fcd5216f101a..43079880b235 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FilteredDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class FilteredDStream[T: ClassTag]( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala index 9d09a3baf37c..778d556d2efb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMapValuedDStream.scala @@ -17,11 +17,12 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ import scala.reflect.ClassTag +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class FlatMapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( parent: DStream[(K, V)], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala index 475ea2d2d4f3..96a444a7baa5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlatMappedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class FlatMappedDStream[T: ClassTag, U: ClassTag]( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala index 4410a9977c87..a0fadee8a984 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ForEachDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} import org.apache.spark.streaming.scheduler.Job -import scala.reflect.ClassTag /** * An internal DStream used to represent output operations like DStream.foreachRDD. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala index dbb295fe54f7..9f1252f091a6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/GlommedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class GlommedDStream[T: ClassTag](parent: DStream[T]) extends DStream[Array[T]](parent.ssc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 95994c983c0c..d60f418e5c4d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -28,7 +28,8 @@ import org.apache.spark.util.Utils /** * This is the abstract base class for all input streams. This class provides methods - * start() and stop() which is called by Spark Streaming system to start and stop receiving data. + * start() and stop() which are called by Spark Streaming system to start and stop + * receiving data, respectively. * Input streams that can generate RDDs from new data by running a service/thread only on * the driver node (that is, without running a receiver on worker nodes), can be * implemented by directly inheriting this InputDStream. For example, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala index 5994bc1e23f2..bcdf1752e61e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapPartitionedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class MapPartitionedDStream[T: ClassTag, U: ClassTag]( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala index 954d2eb4a7b0..855c3dd096f4 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapValuedDStream.scala @@ -17,11 +17,12 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ import scala.reflect.ClassTag +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class MapValuedDStream[K: ClassTag, V: ClassTag, U: ClassTag]( parent: DStream[(K, V)], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala index 706465d4e25d..36ff9c7e6182 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MapWithStateDStream.scala @@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.{EmptyRDD, RDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ -import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord} import org.apache.spark.streaming.dstream.InternalMapWithStateDStream._ +import org.apache.spark.streaming.rdd.{MapWithStateRDD, MapWithStateRDDRecord} /** * :: Experimental :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala index fa14b2e897c3..e11d82697af8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/MappedDStream.scala @@ -17,10 +17,11 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD import scala.reflect.ClassTag +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.{Duration, Time} + private[streaming] class MappedDStream[T: ClassTag, U: ClassTag] ( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala index a64a1fe93f40..babc72270932 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PairDStreamFunctions.scala @@ -24,12 +24,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.{JobConf, OutputFormat} import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat} +import org.apache.spark.{HashPartitioner, Partitioner} import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.streaming._ +import org.apache.spark.streaming.StreamingContext.rddToFileName import org.apache.spark.util.{SerializableConfiguration, SerializableJobConf} -import org.apache.spark.{HashPartitioner, Partitioner} /** * Extra functions available on DStream of (key, value) pairs through an implicit conversion. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala index 002aac9f4361..2442e4c01a0c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/PluggableInputDStream.scala @@ -17,8 +17,9 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.streaming.StreamingContext import scala.reflect.ClassTag + +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver private[streaming] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala index cd073646370d..a8d108de6c3e 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/QueueInputDStream.scala @@ -23,7 +23,7 @@ import scala.collection.mutable.{ArrayBuffer, Queue} import scala.reflect.ClassTag import org.apache.spark.rdd.{RDD, UnionRDD} -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} private[streaming] class QueueInputDStream[T: ClassTag]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala index 5a9eda7c1277..ac73dca05a67 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/RawInputDStream.scala @@ -17,19 +17,18 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.{Logging, SparkEnv} -import org.apache.spark.storage.{StorageLevel, StreamBlockId} -import org.apache.spark.streaming.StreamingContext - -import scala.reflect.ClassTag - +import java.io.EOFException import java.net.InetSocketAddress import java.nio.ByteBuffer import java.nio.channels.{ReadableByteChannel, SocketChannel} -import java.io.EOFException import java.util.concurrent.ArrayBlockingQueue -import org.apache.spark.streaming.receiver.Receiver +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.storage.{StorageLevel, StreamBlockId} +import org.apache.spark.streaming.StreamingContext +import org.apache.spark.streaming.receiver.Receiver /** * An input stream that reads blocks of serialized objects from a given network address. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala index 87c20afd5c13..565b137228d0 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReceiverInputDStream.scala @@ -21,18 +21,18 @@ import scala.reflect.ClassTag import org.apache.spark.rdd.{BlockRDD, RDD} import org.apache.spark.storage.BlockId +import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.{RateController, ReceivedBlockInfo, StreamInputInfo} import org.apache.spark.streaming.scheduler.rate.RateEstimator -import org.apache.spark.streaming.scheduler.{ReceivedBlockInfo, RateController, StreamInputInfo} import org.apache.spark.streaming.util.WriteAheadLogUtils -import org.apache.spark.streaming.{StreamingContext, Time} /** * Abstract class for defining any [[org.apache.spark.streaming.dstream.InputDStream]] * that has to start a receiver on worker nodes to receive external data. * Specific implementations of ReceiverInputDStream must - * define `the getReceiver()` function that gets the receiver object of type + * define [[getReceiver]] function that gets the receiver object of type * [[org.apache.spark.streaming.receiver.Receiver]] that will be sent * to the workers to receive data. * @param ssc_ Streaming context that will execute this input stream @@ -121,7 +121,7 @@ abstract class ReceiverInputDStream[T: ClassTag](ssc_ : StreamingContext) } if (validBlockIds.size != blockIds.size) { logWarning("Some blocks could not be recovered as they were not found in memory. " + - "To prevent such data loss, enabled Write Ahead Log (see programming guide " + + "To prevent such data loss, enable Write Ahead Log (see programming guide " + "for more details.") } new BlockRDD[T](ssc.sc, validBlockIds) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index 6a583bf2a362..535954908539 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -17,18 +17,15 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.{CoGroupedRDD, MapPartitionsRDD} +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + import org.apache.spark.Partitioner import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.{CoGroupedRDD, MapPartitionsRDD, RDD} import org.apache.spark.storage.StorageLevel - -import scala.collection.mutable.ArrayBuffer import org.apache.spark.streaming.{Duration, Interval, Time} -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - private[streaming] class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( parent: DStream[(K, V)], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala index e0ffd5d86b43..0fe15440dd44 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala @@ -17,11 +17,12 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.Partitioner -import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{Duration, Time} -import scala.reflect.ClassTag private[streaming] class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag]( diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala index de84e0c9a498..e70fc87c39d9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/SocketInputDStream.scala @@ -17,18 +17,17 @@ package org.apache.spark.streaming.dstream -import scala.util.control.NonFatal - -import org.apache.spark.streaming.StreamingContext -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.NextIterator +import java.io._ +import java.net.{ConnectException, Socket} import scala.reflect.ClassTag +import scala.util.control.NonFatal -import java.io._ -import java.net.{UnknownHostException, Socket} import org.apache.spark.Logging +import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.util.NextIterator private[streaming] class SocketInputDStream[T: ClassTag]( @@ -52,7 +51,20 @@ class SocketReceiver[T: ClassTag]( storageLevel: StorageLevel ) extends Receiver[T](storageLevel) with Logging { + private var socket: Socket = _ + def onStart() { + + logInfo(s"Connecting to $host:$port") + try { + socket = new Socket(host, port) + } catch { + case e: ConnectException => + restart(s"Error connecting to $host:$port", e) + return + } + logInfo(s"Connected to $host:$port") + // Start the thread that receives data over a connection new Thread("Socket Receiver") { setDaemon(true) @@ -61,20 +73,22 @@ class SocketReceiver[T: ClassTag]( } def onStop() { - // There is nothing much to do as the thread calling receive() - // is designed to stop by itself isStopped() returns false + // in case restart thread close it twice + synchronized { + if (socket != null) { + socket.close() + socket = null + logInfo(s"Closed socket to $host:$port") + } + } } /** Create a socket connection and receive data until receiver is stopped */ def receive() { - var socket: Socket = null try { - logInfo("Connecting to " + host + ":" + port) - socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) val iterator = bytesToObjects(socket.getInputStream()) while(!isStopped && iterator.hasNext) { - store(iterator.next) + store(iterator.next()) } if (!isStopped()) { restart("Socket data stream had no more data") @@ -82,16 +96,11 @@ class SocketReceiver[T: ClassTag]( logInfo("Stopped receiving") } } catch { - case e: java.net.ConnectException => - restart("Error connecting to " + host + ":" + port, e) case NonFatal(e) => logWarning("Error receiving data", e) restart("Error receiving data", e) } finally { - if (socket != null) { - socket.close() - logInfo("Closed socket to " + host + ":" + port) - } + onStop() } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 621d6dff788f..ebbe139a2cdf 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -17,14 +17,14 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD +import scala.reflect.ClassTag + import org.apache.spark.Partitioner import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.{Duration, Time} -import scala.reflect.ClassTag - private[streaming] class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( parent: DStream[(K, V)], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala index d73ffdfd84d2..2b07dd618586 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/UnionDStream.scala @@ -21,9 +21,8 @@ import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag import org.apache.spark.SparkException +import org.apache.spark.rdd.{RDD, UnionRDD} import org.apache.spark.streaming.{Duration, Time} -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.UnionRDD private[streaming] class UnionDStream[T: ClassTag](parents: Array[DStream[T]]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 4efba039f895..ee50a8d024e1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -17,13 +17,13 @@ package org.apache.spark.streaming.dstream +import scala.reflect.ClassTag + import org.apache.spark.rdd.{PartitionerAwareUnionRDD, RDD, UnionRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark.streaming.Duration -import scala.reflect.ClassTag - private[streaming] class WindowedDStream[T: ClassTag]( parent: DStream[T], diff --git a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala index fdf61674a37f..1d2244eaf22b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/rdd/MapWithStateRDD.scala @@ -22,11 +22,11 @@ import java.io.{IOException, ObjectInputStream, ObjectOutputStream} import scala.collection.mutable.ArrayBuffer import scala.reflect.ClassTag +import org.apache.spark._ import org.apache.spark.rdd.{MapPartitionsRDD, RDD} -import org.apache.spark.streaming.{Time, StateImpl, State} +import org.apache.spark.streaming.{State, StateImpl, Time} import org.apache.spark.streaming.util.{EmptyStateMap, StateMap} import org.apache.spark.util.Utils -import org.apache.spark._ /** * Record storing the keyed-state [[MapWithStateRDD]]. Each record contains a [[StateMap]] and a diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala index 7ec74016a1c2..0eabf3d260b2 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ActorReceiver.scala @@ -47,13 +47,12 @@ object ActorSupervisorStrategy { /** * :: DeveloperApi :: - * A receiver trait to be mixed in with your Actor to gain access to - * the API for pushing received data into Spark Streaming for being processed. + * A base Actor that provides APIs for pushing received data into Spark Streaming for processing. * * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html * * @example {{{ - * class MyActor extends Actor with ActorHelper{ + * class MyActor extends ActorReceiver { * def receive { * case anything: String => store(anything) * } @@ -69,13 +68,60 @@ object ActorSupervisorStrategy { * should be same. */ @DeveloperApi -trait ActorHelper extends Logging{ +abstract class ActorReceiver extends Actor { - self: Actor => // to ensure that this can be added to Actor classes only + /** Store an iterator of received data as a data block into Spark's memory. */ + def store[T](iter: Iterator[T]) { + context.parent ! IteratorData(iter) + } + + /** + * Store the bytes of received data as a data block into Spark's memory. Note + * that the data in the ByteBuffer must be serialized using the same serializer + * that Spark is configured to use. + */ + def store(bytes: ByteBuffer) { + context.parent ! ByteBufferData(bytes) + } + + /** + * Store a single item of received data to Spark's memory. + * These single items will be aggregated together into data blocks before + * being pushed into Spark's memory. + */ + def store[T](item: T) { + context.parent ! SingleItemData(item) + } +} + +/** + * :: DeveloperApi :: + * A Java UntypedActor that provides APIs for pushing received data into Spark Streaming for + * processing. + * + * Find more details at: http://spark.apache.org/docs/latest/streaming-custom-receivers.html + * + * @example {{{ + * class MyActor extends JavaActorReceiver { + * def receive { + * case anything: String => store(anything) + * } + * } + * + * // Can be used with an actorStream as follows + * ssc.actorStream[String](Props(new MyActor),"MyActorReceiver") + * + * }}} + * + * @note Since Actor may exist outside the spark framework, It is thus user's responsibility + * to ensure the type safety, i.e parametrized type of push block and InputDStream + * should be same. + */ +@DeveloperApi +abstract class JavaActorReceiver extends UntypedActor { /** Store an iterator of received data as a data block into Spark's memory. */ def store[T](iter: Iterator[T]) { - logDebug("Storing iterator") context.parent ! IteratorData(iter) } @@ -85,7 +131,6 @@ trait ActorHelper extends Logging{ * that Spark is configured to use. */ def store(bytes: ByteBuffer) { - logDebug("Storing Bytes") context.parent ! ByteBufferData(bytes) } @@ -95,7 +140,6 @@ trait ActorHelper extends Logging{ * being pushed into Spark's memory. */ def store[T](item: T) { - logDebug("Storing item") context.parent ! SingleItemData(item) } } @@ -104,7 +148,7 @@ trait ActorHelper extends Logging{ * :: DeveloperApi :: * Statistics for querying the supervisor about state of workers. Used in * conjunction with `StreamingContext.actorStream` and - * [[org.apache.spark.streaming.receiver.ActorHelper]]. + * [[org.apache.spark.streaming.receiver.ActorReceiver]]. */ @DeveloperApi case class Statistics(numberOfMsgs: Int, @@ -137,7 +181,7 @@ private[streaming] case class ByteBufferData(bytes: ByteBuffer) extends ActorRec * context.parent ! Props(new Worker, "Worker") * }}} */ -private[streaming] class ActorReceiver[T: ClassTag]( +private[streaming] class ActorReceiverSupervisor[T: ClassTag]( props: Props, name: String, storageLevel: StorageLevel, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index cc7c04bfc9f6..109af32cf4bb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -21,7 +21,7 @@ import java.util.concurrent.{ArrayBlockingQueue, TimeUnit} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkException, Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.util.RecurringTimer import org.apache.spark.util.{Clock, SystemClock} diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 5f6c5b024085..faa5aca1d8f7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -17,18 +17,18 @@ package org.apache.spark.streaming.receiver -import scala.concurrent.duration._ import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration._ import scala.language.{existentials, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkConf, SparkException} import org.apache.spark.storage._ import org.apache.spark.streaming.receiver.WriteAheadLogBasedBlockHandler._ import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} import org.apache.spark.util.{Clock, SystemClock, ThreadUtils} -import org.apache.spark.{Logging, SparkConf, SparkException} /** Trait that represents the metadata related to storage of blocks */ private[streaming] trait ReceivedBlockStoreResult { @@ -69,7 +69,7 @@ private[streaming] class BlockManagerBasedBlockHandler( def storeBlock(blockId: StreamBlockId, block: ReceivedBlock): ReceivedBlockStoreResult = { - var numRecords = None: Option[Long] + var numRecords: Option[Long] = None val putResult: Seq[(BlockId, BlockStatus)] = block match { case ArrayBufferBlock(arrayBuffer) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 2252e28f22af..639f4259e2e7 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -22,8 +22,8 @@ import java.nio.ByteBuffer import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ -import org.apache.spark.storage.StorageLevel import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.storage.StorageLevel /** * :: DeveloperApi :: @@ -103,7 +103,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable /** * This method is called by the system when the receiver is stopped. All resources - * (threads, buffers, etc.) setup in `onStart()` must be cleaned up in this method. + * (threads, buffers, etc.) set up in `onStart()` must be cleaned up in this method. */ def onStop() @@ -273,7 +273,7 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable /** Get the attached supervisor. */ private[streaming] def supervisor: ReceiverSupervisor = { assert(_supervisor != null, - "A ReceiverSupervisor have not been attached to the receiver yet. Maybe you are starting " + + "A ReceiverSupervisor has not been attached to the receiver yet. Maybe you are starting " + "some computation in the receiver before the Receiver.onStart() has been called.") _supervisor } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index 158d1ba2f183..d0195fb14f0a 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent._ import scala.util.control.NonFatal -import org.apache.spark.{SparkEnv, Logging, SparkConf} +import org.apache.spark.{Logging, SparkConf, SparkEnv} import org.apache.spark.storage.StreamBlockId -import org.apache.spark.util.{Utils, ThreadUtils} +import org.apache.spark.util.{ThreadUtils, Utils} /** * Abstract class that is responsible for supervising a Receiver in the worker. @@ -143,10 +143,10 @@ private[streaming] abstract class ReceiverSupervisor( def startReceiver(): Unit = synchronized { try { if (onReceiverStart()) { - logInfo("Starting receiver") + logInfo(s"Starting receiver $streamId") receiverState = Started receiver.onStart() - logInfo("Called receiver onStart") + logInfo(s"Called receiver $streamId onStart") } else { // The driver refused us stop("Registered unsuccessfully because Driver refused to start receiver " + streamId, None) @@ -218,11 +218,9 @@ private[streaming] abstract class ReceiverSupervisor( stopLatch.await() if (stoppingError != null) { logError("Stopped receiver with error: " + stoppingError) + throw stoppingError } else { logInfo("Stopped receiver without error") } - if (stoppingError != null) { - throw stoppingError - } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala index 167f56aa4228..b774b6b9a55d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisorImpl.scala @@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer import com.google.common.base.Throwables import org.apache.hadoop.conf.Configuration +import org.apache.spark.{Logging, SparkEnv, SparkException} import org.apache.spark.rpc.{RpcEnv, ThreadSafeRpcEndpoint} import org.apache.spark.storage.StreamBlockId import org.apache.spark.streaming.Time import org.apache.spark.streaming.scheduler._ import org.apache.spark.streaming.util.WriteAheadLogUtils import org.apache.spark.util.RpcUtils -import org.apache.spark.{Logging, SparkEnv, SparkException} /** * Concrete implementation of [[org.apache.spark.streaming.receiver.ReceiverSupervisor]] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala index deb15d075975..92da0ced28fb 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/InputInfoTracker.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.Logging import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} /** * :: DeveloperApi :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala index ab1b3565fcc1..7050d7ef4524 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/Job.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Try} import org.apache.spark.streaming.Time -import org.apache.spark.util.{Utils, CallSite} +import org.apache.spark.util.{CallSite, Utils} /** * Class representing a Spark computation. It may contain multiple Spark jobs. diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala index 8dfdc1f57b40..a5a01e77639c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala @@ -19,10 +19,10 @@ package org.apache.spark.streaming.scheduler import scala.util.{Failure, Success, Try} -import org.apache.spark.{SparkEnv, Logging} +import org.apache.spark.{Logging, SparkEnv} import org.apache.spark.streaming.{Checkpoint, CheckpointWriter, Time} import org.apache.spark.streaming.util.RecurringTimer -import org.apache.spark.util.{Utils, Clock, EventLoop, ManualClock} +import org.apache.spark.util.{Clock, EventLoop, ManualClock, Utils} /** Event classes for JobGenerator */ private[scheduler] sealed trait JobGeneratorEvent diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index f76300351e3c..6e7232a2a088 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -59,17 +59,15 @@ case class JobSet( // Time taken to process all the jobs from the time they were submitted // (i.e. including the time they wait in the streaming scheduler queue) - def totalDelay: Long = { - processingEndTime - time.milliseconds - } + def totalDelay: Long = processingEndTime - time.milliseconds def toBatchInfo: BatchInfo = { BatchInfo( time, streamIdToInputInfo, submissionTime, - if (processingStartTime >= 0) Some(processingStartTime) else None, - if (processingEndTime >= 0) Some(processingEndTime) else None, + if (hasStarted) Some(processingStartTime) else None, + if (hasCompleted) Some(processingEndTime) else None, jobs.map { job => (job.outputOpId, job.toOutputOperationInfo) }.toMap ) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 4dab64d696b3..60b5c838e973 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -27,11 +27,11 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.network.util.JavaUtils import org.apache.spark.streaming.Time import org.apache.spark.streaming.util.{BatchedWriteAheadLog, WriteAheadLog, WriteAheadLogUtils} import org.apache.spark.util.{Clock, Utils} -import org.apache.spark.{Logging, SparkConf} /** Trait representing any event in the ReceivedBlockTracker that updates its state. */ private[streaming] sealed trait ReceivedBlockTrackerLogEvent diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index ea5d12b50fcc..678f1dc950ad 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -20,14 +20,14 @@ package org.apache.spark.streaming.scheduler import java.util.concurrent.{CountDownLatch, TimeUnit} import scala.collection.mutable.HashMap -import scala.concurrent.{Future, ExecutionContext} +import scala.concurrent.{ExecutionContext, Future} import scala.language.existentials import scala.util.{Failure, Success} import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.rpc._ -import org.apache.spark.scheduler.{TaskLocation, ExecutorCacheTaskLocation} +import org.apache.spark.scheduler.{ExecutorCacheTaskLocation, TaskLocation} import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util.WriteAheadLogUtils @@ -435,10 +435,6 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false /** RpcEndpoint to receive messages from the receivers. */ private class ReceiverTrackerEndpoint(override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint { - // TODO Remove this thread pool after https://github.com/apache/spark/issues/7385 is merged - private val submitJobThreadPool = ExecutionContext.fromExecutorService( - ThreadUtils.newDaemonCachedThreadPool("submit-job-thread-pool")) - private val walBatchingThreadPool = ExecutionContext.fromExecutorService( ThreadUtils.newDaemonCachedThreadPool("wal-batching-thread-pool")) @@ -610,12 +606,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false logInfo(s"Restarting Receiver $receiverId") self.send(RestartReceiver(receiver)) } - }(submitJobThreadPool) + }(ThreadUtils.sameThread) logInfo(s"Receiver ${receiver.streamId} started") } override def onStop(): Unit = { - submitJobThreadPool.shutdownNow() active = false walBatchingThreadPool.shutdown() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala index d19bdbb443c5..58fc78d55210 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/StreamingListener.scala @@ -19,8 +19,8 @@ package org.apache.spark.streaming.scheduler import scala.collection.mutable.Queue -import org.apache.spark.util.Distribution import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.util.Distribution /** * :: DeveloperApi :: diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala index bc1711930d3a..7635f79a3d2d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/BatchPage.scala @@ -25,8 +25,8 @@ import org.apache.commons.lang3.StringEscapeUtils import org.apache.spark.streaming.Time import org.apache.spark.streaming.ui.StreamingJobProgressListener.{OutputOpId, SparkJobId} -import org.apache.spark.ui.jobs.UIData.JobUIData import org.apache.spark.ui.{UIUtils => SparkUIUtils, WebUIPage} +import org.apache.spark.ui.jobs.UIData.JobUIData private[ui] case class SparkJobIdWithUIData(sparkJobId: SparkJobId, jobUIData: Option[JobUIData]) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala index f6cc6edf2569..4908be053635 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingJobProgressListener.scala @@ -17,19 +17,13 @@ package org.apache.spark.streaming.ui -import java.util.LinkedHashMap -import java.util.{Map => JMap} -import java.util.Properties +import java.util.{LinkedHashMap, Map => JMap, Properties} -import scala.collection.mutable.{ArrayBuffer, Queue, HashMap, SynchronizedBuffer} +import scala.collection.mutable.{ArrayBuffer, HashMap, Queue, SynchronizedBuffer} import org.apache.spark.scheduler._ -import org.apache.spark.streaming.{Time, StreamingContext} +import org.apache.spark.streaming.{StreamingContext, Time} import org.apache.spark.streaming.scheduler._ -import org.apache.spark.streaming.scheduler.StreamingListenerReceiverStarted -import org.apache.spark.streaming.scheduler.StreamingListenerBatchStarted -import org.apache.spark.streaming.scheduler.StreamingListenerBatchSubmitted - private[streaming] class StreamingJobProgressListener(ssc: StreamingContext) extends StreamingListener with SparkListener { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala index bc53f2a31f6d..0662c64a0ce9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/StreamingTab.scala @@ -21,14 +21,14 @@ import org.apache.spark.{Logging, SparkException} import org.apache.spark.streaming.StreamingContext import org.apache.spark.ui.{SparkUI, SparkUITab} -import StreamingTab._ - /** * Spark Web UI tab that shows statistics of a streaming job. * This assumes the given SparkContext has enabled its SparkUI. */ private[spark] class StreamingTab(val ssc: StreamingContext) - extends SparkUITab(getSparkUI(ssc), "streaming") with Logging { + extends SparkUITab(StreamingTab.getSparkUI(ssc), "streaming") with Logging { + + import StreamingTab._ private val STATIC_RESOURCE_DIR = "org/apache/spark/streaming/ui/static" diff --git a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala index d89f7ad3e16b..a485a46937f3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/ui/UIUtils.scala @@ -17,14 +17,14 @@ package org.apache.spark.streaming.ui -import scala.xml.Node - -import org.apache.commons.lang3.StringEscapeUtils - import java.text.SimpleDateFormat import java.util.TimeZone import java.util.concurrent.TimeUnit +import scala.xml.Node + +import org.apache.commons.lang3.StringEscapeUtils + private[streaming] object UIUtils { /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala index b2cd524f28b7..8cb45cdffa5d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/BatchedWriteAheadLog.scala @@ -18,8 +18,8 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer -import java.util.concurrent.LinkedBlockingQueue import java.util.{Iterator => JIterator} +import java.util.concurrent.LinkedBlockingQueue import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index a99b57083583..15ad2e27d372 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer -import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} import java.util.{Iterator => JIterator} +import java.util.concurrent.{RejectedExecutionException, ThreadPoolExecutor} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -29,8 +29,8 @@ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.util.{CompletionIterator, ThreadUtils} import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.util.{CompletionIterator, ThreadUtils} /** * This class manages write ahead log files. @@ -224,7 +224,8 @@ private[streaming] class FileBasedWriteAheadLog( val logDirectoryPath = new Path(logDirectory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { + if (fileSystem.exists(logDirectoryPath) && + fileSystem.getFileStatus(logDirectoryPath).isDirectory) { val logFileInfo = logFilesTologInfo(fileSystem.listStatus(logDirectoryPath).map { _.getPath }) pastLogs.clear() pastLogs ++= logFileInfo @@ -253,7 +254,7 @@ private[streaming] object FileBasedWriteAheadLog { def getCallerName(): Option[String] = { val stackTraceClasses = Thread.currentThread.getStackTrace().map(_.getClassName) - stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split(".").lastOption) + stackTraceClasses.find(!_.contains("WriteAheadLog")).flatMap(_.split("\\.").lastOption) } /** Convert a sequence of files to a sequence of sorted LogInfo objects */ diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala index a375c0729534..e79b139bdd03 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogReader.scala @@ -16,10 +16,11 @@ */ package org.apache.spark.streaming.util -import java.io.{IOException, Closeable, EOFException} +import java.io.{Closeable, EOFException, IOException} import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration + import org.apache.spark.Logging /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala index 1185f30265f6..1f5c1d4369b5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLogWriter.scala @@ -19,10 +19,7 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer -import scala.util.Try - import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.FSDataOutputStream import org.apache.spark.util.Utils @@ -34,11 +31,6 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: private lazy val stream = HdfsUtils.getOutputStream(path, hadoopConf) - private lazy val hadoopFlushMethod = { - // Use reflection to get the right flush operation - val cls = classOf[FSDataOutputStream] - Try(cls.getMethod("hflush")).orElse(Try(cls.getMethod("sync"))).toOption - } private var nextOffset = stream.getPos() private var closed = false @@ -62,7 +54,7 @@ private[streaming] class FileBasedWriteAheadLogWriter(path: String, hadoopConf: } private def flush() { - hadoopFlushMethod.foreach { _.invoke(stream) } + stream.hflush() // Useful for local file system where hflush/sync does not work (HADOOP-7844) stream.getWrappedStream.flush() } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala index a96e2924a0b4..5c3c7a6bf1b3 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RateLimitedOutputStream.scala @@ -17,13 +17,12 @@ package org.apache.spark.streaming.util -import scala.annotation.tailrec - import java.io.OutputStream import java.util.concurrent.TimeUnit._ -import org.apache.spark.Logging +import scala.annotation.tailrec +import org.apache.spark.Logging private[streaming] class RateLimitedOutputStream(out: OutputStream, desiredBytesPerSec: Int) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index 6addb9675203..e48eaf7913b1 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -23,7 +23,7 @@ import java.nio.ByteBuffer import scala.io.Source -import org.apache.spark.{SparkConf, Logging} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.serializer.KryoSerializer import org.apache.spark.util.IntParam diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala index 3f139ad138c8..4e5baebaae04 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/StateMap.scala @@ -17,16 +17,20 @@ package org.apache.spark.streaming.util -import java.io.{ObjectInputStream, ObjectOutputStream} +import java.io._ import scala.reflect.ClassTag +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Input, Output} + import org.apache.spark.SparkConf +import org.apache.spark.serializer.{KryoOutputObjectOutputBridge, KryoInputObjectInputBridge} import org.apache.spark.streaming.util.OpenHashMapBasedStateMap._ import org.apache.spark.util.collection.OpenHashMap /** Internal interface for defining the map that keeps track of sessions. */ -private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Serializable { +private[streaming] abstract class StateMap[K, S] extends Serializable { /** Get the state for a key if it exists */ def get(key: K): Option[S] @@ -54,7 +58,7 @@ private[streaming] abstract class StateMap[K: ClassTag, S: ClassTag] extends Ser /** Companion object for [[StateMap]], with utility methods */ private[streaming] object StateMap { - def empty[K: ClassTag, S: ClassTag]: StateMap[K, S] = new EmptyStateMap[K, S] + def empty[K, S]: StateMap[K, S] = new EmptyStateMap[K, S] def create[K: ClassTag, S: ClassTag](conf: SparkConf): StateMap[K, S] = { val deltaChainThreshold = conf.getInt("spark.streaming.sessionByKey.deltaChainThreshold", @@ -64,7 +68,7 @@ private[streaming] object StateMap { } /** Implementation of StateMap interface representing an empty map */ -private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMap[K, S] { +private[streaming] class EmptyStateMap[K, S] extends StateMap[K, S] { override def put(key: K, session: S, updateTime: Long): Unit = { throw new NotImplementedError("put() should not be called on an EmptyStateMap") } @@ -77,21 +81,26 @@ private[streaming] class EmptyStateMap[K: ClassTag, S: ClassTag] extends StateMa } /** Implementation of StateMap based on Spark's [[org.apache.spark.util.collection.OpenHashMap]] */ -private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( +private[streaming] class OpenHashMapBasedStateMap[K, S]( @transient @volatile var parentStateMap: StateMap[K, S], - initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, - deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD - ) extends StateMap[K, S] { self => + private var initialCapacity: Int = DEFAULT_INITIAL_CAPACITY, + private var deltaChainThreshold: Int = DELTA_CHAIN_LENGTH_THRESHOLD + )(implicit private var keyClassTag: ClassTag[K], private var stateClassTag: ClassTag[S]) + extends StateMap[K, S] with KryoSerializable { self => - def this(initialCapacity: Int, deltaChainThreshold: Int) = this( + def this(initialCapacity: Int, deltaChainThreshold: Int) + (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( new EmptyStateMap[K, S], initialCapacity = initialCapacity, deltaChainThreshold = deltaChainThreshold) - def this(deltaChainThreshold: Int) = this( + def this(deltaChainThreshold: Int) + (implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = this( initialCapacity = DEFAULT_INITIAL_CAPACITY, deltaChainThreshold = deltaChainThreshold) - def this() = this(DELTA_CHAIN_LENGTH_THRESHOLD) + def this()(implicit keyClassTag: ClassTag[K], stateClassTag: ClassTag[S]) = { + this(DELTA_CHAIN_LENGTH_THRESHOLD) + } require(initialCapacity >= 1, "Invalid initial capacity") require(deltaChainThreshold >= 1, "Invalid delta chain threshold") @@ -206,11 +215,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( * Serialize the map data. Besides serialization, this method actually compact the deltas * (if needed) in a single pass over all the data in the map. */ - - private def writeObject(outputStream: ObjectOutputStream): Unit = { - // Write all the non-transient fields, especially class tags, etc. - outputStream.defaultWriteObject() - + private def writeObjectInternal(outputStream: ObjectOutput): Unit = { // Write the data in the delta of this state map outputStream.writeInt(deltaMap.size) val deltaMapIterator = deltaMap.iterator @@ -262,11 +267,7 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } /** Deserialize the map data. */ - private def readObject(inputStream: ObjectInputStream): Unit = { - - // Read the non-transient fields, especially class tags, etc. - inputStream.defaultReadObject() - + private def readObjectInternal(inputStream: ObjectInput): Unit = { // Read the data of the delta val deltaMapSize = inputStream.readInt() deltaMap = if (deltaMapSize != 0) { @@ -309,6 +310,34 @@ private[streaming] class OpenHashMapBasedStateMap[K: ClassTag, S: ClassTag]( } parentStateMap = newParentSessionStore } + + private def writeObject(outputStream: ObjectOutputStream): Unit = { + // Write all the non-transient fields, especially class tags, etc. + outputStream.defaultWriteObject() + writeObjectInternal(outputStream) + } + + private def readObject(inputStream: ObjectInputStream): Unit = { + // Read the non-transient fields, especially class tags, etc. + inputStream.defaultReadObject() + readObjectInternal(inputStream) + } + + override def write(kryo: Kryo, output: Output): Unit = { + output.writeInt(initialCapacity) + output.writeInt(deltaChainThreshold) + kryo.writeClassAndObject(output, keyClassTag) + kryo.writeClassAndObject(output, stateClassTag) + writeObjectInternal(new KryoOutputObjectOutputBridge(kryo, output)) + } + + override def read(kryo: Kryo, input: Input): Unit = { + initialCapacity = input.readInt() + deltaChainThreshold = input.readInt() + keyClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[K]] + stateClassTag = kryo.readClassAndObject(input).asInstanceOf[ClassTag[S]] + readObjectInternal(new KryoInputObjectInputBridge(kryo, input)) + } } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala index 7f9e2c973497..ed616d8e810b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/WriteAheadLogUtils.scala @@ -21,8 +21,8 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.spark.util.Utils import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.util.Utils /** A helper class with utility functions related to the WriteAheadLog interface */ private[streaming] object WriteAheadLogUtils extends Logging { diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java index 9722c60bba1c..4dbcef293487 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaAPISuite.java @@ -33,7 +33,6 @@ import org.junit.Assert; import org.junit.Test; -import com.google.common.base.Optional; import com.google.common.io.Files; import com.google.common.collect.Sets; @@ -43,6 +42,7 @@ import org.apache.spark.api.java.JavaPairRDD; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.function.*; import org.apache.spark.storage.StorageLevel; import org.apache.spark.streaming.api.java.*; @@ -772,8 +772,8 @@ public Iterable call(String x) { @SuppressWarnings("unchecked") @Test public void testForeachRDD() { - final Accumulator accumRdd = ssc.sc().accumulator(0); - final Accumulator accumEle = ssc.sc().accumulator(0); + final Accumulator accumRdd = ssc.sparkContext().accumulator(0); + final Accumulator accumEle = ssc.sparkContext().accumulator(0); List> inputData = Arrays.asList( Arrays.asList(1,1,1), Arrays.asList(1,1,1)); diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java index bc4bc2eb4223..9b7701003d8d 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaMapWithStateSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.streaming; import java.io.Serializable; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -25,11 +26,10 @@ import scala.Tuple2; -import com.google.common.base.Optional; -import com.google.common.collect.Lists; import com.google.common.collect.Sets; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.function.Function; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.util.ManualClock; import org.junit.Assert; @@ -37,6 +37,7 @@ import org.apache.spark.HashPartitioner; import org.apache.spark.api.java.JavaPairRDD; +import org.apache.spark.api.java.Optional; import org.apache.spark.api.java.function.Function3; import org.apache.spark.api.java.function.Function4; import org.apache.spark.streaming.api.java.JavaPairDStream; @@ -51,10 +52,8 @@ public void testAPI() { JavaPairRDD initialRDD = null; JavaPairDStream wordsDstream = null; - final Function4, State, Optional> - mappingFunc = + Function4, State, Optional> mappingFunc = new Function4, State, Optional>() { - @Override public Optional call( Time time, String word, Optional one, State state) { @@ -76,11 +75,10 @@ public Optional call( .partitioner(new HashPartitioner(10)) .timeout(Durations.seconds(10))); - JavaPairDStream stateSnapshots = stateDstream.stateSnapshots(); + stateDstream.stateSnapshots(); - final Function3, State, Double> mappingFunc2 = + Function3, State, Double> mappingFunc2 = new Function3, State, Double>() { - @Override public Double call(String key, Optional one, State state) { // Use all State's methods here @@ -95,13 +93,13 @@ public Double call(String key, Optional one, State state) { JavaMapWithStateDStream stateDstream2 = wordsDstream.mapWithState( - StateSpec.function(mappingFunc2) + StateSpec.function(mappingFunc2) .initialState(initialRDD) .numPartitions(10) .partitioner(new HashPartitioner(10)) .timeout(Durations.seconds(10))); - JavaPairDStream stateSnapshots2 = stateDstream2.stateSnapshots(); + stateDstream2.stateSnapshots(); } @Test @@ -126,41 +124,29 @@ public void testBasicFunction() { Collections.emptySet() ); + @SuppressWarnings("unchecked") List>> stateData = Arrays.asList( Collections.>emptySet(), - Sets.newHashSet(new Tuple2("a", 1)), - Sets.newHashSet(new Tuple2("a", 2), new Tuple2("b", 1)), - Sets.newHashSet( - new Tuple2("a", 3), - new Tuple2("b", 2), - new Tuple2("c", 1)), - Sets.newHashSet( - new Tuple2("a", 4), - new Tuple2("b", 3), - new Tuple2("c", 1)), - Sets.newHashSet( - new Tuple2("a", 5), - new Tuple2("b", 3), - new Tuple2("c", 1)), - Sets.newHashSet( - new Tuple2("a", 5), - new Tuple2("b", 3), - new Tuple2("c", 1)) + Sets.newHashSet(new Tuple2<>("a", 1)), + Sets.newHashSet(new Tuple2<>("a", 2), new Tuple2<>("b", 1)), + Sets.newHashSet(new Tuple2<>("a", 3), new Tuple2<>("b", 2), new Tuple2<>("c", 1)), + Sets.newHashSet(new Tuple2<>("a", 4), new Tuple2<>("b", 3), new Tuple2<>("c", 1)), + Sets.newHashSet(new Tuple2<>("a", 5), new Tuple2<>("b", 3), new Tuple2<>("c", 1)), + Sets.newHashSet(new Tuple2<>("a", 5), new Tuple2<>("b", 3), new Tuple2<>("c", 1)) ); Function3, State, Integer> mappingFunc = new Function3, State, Integer>() { - @Override - public Integer call(String key, Optional value, State state) throws Exception { - int sum = value.or(0) + (state.exists() ? state.get() : 0); + public Integer call(String key, Optional value, State state) { + int sum = value.orElse(0) + (state.exists() ? state.get() : 0); state.update(sum); return sum; } }; testOperation( inputData, - StateSpec.function(mappingFunc), + StateSpec.function(mappingFunc), outputData, stateData); } @@ -175,27 +161,25 @@ private void testOperation( JavaMapWithStateDStream mapWithStateDStream = JavaPairDStream.fromJavaDStream(inputStream.map(new Function>() { @Override - public Tuple2 call(K x) throws Exception { - return new Tuple2(x, 1); + public Tuple2 call(K x) { + return new Tuple2<>(x, 1); } })).mapWithState(mapWithStateSpec); final List> collectedOutputs = - Collections.synchronizedList(Lists.>newArrayList()); - mapWithStateDStream.foreachRDD(new Function, Void>() { + Collections.synchronizedList(new ArrayList>()); + mapWithStateDStream.foreachRDD(new VoidFunction>() { @Override - public Void call(JavaRDD rdd) throws Exception { + public void call(JavaRDD rdd) { collectedOutputs.add(Sets.newHashSet(rdd.collect())); - return null; } }); final List>> collectedStateSnapshots = - Collections.synchronizedList(Lists.>>newArrayList()); - mapWithStateDStream.stateSnapshots().foreachRDD(new Function, Void>() { + Collections.synchronizedList(new ArrayList>>()); + mapWithStateDStream.stateSnapshots().foreachRDD(new VoidFunction>() { @Override - public Void call(JavaPairRDD rdd) throws Exception { + public void call(JavaPairRDD rdd) { collectedStateSnapshots.add(Sets.newHashSet(rdd.collect())); - return null; } }); BatchCounter batchCounter = new BatchCounter(ssc.ssc()); diff --git a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java index 7a8ef9d14784..d09258e0e4a8 100644 --- a/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java +++ b/streaming/src/test/java/org/apache/spark/streaming/JavaReceiverAPISuite.java @@ -18,13 +18,14 @@ package org.apache.spark.streaming; import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.api.java.function.VoidFunction; import org.apache.spark.streaming.api.java.JavaDStream; import org.apache.spark.streaming.api.java.JavaReceiverInputDStream; import org.apache.spark.streaming.api.java.JavaStreamingContext; -import static org.junit.Assert.*; import com.google.common.io.Closeables; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; @@ -68,12 +69,11 @@ public String call(String v1) { return v1 + "."; } }); - mapped.foreachRDD(new Function, Void>() { + mapped.foreachRDD(new VoidFunction>() { @Override - public Void call(JavaRDD rdd) { + public void call(JavaRDD rdd) { long count = rdd.count(); dataCounter.addAndGet(count); - return null; } }); @@ -90,7 +90,7 @@ public Void call(JavaRDD rdd) { Thread.sleep(100); } ssc.stop(); - assertTrue(dataCounter.get() > 0); + Assert.assertTrue(dataCounter.get() > 0); } finally { server.stop(); } @@ -98,8 +98,8 @@ public Void call(JavaRDD rdd) { private static class JavaSocketReceiver extends Receiver { - String host = null; - int port = -1; + private String host = null; + private int port = -1; JavaSocketReceiver(String host_ , int port_) { super(StorageLevel.MEMORY_AND_DISK()); diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index cd28d3cf408d..4d04138da01f 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import java.io.{ObjectOutputStream, ByteArrayOutputStream, ByteArrayInputStream, File} +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File, ObjectOutputStream} import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag @@ -34,9 +34,31 @@ import org.scalatest.concurrent.Eventually._ import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite, TestUtils} -import org.apache.spark.streaming.dstream.{DStream, FileInputDStream} +import org.apache.spark.rdd.RDD +import org.apache.spark.streaming.dstream._ import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{MutableURLClassLoader, Clock, ManualClock, Utils} +import org.apache.spark.util.{Clock, ManualClock, MutableURLClassLoader, ResetSystemProperties, + Utils} + +/** + * A input stream that records the times of restore() invoked + */ +private[streaming] +class CheckpointInputDStream(ssc_ : StreamingContext) extends InputDStream[Int](ssc_) { + protected[streaming] override val checkpointData = new FileInputDStreamCheckpointData + override def start(): Unit = { } + override def stop(): Unit = { } + override def compute(time: Time): Option[RDD[Int]] = Some(ssc.sc.makeRDD(Seq(1))) + private[streaming] + class FileInputDStreamCheckpointData extends DStreamCheckpointData(this) { + @transient + var restoredTimes = 0 + override def restore() { + restoredTimes += 1 + super.restore() + } + } +} /** * A trait of that can be mixed in to get methods for testing DStream operations under @@ -110,7 +132,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => new StreamingContext(SparkContext.getOrCreate(conf), batchDuration) } - private def generateOutput[V: ClassTag]( + protected def generateOutput[V: ClassTag]( ssc: StreamingContext, targetBatchTime: Time, checkpointDir: String, @@ -175,7 +197,8 @@ trait DStreamCheckpointTester { self: SparkFunSuite => * the checkpointing of a DStream's RDDs as well as the checkpointing of * the whole DStream graph. */ -class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { +class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester + with ResetSystemProperties { var ssc: StreamingContext = null @@ -187,9 +210,12 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { } override def afterFunction() { - super.afterFunction() - if (ssc != null) { ssc.stop() } - Utils.deleteRecursively(new File(checkpointDir)) + try { + if (ssc != null) { ssc.stop() } + Utils.deleteRecursively(new File(checkpointDir)) + } finally { + super.afterFunction() + } } test("basic rdd checkpoints + dstream graph checkpoint recovery") { @@ -715,6 +741,33 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester { } } + test("DStreamCheckpointData.restore invoking times") { + withStreamingContext(new StreamingContext(conf, batchDuration)) { ssc => + ssc.checkpoint(checkpointDir) + val inputDStream = new CheckpointInputDStream(ssc) + val checkpointData = inputDStream.checkpointData + val mappedDStream = inputDStream.map(_ + 100) + val outputStream = new TestOutputStreamWithPartitions(mappedDStream) + outputStream.register() + // do two more times output + mappedDStream.foreachRDD(rdd => rdd.count()) + mappedDStream.foreachRDD(rdd => rdd.count()) + assert(checkpointData.restoredTimes === 0) + val batchDurationMillis = ssc.progressListener.batchDuration + generateOutput(ssc, Time(batchDurationMillis * 3), checkpointDir, stopSparkContext = true) + assert(checkpointData.restoredTimes === 0) + } + logInfo("*********** RESTARTING ************") + withStreamingContext(new StreamingContext(checkpointDir)) { ssc => + val checkpointData = + ssc.graph.getInputStreams().head.asInstanceOf[CheckpointInputDStream].checkpointData + assert(checkpointData.restoredTimes === 1) + ssc.start() + ssc.stop() + assert(checkpointData.restoredTimes === 1) + } + } + // This tests whether spark can deserialize array object // refer to SPARK-5569 test("recovery from checkpoint contains array object") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala index 9b5e4dc819a2..e897de3cba6d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamClosureSuite.scala @@ -33,13 +33,18 @@ class DStreamClosureSuite extends SparkFunSuite with BeforeAndAfterAll { private var ssc: StreamingContext = null override def beforeAll(): Unit = { + super.beforeAll() val sc = new SparkContext("local", "test") ssc = new StreamingContext(sc, Seconds(1)) } override def afterAll(): Unit = { - ssc.stop(stopSparkContext = true) - ssc = null + try { + ssc.stop(stopSparkContext = true) + ssc = null + } finally { + super.afterAll() + } } test("user provided closures are actually cleaned") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala index bc223e648a41..94f1bcebc3a3 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/DStreamScopeSuite.scala @@ -21,11 +21,11 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.rdd.{RDD, RDDOperationScope} import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.ui.UIUtils import org.apache.spark.util.ManualClock -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} /** * Tests whether scope information is passed from DStream operations to RDDs correctly. @@ -35,13 +35,18 @@ class DStreamScopeSuite extends SparkFunSuite with BeforeAndAfter with BeforeAnd private val batchDuration: Duration = Seconds(1) override def beforeAll(): Unit = { + super.beforeAll() val conf = new SparkConf().setMaster("local").setAppName("test") conf.set("spark.streaming.clock", classOf[ManualClock].getName()) ssc = new StreamingContext(new SparkContext(conf), batchDuration) } override def afterAll(): Unit = { - ssc.stop(stopSparkContext = true) + try { + ssc.stop(stopSparkContext = true) + } finally { + super.afterAll() + } } before { assertPropertiesNotSet() } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala index e82c2fa4e72a..6a0b0a1d47bc 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/FailureSuite.scala @@ -21,7 +21,7 @@ import java.io.File import org.scalatest.BeforeAndAfter -import org.apache.spark.{SparkFunSuite, Logging} +import org.apache.spark.{Logging, SparkFunSuite} import org.apache.spark.util.Utils /** diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index 3a3176b91b1e..2e231601c395 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -17,30 +17,30 @@ package org.apache.spark.streaming -import java.io.{File, BufferedWriter, OutputStreamWriter} -import java.net.{Socket, SocketException, ServerSocket} +import java.io.{BufferedWriter, File, OutputStreamWriter} +import java.net.{ServerSocket, Socket, SocketException} import java.nio.charset.Charset -import java.util.concurrent.{CountDownLatch, Executors, TimeUnit, ArrayBlockingQueue} +import java.util.concurrent.{ArrayBlockingQueue, CountDownLatch, Executors, TimeUnit} import java.util.concurrent.atomic.AtomicInteger -import scala.collection.mutable.{SynchronizedBuffer, ArrayBuffer, SynchronizedQueue} +import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer, SynchronizedQueue} import scala.language.postfixOps import com.google.common.io.Files -import org.apache.hadoop.io.{Text, LongWritable} -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.apache.hadoop.fs.Path +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapreduce.lib.input.TextInputFormat import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.Eventually._ import org.apache.spark.Logging import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel -import org.apache.spark.streaming.scheduler.{StreamingListenerBatchCompleted, StreamingListener} -import org.apache.spark.util.{ManualClock, Utils} import org.apache.spark.streaming.dstream.{InputDStream, ReceiverInputDStream} import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.Receiver +import org.apache.spark.streaming.scheduler.{StreamingListener, StreamingListenerBatchCompleted} +import org.apache.spark.util.{ManualClock, Utils} class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala index 6b21433f1781..2984fd2b298d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MapWithStateSuite.scala @@ -22,12 +22,12 @@ import java.io.File import scala.collection.mutable.{ArrayBuffer, SynchronizedBuffer} import scala.reflect.ClassTag -import org.scalatest.PrivateMethodTester._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll} +import org.scalatest.PrivateMethodTester._ +import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} import org.apache.spark.streaming.dstream.{DStream, InternalMapWithStateDStream, MapWithStateDStream, MapWithStateDStreamImpl} import org.apache.spark.util.{ManualClock, Utils} -import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} class MapWithStateSuite extends SparkFunSuite with DStreamCheckpointTester with BeforeAndAfterAll with BeforeAndAfter { @@ -49,14 +49,19 @@ class MapWithStateSuite extends SparkFunSuite } override def beforeAll(): Unit = { + super.beforeAll() val conf = new SparkConf().setMaster("local").setAppName("MapWithStateSuite") conf.set("spark.streaming.clock", classOf[ManualClock].getName()) sc = new SparkContext(conf) } override def afterAll(): Unit = { - if (sc != null) { - sc.stop() + try { + if (sc != null) { + sc.stop() + } + } finally { + super.afterAll() } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 0e64b57e0ffd..4e56dfbd424b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -17,23 +17,21 @@ package org.apache.spark.streaming -import org.apache.spark.Logging -import org.apache.spark.streaming.dstream.DStream -import org.apache.spark.util.Utils - -import scala.util.Random -import scala.collection.mutable.ArrayBuffer -import scala.reflect.ClassTag - import java.io.{File, IOException} import java.nio.charset.Charset import java.util.UUID -import com.google.common.io.Files +import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag +import scala.util.Random -import org.apache.hadoop.fs.Path +import com.google.common.io.Files import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path +import org.apache.spark.Logging +import org.apache.spark.streaming.dstream.DStream +import org.apache.spark.util.Utils private[streaming] object MasterFailureTest extends Logging { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index c17fb7238151..dd16fc3ecaf5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -39,8 +39,6 @@ import org.apache.spark.storage._ import org.apache.spark.streaming.receiver._ import org.apache.spark.streaming.util._ import org.apache.spark.util.{ManualClock, Utils} -import WriteAheadLogBasedBlockHandler._ -import WriteAheadLogSuite._ class ReceivedBlockHandlerSuite extends SparkFunSuite @@ -48,6 +46,9 @@ class ReceivedBlockHandlerSuite with Matchers with Logging { + import WriteAheadLogBasedBlockHandler._ + import WriteAheadLogSuite._ + val conf = new SparkConf() .set("spark.streaming.receiver.writeAheadLog.rollingIntervalSecs", "1") .set("spark.app.id", "streaming-test") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala index 6d388d9624d9..a4871b460eb4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverInputDStreamSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.scalatest.BeforeAndAfterAll +import org.apache.spark.{SparkConf, SparkEnv} import org.apache.spark.rdd.BlockRDD import org.apache.spark.storage.{StorageLevel, StreamBlockId} import org.apache.spark.streaming.dstream.ReceiverInputDStream @@ -28,12 +29,15 @@ import org.apache.spark.streaming.rdd.WriteAheadLogBackedBlockRDD import org.apache.spark.streaming.receiver.{BlockManagerBasedStoreResult, Receiver, WriteAheadLogBasedStoreResult} import org.apache.spark.streaming.scheduler.ReceivedBlockInfo import org.apache.spark.streaming.util.{WriteAheadLogRecordHandle, WriteAheadLogUtils} -import org.apache.spark.{SparkConf, SparkEnv} class ReceiverInputDStreamSuite extends TestSuiteBase with BeforeAndAfterAll { override def afterAll(): Unit = { - StreamingContext.getActive().map { _.stop() } + try { + StreamingContext.getActive().map { _.stop() } + } finally { + super.afterAll() + } } testWithoutWAL("createBlockRDD creates empty BlockRDD when no block info") { receiverStream => diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala index 01279b34f73d..917232c9cdd6 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceiverSuite.scala @@ -24,8 +24,8 @@ import java.util.concurrent.Semaphore import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.scalatest.concurrent.Timeouts import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkConf diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala index c4a01eaea739..ea32bbf95ce5 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StateMapSuite.scala @@ -17,15 +17,23 @@ package org.apache.spark.streaming +import org.apache.spark.streaming.rdd.MapWithStateRDDRecord + import scala.collection.{immutable, mutable, Map} +import scala.reflect.ClassTag import scala.util.Random -import org.apache.spark.SparkFunSuite +import com.esotericsoftware.kryo.{Kryo, KryoSerializable} +import com.esotericsoftware.kryo.io.{Output, Input} + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.serializer._ import org.apache.spark.streaming.util.{EmptyStateMap, OpenHashMapBasedStateMap, StateMap} -import org.apache.spark.util.Utils class StateMapSuite extends SparkFunSuite { + private val conf = new SparkConf() + test("EmptyStateMap") { val map = new EmptyStateMap[Int, Int] intercept[scala.NotImplementedError] { @@ -128,17 +136,17 @@ class StateMapSuite extends SparkFunSuite { map1.put(2, 200, 2) testSerialization(map1, "error deserializing and serialized map with data + no delta") - val map2 = map1.copy() + val map2 = map1.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] // Do not test compaction - assert(map2.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + assert(map2.shouldCompact === false) testSerialization(map2, "error deserializing and serialized map with 1 delta + no new data") map2.put(3, 300, 3) map2.put(4, 400, 4) testSerialization(map2, "error deserializing and serialized map with 1 delta + new data") - val map3 = map2.copy() - assert(map3.asInstanceOf[OpenHashMapBasedStateMap[_, _]].shouldCompact === false) + val map3 = map2.copy().asInstanceOf[OpenHashMapBasedStateMap[Int, Int]] + assert(map3.shouldCompact === false) testSerialization(map3, "error deserializing and serialized map with 2 delta + no new data") map3.put(3, 600, 3) map3.remove(2) @@ -267,18 +275,25 @@ class StateMapSuite extends SparkFunSuite { assertMap(stateMap, refMap.toMap, time, "Final state map does not match reference map") } - private def testSerialization[MapType <: StateMap[Int, Int]]( - map: MapType, msg: String): MapType = { - val deserMap = Utils.deserialize[MapType]( - Utils.serialize(map), Thread.currentThread().getContextClassLoader) + private def testSerialization[T: ClassTag]( + map: OpenHashMapBasedStateMap[T, T], msg: String): OpenHashMapBasedStateMap[T, T] = { + testSerialization(new JavaSerializer(conf), map, msg) + testSerialization(new KryoSerializer(conf), map, msg) + } + + private def testSerialization[T : ClassTag]( + serializer: Serializer, + map: OpenHashMapBasedStateMap[T, T], + msg: String): OpenHashMapBasedStateMap[T, T] = { + val deserMap = serializeAndDeserialize(serializer, map) assertMap(deserMap, map, 1, msg) deserMap } // Assert whether all the data and operations on a state map matches that of a reference state map - private def assertMap( - mapToTest: StateMap[Int, Int], - refMapToTestWith: StateMap[Int, Int], + private def assertMap[T]( + mapToTest: StateMap[T, T], + refMapToTestWith: StateMap[T, T], time: Long, msg: String): Unit = { withClue(msg) { @@ -321,4 +336,59 @@ class StateMapSuite extends SparkFunSuite { } } } + + test("OpenHashMapBasedStateMap - serializing and deserializing with KryoSerializable states") { + val map = new OpenHashMapBasedStateMap[KryoState, KryoState]() + map.put(new KryoState("a"), new KryoState("b"), 1) + testSerialization( + new KryoSerializer(conf), map, "error deserializing and serialized KryoSerializable states") + } + + test("EmptyStateMap - serializing and deserializing") { + val map = StateMap.empty[KryoState, KryoState] + // Since EmptyStateMap doesn't contains any date, KryoState won't break JavaSerializer. + assert(serializeAndDeserialize(new JavaSerializer(conf), map). + isInstanceOf[EmptyStateMap[KryoState, KryoState]]) + assert(serializeAndDeserialize(new KryoSerializer(conf), map). + isInstanceOf[EmptyStateMap[KryoState, KryoState]]) + } + + test("MapWithStateRDDRecord - serializing and deserializing with KryoSerializable states") { + val map = new OpenHashMapBasedStateMap[KryoState, KryoState]() + map.put(new KryoState("a"), new KryoState("b"), 1) + + val record = + MapWithStateRDDRecord[KryoState, KryoState, KryoState](map, Seq(new KryoState("c"))) + val deserRecord = serializeAndDeserialize(new KryoSerializer(conf), record) + assert(!(record eq deserRecord)) + assert(record.stateMap.getAll().toSeq === deserRecord.stateMap.getAll().toSeq) + assert(record.mappedData === deserRecord.mappedData) + } + + private def serializeAndDeserialize[T: ClassTag](serializer: Serializer, t: T): T = { + val serializerInstance = serializer.newInstance() + serializerInstance.deserialize[T]( + serializerInstance.serialize(t), Thread.currentThread().getContextClassLoader) + } +} + +/** A class that only supports Kryo serialization. */ +private[streaming] final class KryoState(var state: String) extends KryoSerializable { + + override def write(kryo: Kryo, output: Output): Unit = { + kryo.writeClassAndObject(output, state) + } + + override def read(kryo: Kryo, input: Input): Unit = { + state = kryo.readClassAndObject(input).asInstanceOf[String] + } + + override def equals(other: Any): Boolean = other match { + case that: KryoState => state == that.state + case _ => false + } + + override def hashCode(): Int = { + if (state == null) 0 else state.hashCode() + } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala index 860fac29c0ee..0ae4c4598803 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingContextSuite.scala @@ -81,9 +81,9 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo test("from conf with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10s") + myConf.set("spark.dummyTimeConfig", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert(ssc.conf.getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) } test("from existing SparkContext") { @@ -93,26 +93,27 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo test("from existing SparkContext with settings") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10s") + myConf.set("spark.dummyTimeConfig", "10s") ssc = new StreamingContext(myConf, batchDuration) - assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert(ssc.conf.getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) } test("from checkpoint") { val myConf = SparkContext.updatedConf(new SparkConf(false), master, appName) - myConf.set("spark.cleaner.ttl", "10s") + myConf.set("spark.dummyTimeConfig", "10s") val ssc1 = new StreamingContext(myConf, batchDuration) addInputStream(ssc1).register() ssc1.start() val cp = new Checkpoint(ssc1, Time(1000)) assert( Utils.timeStringAsSeconds(cp.sparkConfPairs - .toMap.getOrElse("spark.cleaner.ttl", "-1")) === 10) + .toMap.getOrElse("spark.dummyTimeConfig", "-1")) === 10) ssc1.stop() val newCp = Utils.deserialize[Checkpoint](Utils.serialize(cp)) - assert(newCp.createSparkConf().getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert( + newCp.createSparkConf().getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) ssc = new StreamingContext(null, newCp, null) - assert(ssc.conf.getTimeAsSeconds("spark.cleaner.ttl", "-1") === 10) + assert(ssc.conf.getTimeAsSeconds("spark.dummyTimeConfig", "-1") === 10) } test("checkPoint from conf") { @@ -288,7 +289,7 @@ class StreamingContextSuite extends SparkFunSuite with BeforeAndAfter with Timeo test("stop gracefully") { val conf = new SparkConf().setMaster(master).setAppName(appName) - conf.set("spark.cleaner.ttl", "3600s") + conf.set("spark.dummyTimeConfig", "3600s") sc = new SparkContext(conf) for (i <- 1 to 4) { logInfo("==================================\n\n\n") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala index 04cd5bdc26be..628a5082074d 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/StreamingListenerSuite.scala @@ -18,20 +18,20 @@ package org.apache.spark.streaming import scala.collection.mutable.{ArrayBuffer, HashMap, SynchronizedBuffer, SynchronizedMap} -import scala.concurrent.Future import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future + +import org.scalatest.Matchers +import org.scalatest.concurrent.Eventually._ +import org.scalatest.time.SpanSugar._ +import org.apache.spark.Logging import org.apache.spark.SparkException import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming.dstream.DStream import org.apache.spark.streaming.receiver.Receiver import org.apache.spark.streaming.scheduler._ -import org.scalatest.Matchers -import org.scalatest.concurrent.Eventually._ -import org.scalatest.time.SpanSugar._ -import org.apache.spark.Logging - class StreamingListenerSuite extends TestSuiteBase with Matchers { val input = (1 to 4).map(Seq(_)).toSeq diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index be0f4636a6cb..54eff2b21429 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming -import java.io.{ObjectInputStream, IOException} +import java.io.{IOException, ObjectInputStream} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.SynchronizedBuffer @@ -25,13 +25,13 @@ import scala.language.implicitConversions import scala.reflect.ClassTag import org.scalatest.BeforeAndAfter -import org.scalatest.time.{Span, Seconds => ScalaTestSeconds} import org.scalatest.concurrent.Eventually.timeout import org.scalatest.concurrent.PatienceConfiguration +import org.scalatest.time.{Seconds => ScalaTestSeconds, Span} import org.apache.spark.{Logging, SparkConf, SparkFunSuite} import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.dstream.{DStream, InputDStream, ForEachDStream} +import org.apache.spark.streaming.dstream.{DStream, ForEachDStream, InputDStream} import org.apache.spark.streaming.scheduler._ import org.apache.spark.util.{ManualClock, Utils} diff --git a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala index a5744a9009c1..c4ecebcacf3c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/UISeleniumSuite.scala @@ -38,14 +38,19 @@ class UISeleniumSuite implicit var webDriver: WebDriver = _ override def beforeAll(): Unit = { + super.beforeAll() webDriver = new HtmlUnitDriver { getWebClient.setCssErrorHandler(new SparkUICssErrorHandler) } } override def afterAll(): Unit = { - if (webDriver != null) { - webDriver.quit() + try { + if (webDriver != null) { + webDriver.quit() + } + } finally { + super.afterAll() } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala index c39ad05f4152..c7d085ec0799 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.streaming -import org.apache.spark.streaming.dstream.DStream import org.apache.spark.storage.StorageLevel +import org.apache.spark.streaming.dstream.DStream class WindowOperationsSuite extends TestSuiteBase { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala index aa95bd33dda9..5b13fd6ad611 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/MapWithStateRDDSuite.scala @@ -26,8 +26,8 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.streaming.util.OpenHashMapBasedStateMap import org.apache.spark.streaming.{State, Time} +import org.apache.spark.streaming.util.OpenHashMapBasedStateMap import org.apache.spark.util.Utils class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with BeforeAndAfterAll { @@ -36,6 +36,7 @@ class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with B private var checkpointDir: File = _ override def beforeAll(): Unit = { + super.beforeAll() sc = new SparkContext( new SparkConf().setMaster("local").setAppName("MapWithStateRDDSuite")) checkpointDir = Utils.createTempDir() @@ -43,10 +44,14 @@ class MapWithStateRDDSuite extends SparkFunSuite with RDDCheckpointTester with B } override def afterAll(): Unit = { - if (sc != null) { - sc.stop() + try { + if (sc != null) { + sc.stop() + } + Utils.deleteRecursively(checkpointDir) + } finally { + super.afterAll() } - Utils.deleteRecursively(checkpointDir) } override def sparkContext: SparkContext = sc diff --git a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala index cb017b798b2a..79ac833c1846 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/rdd/WriteAheadLogBackedBlockRDDSuite.scala @@ -23,10 +23,10 @@ import scala.util.Random import org.apache.hadoop.conf.Configuration import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} +import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} import org.apache.spark.storage.{BlockId, BlockManager, StorageLevel, StreamBlockId} import org.apache.spark.streaming.util.{FileBasedWriteAheadLogSegment, FileBasedWriteAheadLogWriter} import org.apache.spark.util.Utils -import org.apache.spark.{SparkConf, SparkContext, SparkException, SparkFunSuite} class WriteAheadLogBackedBlockRDDSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfterEach { @@ -42,22 +42,32 @@ class WriteAheadLogBackedBlockRDDSuite var dir: File = null override def beforeEach(): Unit = { + super.beforeEach() dir = Utils.createTempDir() } override def afterEach(): Unit = { - Utils.deleteRecursively(dir) + try { + Utils.deleteRecursively(dir) + } finally { + super.afterEach() + } } override def beforeAll(): Unit = { + super.beforeAll() sparkContext = new SparkContext(conf) blockManager = sparkContext.env.blockManager } override def afterAll(): Unit = { // Copied from LocalSparkContext, simpler than to introduced test dependencies to core tests. - sparkContext.stop() - System.clearProperty("spark.driver.port") + try { + sparkContext.stop() + System.clearProperty("spark.driver.port") + } finally { + super.afterAll() + } } test("Read data available in both block manager and write ahead log") { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala index 92ad9fe52b77..f5ec0ff60aa2 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/receiver/BlockGeneratorSuite.scala @@ -22,13 +22,13 @@ import scala.language.reflectiveCalls import org.scalatest.BeforeAndAfter import org.scalatest.Matchers._ -import org.scalatest.concurrent.Timeouts._ import org.scalatest.concurrent.Eventually._ +import org.scalatest.concurrent.Timeouts._ import org.scalatest.time.SpanSugar._ +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.storage.StreamBlockId import org.apache.spark.util.ManualClock -import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} class BlockGeneratorSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala index f5248acf712b..a7e365649d3e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/scheduler/InputInfoTrackerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.scheduler import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.streaming.{Time, Duration, StreamingContext} +import org.apache.spark.streaming.{Duration, StreamingContext, Time} class InputInfoTrackerSuite extends SparkFunSuite with BeforeAndAfter { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala index ef1e89df3130..b5d6a24ce8dd 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.streaming.util import java.io._ import java.nio.ByteBuffer import java.util.{Iterator => JIterator} +import java.util.concurrent.{CountDownLatch, RejectedExecutionException, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicInteger -import java.util.concurrent.{RejectedExecutionException, TimeUnit, CountDownLatch, ThreadPoolExecutor} import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer @@ -31,17 +31,16 @@ import scala.language.{implicitConversions, postfixOps} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.mockito.ArgumentCaptor -import org.mockito.Matchers.{eq => meq} -import org.mockito.Matchers._ +import org.mockito.Matchers.{eq => meq, _} import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach, PrivateMethodTester} import org.scalatest.concurrent.Eventually import org.scalatest.concurrent.Eventually._ -import org.scalatest.{PrivateMethodTester, BeforeAndAfterEach, BeforeAndAfter} import org.scalatest.mock.MockitoSugar -import org.apache.spark.streaming.scheduler._ -import org.apache.spark.util.{CompletionIterator, ThreadUtils, ManualClock, Utils} import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.streaming.scheduler._ +import org.apache.spark.util.{CompletionIterator, ManualClock, ThreadUtils, Utils} /** Common tests for WriteAheadLogs that we would like to test with different configurations. */ abstract class CommonWriteAheadLogTests( @@ -432,6 +431,7 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( private val queueLength = PrivateMethod[Int]('getQueueLength) override def beforeEach(): Unit = { + super.beforeEach() wal = mock[WriteAheadLog] walHandle = mock[WriteAheadLogRecordHandle] walBatchingThreadPool = ThreadUtils.newDaemonFixedThreadPool(8, "wal-test-thread-pool") @@ -439,8 +439,12 @@ class BatchedWriteAheadLogSuite extends CommonWriteAheadLogTests( } override def afterEach(): Unit = { - if (walBatchingExecutionContext != null) { - walBatchingExecutionContext.shutdownNow() + try { + if (walBatchingExecutionContext != null) { + walBatchingExecutionContext.shutdownNow() + } + } finally { + super.afterEach() } } @@ -700,7 +704,8 @@ object WriteAheadLogSuite { val logDirectoryPath = new Path(directory) val fileSystem = HdfsUtils.getFileSystemForPath(logDirectoryPath, hadoopConf) - if (fileSystem.exists(logDirectoryPath) && fileSystem.getFileStatus(logDirectoryPath).isDir) { + if (fileSystem.exists(logDirectoryPath) && + fileSystem.getFileStatus(logDirectoryPath).isDirectory) { fileSystem.listStatus(logDirectoryPath).map { _.getPath() }.sortBy { _.getName().split("-")(1).toLong }.map { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala index bfc5b0cf60fb..2a41177a5e63 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/util/WriteAheadLogUtilsSuite.scala @@ -24,7 +24,7 @@ import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration -import org.apache.spark.{SparkException, SparkConf, SparkFunSuite} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite} import org.apache.spark.util.Utils class WriteAheadLogUtilsSuite extends SparkFunSuite { diff --git a/tags/README.md b/tags/README.md new file mode 100644 index 000000000000..01e5126945eb --- /dev/null +++ b/tags/README.md @@ -0,0 +1 @@ +This module includes annotations in Java that are used to annotate test suites. diff --git a/tags/pom.xml b/tags/pom.xml index ca93722e7334..9e4610dae7a6 100644 --- a/tags/pom.xml +++ b/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 1e64f280e5be..30cbb6a5a59c 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 5155daa6d17b..a947fac1d751 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -23,8 +23,8 @@ import java.util.jar.JarFile import scala.collection.mutable import scala.collection.JavaConverters._ -import scala.reflect.runtime.universe.runtimeMirror import scala.reflect.runtime.{universe => unv} +import scala.reflect.runtime.universe.runtimeMirror import scala.util.Try /** diff --git a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala index 856ea177a9a1..ccd8fd3969f6 100644 --- a/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala +++ b/tools/src/main/scala/org/apache/spark/tools/JavaAPICompletenessChecker.scala @@ -17,16 +17,16 @@ package org.apache.spark.tools -import java.lang.reflect.{Type, Method} +import java.lang.reflect.{Method, Type} import scala.collection.mutable.ArrayBuffer import scala.language.existentials import org.apache.spark._ import org.apache.spark.api.java._ -import org.apache.spark.rdd.{RDD, DoubleRDDFunctions, PairRDDFunctions, OrderedRDDFunctions} +import org.apache.spark.rdd.{DoubleRDDFunctions, OrderedRDDFunctions, PairRDDFunctions, RDD} import org.apache.spark.streaming.StreamingContext -import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaDStream, JavaStreamingContext} +import org.apache.spark.streaming.api.java.{JavaDStream, JavaPairDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.{DStream, PairDStreamFunctions} @@ -161,7 +161,7 @@ object JavaAPICompletenessChecker { } case "scala.Option" => { if (isReturnType) { - ParameterizedType("com.google.common.base.Optional", parameters.map(applySubs)) + ParameterizedType("org.apache.spark.api.java.Optional", parameters.map(applySubs)) } else { applySubs(parameters(0)) } diff --git a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala index 0dc2861253f1..8a5c7c0e730e 100644 --- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala +++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala @@ -20,8 +20,8 @@ package org.apache.spark.tools import java.util.concurrent.{CountDownLatch, Executors} import java.util.concurrent.atomic.AtomicLong -import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.util.Utils diff --git a/unsafe/pom.xml b/unsafe/pom.xml index a1c1111364ee..21fef3415adc 100644 --- a/unsafe/pom.xml +++ b/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala b/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala index 12a002befa0a..b3bbd68827b0 100644 --- a/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala +++ b/unsafe/src/test/scala/org/apache/spark/unsafe/types/UTF8StringPropertyCheckSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.unsafe.types import org.apache.commons.lang3.StringUtils - import org.scalacheck.{Arbitrary, Gen} import org.scalatest.prop.GeneratorDrivenPropertyChecks // scalastyle:off diff --git a/yarn/pom.xml b/yarn/pom.xml index 989b820bec9e..a8c122fd40a1 100644 --- a/yarn/pom.xml +++ b/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.10 - 1.6.0-SNAPSHOT + 2.0.0-SNAPSHOT ../pom.xml diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala index 56e4741b9387..b8daa501af7f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/AMDelegationTokenRenewer.scala @@ -24,9 +24,9 @@ import scala.language.postfixOps import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.UserGroupInformation -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.ThreadUtils /* diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index fc742df73d73..cccc061647a7 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -17,23 +17,22 @@ package org.apache.spark.deploy.yarn -import scala.util.control.NonFatal - import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException import java.net.{Socket, URL} import java.util.concurrent.atomic.AtomicReference +import scala.util.control.NonFatal + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration -import org.apache.spark.rpc._ -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkEnv, - SparkException, SparkUserAppException} +import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.history.HistoryServer +import org.apache.spark.rpc._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, YarnSchedulerBackend} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util._ @@ -281,10 +280,10 @@ private[spark] class ApplicationMaster( .getOrElse("") val _sparkConf = if (sc != null) sc.getConf else sparkConf - val driverUrl = _rpcEnv.uriOf( - SparkEnv.driverActorSystemName, - RpcAddress(_sparkConf.get("spark.driver.host"), _sparkConf.get("spark.driver.port").toInt), - CoarseGrainedSchedulerBackend.ENDPOINT_NAME) + val driverUrl = RpcEndpointAddress( + _sparkConf.get("spark.driver.host"), + _sparkConf.get("spark.driver.port").toInt, + CoarseGrainedSchedulerBackend.ENDPOINT_NAME).toString allocator = client.register(driverUrl, driverRef, yarnConf, @@ -310,7 +309,6 @@ private[spark] class ApplicationMaster( port: String, isClusterMode: Boolean): RpcEndpointRef = { val driverEndpoint = rpcEnv.setupEndpointRef( - SparkEnv.driverActorSystemName, RpcAddress(host, port.toInt), YarnSchedulerBackend.ENDPOINT_NAME) amEndpoint = diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala index 17d9943c795e..5af3941c6023 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMasterArguments.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.yarn -import org.apache.spark.util.{MemoryParam, IntParam} +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ -import collection.mutable.ArrayBuffer +import org.apache.spark.util.{IntParam, MemoryParam} class ApplicationMasterArguments(val args: Array[String]) { var userJar: String = null diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 7742ec92eb4e..8cf438be587d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -28,22 +28,20 @@ import java.util.zip.{ZipEntry, ZipOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, ListBuffer, Map} import scala.reflect.runtime.universe -import scala.util.{Try, Success, Failure} +import scala.util.{Failure, Success, Try} import scala.util.control.NonFatal import com.google.common.base.Charsets.UTF_8 import com.google.common.base.Objects import com.google.common.io.Files - -import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.fs._ import org.apache.hadoop.fs.permission.FsPermission -import org.apache.hadoop.io.Text +import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier +import org.apache.hadoop.io.{DataOutputBuffer, Text} import org.apache.hadoop.mapreduce.MRJobConfig import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.hadoop.security.token.{TokenIdentifier, Token} +import org.apache.hadoop.security.token.{Token, TokenIdentifier} import org.apache.hadoop.util.StringUtils import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment @@ -55,8 +53,8 @@ import org.apache.hadoop.yarn.exceptions.ApplicationNotFoundException import org.apache.hadoop.yarn.util.Records import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkContext, SparkException} -import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle, YarnCommandBuilderUtils} import org.apache.spark.util.Utils private[spark] class Client( diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 3d3a966960e9..4ef05c5a846d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -25,7 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.util.{Records, ConverterUtils} +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.Logging diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala index 94feb6393fd6..9d99c0d93fd1 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorDelegationTokenUpdater.scala @@ -18,16 +18,16 @@ package org.apache.spark.deploy.yarn import java.util.concurrent.{Executors, TimeUnit} +import scala.util.control.NonFatal + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.security.{Credentials, UserGroupInformation} -import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.util.{ThreadUtils, Utils} -import scala.util.control.NonFatal - private[spark] class ExecutorDelegationTokenUpdater( sparkConf: SparkConf, hadoopConf: Configuration) extends Logging { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala index 2232ffba473b..31fa53e24b50 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ExecutorRunnable.scala @@ -25,12 +25,12 @@ import java.util.Collections import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, ListBuffer} -import org.apache.hadoop.fs.Path -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.NMClient import org.apache.hadoop.yarn.conf.YarnConfiguration diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index 4e044aa4788d..11426eb07c7e 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -30,7 +30,6 @@ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.RackResolver - import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala index d2a211f6711f..af83cf6a77d1 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnRMClient.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.yarn import java.util.{List => JList} -import scala.collection.JavaConverters._ import scala.collection.{Map, Set} +import scala.collection.JavaConverters._ import scala.util.Try import org.apache.hadoop.conf.Configuration diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 36a2d6142988..e286aed9f978 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -30,19 +30,19 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hdfs.security.token.delegation.DelegationTokenIdentifier import org.apache.hadoop.io.Text -import org.apache.hadoop.mapred.{Master, JobConf} +import org.apache.hadoop.mapred.{JobConf, Master} import org.apache.hadoop.security.Credentials import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.security.token.{Token, TokenIdentifier} -import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records.{ApplicationAccessType, ContainerId, Priority} +import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.util.ConverterUtils +import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.launcher.YarnCommandBuilderUtils -import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.util.Utils /** diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala index 0e27a2665e93..20e2030fce08 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -21,7 +21,7 @@ import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.yarn.api.records.YarnApplicationState -import org.apache.spark.{SparkException, Logging, SparkContext} +import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.deploy.yarn.{Client, ClientArguments, YarnSparkHadoopUtil} import org.apache.spark.launcher.SparkAppHandle import org.apache.spark.scheduler.TaskSchedulerImpl diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala index 4ebf3af12b38..029382133ddf 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnScheduler.scala @@ -18,7 +18,6 @@ package org.apache.spark.scheduler.cluster import org.apache.hadoop.yarn.util.RackResolver - import org.apache.log4j.{Level, Logger} import org.apache.spark._ diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 12494b01054b..cd24c704ece5 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -27,6 +27,7 @@ import scala.language.postfixOps import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files +import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.server.MiniYARNCluster import org.scalatest.{BeforeAndAfterAll, Matchers} @@ -59,10 +60,13 @@ abstract class BaseYarnClusterSuite protected var hadoopConfDir: File = _ private var logConfDir: File = _ + var oldSystemProperties: Properties = null + def newYarnConfig(): YarnConfiguration override def beforeAll() { super.beforeAll() + oldSystemProperties = SerializationUtils.clone(System.getProperties) tempDir = Utils.createTempDir() logConfDir = new File(tempDir, "log4j") @@ -115,9 +119,12 @@ abstract class BaseYarnClusterSuite } override def afterAll() { - yarnCluster.stop() - System.clearProperty("SPARK_YARN_MODE") - super.afterAll() + try { + yarnCluster.stop() + } finally { + System.setProperties(oldSystemProperties) + super.afterAll() + } } protected def runSpark( diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index 804dfecde786..4cffbb2e9b96 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -19,8 +19,8 @@ package org.apache.spark.deploy.yarn import java.net.URI -import org.scalatest.mock.MockitoSugar -import org.mockito.Mockito.when +import scala.collection.mutable.HashMap +import scala.collection.mutable.Map import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus @@ -28,16 +28,14 @@ import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.fs.Path import org.apache.hadoop.fs.permission.FsAction import org.apache.hadoop.yarn.api.records.LocalResource -import org.apache.hadoop.yarn.api.records.LocalResourceVisibility import org.apache.hadoop.yarn.api.records.LocalResourceType -import org.apache.hadoop.yarn.util.{Records, ConverterUtils} - -import scala.collection.mutable.HashMap -import scala.collection.mutable.Map +import org.apache.hadoop.yarn.api.records.LocalResourceVisibility +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import org.mockito.Mockito.when +import org.scalatest.mock.MockitoSugar import org.apache.spark.SparkFunSuite - class ClientDistributedCacheManagerSuite extends SparkFunSuite with MockitoSugar { class MockClientDistributedCacheManager extends ClientDistributedCacheManager { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala index e7f2501e7899..998bd1377d56 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientSuite.scala @@ -19,12 +19,14 @@ package org.apache.spark.deploy.yarn import java.io.File import java.net.URI +import java.util.Properties import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap => MutableHashMap} import scala.reflect.ClassTag import scala.util.Try +import org.apache.commons.lang3.SerializationUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.mapreduce.MRJobConfig @@ -39,16 +41,26 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfterAll, Matchers} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ResetSystemProperties, Utils} -class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll { +class ClientSuite extends SparkFunSuite with Matchers with BeforeAndAfterAll + with ResetSystemProperties { + + var oldSystemProperties: Properties = null override def beforeAll(): Unit = { + super.beforeAll() + oldSystemProperties = SerializationUtils.clone(System.getProperties) System.setProperty("SPARK_YARN_MODE", "true") } override def afterAll(): Unit = { - System.clearProperty("SPARK_YARN_MODE") + try { + System.setProperties(oldSystemProperties) + oldSystemProperties = null + } finally { + super.afterAll() + } } test("default Yarn application classpath") { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala index bd80036c5cfa..1dd2f93bb708 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnAllocatorSuite.scala @@ -25,15 +25,12 @@ import org.apache.hadoop.net.DNSToSwitchMapping import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.client.api.AMRMClient import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest -import org.scalatest.{BeforeAndAfterEach, Matchers} - -import org.scalatest.{BeforeAndAfterEach, Matchers} import org.mockito.Mockito._ +import org.scalatest.{BeforeAndAfterEach, Matchers} -import org.apache.spark.{SecurityManager, SparkFunSuite} -import org.apache.spark.SparkConf -import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy.yarn.YarnAllocator._ +import org.apache.spark.deploy.yarn.YarnSparkHadoopUtil._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler.SplitInfo @@ -72,13 +69,18 @@ class YarnAllocatorSuite extends SparkFunSuite with Matchers with BeforeAndAfter var containerNum = 0 override def beforeEach() { + super.beforeEach() rmClient = AMRMClient.createAMRMClient() rmClient.init(conf) rmClient.start() } override def afterEach() { - rmClient.stop() + try { + rmClient.stop() + } finally { + super.afterEach() + } } class MockSplitInfo(host: String) extends SplitInfo(null, host, null, 1, null) { diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index 3fafc91a166a..d3acaf229cc8 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -27,17 +27,16 @@ import org.apache.hadoop.hive.ql.metadata.HiveException import org.apache.hadoop.io.Text import org.apache.hadoop.yarn.api.ApplicationConstants import org.apache.hadoop.yarn.api.ApplicationConstants.Environment +import org.apache.hadoop.yarn.api.records.ApplicationAccessType import org.apache.hadoop.yarn.conf.YarnConfiguration import org.scalatest.Matchers -import org.apache.hadoop.yarn.api.records.ApplicationAccessType - import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils - +import org.apache.spark.util.{ResetSystemProperties, Utils} -class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging { +class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging + with ResetSystemProperties { val hasBash = try { diff --git a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala index 94bf579dc824..d6902c7bb073 100644 --- a/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala +++ b/yarn/src/test/scala/org/apache/spark/network/shuffle/ShuffleTestAccessor.scala @@ -16,7 +16,7 @@ */ package org.apache.spark.network.shuffle -import java.io.{IOException, File} +import java.io.{File, IOException} import java.util.concurrent.ConcurrentMap import org.apache.hadoop.yarn.api.records.ApplicationId diff --git a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala index 6aa8c814cd4f..5a426b86d10e 100644 --- a/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/network/yarn/YarnShuffleServiceSuite.scala @@ -34,6 +34,7 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd private[yarn] var yarnConfig: YarnConfiguration = new YarnConfiguration override def beforeEach(): Unit = { + super.beforeEach() yarnConfig.set(YarnConfiguration.NM_AUX_SERVICES, "spark_shuffle") yarnConfig.set(YarnConfiguration.NM_AUX_SERVICE_FMT.format("spark_shuffle"), classOf[YarnShuffleService].getCanonicalName) @@ -54,17 +55,21 @@ class YarnShuffleServiceSuite extends SparkFunSuite with Matchers with BeforeAnd var s3: YarnShuffleService = null override def afterEach(): Unit = { - if (s1 != null) { - s1.stop() - s1 = null - } - if (s2 != null) { - s2.stop() - s2 = null - } - if (s3 != null) { - s3.stop() - s3 = null + try { + if (s1 != null) { + s1.stop() + s1 = null + } + if (s2 != null) { + s2.stop() + s2 = null + } + if (s3 != null) { + s3.stop() + s3 = null + } + } finally { + super.afterEach() } }