diff --git a/R/install-dev.sh b/R/install-dev.sh index d613552718307..9fbc999f2e805 100755 --- a/R/install-dev.sh +++ b/R/install-dev.sh @@ -28,6 +28,7 @@ set -o pipefail set -e +set -x FWDIR="$(cd "`dirname "${BASH_SOURCE[0]}"`"; pwd)" LIB_DIR="$FWDIR/lib" diff --git a/R/pkg/.lintr b/R/pkg/.lintr index ae50b28ec6166..c83ad2adfe0ef 100644 --- a/R/pkg/.lintr +++ b/R/pkg/.lintr @@ -1,2 +1,2 @@ -linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) +linters: with_defaults(line_length_linter(100), multiple_dots_linter = NULL, object_name_linter = NULL, camel_case_linter = NULL, open_curly_linter(allow_single_line = TRUE), closed_curly_linter(allow_single_line = TRUE)) exclusions: list("inst/profile/general.R" = 1, "inst/profile/shell.R") diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index 0728141fa483e..aaa3349d57506 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -1191,6 +1191,9 @@ setMethod("collect", vec <- do.call(c, col) stopifnot(class(vec) != "list") class(vec) <- PRIMITIVE_TYPES[[colType]] + if (is.character(vec) && stringsAsFactors) { + vec <- as.factor(vec) + } df[[colIndex]] <- vec } else { df[[colIndex]] <- col @@ -1923,13 +1926,15 @@ setMethod("[", signature(x = "SparkDataFrame"), #' @param i,subset (Optional) a logical expression to filter on rows. #' For extract operator [[ and replacement operator [[<-, the indexing parameter for #' a single Column. -#' @param j,select expression for the single Column or a list of columns to select from the SparkDataFrame. +#' @param j,select expression for the single Column or a list of columns to select from the +#' SparkDataFrame. #' @param drop if TRUE, a Column will be returned if the resulting dataset has only one column. #' Otherwise, a SparkDataFrame will always be returned. #' @param value a Column or an atomic vector in the length of 1 as literal value, or \code{NULL}. #' If \code{NULL}, the specified Column is dropped. #' @param ... currently not used. -#' @return A new SparkDataFrame containing only the rows that meet the condition with selected columns. +#' @return A new SparkDataFrame containing only the rows that meet the condition with selected +#' columns. #' @export #' @family SparkDataFrame functions #' @aliases subset,SparkDataFrame-method @@ -2608,12 +2613,12 @@ setMethod("merge", } else { # if by or both by.x and by.y have length 0, use Cartesian Product joinRes <- crossJoin(x, y) - return (joinRes) + return(joinRes) } # sets alias for making colnames unique in dataframes 'x' and 'y' - colsX <- generateAliasesForIntersectedCols(x, by, suffixes[1]) - colsY <- generateAliasesForIntersectedCols(y, by, suffixes[2]) + colsX <- genAliasesForIntersectedCols(x, by, suffixes[1]) + colsY <- genAliasesForIntersectedCols(y, by, suffixes[2]) # selects columns with their aliases from dataframes # in case same column names are present in both data frames @@ -2661,9 +2666,8 @@ setMethod("merge", #' @param intersectedColNames a list of intersected column names of the SparkDataFrame #' @param suffix a suffix for the column name #' @return list of columns -#' -#' @note generateAliasesForIntersectedCols since 1.6.0 -generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { +#' @noRd +genAliasesForIntersectedCols <- function(x, intersectedColNames, suffix) { allColNames <- names(x) # sets alias for making colnames unique in dataframe 'x' cols <- lapply(allColNames, function(colName) { @@ -2671,7 +2675,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { if (colName %in% intersectedColNames) { newJoin <- paste(colName, suffix, sep = "") if (newJoin %in% allColNames){ - stop ("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.", + stop("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.", "Please use different suffixes for the intersected columns.") } col <- alias(col, newJoin) @@ -3058,7 +3062,8 @@ setMethod("describe", #' summary(select(df, "age", "height")) #' } #' @note summary(SparkDataFrame) since 1.5.0 -#' @note The statistics provided by \code{summary} were change in 2.3.0 use \link{describe} for previous defaults. +#' @note The statistics provided by \code{summary} were change in 2.3.0 use \link{describe} for +#' previous defaults. #' @seealso \link{describe} setMethod("summary", signature(object = "SparkDataFrame"), @@ -3765,8 +3770,8 @@ setMethod("checkpoint", #' #' Create a multi-dimensional cube for the SparkDataFrame using the specified columns. #' -#' If grouping expression is missing \code{cube} creates a single global aggregate and is equivalent to -#' direct application of \link{agg}. +#' If grouping expression is missing \code{cube} creates a single global aggregate and is +#' equivalent to direct application of \link{agg}. #' #' @param x a SparkDataFrame. #' @param ... character name(s) or Column(s) to group on. @@ -3800,8 +3805,8 @@ setMethod("cube", #' #' Create a multi-dimensional rollup for the SparkDataFrame using the specified columns. #' -#' If grouping expression is missing \code{rollup} creates a single global aggregate and is equivalent to -#' direct application of \link{agg}. +#' If grouping expression is missing \code{rollup} creates a single global aggregate and is +#' equivalent to direct application of \link{agg}. #' #' @param x a SparkDataFrame. #' @param ... character name(s) or Column(s) to group on. diff --git a/R/pkg/R/RDD.R b/R/pkg/R/RDD.R index 15ca212acf87f..6e89b4bb4d964 100644 --- a/R/pkg/R/RDD.R +++ b/R/pkg/R/RDD.R @@ -131,7 +131,7 @@ PipelinedRDD <- function(prev, func) { # Return the serialization mode for an RDD. setGeneric("getSerializedMode", function(rdd, ...) { standardGeneric("getSerializedMode") }) # For normal RDDs we can directly read the serializedMode -setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode ) +setMethod("getSerializedMode", signature(rdd = "RDD"), function(rdd) rdd@env$serializedMode) # For pipelined RDDs if jrdd_val is set then serializedMode should exist # if not we return the defaultSerialization mode of "byte" as we don't know the serialization # mode at this point in time. @@ -145,7 +145,7 @@ setMethod("getSerializedMode", signature(rdd = "PipelinedRDD"), }) # The jrdd accessor function. -setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd ) +setMethod("getJRDD", signature(rdd = "RDD"), function(rdd) rdd@jrdd) setMethod("getJRDD", signature(rdd = "PipelinedRDD"), function(rdd, serializedMode = "byte") { if (!is.null(rdd@env$jrdd_val)) { @@ -893,7 +893,7 @@ setMethod("sampleRDD", if (withReplacement) { count <- stats::rpois(1, fraction) if (count > 0) { - res[ (len + 1) : (len + count) ] <- rep(list(elem), count) + res[(len + 1) : (len + count)] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/WindowSpec.R b/R/pkg/R/WindowSpec.R index 81beac9ea9925..debc7cbde55e7 100644 --- a/R/pkg/R/WindowSpec.R +++ b/R/pkg/R/WindowSpec.R @@ -73,7 +73,7 @@ setMethod("show", "WindowSpec", setMethod("partitionBy", signature(x = "WindowSpec"), function(x, col, ...) { - stopifnot (class(col) %in% c("character", "Column")) + stopifnot(class(col) %in% c("character", "Column")) if (class(col) == "character") { windowSpec(callJMethod(x@sws, "partitionBy", col, list(...))) diff --git a/R/pkg/R/column.R b/R/pkg/R/column.R index a5c2ea81f2490..3095adb918b67 100644 --- a/R/pkg/R/column.R +++ b/R/pkg/R/column.R @@ -238,8 +238,10 @@ setMethod("between", signature(x = "Column"), #' @param x a Column. #' @param dataType a character object describing the target data type. #' See +# nolint start #' \href{https://spark.apache.org/docs/latest/sparkr.html#data-type-mapping-between-r-and-spark}{ #' Spark Data Types} for available data types. +# nolint end #' @rdname cast #' @name cast #' @family colum_func diff --git a/R/pkg/R/context.R b/R/pkg/R/context.R index 8349b57a30a93..443c2ff8f9ace 100644 --- a/R/pkg/R/context.R +++ b/R/pkg/R/context.R @@ -329,7 +329,7 @@ spark.addFile <- function(path, recursive = FALSE) { #' spark.getSparkFilesRootDirectory() #'} #' @note spark.getSparkFilesRootDirectory since 2.1.0 -spark.getSparkFilesRootDirectory <- function() { +spark.getSparkFilesRootDirectory <- function() { # nolint if (Sys.getenv("SPARKR_IS_RUNNING_ON_WORKER") == "") { # Running on driver. callJStatic("org.apache.spark.SparkFiles", "getRootDirectory") diff --git a/R/pkg/R/deserialize.R b/R/pkg/R/deserialize.R index 0e99b171cabeb..a90f7d381026b 100644 --- a/R/pkg/R/deserialize.R +++ b/R/pkg/R/deserialize.R @@ -43,7 +43,7 @@ readObject <- function(con) { } readTypedObject <- function(con, type) { - switch (type, + switch(type, "i" = readInt(con), "c" = readString(con), "b" = readBoolean(con), diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9f286263c2162..0143a3e63ba61 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -38,7 +38,8 @@ NULL #' #' Date time functions defined for \code{Column}. #' -#' @param x Column to compute on. In \code{window}, it must be a time Column of \code{TimestampType}. +#' @param x Column to compute on. In \code{window}, it must be a time Column of +#' \code{TimestampType}. #' @param format For \code{to_date} and \code{to_timestamp}, it is the string to use to parse #' Column \code{x} to DateType or TimestampType. For \code{trunc}, it is the string #' to use to specify the truncation method. For example, "year", "yyyy", "yy" for @@ -90,8 +91,8 @@ NULL #' #' Math functions defined for \code{Column}. #' -#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and \code{shiftRightUnsigned}, -#' this is the number of bits to shift. +#' @param x Column to compute on. In \code{shiftLeft}, \code{shiftRight} and +#' \code{shiftRightUnsigned}, this is the number of bits to shift. #' @param y Column to compute on. #' @param ... additional argument(s). #' @name column_math_functions @@ -480,7 +481,7 @@ setMethod("ceiling", setMethod("coalesce", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -676,7 +677,7 @@ setMethod("crc32", setMethod("hash", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -1310,9 +1311,9 @@ setMethod("round", #' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. #' bround(2.5, 0) = 2, bround(3.5, 0) = 4. #' -#' @param scale round to \code{scale} digits to the right of the decimal point when \code{scale} > 0, -#' the nearest even number when \code{scale} = 0, and \code{scale} digits to the left -#' of the decimal point when \code{scale} < 0. +#' @param scale round to \code{scale} digits to the right of the decimal point when +#' \code{scale} > 0, the nearest even number when \code{scale} = 0, and \code{scale} digits +#' to the left of the decimal point when \code{scale} < 0. #' @rdname column_math_functions #' @aliases bround bround,Column-method #' @export @@ -2005,8 +2006,9 @@ setMethod("months_between", signature(y = "Column"), }) #' @details -#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column (\code{x}) if -#' the first column is NaN. Both inputs should be floating point columns (DoubleType or FloatType). +#' \code{nanvl}: Returns the first column (\code{y}) if it is not NaN, or the second column +#' (\code{x}) if the first column is NaN. Both inputs should be floating point columns +#' (DoubleType or FloatType). #' #' @rdname column_nonaggregate_functions #' @aliases nanvl nanvl,Column-method @@ -2061,7 +2063,7 @@ setMethod("approxCountDistinct", setMethod("countDistinct", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(...), function (x) { + jcols <- lapply(list(...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2090,7 +2092,7 @@ setMethod("countDistinct", setMethod("concat", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2110,7 +2112,7 @@ setMethod("greatest", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2130,7 +2132,7 @@ setMethod("least", signature(x = "Column"), function(x, ...) { stopifnot(length(list(...)) > 0) - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -2406,8 +2408,8 @@ setMethod("shiftLeft", signature(y = "Column", x = "numeric"), }) #' @details -#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long value, -#' it will return a long value else it will return an integer value. +#' \code{shiftRight}: (Signed) shifts the given value numBits right. If the given value is a long +#' value, it will return a long value else it will return an integer value. #' #' @rdname column_math_functions #' @aliases shiftRight shiftRight,Column,numeric-method @@ -2505,9 +2507,10 @@ setMethod("format_string", signature(format = "character", x = "Column"), }) #' @details -#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) to a -#' string representing the timestamp of that moment in the current system time zone in the JVM in the -#' given format. See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ +#' \code{from_unixtime}: Converts the number of seconds from unix epoch (1970-01-01 00:00:00 UTC) +#' to a string representing the timestamp of that moment in the current system time zone in the JVM +#' in the given format. +#' See \href{http://docs.oracle.com/javase/tutorial/i18n/format/simpleDateFormat.html}{ #' Customizing Formats} for available options. #' #' @rdname column_datetime_functions @@ -2634,8 +2637,8 @@ setMethod("lpad", signature(x = "Column", len = "numeric", pad = "character"), }) #' @details -#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) samples -#' from U[0.0, 1.0]. +#' \code{rand}: Generates a random column with independent and identically distributed (i.i.d.) +#' samples from U[0.0, 1.0]. #' #' @rdname column_nonaggregate_functions #' @param seed a random seed. Can be missing. @@ -2664,8 +2667,8 @@ setMethod("rand", signature(seed = "numeric"), }) #' @details -#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples from -#' the standard normal distribution. +#' \code{randn}: Generates a column with independent and identically distributed (i.i.d.) samples +#' from the standard normal distribution. #' #' @rdname column_nonaggregate_functions #' @aliases randn randn,missing-method @@ -2831,8 +2834,8 @@ setMethod("unix_timestamp", signature(x = "Column", format = "character"), }) #' @details -#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result expressions. -#' For unmatched expressions null is returned. +#' \code{when}: Evaluates a list of conditions and returns one of multiple possible result +#' expressions. For unmatched expressions null is returned. #' #' @rdname column_nonaggregate_functions #' @param condition the condition to test on. Must be a Column expression. @@ -2859,8 +2862,8 @@ setMethod("when", signature(condition = "Column", value = "ANY"), }) #' @details -#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are satisfied. -#' Otherwise \code{no} is returned for unmatched conditions. +#' \code{ifelse}: Evaluates a list of conditions and returns \code{yes} if the conditions are +#' satisfied. Otherwise \code{no} is returned for unmatched conditions. #' #' @rdname column_nonaggregate_functions #' @param test a Column expression that describes the condition. @@ -2990,7 +2993,8 @@ setMethod("ntile", }) #' @details -#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window partition. +#' \code{percent_rank}: Returns the relative rank (i.e. percentile) of rows within a window +#' partition. #' This is computed by: (rank of row in its partition - 1) / (number of rows in the partition - 1). #' This is equivalent to the \code{PERCENT_RANK} function in SQL. #' The method should be used with no argument. @@ -3160,7 +3164,8 @@ setMethod("posexplode", }) #' @details -#' \code{create_array}: Creates a new array column. The input columns must all have the same data type. +#' \code{create_array}: Creates a new array column. The input columns must all have the same data +#' type. #' #' @rdname column_nonaggregate_functions #' @aliases create_array create_array,Column-method @@ -3169,7 +3174,7 @@ setMethod("posexplode", setMethod("create_array", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -3178,8 +3183,8 @@ setMethod("create_array", }) #' @details -#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value pairs, -#' e.g. (key1, value1, key2, value2, ...). +#' \code{create_map}: Creates a new map column. The input columns must be grouped as key-value +#' pairs, e.g. (key1, value1, key2, value2, ...). #' The key columns must all have the same data type, and can't be null. #' The value columns must all have the same data type. #' @@ -3190,7 +3195,7 @@ setMethod("create_array", setMethod("create_map", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) @@ -3352,9 +3357,9 @@ setMethod("not", }) #' @details -#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or not, -#' returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} in SQL -#' and \code{grouping} function in Scala. +#' \code{grouping_bit}: Indicates whether a specified column in a GROUP BY list is aggregated or +#' not, returns 1 for aggregated or 0 for not aggregated in the result set. Same as \code{GROUPING} +#' in SQL and \code{grouping} function in Scala. #' #' @rdname column_aggregate_functions #' @aliases grouping_bit grouping_bit,Column-method @@ -3412,7 +3417,7 @@ setMethod("grouping_bit", setMethod("grouping_id", signature(x = "Column"), function(x, ...) { - jcols <- lapply(list(x, ...), function (x) { + jcols <- lapply(list(x, ...), function(x) { stopifnot(class(x) == "Column") x@jc }) diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 0fe8f0453b064..4e427489f6860 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -385,7 +385,7 @@ setGeneric("value", function(bcast) { standardGeneric("value") }) #' @return A SparkDataFrame. #' @rdname summarize #' @export -setGeneric("agg", function (x, ...) { standardGeneric("agg") }) +setGeneric("agg", function(x, ...) { standardGeneric("agg") }) #' alias #' @@ -731,7 +731,7 @@ setGeneric("schema", function(x) { standardGeneric("schema") }) #' @rdname select #' @export -setGeneric("select", function(x, col, ...) { standardGeneric("select") } ) +setGeneric("select", function(x, col, ...) { standardGeneric("select") }) #' @rdname selectExpr #' @export diff --git a/R/pkg/R/group.R b/R/pkg/R/group.R index 0a7be0e993975..54ef9f07d6fae 100644 --- a/R/pkg/R/group.R +++ b/R/pkg/R/group.R @@ -133,8 +133,8 @@ setMethod("summarize", # Aggregate Functions by name methods <- c("avg", "max", "mean", "min", "sum") -# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", "stddev_pop", -# "variance", "var_samp", "var_pop" +# These are not exposed on GroupedData: "kurtosis", "skewness", "stddev", "stddev_samp", +# "stddev_pop", "variance", "var_samp", "var_pop" #' Pivot a column of the GroupedData and perform the specified aggregation. #' diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 15af8298ba484..7cd072a1d6f89 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -58,22 +58,25 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @param regParam The regularization parameter. Only supports L2 regularization currently. #' @param maxIter Maximum iteration number. #' @param tol Convergence tolerance of iterations. -#' @param standardization Whether to standardize the training features before fitting the model. The coefficients -#' of models will be always returned on the original scale, so it will be transparent for -#' users. Note that with/without standardization, the models should be always converged -#' to the same solution when no regularization is applied. +#' @param standardization Whether to standardize the training features before fitting the model. +#' The coefficients of models will be always returned on the original scale, +#' so it will be transparent for users. Note that with/without +#' standardization, the models should be always converged to the same +#' solution when no regularization is applied. #' @param threshold The threshold in binary classification applied to the linear model prediction. #' This threshold can be any real number, where Inf will make all predictions 0.0 #' and -Inf will make all predictions 1.0. #' @param weightCol The weight column name. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this param +#' could be adjusted to a larger size. #' This is an expert parameter. Default value should be good for most cases. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.svmLinear} returns a fitted linear SVM model. #' @rdname spark.svmLinear @@ -175,62 +178,80 @@ function(object, path, overwrite = FALSE) { #' Logistic Regression Model #' -#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary logistic regression -#' with pivoting; "multinomial": Multinomial logistic (softmax) regression without pivoting, similar to glmnet. -#' Users can print, make predictions on the produced model and save the model to the input path. +#' Fits an logistic regression model against a SparkDataFrame. It supports "binomial": Binary +#' logistic regression with pivoting; "multinomial": Multinomial logistic (softmax) regression +#' without pivoting, similar to glmnet. Users can print, make predictions on the produced model +#' and save the model to the input path. #' #' @param data SparkDataFrame for training. #' @param formula A symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param regParam the regularization parameter. -#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 penalty. -#' For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, the penalty is a combination -#' of L1 and L2. Default is 0.0 which is an L2 penalty. +#' @param elasticNetParam the ElasticNet mixing parameter. For alpha = 0.0, the penalty is an L2 +#' penalty. For alpha = 1.0, it is an L1 penalty. For 0.0 < alpha < 1.0, +#' the penalty is a combination of L1 and L2. Default is 0.0 which is an +#' L2 penalty. #' @param maxIter maximum iteration number. #' @param tol convergence tolerance of iterations. -#' @param family the name of family which is a description of the label distribution to be used in the model. +#' @param family the name of family which is a description of the label distribution to be used +#' in the model. #' Supported options: #' \itemize{ #' \item{"auto": Automatically select the family based on the number of classes: #' If number of classes == 1 || number of classes == 2, set to "binomial". #' Else, set to "multinomial".} #' \item{"binomial": Binary logistic regression with pivoting.} -#' \item{"multinomial": Multinomial logistic (softmax) regression without pivoting.} +#' \item{"multinomial": Multinomial logistic (softmax) regression without +#' pivoting.} #' } -#' @param standardization whether to standardize the training features before fitting the model. The coefficients -#' of models will be always returned on the original scale, so it will be transparent for -#' users. Note that with/without standardization, the models should be always converged -#' to the same solution when no regularization is applied. Default is TRUE, same as glmnet. -#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of class label 1 -#' is > threshold, then predict 1, else 0. A high threshold encourages the model to predict 0 -#' more often; a low threshold encourages the model to predict 1 more often. Note: Setting this with -#' threshold p is equivalent to setting thresholds c(1-p, p). In multiclass (or binary) classification to adjust the probability of -#' predicting each class. Array must have length equal to the number of classes, with values > 0, -#' excepting that at most one value may be 0. The class with largest value p/t is predicted, where p -#' is the original probability of that class and t is the class's threshold. +#' @param standardization whether to standardize the training features before fitting the model. +#' The coefficients of models will be always returned on the original scale, +#' so it will be transparent for users. Note that with/without +#' standardization, the models should be always converged to the same +#' solution when no regularization is applied. Default is TRUE, same as +#' glmnet. +#' @param thresholds in binary classification, in range [0, 1]. If the estimated probability of +#' class label 1 is > threshold, then predict 1, else 0. A high threshold +#' encourages the model to predict 0 more often; a low threshold encourages the +#' model to predict 1 more often. Note: Setting this with threshold p is +#' equivalent to setting thresholds c(1-p, p). In multiclass (or binary) +#' classification to adjust the probability of predicting each class. Array must +#' have length equal to the number of classes, with values > 0, excepting that +#' at most one value may be 0. The class with largest value p/t is predicted, +#' where p is the original probability of that class and t is the class's +#' threshold. #' @param weightCol The weight column name. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. -#' This is an expert parameter. Default value should be good for most cases. -#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound constrained optimization. -#' The bound matrix must be compatible with the shape (1, number of features) for binomial -#' regression, or (number of classes, number of features) for multinomial regression. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this param +#' could be adjusted to a larger size. This is an expert parameter. Default +#' value should be good for most cases. +#' @param lowerBoundsOnCoefficients The lower bounds on coefficients if fitting under bound +#' constrained optimization. +#' The bound matrix must be compatible with the shape (1, number +#' of features) for binomial regression, or (number of classes, +#' number of features) for multinomial regression. #' It is a R matrix. -#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound constrained optimization. -#' The bound matrix must be compatible with the shape (1, number of features) for binomial -#' regression, or (number of classes, number of features) for multinomial regression. +#' @param upperBoundsOnCoefficients The upper bounds on coefficients if fitting under bound +#' constrained optimization. +#' The bound matrix must be compatible with the shape (1, number +#' of features) for binomial regression, or (number of classes, +#' number of features) for multinomial regression. #' It is a R matrix. -#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained optimization. -#' The bounds vector size must be equal to 1 for binomial regression, or the number -#' of classes for multinomial regression. -#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained optimization. -#' The bound vector size must be equal to 1 for binomial regression, or the number +#' @param lowerBoundsOnIntercepts The lower bounds on intercepts if fitting under bound constrained +#' optimization. +#' The bounds vector size must be equal to 1 for binomial regression, +#' or the number #' of classes for multinomial regression. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param upperBoundsOnIntercepts The upper bounds on intercepts if fitting under bound constrained +#' optimization. +#' The bound vector size must be equal to 1 for binomial regression, +#' or the number of classes for multinomial regression. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.logit} returns a fitted logistic regression model. #' @rdname spark.logit @@ -412,11 +433,12 @@ setMethod("write.ml", signature(object = "LogisticRegressionModel", path = "char #' @param seed seed parameter for weights initialization. #' @param initialWeights initialWeights parameter for weights initialization, it should be a #' numeric vector. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @return \code{spark.mlp} returns a fitted Multilayer Perceptron Classification Model. #' @rdname spark.mlp @@ -452,11 +474,11 @@ setMethod("spark.mlp", signature(data = "SparkDataFrame", formula = "formula"), handleInvalid = c("error", "keep", "skip")) { formula <- paste(deparse(formula), collapse = "") if (is.null(layers)) { - stop ("layers must be a integer vector with length > 1.") + stop("layers must be a integer vector with length > 1.") } layers <- as.integer(na.omit(layers)) if (length(layers) <= 1) { - stop ("layers must be a integer vector with length > 1.") + stop("layers must be a integer vector with length > 1.") } if (!is.null(seed)) { seed <- as.character(as.integer(seed)) @@ -538,11 +560,12 @@ setMethod("write.ml", signature(object = "MultilayerPerceptronClassificationMode #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', '.', ':', '+', and '-'. #' @param smoothing smoothing parameter. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional argument(s) passed to the method. Currently only \code{smoothing}. #' @return \code{spark.naiveBayes} returns a fitted naive Bayes model. #' @rdname spark.naiveBayes diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 97c9fa1b45840..a25bf81c6d977 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -60,9 +60,9 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @param maxIter maximum iteration number. #' @param seed the random seed. #' @param minDivisibleClusterSize The minimum number of points (if greater than or equal to 1.0) -#' or the minimum proportion of points (if less than 1.0) of a divisible cluster. -#' Note that it is an expert parameter. The default value should be good enough -#' for most cases. +#' or the minimum proportion of points (if less than 1.0) of a +#' divisible cluster. Note that it is an expert parameter. The +#' default value should be good enough for most cases. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.bisectingKmeans} returns a fitted bisecting k-means model. #' @rdname spark.bisectingKmeans @@ -325,10 +325,11 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' Note that the response variable of formula is empty in spark.kmeans. #' @param k number of centers. #' @param maxIter maximum iteration number. -#' @param initMode the initialization algorithm choosen to fit the model. +#' @param initMode the initialization algorithm chosen to fit the model. #' @param seed the random seed for cluster initialization. #' @param initSteps the number of steps for the k-means|| initialization mode. -#' This is an advanced setting, the default of 2 is almost always enough. Must be > 0. +#' This is an advanced setting, the default of 2 is almost always enough. +#' Must be > 0. #' @param tol convergence tolerance of iterations. #' @param ... additional argument(s) passed to the method. #' @return \code{spark.kmeans} returns a fitted k-means model. @@ -548,8 +549,8 @@ setMethod("spark.lda", signature(data = "SparkDataFrame"), #' \item{\code{topics}}{top 10 terms and their weights of all topics} #' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file #' used as training set} -#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the training set, -#' given the current parameter estimates: +#' \item{\code{trainingLogLikelihood}}{Log likelihood of the observed tokens in the +#' training set, given the current parameter estimates: #' log P(docs | topics, topic distributions for docs, Dirichlet hyperparameters) #' It is only for distributed LDA model (i.e., optimizer = "em")} #' \item{\code{logPrior}}{Log probability of the current parameter estimate: diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index ebaeae970218a..f734a0865ec3b 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -58,8 +58,8 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' Note that there are two ways to specify the tweedie family. #' \itemize{ #' \item Set \code{family = "tweedie"} and specify the var.power and link.power; -#' \item When package \code{statmod} is loaded, the tweedie family is specified using the -#' family definition therein, i.e., \code{tweedie(var.power, link.power)}. +#' \item When package \code{statmod} is loaded, the tweedie family is specified +#' using the family definition therein, i.e., \code{tweedie(var.power, link.power)}. #' } #' @param tol positive convergence tolerance of iterations. #' @param maxIter integer giving the maximal number of IRLS iterations. @@ -71,13 +71,15 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' applicable to the Tweedie family. #' @param link.power the index in the power link function. Only applicable to the Tweedie family. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. -#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets -#' as 0.0. The feature specified as offset has a constant coefficient of 1.0. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. +#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance +#' offsets as 0.0. The feature specified as offset has a constant coefficient of +#' 1.0. #' @param ... additional arguments passed to the method. #' @aliases spark.glm,SparkDataFrame,formula-method #' @return \code{spark.glm} returns a fitted generalized linear model. @@ -197,13 +199,15 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @param var.power the index of the power variance function in the Tweedie family. #' @param link.power the index of the power link function in the Tweedie family. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. -#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance offsets -#' as 0.0. The feature specified as offset has a constant coefficient of 1.0. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. +#' @param offsetCol the offset column name. If this is not set or empty, we treat all instance +#' offsets as 0.0. The feature specified as offset has a constant coefficient of +#' 1.0. #' @return \code{glm} returns a fitted generalized linear model. #' @rdname glm #' @export @@ -233,11 +237,11 @@ setMethod("glm", signature(formula = "formula", family = "ANY", data = "SparkDat #' @param object a fitted generalized linear model. #' @return \code{summary} returns summary information of the fitted model, which is a list. -#' The list of components includes at least the \code{coefficients} (coefficients matrix, which includes -#' coefficients, standard error of coefficients, t value and p value), +#' The list of components includes at least the \code{coefficients} (coefficients matrix, +#' which includes coefficients, standard error of coefficients, t value and p value), #' \code{null.deviance} (null/residual degrees of freedom), \code{aic} (AIC) -#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in the data, -#' the coefficients matrix only provides coefficients. +#' and \code{iter} (number of iterations IRLS takes). If there are collinear columns in +#' the data, the coefficients matrix only provides coefficients. #' @rdname spark.glm #' @export #' @note summary(GeneralizedLinearRegressionModel) since 2.0.0 @@ -457,15 +461,17 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path = "char #' @param formula a symbolic description of the model to be fitted. Currently only a few formula #' operators are supported, including '~', ':', '+', and '-'. #' Note that operator '.' is not supported currently. -#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the dimensions of features -#' or the number of partitions are large, this param could be adjusted to a larger size. -#' This is an expert parameter. Default value should be good for most cases. +#' @param aggregationDepth The depth for treeAggregate (greater than or equal to 2). If the +#' dimensions of features or the number of partitions are large, this +#' param could be adjusted to a larger size. This is an expert parameter. +#' Default value should be good for most cases. #' @param stringIndexerOrderType how to order categories of a string feature column. This is used to -#' decide the base level of a string feature as the last category after -#' ordering is dropped when encoding strings. Supported options are -#' "frequencyDesc", "frequencyAsc", "alphabetDesc", and "alphabetAsc". -#' The default value is "frequencyDesc". When the ordering is set to -#' "alphabetDesc", this drops the same category as R when encoding strings. +#' decide the base level of a string feature as the last category +#' after ordering is dropped when encoding strings. Supported options +#' are "frequencyDesc", "frequencyAsc", "alphabetDesc", and +#' "alphabetAsc". The default value is "frequencyDesc". When the +#' ordering is set to "alphabetDesc", this drops the same category +#' as R when encoding strings. #' @param ... additional arguments passed to the method. #' @return \code{spark.survreg} returns a fitted AFT survival regression model. #' @rdname spark.survreg diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 33c4653f4c184..89a58bf0aadae 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -132,10 +132,12 @@ print.summary.decisionTree <- function(x) { #' Gradient Boosted Tree model, \code{predict} to make predictions on new data, and #' \code{write.ml}/\code{read.ml} to save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-regression}{ #' GBT Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#gradient-boosted-tree-classifier}{ #' GBT Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -164,11 +166,12 @@ print.summary.decisionTree <- function(x) { #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.gbt,SparkDataFrame,formula-method #' @return \code{spark.gbt} returns a fitted Gradient Boosted Tree model. @@ -352,10 +355,12 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-regression}{ #' Random Forest Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#random-forest-classifier}{ #' Random Forest Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -382,11 +387,12 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.randomForest,SparkDataFrame,formula-method #' @return \code{spark.randomForest} returns a fitted Random Forest model. @@ -567,10 +573,12 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' model, \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml} to #' save/load fitted models. #' For more details, see +# nolint start #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-regression}{ #' Decision Tree Regression} and #' \href{http://spark.apache.org/docs/latest/ml-classification-regression.html#decision-tree-classifier}{ #' Decision Tree Classification} +# nolint end #' #' @param data a SparkDataFrame for training. #' @param formula a symbolic description of the model to be fitted. Currently only a few formula @@ -592,11 +600,12 @@ setMethod("write.ml", signature(object = "RandomForestClassificationModel", path #' nodes. If TRUE, the algorithm will cache node IDs for each instance. Caching #' can speed up training of deeper trees. Users can set how often should the #' cache be checkpointed or disable it by setting checkpointInterval. -#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and label -#' column of string type in classification model. +#' @param handleInvalid How to handle invalid data (unseen labels or NULL values) in features and +#' label column of string type in classification model. #' Supported options: "skip" (filter out rows with invalid data), -#' "error" (throw an error), "keep" (put invalid data in a special additional -#' bucket, at index numLabels). Default is "error". +#' "error" (throw an error), "keep" (put invalid data in +#' a special additional bucket, at index numLabels). Default +#' is "error". #' @param ... additional arguments passed to the method. #' @aliases spark.decisionTree,SparkDataFrame,formula-method #' @return \code{spark.decisionTree} returns a fitted Decision Tree model. @@ -671,7 +680,8 @@ setMethod("spark.decisionTree", signature(data = "SparkDataFrame", formula = "fo #' @return \code{summary} returns summary information of the fitted model, which is a list. #' The list of components includes \code{formula} (formula), #' \code{numFeatures} (number of features), \code{features} (list of features), -#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of trees). +#' \code{featureImportances} (feature importances), and \code{maxDepth} (max depth of +#' trees). #' @rdname spark.decisionTree #' @aliases summary,DecisionTreeRegressionModel-method #' @export diff --git a/R/pkg/R/pairRDD.R b/R/pkg/R/pairRDD.R index 8fa21be3076b5..9c2e57d3067db 100644 --- a/R/pkg/R/pairRDD.R +++ b/R/pkg/R/pairRDD.R @@ -860,7 +860,7 @@ setMethod("subtractByKey", other, numPartitions = numPartitions), filterFunction), - function (v) { v[[1]] }) + function(v) { v[[1]] }) }) #' Return a subset of this RDD sampled by key. @@ -925,7 +925,7 @@ setMethod("sampleByKey", if (withReplacement) { count <- stats::rpois(1, frac) if (count > 0) { - res[ (len + 1) : (len + count) ] <- rep(list(elem), count) + res[(len + 1) : (len + count)] <- rep(list(elem), count) len <- len + count } } else { diff --git a/R/pkg/R/schema.R b/R/pkg/R/schema.R index d1ed6833d5d02..65f418740c643 100644 --- a/R/pkg/R/schema.R +++ b/R/pkg/R/schema.R @@ -155,7 +155,7 @@ checkType <- function(type) { } else { # Check complex types firstChar <- substr(type, 1, 1) - switch (firstChar, + switch(firstChar, a = { # Array type m <- regexec("^array<(.+)>$", type) diff --git a/R/pkg/R/stats.R b/R/pkg/R/stats.R index 9a9fa84044ce6..c8af798830b30 100644 --- a/R/pkg/R/stats.R +++ b/R/pkg/R/stats.R @@ -29,9 +29,9 @@ setOldClass("jobj") #' @param col1 name of the first column. Distinct items will make the first item of each row. #' @param col2 name of the second column. Distinct items will make the column names of the output. #' @return a local R data.frame representing the contingency table. The first column of each row -#' will be the distinct values of \code{col1} and the column names will be the distinct values -#' of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". Pairs -#' that have no occurrences will have zero as their counts. +#' will be the distinct values of \code{col1} and the column names will be the distinct +#' values of \code{col2}. The name of the first column will be "\code{col1}_\code{col2}". +#' Pairs that have no occurrences will have zero as their counts. #' #' @rdname crosstab #' @name crosstab @@ -53,8 +53,8 @@ setMethod("crosstab", }) #' @details -#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two numerical -#' columns of \emph{one} SparkDataFrame. +#' \code{cov}: When applied to SparkDataFrame, this calculates the sample covariance of two +#' numerical columns of \emph{one} SparkDataFrame. #' #' @param colName1 the name of the first column #' @param colName2 the name of the second column @@ -159,8 +159,8 @@ setMethod("freqItems", signature(x = "SparkDataFrame", cols = "character"), #' @param relativeError The relative target precision to achieve (>= 0). If set to zero, #' the exact quantiles are computed, which could be very expensive. #' Note that values greater than 1 are accepted but give the same result as 1. -#' @return The approximate quantiles at the given probabilities. If the input is a single column name, -#' the output is a list of approximate quantiles in that column; If the input is +#' @return The approximate quantiles at the given probabilities. If the input is a single column +#' name, the output is a list of approximate quantiles in that column; If the input is #' multiple column names, the output should be a list, and each element in it is a list of #' numeric values which represents the approximate quantiles in corresponding column. #' diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index 91483a4d23d9b..4b716995f2c46 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -625,7 +625,7 @@ appendPartitionLengths <- function(x, other) { x <- lapplyPartition(x, appendLength) other <- lapplyPartition(other, appendLength) } - list (x, other) + list(x, other) } # Perform zip or cartesian between elements from two RDDs in each partition @@ -657,7 +657,7 @@ mergePartitions <- function(rdd, zip) { keys <- list() } if (lengthOfValues > 1) { - values <- part[ (lengthOfKeys + 1) : (len - 1) ] + values <- part[(lengthOfKeys + 1) : (len - 1)] } else { values <- list() } diff --git a/R/pkg/inst/worker/worker.R b/R/pkg/inst/worker/worker.R index 03e7450147865..00789d815bba8 100644 --- a/R/pkg/inst/worker/worker.R +++ b/R/pkg/inst/worker/worker.R @@ -68,7 +68,7 @@ compute <- function(mode, partition, serializer, deserializer, key, } else { output <- computeFunc(partition, inputData) } - return (output) + return(output) } outputResult <- function(serializer, output, outputCon) { diff --git a/R/pkg/tests/fulltests/test_binary_function.R b/R/pkg/tests/fulltests/test_binary_function.R index 442bed509bb1d..c5d240f3e7344 100644 --- a/R/pkg/tests/fulltests/test_binary_function.R +++ b/R/pkg/tests/fulltests/test_binary_function.R @@ -73,7 +73,7 @@ test_that("zipPartitions() on RDDs", { rdd2 <- parallelize(sc, 1:4, 2L) # 1:2, 3:4 rdd3 <- parallelize(sc, 1:6, 2L) # 1:3, 4:6 actual <- collectRDD(zipPartitions(rdd1, rdd2, rdd3, - func = function(x, y, z) { list(list(x, y, z))} )) + func = function(x, y, z) { list(list(x, y, z))})) expect_equal(actual, list(list(1, c(1, 2), c(1, 2, 3)), list(2, c(3, 4), c(4, 5, 6)))) diff --git a/R/pkg/tests/fulltests/test_rdd.R b/R/pkg/tests/fulltests/test_rdd.R index 6ee1fceffd822..0c702ea897f7c 100644 --- a/R/pkg/tests/fulltests/test_rdd.R +++ b/R/pkg/tests/fulltests/test_rdd.R @@ -698,14 +698,14 @@ test_that("fullOuterJoin() on pairwise RDDs", { }) test_that("sortByKey() on pairwise RDDs", { - numPairsRdd <- map(rdd, function(x) { list (x, x) }) + numPairsRdd <- map(rdd, function(x) { list(x, x) }) sortedRdd <- sortByKey(numPairsRdd, ascending = FALSE) actual <- collectRDD(sortedRdd) - numPairs <- lapply(nums, function(x) { list (x, x) }) + numPairs <- lapply(nums, function(x) { list(x, x) }) expect_equal(actual, sortKeyValueList(numPairs, decreasing = TRUE)) rdd2 <- parallelize(sc, sort(nums, decreasing = TRUE), 2L) - numPairsRdd2 <- map(rdd2, function(x) { list (x, x) }) + numPairsRdd2 <- map(rdd2, function(x) { list(x, x) }) sortedRdd2 <- sortByKey(numPairsRdd2) actual <- collectRDD(sortedRdd2) expect_equal(actual, numPairs) diff --git a/R/pkg/tests/fulltests/test_sparkSQL.R b/R/pkg/tests/fulltests/test_sparkSQL.R index 4e62be9b4d619..0c8118a7c73f3 100644 --- a/R/pkg/tests/fulltests/test_sparkSQL.R +++ b/R/pkg/tests/fulltests/test_sparkSQL.R @@ -499,6 +499,12 @@ test_that("create DataFrame with different data types", { expect_equal(collect(df), data.frame(l, stringsAsFactors = FALSE)) }) +test_that("SPARK-17902: collect() with stringsAsFactors enabled", { + df <- suppressWarnings(collect(createDataFrame(iris), stringsAsFactors = TRUE)) + expect_equal(class(iris$Species), class(df$Species)) + expect_equal(iris$Species, df$Species) +}) + test_that("SPARK-17811: can create DataFrame containing NA as date and time", { df <- data.frame( id = 1:2, @@ -560,9 +566,9 @@ test_that("Collect DataFrame with complex types", { expect_equal(nrow(ldf), 3) expect_equal(ncol(ldf), 3) expect_equal(names(ldf), c("c1", "c2", "c3")) - expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list (7, 8, 9))) - expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list ("g", "h", "i"))) - expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list (7.0, 8.0, 9.0))) + expect_equal(ldf$c1, list(list(1, 2, 3), list(4, 5, 6), list(7, 8, 9))) + expect_equal(ldf$c2, list(list("a", "b", "c"), list("d", "e", "f"), list("g", "h", "i"))) + expect_equal(ldf$c3, list(list(1.0, 2.0, 3.0), list(4.0, 5.0, 6.0), list(7.0, 8.0, 9.0))) # MapType schema <- structType(structField("name", "string"), @@ -1524,7 +1530,7 @@ test_that("column functions", { expect_equal(ncol(s), 1) expect_equal(nrow(s), 3) expect_is(s[[1]][[1]], "struct") - expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 }))) } # passing option @@ -2538,7 +2544,7 @@ test_that("describe() and summary() on a DataFrame", { stats2 <- summary(df) expect_equal(collect(stats2)[5, "summary"], "25%") - expect_equal(collect(stats2)[5, "age"], "30") + expect_equal(collect(stats2)[5, "age"], "19") stats3 <- summary(df, "min", "max", "55.1%") @@ -2710,7 +2716,7 @@ test_that("freqItems() on a DataFrame", { input <- 1:1000 rdf <- data.frame(numbers = input, letters = as.character(input), negDoubles = input * -1.0, stringsAsFactors = F) - rdf[ input %% 3 == 0, ] <- c(1, "1", -1) + rdf[input %% 3 == 0, ] <- c(1, "1", -1) df <- createDataFrame(rdf) multiColResults <- freqItems(df, c("numbers", "letters"), support = 0.1) expect_true(1 %in% multiColResults$numbers[[1]]) @@ -2738,7 +2744,7 @@ test_that("sampleBy() on a DataFrame", { }) test_that("approxQuantile() on a DataFrame", { - l <- lapply(c(0:99), function(i) { list(i, 99 - i) }) + l <- lapply(c(0:100), function(i) { list(i, 100 - i) }) df <- createDataFrame(l, list("a", "b")) quantiles <- approxQuantile(df, "a", c(0.5, 0.8), 0.0) expect_equal(quantiles, list(50, 80)) @@ -2749,8 +2755,8 @@ test_that("approxQuantile() on a DataFrame", { dfWithNA <- createDataFrame(data.frame(a = c(NA, 30, 19, 11, 28, 15), b = c(-30, -19, NA, -11, -28, -15))) quantiles3 <- approxQuantile(dfWithNA, c("a", "b"), c(0.5), 0.0) - expect_equal(quantiles3[[1]], list(28)) - expect_equal(quantiles3[[2]], list(-15)) + expect_equal(quantiles3[[1]], list(19)) + expect_equal(quantiles3[[2]], list(-19)) }) test_that("SQL error message is returned from JVM", { @@ -3064,7 +3070,7 @@ test_that("coalesce, repartition, numPartitions", { }) test_that("gapply() and gapplyCollect() on a DataFrame", { - df <- createDataFrame ( + df <- createDataFrame( list(list(1L, 1, "1", 0.1), list(1L, 2, "1", 0.2), list(3L, 3, "3", 0.3)), c("a", "b", "c", "d")) expected <- collect(df) @@ -3075,6 +3081,11 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { df1Collect <- gapplyCollect(df, list("a"), function(key, x) { x }) expect_identical(df1Collect, expected) + # gapply on empty grouping columns. + df1 <- gapply(df, c(), function(key, x) { x }, schema(df)) + actual <- collect(df1) + expect_identical(actual, expected) + # Computes the sum of second column by grouping on the first and third columns # and checks if the sum is larger than 2 schemas <- list(structType(structField("a", "integer"), structField("e", "boolean")), @@ -3135,7 +3146,7 @@ test_that("gapply() and gapplyCollect() on a DataFrame", { actual <- df3Collect[order(df3Collect$a), ] expect_identical(actual$avg, expected$avg) - irisDF <- suppressWarnings(createDataFrame (iris)) + irisDF <- suppressWarnings(createDataFrame(iris)) schema <- structType(structField("Sepal_Length", "double"), structField("Avg", "double")) # Groups by `Sepal_Length` and computes the average for `Sepal_Width` df4 <- gapply( diff --git a/bin/beeline.cmd b/bin/beeline.cmd index 02464bd088792..288059a28cd74 100644 --- a/bin/beeline.cmd +++ b/bin/beeline.cmd @@ -17,4 +17,6 @@ rem See the License for the specific language governing permissions and rem limitations under the License. rem -cmd /V /E /C "%~dp0spark-class.cmd" org.apache.hive.beeline.BeeLine %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-class.cmd" org.apache.hive.beeline.BeeLine %*" diff --git a/bin/pyspark.cmd b/bin/pyspark.cmd index 72d046a4ba2cf..3dcf1d45a8189 100644 --- a/bin/pyspark.cmd +++ b/bin/pyspark.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running PySpark. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0pyspark2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0pyspark2.cmd" %*" diff --git a/bin/run-example.cmd b/bin/run-example.cmd index f9b786e92b823..efa5f81d08f7f 100644 --- a/bin/run-example.cmd +++ b/bin/run-example.cmd @@ -19,4 +19,7 @@ rem set SPARK_HOME=%~dp0.. set _SPARK_CMD_USAGE=Usage: ./bin/run-example [options] example-class [example args] -cmd /V /E /C "%~dp0spark-submit.cmd" run-example %* + +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-submit.cmd" run-example %*" diff --git a/bin/spark-class.cmd b/bin/spark-class.cmd index 3bf3d20cb57b5..b22536ab6f458 100644 --- a/bin/spark-class.cmd +++ b/bin/spark-class.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running a Spark class. To avoid polluting rem the environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-class2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-class2.cmd" %*" diff --git a/bin/spark-shell.cmd b/bin/spark-shell.cmd index 991423da6ab99..e734f13097d61 100644 --- a/bin/spark-shell.cmd +++ b/bin/spark-shell.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running Spark shell. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-shell2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-shell2.cmd" %*" diff --git a/bin/spark-submit.cmd b/bin/spark-submit.cmd index f301606933a95..da62a8777524d 100644 --- a/bin/spark-submit.cmd +++ b/bin/spark-submit.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running Spark submit. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0spark-submit2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0spark-submit2.cmd" %*" diff --git a/bin/sparkR.cmd b/bin/sparkR.cmd index 1e5ea6a623219..fcd172b083e1e 100644 --- a/bin/sparkR.cmd +++ b/bin/sparkR.cmd @@ -20,4 +20,6 @@ rem rem This is the entry point for running SparkR. To avoid polluting the rem environment, it just launches a new cmd to do the real work. -cmd /V /E /C "%~dp0sparkR2.cmd" %* +rem The outermost quotes are used to prevent Windows command line parse error +rem when there are some quotes in parameters, see SPARK-21877. +cmd /V /E /C ""%~dp0sparkR2.cmd" %*" diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java index a2b077e4531ee..870b484f99068 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/KVTypeInfo.java @@ -46,6 +46,7 @@ public KVTypeInfo(Class type) throws Exception { KVIndex idx = f.getAnnotation(KVIndex.class); if (idx != null) { checkIndex(idx, indices); + f.setAccessible(true); indices.put(idx.value(), idx); f.setAccessible(true); accessors.put(idx.value(), new FieldAccessor(f)); @@ -58,6 +59,7 @@ public KVTypeInfo(Class type) throws Exception { checkIndex(idx, indices); Preconditions.checkArgument(m.getParameterTypes().length == 0, "Annotated method %s::%s should not have any parameters.", type.getName(), m.getName()); + m.setAccessible(true); indices.put(idx.value(), idx); m.setAccessible(true); accessors.put(idx.value(), new MethodAccessor(m)); diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java index 310febc352ef8..4f9e10ca20066 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDB.java @@ -76,7 +76,7 @@ public LevelDB(File path, KVStoreSerializer serializer) throws Exception { this.types = new ConcurrentHashMap<>(); Options options = new Options(); - options.createIfMissing(!path.exists()); + options.createIfMissing(true); this._db = new AtomicReference<>(JniDBFactory.factory.open(path, options)); byte[] versionData = db().get(STORE_VERSION_KEY); @@ -213,17 +213,32 @@ public long count(Class type, String index, Object indexedValue) throws Excep @Override public void close() throws IOException { - DB _db = this._db.getAndSet(null); - if (_db == null) { - return; + synchronized (this._db) { + DB _db = this._db.getAndSet(null); + if (_db == null) { + return; + } + + try { + _db.close(); + } catch (IOException ioe) { + throw ioe; + } catch (Exception e) { + throw new IOException(e.getMessage(), e); + } } + } - try { - _db.close(); - } catch (IOException ioe) { - throw ioe; - } catch (Exception e) { - throw new IOException(e.getMessage(), e); + /** + * Closes the given iterator if the DB is still open. Trying to close a JNI LevelDB handle + * with a closed DB can cause JVM crashes, so this ensures that situation does not happen. + */ + void closeIterator(LevelDBIterator it) throws IOException { + synchronized (this._db) { + DB _db = this._db.get(); + if (_db != null) { + it.close(); + } } } diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index a2181f3874f86..b3ba76ba58052 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -191,6 +191,16 @@ public synchronized void close() throws IOException { } } + /** + * Because it's tricky to expose closeable iterators through many internal APIs, especially + * when Scala wrappers are used, this makes sure that, hopefully, the JNI resources held by + * the iterator will eventually be released. + */ + @Override + protected void finalize() throws Throwable { + db.closeIterator(this); + } + private byte[] loadNext() { if (count >= max) { return null; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index d2d008f8a3d35..7253101f41df6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -47,12 +47,11 @@ public ShuffleSecretManager() { * fetching shuffle files written by other executors in this application. */ public void registerApp(String appId, String shuffleSecret) { - if (!shuffleSecretMap.containsKey(appId)) { - shuffleSecretMap.put(appId, shuffleSecret); - logger.info("Registered shuffle secret for application {}", appId); - } else { - logger.debug("Application {} already registered", appId); - } + // Always put the new secret information to make sure it's the most up to date. + // Otherwise we have to specifically look at the application attempt in addition + // to the applicationId since the secrets change between application attempts on yarn. + shuffleSecretMap.put(appId, shuffleSecret); + logger.info("Registered shuffle secret for application {}", appId); } /** @@ -67,12 +66,8 @@ public void registerApp(String appId, ByteBuffer shuffleSecret) { * This is called when the application terminates. */ public void unregisterApp(String appId) { - if (shuffleSecretMap.containsKey(appId)) { - shuffleSecretMap.remove(appId); - logger.info("Unregistered shuffle secret for application {}", appId); - } else { - logger.warn("Attempted to unregister application {} when it is not registered", appId); - } + shuffleSecretMap.remove(appId); + logger.info("Unregistered shuffle secret for application {}", appId); } /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java index 77702447edb88..510017fee2db5 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java @@ -91,7 +91,7 @@ public void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - TempShuffleFileManager tempShuffleFileManager) { + TempFileManager tempFileManager) { checkInit(); logger.debug("External shuffle fetch from {}:{} (executor id {})", host, port, execId); try { @@ -99,7 +99,7 @@ public void fetchBlocks( (blockIds1, listener1) -> { TransportClient client = clientFactory.createClient(host, port); new OneForOneBlockFetcher(client, appId, execId, - blockIds1, listener1, conf, tempShuffleFileManager).start(); + blockIds1, listener1, conf, tempFileManager).start(); }; int maxRetries = conf.maxIORetries(); diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java index 66b67e282c80d..3f2f20b4149f1 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java @@ -58,7 +58,7 @@ public class OneForOneBlockFetcher { private final BlockFetchingListener listener; private final ChunkReceivedCallback chunkCallback; private final TransportConf transportConf; - private final TempShuffleFileManager tempShuffleFileManager; + private final TempFileManager tempFileManager; private StreamHandle streamHandle = null; @@ -79,14 +79,14 @@ public OneForOneBlockFetcher( String[] blockIds, BlockFetchingListener listener, TransportConf transportConf, - TempShuffleFileManager tempShuffleFileManager) { + TempFileManager tempFileManager) { this.client = client; this.openMessage = new OpenBlocks(appId, execId, blockIds); this.blockIds = blockIds; this.listener = listener; this.chunkCallback = new ChunkCallback(); this.transportConf = transportConf; - this.tempShuffleFileManager = tempShuffleFileManager; + this.tempFileManager = tempFileManager; } /** Callback invoked on receipt of each chunk. We equate a single chunk to a single block. */ @@ -125,7 +125,7 @@ public void onSuccess(ByteBuffer response) { // Immediately request all chunks -- we expect that the total size of the request is // reasonable due to higher level chunking in [[ShuffleBlockFetcherIterator]]. for (int i = 0; i < streamHandle.numChunks; i++) { - if (tempShuffleFileManager != null) { + if (tempFileManager != null) { client.stream(OneForOneStreamManager.genStreamChunkId(streamHandle.streamId, i), new DownloadCallback(i)); } else { @@ -164,7 +164,7 @@ private class DownloadCallback implements StreamCallback { private int chunkIndex; DownloadCallback(int chunkIndex) throws IOException { - this.targetFile = tempShuffleFileManager.createTempShuffleFile(); + this.targetFile = tempFileManager.createTempFile(); this.channel = Channels.newChannel(Files.newOutputStream(targetFile.toPath())); this.chunkIndex = chunkIndex; } @@ -180,7 +180,7 @@ public void onComplete(String streamId) throws IOException { ManagedBuffer buffer = new FileSegmentManagedBuffer(transportConf, targetFile, 0, targetFile.length()); listener.onBlockFetchSuccess(blockIds[chunkIndex], buffer); - if (!tempShuffleFileManager.registerTempShuffleFileToClean(targetFile)) { + if (!tempFileManager.registerTempFileToClean(targetFile)) { targetFile.delete(); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java index 5bd4412b75275..18b04fedcac5b 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleClient.java @@ -43,10 +43,10 @@ public void init(String appId) { } * @param execId the executor id. * @param blockIds block ids to fetch. * @param listener the listener to receive block fetching status. - * @param tempShuffleFileManager TempShuffleFileManager to create and clean temp shuffle files. - * If it's not null, the remote blocks will be streamed - * into temp shuffle files to reduce the memory usage, otherwise, - * they will be kept in memory. + * @param tempFileManager TempFileManager to create and clean temp files. + * If it's not null, the remote blocks will be streamed + * into temp shuffle files to reduce the memory usage, otherwise, + * they will be kept in memory. */ public abstract void fetchBlocks( String host, @@ -54,7 +54,7 @@ public abstract void fetchBlocks( String execId, String[] blockIds, BlockFetchingListener listener, - TempShuffleFileManager tempShuffleFileManager); + TempFileManager tempFileManager); /** * Get the shuffle MetricsSet from ShuffleClient, this will be used in MetricsSystem to diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java similarity index 74% rename from common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java rename to common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java index 84a5ed6a276bd..552364d274f19 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempShuffleFileManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/TempFileManager.java @@ -20,17 +20,17 @@ import java.io.File; /** - * A manager to create temp shuffle block files to reduce the memory usage and also clean temp + * A manager to create temp block files to reduce the memory usage and also clean temp * files when they won't be used any more. */ -public interface TempShuffleFileManager { +public interface TempFileManager { - /** Create a temp shuffle block file. */ - File createTempShuffleFile(); + /** Create a temp block file. */ + File createTempFile(); /** - * Register a temp shuffle file to clean up when it won't be used any more. Return whether the + * Register a temp file to clean up when it won't be used any more. Return whether the * file is registered successfully. If `false`, the caller should clean up the file by itself. */ - boolean registerTempShuffleFileToClean(File file); + boolean registerTempFileToClean(File file); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java new file mode 100644 index 0000000000000..46c4c33865eea --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/sasl/ShuffleSecretManagerSuite.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.sasl; + +import java.nio.ByteBuffer; + +import org.junit.Test; +import static org.junit.Assert.*; + +public class ShuffleSecretManagerSuite { + static String app1 = "app1"; + static String app2 = "app2"; + static String pw1 = "password1"; + static String pw2 = "password2"; + static String pw1update = "password1update"; + static String pw2update = "password2update"; + + @Test + public void testMultipleRegisters() { + ShuffleSecretManager secretManager = new ShuffleSecretManager(); + secretManager.registerApp(app1, pw1); + assertEquals(pw1, secretManager.getSecretKey(app1)); + secretManager.registerApp(app2, ByteBuffer.wrap(pw2.getBytes())); + assertEquals(pw2, secretManager.getSecretKey(app2)); + + // now update the password for the apps and make sure it takes affect + secretManager.registerApp(app1, pw1update); + assertEquals(pw1update, secretManager.getSecretKey(app1)); + secretManager.registerApp(app2, ByteBuffer.wrap(pw2update.getBytes())); + assertEquals(pw2update, secretManager.getSecretKey(app2)); + + secretManager.unregisterApp(app1); + assertNull(secretManager.getSecretKey(app1)); + assertEquals(pw2update, secretManager.getSecretKey(app2)); + + secretManager.unregisterApp(app2); + assertNull(secretManager.getSecretKey(app2)); + assertNull(secretManager.getSecretKey(app1)); + } +} diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java index 9c551ab19e9aa..f121b1cd745b8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/array/ByteArrayMethods.java @@ -40,6 +40,13 @@ public static int roundNumberOfBytesToNearestWord(int numBytes) { } } + // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat smaller. + // Be conservative and lower the cap a little. + // Refer to "http://hg.openjdk.java.net/jdk8/jdk8/jdk/file/tip/src/share/classes/java/util/ArrayList.java#l229" + // This value is word rounded. Use this value if the allocated byte arrays are used to store other + // types rather than bytes. + public static int MAX_ROUNDED_ARRAY_LENGTH = Integer.MAX_VALUE - 15; + private static final boolean unaligned = Platform.unaligned(); /** * Optimized byte array equality check for byte arrays. diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java index 355748238540b..cc9cc429643ad 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/memory/HeapMemoryAllocator.java @@ -56,6 +56,9 @@ public MemoryBlock allocate(long size) throws OutOfMemoryError { final MemoryBlock memory = blockReference.get(); if (memory != null) { assert (memory.size() == size); + if (MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED) { + memory.fill(MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + } return memory; } } diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index ce4a06bde80c4..b0d0c44823e68 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -498,17 +498,16 @@ private UTF8String copyUTF8String(int start, int end) { public UTF8String trim() { int s = 0; - int e = this.numBytes - 1; // skip all of the space (0x20) in the left side while (s < this.numBytes && getByte(s) == 0x20) s++; - // skip all of the space (0x20) in the right side - while (e >= 0 && getByte(e) == 0x20) e--; - if (s > e) { + if (s == this.numBytes) { // empty string return EMPTY_UTF8; - } else { - return copyUTF8String(s, e); } + // skip all of the space (0x20) in the right side + int e = this.numBytes - 1; + while (e > s && getByte(e) == 0x20) e--; + return copyUTF8String(s, e); } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java index 4ae49d82efa29..4b141339ec816 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/PlatformUtilSuite.java @@ -66,10 +66,21 @@ public void overlappingCopyMemory() { public void memoryDebugFillEnabledInTest() { Assert.assertTrue(MemoryAllocator.MEMORY_DEBUG_FILL_ENABLED); MemoryBlock onheap = MemoryAllocator.HEAP.allocate(1); - MemoryBlock offheap = MemoryAllocator.UNSAFE.allocate(1); Assert.assertEquals( Platform.getByte(onheap.getBaseObject(), onheap.getBaseOffset()), MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + + MemoryBlock onheap1 = MemoryAllocator.HEAP.allocate(1024 * 1024); + MemoryAllocator.HEAP.free(onheap1); + Assert.assertEquals( + Platform.getByte(onheap1.getBaseObject(), onheap1.getBaseOffset()), + MemoryAllocator.MEMORY_DEBUG_FILL_FREED_VALUE); + MemoryBlock onheap2 = MemoryAllocator.HEAP.allocate(1024 * 1024); + Assert.assertEquals( + Platform.getByte(onheap2.getBaseObject(), onheap2.getBaseOffset()), + MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); + + MemoryBlock offheap = MemoryAllocator.UNSAFE.allocate(1); Assert.assertEquals( Platform.getByte(offheap.getBaseObject(), offheap.getBaseOffset()), MemoryAllocator.MEMORY_DEBUG_FILL_CLEAN_VALUE); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 7b03d2c650fc9..9b303fa5bc6c5 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -222,10 +222,13 @@ public void substring() { @Test public void trims() { + assertEquals(fromString("1"), fromString("1").trim()); + assertEquals(fromString("hello"), fromString(" hello ").trim()); assertEquals(fromString("hello "), fromString(" hello ").trimLeft()); assertEquals(fromString(" hello"), fromString(" hello ").trimRight()); + assertEquals(EMPTY_UTF8, EMPTY_UTF8.trim()); assertEquals(EMPTY_UTF8, fromString(" ").trim()); assertEquals(EMPTY_UTF8, fromString(" ").trimLeft()); assertEquals(EMPTY_UTF8, fromString(" ").trimRight()); diff --git a/core/pom.xml b/core/pom.xml index da68abd855c7c..54f7a34a6c37e 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -67,6 +67,11 @@ spark-launcher_${scala.binary.version} ${project.version} + + org.apache.spark + spark-kvstore_${scala.binary.version} + ${project.version} + org.apache.spark spark-network-common_${scala.binary.version} @@ -494,7 +499,7 @@ - ..${file.separator}R${file.separator}install-dev${script.extension} + ${project.basedir}${file.separator}..${file.separator}R${file.separator}install-dev${script.extension} diff --git a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java index ea5f1a9abf69b..f6d1288cb263d 100644 --- a/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java +++ b/core/src/main/java/org/apache/spark/io/NioBufferedFileInputStream.java @@ -130,10 +130,8 @@ public synchronized void close() throws IOException { StorageUtils.dispose(byteBuffer); } - //checkstyle.off: NoFinalizer @Override protected void finalize() throws IOException { close(); } - //checkstyle.on: NoFinalizer } diff --git a/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java b/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java index 9dbb565aab707..40b5f627369d5 100644 --- a/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java +++ b/core/src/main/java/org/apache/spark/status/api/v1/StageStatus.java @@ -23,7 +23,8 @@ public enum StageStatus { ACTIVE, COMPLETE, FAILED, - PENDING; + PENDING, + SKIPPED; public static StageStatus fromString(String str) { return EnumUtil.parseIgnoreCase(StageStatus.class, str); diff --git a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java index b8c2294c7b7ab..ee6d9f75ac5aa 100644 --- a/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java +++ b/core/src/main/java/org/apache/spark/unsafe/map/HashMapGrowthStrategy.java @@ -17,6 +17,8 @@ package org.apache.spark.unsafe.map; +import org.apache.spark.unsafe.array.ByteArrayMethods; + /** * Interface that defines how we can grow the size of a hash map when it is over a threshold. */ @@ -31,9 +33,7 @@ public interface HashMapGrowthStrategy { class Doubling implements HashMapGrowthStrategy { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - private static final int ARRAY_MAX = Integer.MAX_VALUE - 8; + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; @Override public int nextCapacity(int currentCapacity) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index 39eda00dd7efb..e749f7ba87c6e 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -480,6 +480,10 @@ public UnsafeSorterIterator getSortedIterator() throws IOException { } } + @VisibleForTesting boolean hasSpaceForAnotherRecord() { + return inMemSorter.hasSpaceForAnotherRecord(); + } + private static void spillIterator(UnsafeSorterIterator inMemIterator, UnsafeSorterSpillWriter spillWriter) throws IOException { while (inMemIterator.hasNext()) { diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java index c14c12664f5ab..3bb87a6ed653d 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorter.java @@ -162,7 +162,9 @@ private int getUsableCapacity() { */ public void free() { if (consumer != null) { - consumer.freeArray(array); + if (array != null) { + consumer.freeArray(array); + } array = null; } } @@ -170,6 +172,15 @@ public void free() { public void reset() { if (consumer != null) { consumer.freeArray(array); + // the call to consumer.allocateArray may trigger a spill which in turn access this instance + // and eventually re-enter this method and try to free the array again. by setting the array + // to null and its length to 0 we effectively make the spill code-path a no-op. setting the + // array to null also indicates that it has already been de-allocated which prevents a double + // de-allocation in free(). + array = null; + usableCapacity = 0; + pos = 0; + nullBoundaryPos = 0; array = consumer.allocateArray(initialSize); usableCapacity = getUsableCapacity(); } diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 1034fdcae8e8c..036c9a60630ea 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -89,7 +89,11 @@ trait FutureAction[T] extends Future[T] { */ override def value: Option[Try[T]] - // These two methods must be implemented in Scala 2.12, but won't be used by Spark + // These two methods must be implemented in Scala 2.12. They're implemented as a no-op here + // and then filled in with a real implementation in the two subclasses below. The no-op exists + // here so that those implementations can declare "override", necessary in 2.12, while working + // in 2.11, where the method doesn't exist in the superclass. + // After 2.11 support goes away, remove these two: def transform[S](f: (Try[T]) => Try[S])(implicit executor: ExecutionContext): Future[S] = throw new UnsupportedOperationException() @@ -113,6 +117,42 @@ trait FutureAction[T] extends Future[T] { } +/** + * Scala 2.12 defines the two new transform/transformWith methods mentioned above. Impementing + * these for 2.12 in the Spark class here requires delegating to these same methods in an + * underlying Future object. But that only exists in 2.12. But these methods are only called + * in 2.12. So define helper shims to access these methods on a Future by reflection. + */ +private[spark] object FutureAction { + + private val transformTryMethod = + try { + classOf[Future[_]].getMethod("transform", classOf[(_) => _], classOf[ExecutionContext]) + } catch { + case _: NoSuchMethodException => null // Would fail later in 2.11, but not called in 2.11 + } + + private val transformWithTryMethod = + try { + classOf[Future[_]].getMethod("transformWith", classOf[(_) => _], classOf[ExecutionContext]) + } catch { + case _: NoSuchMethodException => null // Would fail later in 2.11, but not called in 2.11 + } + + private[spark] def transform[T, S]( + future: Future[T], + f: (Try[T]) => Try[S], + executor: ExecutionContext): Future[S] = + transformTryMethod.invoke(future, f, executor).asInstanceOf[Future[S]] + + private[spark] def transformWith[T, S]( + future: Future[T], + f: (Try[T]) => Future[S], + executor: ExecutionContext): Future[S] = + transformWithTryMethod.invoke(future, f, executor).asInstanceOf[Future[S]] + +} + /** * A [[FutureAction]] holding the result of an action that triggers a single job. Examples include @@ -153,6 +193,18 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: jobWaiter.completionFuture.value.map {res => res.map(_ => resultFunc)} def jobIds: Seq[Int] = Seq(jobWaiter.jobId) + + override def transform[S](f: (Try[T]) => Try[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transform( + jobWaiter.completionFuture, + (u: Try[Unit]) => f(u.map(_ => resultFunc)), + e) + + override def transformWith[S](f: (Try[T]) => Future[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transformWith( + jobWaiter.completionFuture, + (u: Try[Unit]) => f(u.map(_ => resultFunc)), + e) } @@ -246,6 +298,11 @@ class ComplexFutureAction[T](run : JobSubmitter => Future[T]) def jobIds: Seq[Int] = subActions.flatMap(_.jobIds) + override def transform[S](f: (Try[T]) => Try[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transform(p.future, f, e) + + override def transformWith[S](f: (Try[T]) => Future[S])(implicit e: ExecutionContext): Future[S] = + FutureAction.transformWith(p.future, f, e) } diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 1484f29525a4e..debbd8d7c26c9 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -108,11 +108,21 @@ class HashPartitioner(partitions: Int) extends Partitioner { class RangePartitioner[K : Ordering : ClassTag, V]( partitions: Int, rdd: RDD[_ <: Product2[K, V]], - private var ascending: Boolean = true) + private var ascending: Boolean = true, + val samplePointsPerPartitionHint: Int = 20) extends Partitioner { + // A constructor declared in order to maintain backward compatibility for Java, when we add the + // 4th constructor parameter samplePointsPerPartitionHint. See SPARK-22160. + // This is added to make sure from a bytecode point of view, there is still a 3-arg ctor. + def this(partitions: Int, rdd: RDD[_ <: Product2[K, V]], ascending: Boolean) = { + this(partitions, rdd, ascending, samplePointsPerPartitionHint = 20) + } + // We allow partitions = 0, which happens when sorting an empty RDD under the default settings. require(partitions >= 0, s"Number of partitions cannot be negative but found $partitions.") + require(samplePointsPerPartitionHint > 0, + s"Sample points per partition must be greater than 0 but found $samplePointsPerPartitionHint") private var ordering = implicitly[Ordering[K]] @@ -122,7 +132,8 @@ class RangePartitioner[K : Ordering : ClassTag, V]( Array.empty } else { // This is the sample size we need to have roughly balanced output partitions, capped at 1M. - val sampleSize = math.min(20.0 * partitions, 1e6) + // Cast to double to avoid overflowing ints or longs + val sampleSize = math.min(samplePointsPerPartitionHint.toDouble * partitions, 1e6) // Assume the input partitions are roughly balanced and over-sample a little bit. val sampleSizePerPartition = math.ceil(3.0 * sampleSize / rdd.partitions.length).toInt val (numItems, sketched) = RangePartitioner.sketch(rdd.map(_._1), sampleSizePerPartition) diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index e61f943af49f2..57b3744e9c30a 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -662,7 +662,9 @@ private[spark] object SparkConf extends Logging { "spark.yarn.jars" -> Seq( AlternateConfig("spark.yarn.jar", "2.0")), "spark.yarn.access.hadoopFileSystems" -> Seq( - AlternateConfig("spark.yarn.access.namenodes", "2.2")) + AlternateConfig("spark.yarn.access.namenodes", "2.2")), + "spark.maxRemoteBlockSizeFetchToMem" -> Seq( + AlternateConfig("spark.reducer.maxReqSizeShuffleToMem", "2.3")) ) /** diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cec61d85ccf38..6362b730c0045 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -54,6 +54,7 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, StandaloneSchedulerBackend} import org.apache.spark.scheduler.local.LocalSchedulerBackend +import org.apache.spark.status.AppStatusStore import org.apache.spark.storage._ import org.apache.spark.storage.BlockManagerMessages.TriggerThreadDump import org.apache.spark.ui.{ConsoleProgressBar, SparkUI} @@ -213,6 +214,7 @@ class SparkContext(config: SparkConf) extends Logging { private var _jars: Seq[String] = _ private var _files: Seq[String] = _ private var _shutdownHookRef: AnyRef = _ + private var _statusStore: AppStatusStore = _ /* ------------------------------------------------------------------------------------- * | Accessors and public fields. These provide access to the internal state of the | @@ -421,6 +423,10 @@ class SparkContext(config: SparkConf) extends Logging { _jobProgressListener = new JobProgressListener(_conf) listenerBus.addToStatusQueue(jobProgressListener) + // Initialize the app state store and listener before SparkEnv is created so that it gets + // all events. + _statusStore = AppStatusStore.createLiveStore(conf, listenerBus) + // Create the Spark execution environment (cache, map output tracker, etc) _env = createSparkEnv(_conf, isLocal, listenerBus) SparkEnv.set(_env) @@ -434,7 +440,7 @@ class SparkContext(config: SparkConf) extends Logging { _statusTracker = new SparkStatusTracker(this) _progressBar = - if (_conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) { + if (_conf.get(UI_SHOW_CONSOLE_PROGRESS) && !log.isInfoEnabled) { Some(new ConsoleProgressBar(this)) } else { None @@ -442,8 +448,12 @@ class SparkContext(config: SparkConf) extends Logging { _ui = if (conf.getBoolean("spark.ui.enabled", true)) { - Some(SparkUI.createLiveUI(this, _conf, _jobProgressListener, - _env.securityManager, appName, startTime = startTime)) + Some(SparkUI.create(Some(this), _statusStore, _conf, + l => listenerBus.addToStatusQueue(l), + _env.securityManager, + appName, + "", + startTime)) } else { // For tests, do not enable the UI None @@ -1939,6 +1949,9 @@ class SparkContext(config: SparkConf) extends Logging { } SparkEnv.set(null) } + if (_statusStore != null) { + _statusStore.close() + } // Clear this `InheritableThreadLocal`, or it will still be inherited in child threads even this // `SparkContext` is stopped. localProperties.remove() @@ -2344,41 +2357,13 @@ class SparkContext(config: SparkConf) extends Logging { * (e.g. after the web UI and event logging listeners have been registered). */ private def setupAndStartListenerBus(): Unit = { - // Use reflection to instantiate listeners specified via `spark.extraListeners` try { - val listenerClassNames: Seq[String] = - conf.get("spark.extraListeners", "").split(',').map(_.trim).filter(_ != "") - for (className <- listenerClassNames) { - // Use reflection to find the right constructor - val constructors = { - val listenerClass = Utils.classForName(className) - listenerClass - .getConstructors - .asInstanceOf[Array[Constructor[_ <: SparkListenerInterface]]] - } - val constructorTakingSparkConf = constructors.find { c => - c.getParameterTypes.sameElements(Array(classOf[SparkConf])) - } - lazy val zeroArgumentConstructor = constructors.find { c => - c.getParameterTypes.isEmpty - } - val listener: SparkListenerInterface = { - if (constructorTakingSparkConf.isDefined) { - constructorTakingSparkConf.get.newInstance(conf) - } else if (zeroArgumentConstructor.isDefined) { - zeroArgumentConstructor.get.newInstance() - } else { - throw new SparkException( - s"$className did not have a zero-argument constructor or a" + - " single-argument constructor that accepts SparkConf. Note: if the class is" + - " defined inside of another Scala class, then its constructors may accept an" + - " implicit parameter that references the enclosing class; in this case, you must" + - " define the listener as a top-level class in order to prevent this extra" + - " parameter from breaking Spark's ability to find a valid constructor.") - } + conf.get(EXTRA_LISTENERS).foreach { classNames => + val listeners = Utils.loadExtensions(classOf[SparkListenerInterface], classNames, conf) + listeners.foreach { listener => + listenerBus.addToSharedQueue(listener) + logInfo(s"Registered listener ${listener.getClass().getName()}") } - listenerBus.addToSharedQueue(listener) - logInfo(s"Registered listener $className") } } catch { case e: Exception => diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 86d0405c678a7..f6293c0dc5091 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -48,7 +48,7 @@ private[spark] class PythonRDD( extends RDD[Array[Byte]](parent) { val bufferSize = conf.getInt("spark.buffer.size", 65536) - val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true) + val reuseWorker = conf.getBoolean("spark.python.worker.reuse", true) override def getPartitions: Array[Partition] = firstParent.partitions @@ -59,7 +59,7 @@ private[spark] class PythonRDD( val asJavaRDD: JavaRDD[Array[Byte]] = JavaRDD.fromRDD(this) override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { - val runner = PythonRunner(func, bufferSize, reuse_worker) + val runner = PythonRunner(func, bufferSize, reuseWorker) runner.compute(firstParent.iterator(split, context), split.index, context) } } @@ -83,318 +83,9 @@ private[spark] case class PythonFunction( */ private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction]) -/** - * Enumerate the type of command that will be sent to the Python worker - */ -private[spark] object PythonEvalType { - val NON_UDF = 0 - val SQL_BATCHED_UDF = 1 - val SQL_PANDAS_UDF = 2 -} - -private[spark] object PythonRunner { - def apply(func: PythonFunction, bufferSize: Int, reuse_worker: Boolean): PythonRunner = { - new PythonRunner( - Seq(ChainedPythonFunctions(Seq(func))), - bufferSize, - reuse_worker, - PythonEvalType.NON_UDF, - Array(Array(0))) - } -} - -/** - * A helper class to run Python mapPartition/UDFs in Spark. - * - * funcs is a list of independent Python functions, each one of them is a list of chained Python - * functions (from bottom to top). - */ -private[spark] class PythonRunner( - funcs: Seq[ChainedPythonFunctions], - bufferSize: Int, - reuse_worker: Boolean, - evalType: Int, - argOffsets: Array[Array[Int]]) - extends Logging { - - require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") - - // All the Python functions should have the same exec, version and envvars. - private val envVars = funcs.head.funcs.head.envVars - private val pythonExec = funcs.head.funcs.head.pythonExec - private val pythonVer = funcs.head.funcs.head.pythonVer - - // TODO: support accumulator in multiple UDF - private val accumulator = funcs.head.funcs.head.accumulator - - def compute( - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext): Iterator[Array[Byte]] = { - val startTime = System.currentTimeMillis - val env = SparkEnv.get - val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") - envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread - if (reuse_worker) { - envVars.put("SPARK_REUSE_WORKER", "1") - } - val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) - // Whether is the worker released into idle pool - @volatile var released = false - - // Start a thread to feed the process input from our parent's iterator - val writerThread = new WriterThread(env, worker, inputIterator, partitionIndex, context) - - context.addTaskCompletionListener { context => - writerThread.shutdownOnTaskCompletion() - if (!reuse_worker || !released) { - try { - worker.close() - } catch { - case e: Exception => - logWarning("Failed to close worker socket", e) - } - } - } - - writerThread.start() - new MonitorThread(env, worker, context).start() - - // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - val stdoutIterator = new Iterator[Array[Byte]] { - override def next(): Array[Byte] = { - val obj = _nextObj - if (hasNext) { - _nextObj = read() - } - obj - } - - private def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - obj - case 0 => Array.empty[Byte] - case SpecialLengths.TIMING_DATA => - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) - val memoryBytesSpilled = stream.readLong() - val diskBytesSpilled = stream.readLong() - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - throw new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.getOrElse(null)) - case SpecialLengths.END_OF_DATA_SECTION => - // We've finished the data section of the output, but we can still - // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - accumulator.add(update) - } - // Check whether the worker is ready to be re-used. - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuse_worker) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - released = true - } - } - null - } - } catch { - - case e: Exception if context.isInterrupted => - logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - - case e: Exception if env.isStopped => - logDebug("Exception thrown after context is stopped", e) - null // exit silently - - case e: Exception if writerThread.exception.isDefined => - logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get - - case eof: EOFException => - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } - } - - var _nextObj = read() - - override def hasNext: Boolean = _nextObj != null - } - new InterruptibleIterator(context, stdoutIterator) - } - - /** - * The thread responsible for writing the data from the PythonRDD's parent iterator to the - * Python process. - */ - class WriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[_], - partitionIndex: Int, - context: TaskContext) - extends Thread(s"stdout writer for $pythonExec") { - - @volatile private var _exception: Exception = null - - private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet - private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) - - setDaemon(true) - - /** Contains the exception thrown while writing the parent iterator to the Python process. */ - def exception: Option[Exception] = Option(_exception) - - /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ - def shutdownOnTaskCompletion() { - assert(context.isCompleted) - this.interrupt() - } - - override def run(): Unit = Utils.logUncaughtExceptions { - try { - TaskContext.setTaskContext(context) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - // Partition index - dataOut.writeInt(partitionIndex) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) - // Write out the TaskContextInfo - dataOut.writeInt(context.stageId()) - dataOut.writeInt(context.partitionId()) - dataOut.writeInt(context.attemptNumber()) - dataOut.writeLong(context.taskAttemptId()) - // sparkFilesDir - PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - // Broadcast variables - val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet - // number of different broadcasts - val toRemove = oldBids.diff(newBids) - val cnt = toRemove.size + newBids.diff(oldBids).size - dataOut.writeInt(cnt) - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(- bid - 1) // bid >= 0 - oldBids.remove(bid) - } - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { - // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) - } - } - dataOut.flush() - // Serialized command: - dataOut.writeInt(evalType) - if (evalType != PythonEvalType.NON_UDF) { - dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => - dataOut.writeInt(offsets.length) - offsets.foreach { offset => - dataOut.writeInt(offset) - } - dataOut.writeInt(chained.funcs.length) - chained.funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command) - } - } - } else { - val command = funcs.head.funcs.head.command - dataOut.writeInt(command.length) - dataOut.write(command) - } - // Data values - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) - dataOut.writeInt(SpecialLengths.END_OF_STREAM) - dataOut.flush() - } catch { - case e: Exception if context.isCompleted || context.isInterrupted => - logDebug("Exception thrown after task completion (likely due to cleanup)", e) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - - case e: Exception => - // We must avoid throwing exceptions here, because the thread uncaught exception handler - // will kill the whole executor (see org.apache.spark.executor.Executor). - _exception = e - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } - } - } - - /** - * It is necessary to have a monitor thread for python workers if the user cancels with - * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the - * threads can block indefinitely. - */ - class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) - extends Thread(s"Worker Monitor for $pythonExec") { - - setDaemon(true) - - override def run() { - // Kill the worker if it is interrupted, checking until task completion. - // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.isInterrupted && !context.isCompleted) { - Thread.sleep(2000) - } - if (!context.isCompleted) { - try { - logWarning("Incomplete task interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - } -} - /** Thrown for exceptions in user Python code. */ -private class PythonException(msg: String, cause: Exception) extends RuntimeException(msg, cause) +private[spark] class PythonException(msg: String, cause: Exception) + extends RuntimeException(msg, cause) /** * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from Python. @@ -411,14 +102,6 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends RDD[(Long, Array[Byte] val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -private object SpecialLengths { - val END_OF_DATA_SECTION = -1 - val PYTHON_EXCEPTION_THROWN = -2 - val TIMING_DATA = -3 - val END_OF_STREAM = -4 - val NULL = -5 -} - private[spark] object PythonRDD extends Logging { // remember the broadcasts sent to each worker diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala new file mode 100644 index 0000000000000..d417303bb147d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.io._ +import java.net._ +import java.nio.charset.StandardCharsets +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.spark._ +import org.apache.spark.internal.Logging +import org.apache.spark.util._ + + +/** + * Enumerate the type of command that will be sent to the Python worker + */ +private[spark] object PythonEvalType { + val NON_UDF = 0 + val SQL_BATCHED_UDF = 1 + val SQL_PANDAS_UDF = 2 + val SQL_PANDAS_GROUPED_UDF = 3 +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). + */ +private[spark] abstract class BasePythonRunner[IN, OUT]( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]]) + extends Logging { + + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + + // All the Python functions should have the same exec, version and envvars. + protected val envVars = funcs.head.funcs.head.envVars + protected val pythonExec = funcs.head.funcs.head.pythonExec + protected val pythonVer = funcs.head.funcs.head.pythonVer + + // TODO: support accumulator in multiple UDF + protected val accumulator = funcs.head.funcs.head.accumulator + + def compute( + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): Iterator[OUT] = { + val startTime = System.currentTimeMillis + val env = SparkEnv.get + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread + if (reuseWorker) { + envVars.put("SPARK_REUSE_WORKER", "1") + } + val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) + // Whether is the worker released into idle pool + val released = new AtomicBoolean(false) + + // Start a thread to feed the process input from our parent's iterator + val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + + context.addTaskCompletionListener { _ => + writerThread.shutdownOnTaskCompletion() + if (!reuseWorker || !released.get) { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + writerThread.start() + new MonitorThread(env, worker, context).start() + + // Return an iterator that read lines from the process's stdout + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val stdoutIterator = newReaderIterator( + stream, writerThread, startTime, env, worker, released, context) + new InterruptibleIterator(context, stdoutIterator) + } + + protected def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): WriterThread + + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[OUT] + + /** + * The thread responsible for writing the data from the PythonRDD's parent iterator to the + * Python process. + */ + abstract class WriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext) + extends Thread(s"stdout writer for $pythonExec") { + + @volatile private var _exception: Exception = null + + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) + + setDaemon(true) + + /** Contains the exception thrown while writing the parent iterator to the Python process. */ + def exception: Option[Exception] = Option(_exception) + + /** Terminates the writer thread, ignoring any exceptions that may occur due to cleanup. */ + def shutdownOnTaskCompletion() { + assert(context.isCompleted) + this.interrupt() + } + + /** + * Writes a command section to the stream connected to the Python worker. + */ + protected def writeCommand(dataOut: DataOutputStream): Unit + + /** + * Writes input data to the stream connected to the Python worker. + */ + protected def writeIteratorToStream(dataOut: DataOutputStream): Unit + + override def run(): Unit = Utils.logUncaughtExceptions { + try { + TaskContext.setTaskContext(context) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + // Partition index + dataOut.writeInt(partitionIndex) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) + // Write out the TaskContextInfo + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) + // sparkFilesDir + PythonRDD.writeUTF(SparkFiles.getRootDirectory(), dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } + // Broadcast variables + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val toRemove = oldBids.diff(newBids) + val cnt = toRemove.size + newBids.diff(oldBids).size + dataOut.writeInt(cnt) + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(- bid - 1) // bid >= 0 + oldBids.remove(bid) + } + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } + } + dataOut.flush() + + dataOut.writeInt(evalType) + writeCommand(dataOut) + writeIteratorToStream(dataOut) + + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + } catch { + case e: Exception if context.isCompleted || context.isInterrupted => + logDebug("Exception thrown after task completion (likely due to cleanup)", e) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + + case e: Exception => + // We must avoid throwing exceptions here, because the thread uncaught exception handler + // will kill the whole executor (see org.apache.spark.executor.Executor). + _exception = e + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } + } + } + + abstract class ReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext) + extends Iterator[OUT] { + + private var nextObj: OUT = _ + private var eos = false + + override def hasNext: Boolean = nextObj != null || { + if (!eos) { + nextObj = read() + hasNext + } else { + false + } + } + + override def next(): OUT = { + if (hasNext) { + val obj = nextObj + nextObj = null.asInstanceOf[OUT] + obj + } else { + Iterator.empty.next() + } + } + + /** + * Reads next object from the stream. + * When the stream reaches end of data, needs to process the following sections, + * and then returns null. + */ + protected def read(): OUT + + protected def handleTimingData(): Unit = { + // Timing data from worker + val bootTime = stream.readLong() + val initTime = stream.readLong() + val finishTime = stream.readLong() + val boot = bootTime - startTime + val init = initTime - bootTime + val finish = finishTime - initTime + val total = finishTime - startTime + logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, + init, finish)) + val memoryBytesSpilled = stream.readLong() + val diskBytesSpilled = stream.readLong() + context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) + context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) + } + + protected def handlePythonException(): PythonException = { + // Signals that an exception has been thrown in python + val exLength = stream.readInt() + val obj = new Array[Byte](exLength) + stream.readFully(obj) + new PythonException(new String(obj, StandardCharsets.UTF_8), + writerThread.exception.getOrElse(null)) + } + + protected def handleEndOfDataSection(): Unit = { + // We've finished the data section of the output, but we can still + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) + stream.readFully(update) + accumulator.add(update) + } + // Check whether the worker is ready to be re-used. + if (stream.readInt() == SpecialLengths.END_OF_STREAM) { + if (reuseWorker) { + env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) + released.set(true) + } + } + eos = true + } + + protected val handleException: PartialFunction[Throwable, OUT] = { + case e: Exception if context.isInterrupted => + logDebug("Exception thrown after task interruption", e) + throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) + + case e: Exception if env.isStopped => + logDebug("Exception thrown after context is stopped", e) + null.asInstanceOf[OUT] // exit silently + + case e: Exception if writerThread.exception.isDefined => + logError("Python worker exited unexpectedly (crashed)", e) + logError("This may have been caused by a prior exception:", writerThread.exception.get) + throw writerThread.exception.get + + case eof: EOFException => + throw new SparkException("Python worker exited unexpectedly (crashed)", eof) + } + } + + /** + * It is necessary to have a monitor thread for python workers if the user cancels with + * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the + * threads can block indefinitely. + */ + class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) + extends Thread(s"Worker Monitor for $pythonExec") { + + setDaemon(true) + + override def run() { + // Kill the worker if it is interrupted, checking until task completion. + // TODO: This has a race condition if interruption occurs, as completed may still become true. + while (!context.isInterrupted && !context.isCompleted) { + Thread.sleep(2000) + } + if (!context.isCompleted) { + try { + logWarning("Incomplete task interrupted: Attempting to kill Python Worker") + env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } + } + } +} + +private[spark] object PythonRunner { + + def apply(func: PythonFunction, bufferSize: Int, reuseWorker: Boolean): PythonRunner = { + new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), bufferSize, reuseWorker) + } +} + +/** + * A helper class to run Python mapPartition in Spark. + */ +private[spark] class PythonRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean) + extends BasePythonRunner[Array[Byte], Array[Byte]]( + funcs, bufferSize, reuseWorker, PythonEvalType.NON_UDF, Array(Array(0))) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + val command = funcs.head.funcs.head.command + dataOut.writeInt(command.length) + dataOut.write(command) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } +} + +private[spark] object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 + val END_OF_STREAM = -4 + val NULL = -5 + val START_ARROW_STREAM = -6 +} diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index bf6093236d92b..7acb5c55bb252 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -93,19 +93,19 @@ private class ClientEndpoint( driverArgs.cores, driverArgs.supervise, command) - ayncSendToMasterAndForwardReply[SubmitDriverResponse]( + asyncSendToMasterAndForwardReply[SubmitDriverResponse]( RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - ayncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) + asyncSendToMasterAndForwardReply[KillDriverResponse](RequestKillDriver(driverId)) } } /** * Send the message to master and forward the reply to self asynchronously. */ - private def ayncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { + private def asyncSendToMasterAndForwardReply[T: ClassTag](message: Any): Unit = { for (masterEndpoint <- masterEndpoints) { masterEndpoint.ask[T](message).onComplete { case Success(v) => self.send(v) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala b/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala new file mode 100644 index 0000000000000..118b4605675b0 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/SparkApplication.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy + +import java.lang.reflect.Modifier + +import org.apache.spark.SparkConf + +/** + * Entry point for a Spark application. Implementations must provide a no-argument constructor. + */ +private[spark] trait SparkApplication { + + def start(args: Array[String], conf: SparkConf): Unit + +} + +/** + * Implementation of SparkApplication that wraps a standard Java class with a "main" method. + * + * Configuration is propagated to the application via system properties, so running multiple + * of these in the same JVM may lead to undefined behavior due to configuration leaks. + */ +private[deploy] class JavaMainApplication(klass: Class[_]) extends SparkApplication { + + override def start(args: Array[String], conf: SparkConf): Unit = { + val mainMethod = klass.getMethod("main", new Array[String](0).getClass) + if (!Modifier.isStatic(mainMethod.getModifiers)) { + throw new IllegalStateException("The main method in the given main class must be static") + } + + val sysProps = conf.getAll.toMap + sysProps.foreach { case (k, v) => + sys.props(k) = v + } + + mainMethod.invoke(null, args) + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 53775db251bc6..1fa10ab943f34 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -61,13 +61,17 @@ class SparkHadoopUtil extends Logging { * do a FileSystem.closeAllForUGI in order to avoid leaking Filesystems */ def runAsSparkUser(func: () => Unit) { + createSparkUser().doAs(new PrivilegedExceptionAction[Unit] { + def run: Unit = func() + }) + } + + def createSparkUser(): UserGroupInformation = { val user = Utils.getCurrentUserName() - logDebug("running as user: " + user) + logDebug("creating UGI for user: " + user) val ugi = UserGroupInformation.createRemoteUser(user) transferCredentials(UserGroupInformation.getCurrentUser(), ugi) - ugi.doAs(new PrivilegedExceptionAction[Unit] { - def run: Unit = func() - }) + ugi } def transferCredentials(source: UserGroupInformation, dest: UserGroupInformation) { @@ -417,6 +421,11 @@ class SparkHadoopUtil extends Logging { creds.readTokenStorageStream(new DataInputStream(tokensBuf)) creds } + + def isProxyUser(ugi: UserGroupInformation): Boolean = { + ugi.getAuthenticationMethod() == UserGroupInformation.AuthenticationMethod.PROXY + } + } object SparkHadoopUtil { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 286a4379d2040..73b956ef3e470 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -158,7 +158,7 @@ object SparkSubmit extends CommandLineUtils with Logging { */ @tailrec private def submit(args: SparkSubmitArguments, uninitLog: Boolean): Unit = { - val (childArgs, childClasspath, sysProps, childMainClass) = prepareSubmitEnvironment(args) + val (childArgs, childClasspath, sparkConf, childMainClass) = prepareSubmitEnvironment(args) def doRunMain(): Unit = { if (args.proxyUser != null) { @@ -167,7 +167,7 @@ object SparkSubmit extends CommandLineUtils with Logging { try { proxyUser.doAs(new PrivilegedExceptionAction[Unit]() { override def run(): Unit = { - runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + runMain(childArgs, childClasspath, sparkConf, childMainClass, args.verbose) } }) } catch { @@ -185,7 +185,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } } } else { - runMain(childArgs, childClasspath, sysProps, childMainClass, args.verbose) + runMain(childArgs, childClasspath, sparkConf, childMainClass, args.verbose) } } @@ -235,11 +235,11 @@ object SparkSubmit extends CommandLineUtils with Logging { private[deploy] def prepareSubmitEnvironment( args: SparkSubmitArguments, conf: Option[HadoopConfiguration] = None) - : (Seq[String], Seq[String], Map[String, String], String) = { + : (Seq[String], Seq[String], SparkConf, String) = { // Return values val childArgs = new ArrayBuffer[String]() val childClasspath = new ArrayBuffer[String]() - val sysProps = new HashMap[String, String]() + val sparkConf = new SparkConf() var childMainClass = "" // Set the cluster manager @@ -337,34 +337,50 @@ object SparkSubmit extends CommandLineUtils with Logging { } } - val sparkConf = new SparkConf(false) args.sparkProperties.foreach { case (k, v) => sparkConf.set(k, v) } val hadoopConf = conf.getOrElse(SparkHadoopUtil.newConfiguration(sparkConf)) val targetDir = Utils.createTempDir() + // assure a keytab is available from any place in a JVM + if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { + if (args.principal != null) { + if (args.keytab != null) { + require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") + // Add keytab and principal configurations in sysProps to make them available + // for later use; e.g. in spark sql, the isolated class loader used to talk + // to HiveMetastore will use these settings. They will be set as Java system + // properties and then loaded by SparkConf + sparkConf.set(KEYTAB, args.keytab) + sparkConf.set(PRINCIPAL, args.principal) + UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) + } + } + } + // Resolve glob path for different resources. args.jars = Option(args.jars).map(resolveGlobPaths(_, hadoopConf)).orNull args.files = Option(args.files).map(resolveGlobPaths(_, hadoopConf)).orNull args.pyFiles = Option(args.pyFiles).map(resolveGlobPaths(_, hadoopConf)).orNull args.archives = Option(args.archives).map(resolveGlobPaths(_, hadoopConf)).orNull + // This security manager will not need an auth secret, but set a dummy value in case + // spark.authenticate is enabled, otherwise an exception is thrown. + lazy val downloadConf = sparkConf.clone().set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") + lazy val secMgr = new SecurityManager(downloadConf) + // In client mode, download remote files. var localPrimaryResource: String = null var localJars: String = null var localPyFiles: String = null if (deployMode == CLIENT) { - // This security manager will not need an auth secret, but set a dummy value in case - // spark.authenticate is enabled, otherwise an exception is thrown. - sparkConf.set(SecurityManager.SPARK_AUTH_SECRET_CONF, "unused") - val secMgr = new SecurityManager(sparkConf) localPrimaryResource = Option(args.primaryResource).map { - downloadFile(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFile(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull localJars = Option(args.jars).map { - downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFileList(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull localPyFiles = Option(args.pyFiles).map { - downloadFileList(_, targetDir, sparkConf, hadoopConf, secMgr) + downloadFileList(_, targetDir, downloadConf, hadoopConf, secMgr) }.orNull } @@ -393,7 +409,7 @@ object SparkSubmit extends CommandLineUtils with Logging { if (file.exists()) { file.toURI.toString } else { - downloadFile(resource, targetDir, sparkConf, hadoopConf, secMgr) + downloadFile(resource, targetDir, downloadConf, hadoopConf, secMgr) } case _ => uri.toString } @@ -433,7 +449,7 @@ object SparkSubmit extends CommandLineUtils with Logging { args.files = mergeFileLists(args.files, args.pyFiles) } if (localPyFiles != null) { - sysProps("spark.submit.pyFiles") = localPyFiles + sparkConf.set("spark.submit.pyFiles", localPyFiles) } } @@ -499,69 +515,69 @@ object SparkSubmit extends CommandLineUtils with Logging { } // Special flag to avoid deprecation warnings at the client - sysProps("SPARK_SUBMIT") = "true" + sys.props("SPARK_SUBMIT") = "true" // A list of rules to map each argument to system properties or command-line options in // each deploy mode; we iterate through these below val options = List[OptionAssigner]( // All cluster managers - OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.master"), + OptionAssigner(args.master, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.master"), OptionAssigner(args.deployMode, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.submit.deployMode"), - OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, sysProp = "spark.app.name"), - OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.jars.ivy"), + confKey = "spark.submit.deployMode"), + OptionAssigner(args.name, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, confKey = "spark.app.name"), + OptionAssigner(args.ivyRepoPath, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.jars.ivy"), OptionAssigner(args.driverMemory, ALL_CLUSTER_MGRS, CLIENT, - sysProp = "spark.driver.memory"), + confKey = "spark.driver.memory"), OptionAssigner(args.driverExtraClassPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraClassPath"), + confKey = "spark.driver.extraClassPath"), OptionAssigner(args.driverExtraJavaOptions, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraJavaOptions"), + confKey = "spark.driver.extraJavaOptions"), OptionAssigner(args.driverExtraLibraryPath, ALL_CLUSTER_MGRS, ALL_DEPLOY_MODES, - sysProp = "spark.driver.extraLibraryPath"), + confKey = "spark.driver.extraLibraryPath"), // Propagate attributes for dependency resolution at the driver side - OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.packages"), + OptionAssigner(args.packages, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.packages"), OptionAssigner(args.repositories, STANDALONE | MESOS, CLUSTER, - sysProp = "spark.jars.repositories"), - OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, sysProp = "spark.jars.ivy"), + confKey = "spark.jars.repositories"), + OptionAssigner(args.ivyRepoPath, STANDALONE | MESOS, CLUSTER, confKey = "spark.jars.ivy"), OptionAssigner(args.packagesExclusions, STANDALONE | MESOS, - CLUSTER, sysProp = "spark.jars.excludes"), + CLUSTER, confKey = "spark.jars.excludes"), // Yarn only - OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.queue"), + OptionAssigner(args.queue, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.queue"), OptionAssigner(args.numExecutors, YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.instances"), - OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.pyFiles"), - OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.jars"), - OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.files"), - OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.dist.archives"), - OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.principal"), - OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, sysProp = "spark.yarn.keytab"), + confKey = "spark.executor.instances"), + OptionAssigner(args.pyFiles, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.pyFiles"), + OptionAssigner(args.jars, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.jars"), + OptionAssigner(args.files, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.files"), + OptionAssigner(args.archives, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.dist.archives"), + OptionAssigner(args.principal, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.principal"), + OptionAssigner(args.keytab, YARN, ALL_DEPLOY_MODES, confKey = "spark.yarn.keytab"), // Other options OptionAssigner(args.executorCores, STANDALONE | YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.cores"), + confKey = "spark.executor.cores"), OptionAssigner(args.executorMemory, STANDALONE | MESOS | YARN, ALL_DEPLOY_MODES, - sysProp = "spark.executor.memory"), + confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.cores.max"), + confKey = "spark.cores.max"), OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, - sysProp = "spark.files"), - OptionAssigner(args.jars, LOCAL, CLIENT, sysProp = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, sysProp = "spark.jars"), + confKey = "spark.files"), + OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN, CLUSTER, - sysProp = "spark.driver.memory"), + confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN, CLUSTER, - sysProp = "spark.driver.cores"), + confKey = "spark.driver.cores"), OptionAssigner(args.supervise.toString, STANDALONE | MESOS, CLUSTER, - sysProp = "spark.driver.supervise"), - OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, sysProp = "spark.jars.ivy"), + confKey = "spark.driver.supervise"), + OptionAssigner(args.ivyRepoPath, STANDALONE, CLUSTER, confKey = "spark.jars.ivy"), // An internal option used only for spark-shell to add user jars to repl's classloader, // previously it uses "spark.jars" or "spark.yarn.dist.jars" which now may be pointed to // remote jars, so adding a new option to only specify local jars for spark-shell internally. - OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, sysProp = "spark.repl.local.jars") + OptionAssigner(localJars, ALL_CLUSTER_MGRS, CLIENT, confKey = "spark.repl.local.jars") ) // In client mode, launch the application main class directly @@ -594,19 +610,24 @@ object SparkSubmit extends CommandLineUtils with Logging { (deployMode & opt.deployMode) != 0 && (clusterManager & opt.clusterManager) != 0) { if (opt.clOption != null) { childArgs += (opt.clOption, opt.value) } - if (opt.sysProp != null) { sysProps.put(opt.sysProp, opt.value) } + if (opt.confKey != null) { sparkConf.set(opt.confKey, opt.value) } } } + // In case of shells, spark.ui.showConsoleProgress can be true by default or by user. + if (isShell(args.primaryResource) && !sparkConf.contains(UI_SHOW_CONSOLE_PROGRESS)) { + sparkConf.set(UI_SHOW_CONSOLE_PROGRESS, true) + } + // Add the application jar automatically so the user doesn't have to call sc.addJar // For YARN cluster mode, the jar is already distributed on each node as "app.jar" // For python and R files, the primary resource is already distributed as a regular file if (!isYarnCluster && !args.isPython && !args.isR) { - var jars = sysProps.get("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) + var jars = sparkConf.getOption("spark.jars").map(x => x.split(",").toSeq).getOrElse(Seq.empty) if (isUserJar(args.primaryResource)) { jars = jars ++ Seq(args.primaryResource) } - sysProps.put("spark.jars", jars.mkString(",")) + sparkConf.set("spark.jars", jars.mkString(",")) } // In standalone cluster mode, use the REST client to submit the application (Spark 1.3+). @@ -632,28 +653,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // Let YARN know it's a pyspark app, so it distributes needed libraries. if (clusterManager == YARN) { if (args.isPython) { - sysProps.put("spark.yarn.isPython", "true") - } - } - - // assure a keytab is available from any place in a JVM - if (clusterManager == YARN || clusterManager == LOCAL || clusterManager == MESOS) { - if (args.principal != null) { - if (args.keytab != null) { - require(new File(args.keytab).exists(), s"Keytab file: ${args.keytab} does not exist") - // Add keytab and principal configurations in sysProps to make them available - // for later use; e.g. in spark sql, the isolated class loader used to talk - // to HiveMetastore will use these settings. They will be set as Java system - // properties and then loaded by SparkConf - sysProps.put("spark.yarn.keytab", args.keytab) - sysProps.put("spark.yarn.principal", args.principal) - UserGroupInformation.loginUserFromKeytab(args.principal, args.keytab) - } + sparkConf.set("spark.yarn.isPython", "true") } } if (clusterManager == MESOS && UserGroupInformation.isSecurityEnabled) { - setRMPrincipal(sysProps) + setRMPrincipal(sparkConf) } // In yarn-cluster mode, use yarn.Client as a wrapper around the user class @@ -684,7 +689,7 @@ object SparkSubmit extends CommandLineUtils with Logging { // Second argument is main class childArgs += (args.primaryResource, "") if (args.pyFiles != null) { - sysProps("spark.submit.pyFiles") = args.pyFiles + sparkConf.set("spark.submit.pyFiles", args.pyFiles) } } else if (args.isR) { // Second argument is main class @@ -699,12 +704,12 @@ object SparkSubmit extends CommandLineUtils with Logging { // Load any properties specified through --conf and the default properties file for ((k, v) <- args.sparkProperties) { - sysProps.getOrElseUpdate(k, v) + sparkConf.setIfMissing(k, v) } // Ignore invalid spark.driver.host in cluster modes. if (deployMode == CLUSTER) { - sysProps -= "spark.driver.host" + sparkConf.remove("spark.driver.host") } // Resolve paths in certain spark properties @@ -716,15 +721,15 @@ object SparkSubmit extends CommandLineUtils with Logging { "spark.yarn.dist.jars") pathConfigs.foreach { config => // Replace old URIs with resolved URIs, if they exist - sysProps.get(config).foreach { oldValue => - sysProps(config) = Utils.resolveURIs(oldValue) + sparkConf.getOption(config).foreach { oldValue => + sparkConf.set(config, Utils.resolveURIs(oldValue)) } } // Resolve and format python file paths properly before adding them to the PYTHONPATH. // The resolving part is redundant in the case of --py-files, but necessary if the user // explicitly sets `spark.submit.pyFiles` in his/her default properties file. - sysProps.get("spark.submit.pyFiles").foreach { pyFiles => + sparkConf.getOption("spark.submit.pyFiles").foreach { pyFiles => val resolvedPyFiles = Utils.resolveURIs(pyFiles) val formattedPyFiles = if (!isYarnCluster && !isMesosCluster) { PythonRunner.formatPaths(resolvedPyFiles).mkString(",") @@ -734,22 +739,22 @@ object SparkSubmit extends CommandLineUtils with Logging { // locally. resolvedPyFiles } - sysProps("spark.submit.pyFiles") = formattedPyFiles + sparkConf.set("spark.submit.pyFiles", formattedPyFiles) } - (childArgs, childClasspath, sysProps, childMainClass) + (childArgs, childClasspath, sparkConf, childMainClass) } // [SPARK-20328]. HadoopRDD calls into a Hadoop library that fetches delegation tokens with // renewer set to the YARN ResourceManager. Since YARN isn't configured in Mesos mode, we // must trick it into thinking we're YARN. - private def setRMPrincipal(sysProps: HashMap[String, String]): Unit = { + private def setRMPrincipal(sparkConf: SparkConf): Unit = { val shortUserName = UserGroupInformation.getCurrentUser.getShortUserName val key = s"spark.hadoop.${YarnConfiguration.RM_PRINCIPAL}" // scalastyle:off println printStream.println(s"Setting ${key} to ${shortUserName}") // scalastyle:off println - sysProps.put(key, shortUserName) + sparkConf.set(key, shortUserName) } /** @@ -761,7 +766,7 @@ object SparkSubmit extends CommandLineUtils with Logging { private def runMain( childArgs: Seq[String], childClasspath: Seq[String], - sysProps: Map[String, String], + sparkConf: SparkConf, childMainClass: String, verbose: Boolean): Unit = { // scalastyle:off println @@ -769,14 +774,14 @@ object SparkSubmit extends CommandLineUtils with Logging { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") // sysProps may contain sensitive information, so redact before printing - printStream.println(s"System properties:\n${Utils.redact(sysProps).mkString("\n")}") + printStream.println(s"Spark config:\n${Utils.redact(sparkConf.getAll.toMap).mkString("\n")}") printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } // scalastyle:on println val loader = - if (sysProps.getOrElse("spark.driver.userClassPathFirst", "false").toBoolean) { + if (sparkConf.get(DRIVER_USER_CLASS_PATH_FIRST)) { new ChildFirstURLClassLoader(new Array[URL](0), Thread.currentThread.getContextClassLoader) } else { @@ -789,10 +794,6 @@ object SparkSubmit extends CommandLineUtils with Logging { addJarToClasspath(jar, loader) } - for ((key, value) <- sysProps) { - System.setProperty(key, value) - } - var mainClass: Class[_] = null try { @@ -818,14 +819,14 @@ object SparkSubmit extends CommandLineUtils with Logging { System.exit(CLASS_NOT_FOUND_EXIT_STATUS) } - // SPARK-4170 - if (classOf[scala.App].isAssignableFrom(mainClass)) { - printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") - } - - val mainMethod = mainClass.getMethod("main", new Array[String](0).getClass) - if (!Modifier.isStatic(mainMethod.getModifiers)) { - throw new IllegalStateException("The main method in the given main class must be static") + val app: SparkApplication = if (classOf[SparkApplication].isAssignableFrom(mainClass)) { + mainClass.newInstance().asInstanceOf[SparkApplication] + } else { + // SPARK-4170 + if (classOf[scala.App].isAssignableFrom(mainClass)) { + printWarning("Subclasses of scala.App may not work correctly. Use a main() method instead.") + } + new JavaMainApplication(mainClass) } @tailrec @@ -839,7 +840,7 @@ object SparkSubmit extends CommandLineUtils with Logging { } try { - mainMethod.invoke(null, childArgs.toArray) + app.start(childArgs.toArray, sparkConf) } catch { case t: Throwable => findCause(t) match { @@ -1266,4 +1267,4 @@ private case class OptionAssigner( clusterManager: Int, deployMode: Int, clOption: String = null, - sysProp: String = null) + confKey: String = null) diff --git a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala index 757c930b84eb2..34ade4ce6f39b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/StandaloneAppClient.scala @@ -170,7 +170,7 @@ private[spark] class StandaloneAppClient( case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id - logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, + logInfo("Executor added: %s on %s (%s) with %d core(s)".format(fullId, workerId, hostPort, cores)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala index a370526c46f3d..ccf3437451edc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationCache.scala @@ -26,6 +26,7 @@ import scala.util.control.NonFatal import com.codahale.metrics.{Counter, MetricRegistry, Timer} import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache, RemovalListener, RemovalNotification} +import com.google.common.util.concurrent.UncheckedExecutionException import org.eclipse.jetty.servlet.FilterHolder import org.apache.spark.internal.Logging @@ -40,11 +41,6 @@ import org.apache.spark.util.Clock * Incompleted applications have their update time checked on every * retrieval; if the cached entry is out of date, it is refreshed. * - * @note there must be only one instance of [[ApplicationCache]] in a - * JVM at a time. This is because a static field in [[ApplicationCacheCheckFilterRelay]] - * keeps a reference to the cache so that HTTP requests on the attempt-specific web UIs - * can probe the current cache to see if the attempts have changed. - * * Creating multiple instances will break this routing. * @param operations implementation of record access operations * @param retainedApplications number of retained applications @@ -80,7 +76,7 @@ private[history] class ApplicationCache( metrics.evictionCount.inc() val key = rm.getKey logDebug(s"Evicting entry ${key}") - operations.detachSparkUI(key.appId, key.attemptId, rm.getValue().ui) + operations.detachSparkUI(key.appId, key.attemptId, rm.getValue().loadedUI.ui) } } @@ -89,7 +85,7 @@ private[history] class ApplicationCache( * * Tagged as `protected` so as to allow subclasses in tests to access it directly */ - protected val appCache: LoadingCache[CacheKey, CacheEntry] = { + private val appCache: LoadingCache[CacheKey, CacheEntry] = { CacheBuilder.newBuilder() .maximumSize(retainedApplications) .removalListener(removalListener) @@ -101,130 +97,46 @@ private[history] class ApplicationCache( */ val metrics = new CacheMetrics("history.cache") - init() - - /** - * Perform any startup operations. - * - * This includes declaring this instance as the cache to use in the - * [[ApplicationCacheCheckFilterRelay]]. - */ - private def init(): Unit = { - ApplicationCacheCheckFilterRelay.setApplicationCache(this) - } - - /** - * Stop the cache. - * This will reset the relay in [[ApplicationCacheCheckFilterRelay]]. - */ - def stop(): Unit = { - ApplicationCacheCheckFilterRelay.resetApplicationCache() - } - - /** - * Get an entry. - * - * Cache fetch/refresh will have taken place by the time this method returns. - * @param appAndAttempt application to look up in the format needed by the history server web UI, - * `appId/attemptId` or `appId`. - * @return the entry - */ - def get(appAndAttempt: String): SparkUI = { - val parts = splitAppAndAttemptKey(appAndAttempt) - get(parts._1, parts._2) - } - - /** - * Get the Spark UI, converting a lookup failure from an exception to `None`. - * @param appAndAttempt application to look up in the format needed by the history server web UI, - * `appId/attemptId` or `appId`. - * @return the entry - */ - def getSparkUI(appAndAttempt: String): Option[SparkUI] = { + def get(appId: String, attemptId: Option[String] = None): CacheEntry = { try { - val ui = get(appAndAttempt) - Some(ui) + appCache.get(new CacheKey(appId, attemptId)) } catch { - case NonFatal(e) => e.getCause() match { - case nsee: NoSuchElementException => - None - case cause: Exception => throw cause - } + case e: UncheckedExecutionException => + throw Option(e.getCause()).getOrElse(e) } } - /** - * Get the associated spark UI. - * - * Cache fetch/refresh will have taken place by the time this method returns. - * @param appId application ID - * @param attemptId optional attempt ID - * @return the entry - */ - def get(appId: String, attemptId: Option[String]): SparkUI = { - lookupAndUpdate(appId, attemptId)._1.ui - } + def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T = { + var entry = get(appId, attemptId) - /** - * Look up the entry; update it if needed. - * @param appId application ID - * @param attemptId optional attempt ID - * @return the underlying cache entry -which can have its timestamp changed, and a flag to - * indicate that the entry has changed - */ - private def lookupAndUpdate(appId: String, attemptId: Option[String]): (CacheEntry, Boolean) = { - metrics.lookupCount.inc() - val cacheKey = CacheKey(appId, attemptId) - var entry = appCache.getIfPresent(cacheKey) - var updated = false - if (entry == null) { - // no entry, so fetch without any post-fetch probes for out-of-dateness - // this will trigger a callback to loadApplicationEntry() - entry = appCache.get(cacheKey) - } else if (!entry.completed) { - val now = clock.getTimeMillis() - log.debug(s"Probing at time $now for updated application $cacheKey -> $entry") - metrics.updateProbeCount.inc() - updated = time(metrics.updateProbeTimer) { - entry.updateProbe() + // If the entry exists, we need to make sure we run the closure with a valid entry. So + // we need to re-try until we can lock a valid entry for read. + entry.loadedUI.lock.readLock().lock() + try { + while (!entry.loadedUI.valid) { + entry.loadedUI.lock.readLock().unlock() + entry = null + try { + appCache.invalidate(new CacheKey(appId, attemptId)) + entry = get(appId, attemptId) + if (entry == null) { + metrics.lookupFailureCount.inc() + throw new NoSuchElementException() + } + metrics.loadCount.inc() + } finally { + if (entry != null) { + entry.loadedUI.lock.readLock().lock() + } + } } - if (updated) { - logDebug(s"refreshing $cacheKey") - metrics.updateTriggeredCount.inc() - appCache.refresh(cacheKey) - // and repeat the lookup - entry = appCache.get(cacheKey) - } else { - // update the probe timestamp to the current time - entry.probeTime = now + + fn(entry.loadedUI.ui) + } finally { + if (entry != null) { + entry.loadedUI.lock.readLock().unlock() } } - (entry, updated) - } - - /** - * This method is visible for testing. - * - * It looks up the cached entry *and returns a clone of it*. - * This ensures that the cached entries never leak - * @param appId application ID - * @param attemptId optional attempt ID - * @return a new entry with shared SparkUI, but copies of the other fields. - */ - def lookupCacheEntry(appId: String, attemptId: Option[String]): CacheEntry = { - val entry = lookupAndUpdate(appId, attemptId)._1 - new CacheEntry(entry.ui, entry.completed, entry.updateProbe, entry.probeTime) - } - - /** - * Probe for an application being updated. - * @param appId application ID - * @param attemptId attempt ID - * @return true if an update has been triggered - */ - def checkForUpdates(appId: String, attemptId: Option[String]): Boolean = { - val (entry, updated) = lookupAndUpdate(appId, attemptId) - updated } /** @@ -272,27 +184,15 @@ private[history] class ApplicationCache( * @throws NoSuchElementException if there is no matching element */ @throws[NoSuchElementException] - def loadApplicationEntry(appId: String, attemptId: Option[String]): CacheEntry = { - + private def loadApplicationEntry(appId: String, attemptId: Option[String]): CacheEntry = { logDebug(s"Loading application Entry $appId/$attemptId") metrics.loadCount.inc() - time(metrics.loadTimer) { + val loadedUI = time(metrics.loadTimer) { + metrics.lookupCount.inc() operations.getAppUI(appId, attemptId) match { - case Some(LoadedAppUI(ui, updateState)) => - val completed = ui.getApplicationInfoList.exists(_.attempts.last.completed) - if (completed) { - // completed spark UIs are attached directly - operations.attachSparkUI(appId, attemptId, ui, completed) - } else { - // incomplete UIs have the cache-check filter put in front of them. - ApplicationCacheCheckFilterRelay.registerFilter(ui, appId, attemptId) - operations.attachSparkUI(appId, attemptId, ui, completed) - } - // build the cache entry - val now = clock.getTimeMillis() - val entry = new CacheEntry(ui, completed, updateState, now) - logDebug(s"Loaded application $appId/$attemptId -> $entry") - entry + case Some(loadedUI) => + logDebug(s"Loaded application $appId/$attemptId") + loadedUI case None => metrics.lookupFailureCount.inc() // guava's cache logs via java.util log, so is of limited use. Hence: our own message @@ -301,32 +201,20 @@ private[history] class ApplicationCache( attemptId.map { id => s" attemptId '$id'" }.getOrElse(" and no attempt Id")) } } - } - - /** - * Split up an `applicationId/attemptId` or `applicationId` key into the separate pieces. - * - * @param appAndAttempt combined key - * @return a tuple of the application ID and, if present, the attemptID - */ - def splitAppAndAttemptKey(appAndAttempt: String): (String, Option[String]) = { - val parts = appAndAttempt.split("/") - require(parts.length == 1 || parts.length == 2, s"Invalid app key $appAndAttempt") - val appId = parts(0) - val attemptId = if (parts.length > 1) Some(parts(1)) else None - (appId, attemptId) - } - - /** - * Merge an appId and optional attempt Id into a key of the form `applicationId/attemptId`. - * - * If there is an `attemptId`; `applicationId` if not. - * @param appId application ID - * @param attemptId optional attempt ID - * @return a unified string - */ - def mergeAppAndAttemptToKey(appId: String, attemptId: Option[String]): String = { - appId + attemptId.map { id => s"/$id" }.getOrElse("") + try { + val completed = loadedUI.ui.getApplicationInfoList.exists(_.attempts.last.completed) + if (!completed) { + // incomplete UIs have the cache-check filter put in front of them. + registerFilter(new CacheKey(appId, attemptId), loadedUI, this) + } + operations.attachSparkUI(appId, attemptId, loadedUI.ui, completed) + new CacheEntry(loadedUI, completed) + } catch { + case e: Exception => + logWarning(s"Failed to initialize application UI for $appId/$attemptId", e) + operations.detachSparkUI(appId, attemptId, loadedUI.ui) + throw e + } } /** @@ -347,6 +235,26 @@ private[history] class ApplicationCache( sb.append("----\n") sb.toString() } + + /** + * Register a filter for the web UI which checks for updates to the given app/attempt + * @param ui Spark UI to attach filters to + * @param appId application ID + * @param attemptId attempt ID + */ + def registerFilter(key: CacheKey, loadedUI: LoadedAppUI, cache: ApplicationCache): Unit = { + require(loadedUI != null) + val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.REQUEST) + val filter = new ApplicationCacheCheckFilter(key, loadedUI, cache) + val holder = new FilterHolder(filter) + require(loadedUI.ui.getHandlers != null, "null handlers") + loadedUI.ui.getHandlers.foreach { handler => + handler.addFilter(holder, "/*", enumDispatcher) + } + } + + def invalidate(key: CacheKey): Unit = appCache.invalidate(key) + } /** @@ -360,14 +268,12 @@ private[history] class ApplicationCache( * @param probeTime Times in milliseconds when the probe was last executed. */ private[history] final class CacheEntry( - val ui: SparkUI, - val completed: Boolean, - val updateProbe: () => Boolean, - var probeTime: Long) { + val loadedUI: LoadedAppUI, + val completed: Boolean) { /** string value is for test assertions */ override def toString: String = { - s"UI $ui, completed=$completed, probeTime=$probeTime" + s"UI ${loadedUI.ui}, completed=$completed" } } @@ -396,23 +302,17 @@ private[history] class CacheMetrics(prefix: String) extends Source { val evictionCount = new Counter() val loadCount = new Counter() val loadTimer = new Timer() - val updateProbeCount = new Counter() - val updateProbeTimer = new Timer() - val updateTriggeredCount = new Counter() /** all the counters: for registration and string conversion. */ private val counters = Seq( ("lookup.count", lookupCount), ("lookup.failure.count", lookupFailureCount), ("eviction.count", evictionCount), - ("load.count", loadCount), - ("update.probe.count", updateProbeCount), - ("update.triggered.count", updateTriggeredCount)) + ("load.count", loadCount)) /** all metrics, including timers */ private val allMetrics = counters ++ Seq( - ("load.timer", loadTimer), - ("update.probe.timer", updateProbeTimer)) + ("load.timer", loadTimer)) /** * Name of metric source @@ -498,23 +398,11 @@ private[history] trait ApplicationCacheOperations { * Implementation note: there's some abuse of a shared global entry here because * the configuration data passed to the servlet is just a string:string map. */ -private[history] class ApplicationCacheCheckFilter() extends Filter with Logging { - - import ApplicationCacheCheckFilterRelay._ - var appId: String = _ - var attemptId: Option[String] = _ - - /** - * Bind the app and attempt ID, throwing an exception if no application ID was provided. - * @param filterConfig configuration - */ - override def init(filterConfig: FilterConfig): Unit = { - - appId = Option(filterConfig.getInitParameter(APP_ID)) - .getOrElse(throw new ServletException(s"Missing Parameter $APP_ID")) - attemptId = Option(filterConfig.getInitParameter(ATTEMPT_ID)) - logDebug(s"initializing filter $this") - } +private[history] class ApplicationCacheCheckFilter( + key: CacheKey, + loadedUI: LoadedAppUI, + cache: ApplicationCache) + extends Filter with Logging { /** * Filter the request. @@ -543,123 +431,24 @@ private[history] class ApplicationCacheCheckFilter() extends Filter with Logging // if the request is for an attempt, check to see if it is in need of delete/refresh // and have the cache update the UI if so - if (operation=="HEAD" || operation=="GET" - && checkForUpdates(requestURI, appId, attemptId)) { - // send a redirect back to the same location. This will be routed - // to the *new* UI - logInfo(s"Application Attempt $appId/$attemptId updated; refreshing") + loadedUI.lock.readLock().lock() + if (loadedUI.valid) { + try { + chain.doFilter(request, response) + } finally { + loadedUI.lock.readLock.unlock() + } + } else { + loadedUI.lock.readLock.unlock() + cache.invalidate(key) val queryStr = Option(httpRequest.getQueryString).map("?" + _).getOrElse("") val redirectUrl = httpResponse.encodeRedirectURL(requestURI + queryStr) httpResponse.sendRedirect(redirectUrl) - } else { - chain.doFilter(request, response) } } - override def destroy(): Unit = { - } - - override def toString: String = s"ApplicationCacheCheckFilter for $appId/$attemptId" -} - -/** - * Global state for the [[ApplicationCacheCheckFilter]] instances, so that they can relay cache - * probes to the cache. - * - * This is an ugly workaround for the limitation of servlets and filters in the Java servlet - * API; they are still configured on the model of a list of classnames and configuration - * strings in a `web.xml` field, rather than a chain of instances wired up by hand or - * via an injection framework. There is no way to directly configure a servlet filter instance - * with a reference to the application cache which is must use: some global state is needed. - * - * Here, [[ApplicationCacheCheckFilter]] is that global state; it relays all requests - * to the singleton [[ApplicationCache]] - * - * The field `applicationCache` must be set for the filters to work - - * this is done during the construction of [[ApplicationCache]], which requires that there - * is only one cache serving requests through the WebUI. - * - * *Important* In test runs, if there is more than one [[ApplicationCache]], the relay logic - * will break: filters may not find instances. Tests must not do that. - * - */ -private[history] object ApplicationCacheCheckFilterRelay extends Logging { - // name of the app ID entry in the filter configuration. Mandatory. - val APP_ID = "appId" - - // name of the attempt ID entry in the filter configuration. Optional. - val ATTEMPT_ID = "attemptId" - - // name of the filter to register - val FILTER_NAME = "org.apache.spark.deploy.history.ApplicationCacheCheckFilter" - - /** the application cache to relay requests to */ - @volatile - private var applicationCache: Option[ApplicationCache] = None - - /** - * Set the application cache. Logs a warning if it is overwriting an existing value - * @param cache new cache - */ - def setApplicationCache(cache: ApplicationCache): Unit = { - applicationCache.foreach( c => logWarning(s"Overwriting application cache $c")) - applicationCache = Some(cache) - } - - /** - * Reset the application cache - */ - def resetApplicationCache(): Unit = { - applicationCache = None - } - - /** - * Check to see if there has been an update - * @param requestURI URI the request came in on - * @param appId application ID - * @param attemptId attempt ID - * @return true if an update was loaded for the app/attempt - */ - def checkForUpdates(requestURI: String, appId: String, attemptId: Option[String]): Boolean = { - - logDebug(s"Checking $appId/$attemptId from $requestURI") - applicationCache match { - case Some(cache) => - try { - cache.checkForUpdates(appId, attemptId) - } catch { - case ex: Exception => - // something went wrong. Keep going with the existing UI - logWarning(s"When checking for $appId/$attemptId from $requestURI", ex) - false - } - - case None => - logWarning("No application cache instance defined") - false - } - } + override def init(config: FilterConfig): Unit = { } + override def destroy(): Unit = { } - /** - * Register a filter for the web UI which checks for updates to the given app/attempt - * @param ui Spark UI to attach filters to - * @param appId application ID - * @param attemptId attempt ID - */ - def registerFilter( - ui: SparkUI, - appId: String, - attemptId: Option[String] ): Unit = { - require(ui != null) - val enumDispatcher = java.util.EnumSet.of(DispatcherType.ASYNC, DispatcherType.REQUEST) - val holder = new FilterHolder() - holder.setClassName(FILTER_NAME) - holder.setInitParameter(APP_ID, appId) - attemptId.foreach( id => holder.setInitParameter(ATTEMPT_ID, id)) - require(ui.getHandlers != null, "null handlers") - ui.getHandlers.foreach { handler => - handler.addFilter(holder, "/*", enumDispatcher) - } - } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala index 5cb48ca3e60b0..96a80c9a6665c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/ApplicationHistoryProvider.scala @@ -17,6 +17,7 @@ package org.apache.spark.deploy.history +import java.util.concurrent.locks.ReentrantReadWriteLock import java.util.zip.ZipOutputStream import scala.xml.Node @@ -47,31 +48,30 @@ private[spark] case class ApplicationHistoryInfo( } } -/** - * A probe which can be invoked to see if a loaded Web UI has been updated. - * The probe is expected to be relative purely to that of the UI returned - * in the same [[LoadedAppUI]] instance. That is, whenever a new UI is loaded, - * the probe returned with it is the one that must be used to check for it - * being out of date; previous probes must be discarded. - */ -private[history] abstract class HistoryUpdateProbe { - /** - * Return true if the history provider has a later version of the application - * attempt than the one against this probe was constructed. - * @return - */ - def isUpdated(): Boolean -} - /** * All the information returned from a call to `getAppUI()`: the new UI * and any required update state. * @param ui Spark UI * @param updateProbe probe to call to check on the update state of this application attempt */ -private[history] case class LoadedAppUI( - ui: SparkUI, - updateProbe: () => Boolean) +private[history] case class LoadedAppUI(ui: SparkUI) { + + val lock = new ReentrantReadWriteLock() + + @volatile private var _valid = true + + def valid: Boolean = _valid + + def invalidate(): Unit = { + lock.writeLock().lock() + try { + _valid = false + } finally { + lock.writeLock().unlock() + } + } + +} private[history] abstract class ApplicationHistoryProvider { @@ -145,4 +145,10 @@ private[history] abstract class ApplicationHistoryProvider { * @return html text to display when the application list is empty */ def getEmptyListingHtml(): Seq[Node] = Seq.empty + + /** + * Called when an application UI is unloaded from the history server. + */ + def onUIDetached(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { } + } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index 910121e9878b9..a0764c9e4c42b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -17,14 +17,16 @@ package org.apache.spark.deploy.history -import java.io.{FileNotFoundException, IOException, OutputStream} -import java.util.UUID -import java.util.concurrent.{ConcurrentHashMap, Executors, ExecutorService, Future, TimeUnit} +import java.io.{File, FileNotFoundException, IOException} +import java.util.{Date, ServiceLoader, UUID} +import java.util.concurrent.{Executors, ExecutorService, Future, TimeUnit} import java.util.zip.{ZipEntry, ZipOutputStream} +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.xml.Node +import com.fasterxml.jackson.annotation.JsonIgnore import com.google.common.io.ByteStreams import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} import org.apache.hadoop.fs.{FileStatus, Path} @@ -35,11 +37,16 @@ import org.apache.hadoop.security.AccessControlException import org.apache.spark.{SecurityManager, SparkConf, SparkException} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ import org.apache.spark.scheduler.ReplayListenerBus._ +import org.apache.spark.status.{AppStatusListener, AppStatusStore, AppStatusStoreMetadata, KVUtils} +import org.apache.spark.status.KVUtils._ +import org.apache.spark.status.api.v1 import org.apache.spark.ui.SparkUI import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} +import org.apache.spark.util.kvstore._ /** * A class that provides application history from event logs stored in the file system. @@ -50,11 +57,10 @@ import org.apache.spark.util.{Clock, SystemClock, ThreadUtils, Utils} * * - New attempts are detected in [[checkForLogs]]: the log dir is scanned, and any * entries in the log dir whose modification time is greater than the last scan time - * are considered new or updated. These are replayed to create a new [[FsApplicationAttemptInfo]] - * entry and update or create a matching [[FsApplicationHistoryInfo]] element in the list - * of applications. + * are considered new or updated. These are replayed to create a new attempt info entry + * and update or create a matching application info element in the list of applications. * - Updated attempts are also found in [[checkForLogs]] -- if the attempt's log file has grown, the - * [[FsApplicationAttemptInfo]] is replaced by another one with a larger log size. + * attempt is replaced by another one with a larger log size. * - When [[updateProbe()]] is invoked to check if a loaded [[SparkUI]] * instance is out of date, the log size of the cached instance is checked against the app last * loaded by [[checkForLogs]]. @@ -78,6 +84,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) this(conf, new SystemClock()) } + import config._ import FsHistoryProvider._ // Interval between safemode checks. @@ -94,8 +101,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) private val NUM_PROCESSING_THREADS = conf.getInt(SPARK_HISTORY_FS_NUM_REPLAY_THREADS, Math.ceil(Runtime.getRuntime.availableProcessors() / 4f).toInt) - private val logDir = conf.getOption("spark.history.fs.logDirectory") - .getOrElse(DEFAULT_LOG_DIR) + private val logDir = conf.get(EVENT_LOG_DIR) private val HISTORY_UI_ACLS_ENABLE = conf.getBoolean("spark.history.ui.acls.enable", false) private val HISTORY_UI_ADMIN_ACLS = conf.get("spark.history.ui.admin.acls", "") @@ -117,17 +123,31 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // used for logging msgs (logs are re-scanned based on file size, rather than modtime) private val lastScanTime = new java.util.concurrent.atomic.AtomicLong(-1) - // Mapping of application IDs to their metadata, in descending end time order. Apps are inserted - // into the map in order, so the LinkedHashMap maintains the correct ordering. - @volatile private var applications: mutable.LinkedHashMap[String, FsApplicationHistoryInfo] - = new mutable.LinkedHashMap() + private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) - val fileToAppInfo = new ConcurrentHashMap[Path, FsApplicationAttemptInfo]() + private val storePath = conf.get(LOCAL_STORE_DIR).map(new File(_)) - // List of application logs to be deleted by event log cleaner. - private var attemptsToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] + // Visible for testing. + private[history] val listing: KVStore = storePath.map { path => + require(path.isDirectory(), s"Configured store directory ($path) does not exist.") + val dbPath = new File(path, "listing.ldb") + val metadata = new FsHistoryProviderMetadata(CURRENT_LISTING_VERSION, + AppStatusStore.CURRENT_VERSION, logDir.toString()) - private val pendingReplayTasksCount = new java.util.concurrent.atomic.AtomicInteger(0) + try { + open(new File(path, "listing.ldb"), metadata) + } catch { + // If there's an error, remove the listing database and any existing UI database + // from the store directory, since it's extremely likely that they'll all contain + // incompatible information. + case _: UnsupportedStoreVersionException | _: MetadataMismatchException => + logInfo("Detected incompatible DB versions, deleting...") + path.listFiles().foreach(Utils.deleteRecursively) + open(new File(path, "listing.ldb"), metadata) + } + }.getOrElse(new InMemoryStore()) + + private val activeUIs = new mutable.HashMap[(String, Option[String]), LoadedAppUI]() /** * Return a runnable that performs the given operation on the event logs. @@ -152,7 +172,6 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - // Conf option used for testing the initialization code. val initThread = initialize() private[history] def initialize(): Thread = { @@ -231,10 +250,23 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - override def getListing(): Iterator[FsApplicationHistoryInfo] = applications.values.iterator + override def getListing(): Iterator[ApplicationHistoryInfo] = { + // Return the listing in end time descending order. + listing.view(classOf[ApplicationInfoWrapper]) + .index("endTime") + .reverse() + .iterator() + .asScala + .map(_.toAppHistoryInfo()) + } - override def getApplicationInfo(appId: String): Option[FsApplicationHistoryInfo] = { - applications.get(appId) + override def getApplicationInfo(appId: String): Option[ApplicationHistoryInfo] = { + try { + Some(load(appId).toAppHistoryInfo()) + } catch { + case e: NoSuchElementException => + None + } } override def getEventLogsUnderProcess(): Int = pendingReplayTasksCount.get() @@ -242,44 +274,100 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) override def getLastUpdatedTime(): Long = lastScanTime.get() override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { - try { - applications.get(appId).flatMap { appInfo => - appInfo.attempts.find(_.attemptId == attemptId).flatMap { attempt => - val replayBus = new ReplayListenerBus() - val ui = { - val conf = this.conf.clone() - val appSecManager = new SecurityManager(conf) - SparkUI.createHistoryUI(conf, replayBus, appSecManager, appInfo.name, - HistoryServer.getAttemptURI(appId, attempt.attemptId), - Some(attempt.lastUpdated), attempt.startTime) - // Do not call ui.bind() to avoid creating a new server for each application - } + val app = try { + load(appId) + } catch { + case _: NoSuchElementException => + return None + } - val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) - - val appListener = replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) - - if (appListener.appId.isDefined) { - ui.appSparkVersion = appListener.appSparkVersion.getOrElse("") - ui.getSecurityManager.setAcls(HISTORY_UI_ACLS_ENABLE) - // make sure to set admin acls before view acls so they are properly picked up - val adminAcls = HISTORY_UI_ADMIN_ACLS + "," + appListener.adminAcls.getOrElse("") - ui.getSecurityManager.setAdminAcls(adminAcls) - ui.getSecurityManager.setViewAcls(attempt.sparkUser, appListener.viewAcls.getOrElse("")) - val adminAclsGroups = HISTORY_UI_ADMIN_ACLS_GROUPS + "," + - appListener.adminAclsGroups.getOrElse("") - ui.getSecurityManager.setAdminAclsGroups(adminAclsGroups) - ui.getSecurityManager.setViewAclsGroups(appListener.viewAclsGroups.getOrElse("")) - Some(LoadedAppUI(ui, () => updateProbe(appId, attemptId, attempt.fileSize))) - } else { - None - } + val attempt = app.attempts.find(_.info.attemptId == attemptId).orNull + if (attempt == null) { + return None + } + + val conf = this.conf.clone() + val secManager = new SecurityManager(conf) + secManager.setAcls(HISTORY_UI_ACLS_ENABLE) + // make sure to set admin acls before view acls so they are properly picked up + secManager.setAdminAcls(HISTORY_UI_ADMIN_ACLS + "," + attempt.adminAcls.getOrElse("")) + secManager.setViewAcls(attempt.info.sparkUser, attempt.viewAcls.getOrElse("")) + secManager.setAdminAclsGroups(HISTORY_UI_ADMIN_ACLS_GROUPS + "," + + attempt.adminAclsGroups.getOrElse("")) + secManager.setViewAclsGroups(attempt.viewAclsGroups.getOrElse("")) + + val replayBus = new ReplayListenerBus() + + val uiStorePath = storePath.map { path => getStorePath(path, appId, attemptId) } + + val (kvstore, needReplay) = uiStorePath match { + case Some(path) => + try { + val _replay = !path.isDirectory() + (createDiskStore(path, conf), _replay) + } catch { + case e: Exception => + // Get rid of the old data and re-create it. The store is either old or corrupted. + logWarning(s"Failed to load disk store $uiStorePath for $appId.", e) + Utils.deleteRecursively(path) + (createDiskStore(path, conf), true) } + + case _ => + (new InMemoryStore(), true) + } + + val listener = if (needReplay) { + val _listener = new AppStatusListener(kvstore, conf, false) + replayBus.addListener(_listener) + Some(_listener) + } else { + None + } + + val loadedUI = { + val ui = SparkUI.create(None, new AppStatusStore(kvstore), conf, + l => replayBus.addListener(l), + secManager, + app.info.name, + HistoryServer.getAttemptURI(appId, attempt.info.attemptId), + attempt.info.startTime.getTime(), + appSparkVersion = attempt.info.appSparkVersion) + LoadedAppUI(ui) + } + + try { + val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], + Utils.getContextOrSparkClassLoader).asScala + listenerFactories.foreach { listenerFactory => + val listeners = listenerFactory.createListeners(conf, loadedUI.ui) + listeners.foreach(replayBus.addListener) } + + val fileStatus = fs.getFileStatus(new Path(logDir, attempt.logPath)) + replay(fileStatus, isApplicationCompleted(fileStatus), replayBus) + listener.foreach(_.flush()) } catch { - case e: FileNotFoundException => None + case e: Exception => + try { + kvstore.close() + } catch { + case _e: Exception => logInfo("Error closing store.", _e) + } + uiStorePath.foreach(Utils.deleteRecursively) + if (e.isInstanceOf[FileNotFoundException]) { + return None + } else { + throw e + } } + + synchronized { + activeUIs((appId, attemptId)) = loadedUI + } + + Some(loadedUI) } override def getEmptyListingHtml(): Seq[Node] = { @@ -303,9 +391,42 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } override def stop(): Unit = { - if (initThread != null && initThread.isAlive()) { - initThread.interrupt() - initThread.join() + try { + if (initThread != null && initThread.isAlive()) { + initThread.interrupt() + initThread.join() + } + Seq(pool, replayExecutor).foreach { executor => + executor.shutdown() + if (!executor.awaitTermination(5, TimeUnit.SECONDS)) { + executor.shutdownNow() + } + } + } finally { + activeUIs.foreach { case (_, loadedUI) => loadedUI.ui.store.close() } + activeUIs.clear() + listing.close() + } + } + + override def onUIDetached(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { + val uiOption = synchronized { + activeUIs.remove((appId, attemptId)) + } + uiOption.foreach { loadedUI => + loadedUI.lock.writeLock().lock() + try { + loadedUI.ui.store.close() + } finally { + loadedUI.lock.writeLock().unlock() + } + + // If the UI is not valid, delete its files from disk, if any. This relies on the fact that + // ApplicationCache will never call this method concurrently with getAppUI() for the same + // appId / attemptId. + if (!loadedUI.valid && storePath.isDefined) { + Utils.deleteRecursively(getStorePath(storePath.get, appId, attemptId)) + } } } @@ -318,25 +439,20 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) try { val newLastScanTime = getNewLastScanTime() logDebug(s"Scanning $logDir with lastScanTime==$lastScanTime") - val statusList = Option(fs.listStatus(new Path(logDir))).map(_.toSeq) - .getOrElse(Seq.empty[FileStatus]) // scan for modified applications, replay and merge them - val logInfos: Seq[FileStatus] = statusList + val logInfos = Option(fs.listStatus(new Path(logDir))).map(_.toSeq).getOrElse(Nil) .filter { entry => - val fileInfo = fileToAppInfo.get(entry.getPath()) - val prevFileSize = if (fileInfo != null) fileInfo.fileSize else 0L !entry.isDirectory() && // FsHistoryProvider generates a hidden file which can't be read. Accidentally // reading a garbage file is safe, but we would log an error which can be scary to // the end-user. !entry.getPath().getName().startsWith(".") && - prevFileSize < entry.getLen() && - SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) + SparkHadoopUtil.get.checkAccessPermission(entry, FsAction.READ) && + recordedFileSize(entry.getPath()) < entry.getLen() } - .flatMap { entry => Some(entry) } .sortWith { case (entry1, entry2) => - entry1.getModificationTime() >= entry2.getModificationTime() - } + entry1.getModificationTime() > entry2.getModificationTime() + } if (logInfos.nonEmpty) { logDebug(s"New/updated attempts found: ${logInfos.size} ${logInfos.map(_.getPath)}") @@ -424,218 +540,127 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) } } - applications.get(appId) match { - case Some(appInfo) => - try { - // If no attempt is specified, or there is no attemptId for attempts, return all attempts - appInfo.attempts.filter { attempt => - attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get - }.foreach { attempt => - val logPath = new Path(logDir, attempt.logPath) - zipFileToStream(logPath, attempt.logPath, zipStream) - } - } finally { - zipStream.close() + val app = try { + load(appId) + } catch { + case _: NoSuchElementException => + throw new SparkException(s"Logs for $appId not found.") + } + + try { + // If no attempt is specified, or there is no attemptId for attempts, return all attempts + attemptId + .map { id => app.attempts.filter(_.info.attemptId == Some(id)) } + .getOrElse(app.attempts) + .map(_.logPath) + .foreach { log => + zipFileToStream(new Path(logDir, log), log, zipStream) } - case None => throw new SparkException(s"Logs for $appId not found.") + } finally { + zipStream.close() } } /** - * Replay the log files in the list and merge the list of old applications with new ones + * Replay the given log file, saving the application in the listing db. */ protected def mergeApplicationListing(fileStatus: FileStatus): Unit = { - val newAttempts = try { - val eventsFilter: ReplayEventsFilter = { eventString => - eventString.startsWith(APPL_START_EVENT_PREFIX) || - eventString.startsWith(APPL_END_EVENT_PREFIX) || - eventString.startsWith(LOG_START_EVENT_PREFIX) - } - - val logPath = fileStatus.getPath() - val appCompleted = isApplicationCompleted(fileStatus) - - // Use loading time as lastUpdated since some filesystems don't update modifiedTime - // each time file is updated. However use modifiedTime for completed jobs so lastUpdated - // won't change whenever HistoryServer restarts and reloads the file. - val lastUpdated = if (appCompleted) fileStatus.getModificationTime else clock.getTimeMillis() - - val appListener = replay(fileStatus, appCompleted, new ReplayListenerBus(), eventsFilter) - - // Without an app ID, new logs will render incorrectly in the listing page, so do not list or - // try to show their UI. - if (appListener.appId.isDefined) { - val attemptInfo = new FsApplicationAttemptInfo( - logPath.getName(), - appListener.appName.getOrElse(NOT_STARTED), - appListener.appId.getOrElse(logPath.getName()), - appListener.appAttemptId, - appListener.startTime.getOrElse(-1L), - appListener.endTime.getOrElse(-1L), - lastUpdated, - appListener.sparkUser.getOrElse(NOT_STARTED), - appCompleted, - fileStatus.getLen(), - appListener.appSparkVersion.getOrElse("") - ) - fileToAppInfo.put(logPath, attemptInfo) - logDebug(s"Application log ${attemptInfo.logPath} loaded successfully: $attemptInfo") - Some(attemptInfo) - } else { - logWarning(s"Failed to load application log ${fileStatus.getPath}. " + - "The application may have not started.") - None - } - - } catch { - case e: Exception => - logError( - s"Exception encountered when attempting to load application log ${fileStatus.getPath}", - e) - None + val eventsFilter: ReplayEventsFilter = { eventString => + eventString.startsWith(APPL_START_EVENT_PREFIX) || + eventString.startsWith(APPL_END_EVENT_PREFIX) || + eventString.startsWith(LOG_START_EVENT_PREFIX) || + eventString.startsWith(ENV_UPDATE_EVENT_PREXFIX) } - if (newAttempts.isEmpty) { - return - } - - // Build a map containing all apps that contain new attempts. The app information in this map - // contains both the new app attempt, and those that were already loaded in the existing apps - // map. If an attempt has been updated, it replaces the old attempt in the list. - val newAppMap = new mutable.HashMap[String, FsApplicationHistoryInfo]() - - applications.synchronized { - newAttempts.foreach { attempt => - val appInfo = newAppMap.get(attempt.appId) - .orElse(applications.get(attempt.appId)) - .map { app => - val attempts = - app.attempts.filter(_.attemptId != attempt.attemptId) ++ List(attempt) - new FsApplicationHistoryInfo(attempt.appId, attempt.name, - attempts.sortWith(compareAttemptInfo)) - } - .getOrElse(new FsApplicationHistoryInfo(attempt.appId, attempt.name, List(attempt))) - newAppMap(attempt.appId) = appInfo - } - - // Merge the new app list with the existing one, maintaining the expected ordering (descending - // end time). Maintaining the order is important to avoid having to sort the list every time - // there is a request for the log list. - val newApps = newAppMap.values.toSeq.sortWith(compareAppInfo) - val mergedApps = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() - def addIfAbsent(info: FsApplicationHistoryInfo): Unit = { - if (!mergedApps.contains(info.id)) { - mergedApps += (info.id -> info) - } - } + val logPath = fileStatus.getPath() + logInfo(s"Replaying log path: $logPath") - val newIterator = newApps.iterator.buffered - val oldIterator = applications.values.iterator.buffered - while (newIterator.hasNext && oldIterator.hasNext) { - if (newAppMap.contains(oldIterator.head.id)) { - oldIterator.next() - } else if (compareAppInfo(newIterator.head, oldIterator.head)) { - addIfAbsent(newIterator.next()) - } else { - addIfAbsent(oldIterator.next()) + val bus = new ReplayListenerBus() + val listener = new AppListingListener(fileStatus, clock) + bus.addListener(listener) + + replay(fileStatus, isApplicationCompleted(fileStatus), bus, eventsFilter) + listener.applicationInfo.foreach { app => + // Invalidate the existing UI for the reloaded app attempt, if any. Note that this does + // not remove the UI from the active list; that has to be done in onUIDetached, so that + // cleanup of files can be done in a thread-safe manner. It does mean the UI will remain + // in memory for longer than it should. + synchronized { + activeUIs.get((app.info.id, app.attempts.head.info.attemptId)).foreach { ui => + ui.invalidate() + ui.ui.store.close() } } - newIterator.foreach(addIfAbsent) - oldIterator.foreach(addIfAbsent) - applications = mergedApps + addListing(app) } + listing.write(new LogInfo(logPath.toString(), fileStatus.getLen())) } /** * Delete event logs from the log directory according to the clean policy defined by the user. */ private[history] def cleanLogs(): Unit = { + var iterator: Option[KVStoreIterator[ApplicationInfoWrapper]] = None try { - val maxAge = conf.getTimeAsSeconds("spark.history.fs.cleaner.maxAge", "7d") * 1000 - - val now = clock.getTimeMillis() - val appsToRetain = new mutable.LinkedHashMap[String, FsApplicationHistoryInfo]() - - def shouldClean(attempt: FsApplicationAttemptInfo): Boolean = { - now - attempt.lastUpdated > maxAge - } + val maxTime = clock.getTimeMillis() - conf.get(MAX_LOG_AGE_S) * 1000 + + // Iterate descending over all applications whose oldest attempt happened before maxTime. + iterator = Some(listing.view(classOf[ApplicationInfoWrapper]) + .index("oldestAttempt") + .reverse() + .first(maxTime) + .closeableIterator()) + + iterator.get.asScala.foreach { app => + // Applications may have multiple attempts, some of which may not need to be deleted yet. + val (remaining, toDelete) = app.attempts.partition { attempt => + attempt.info.lastUpdated.getTime() >= maxTime + } - // Scan all logs from the log directory. - // Only completed applications older than the specified max age will be deleted. - applications.values.foreach { app => - val (toClean, toRetain) = app.attempts.partition(shouldClean) - attemptsToClean ++= toClean - - if (toClean.isEmpty) { - appsToRetain += (app.id -> app) - } else if (toRetain.nonEmpty) { - appsToRetain += (app.id -> - new FsApplicationHistoryInfo(app.id, app.name, toRetain.toList)) + if (remaining.nonEmpty) { + val newApp = new ApplicationInfoWrapper(app.info, remaining) + listing.write(newApp) } - } - applications = appsToRetain + toDelete.foreach { attempt => + val logPath = new Path(logDir, attempt.logPath) + try { + listing.delete(classOf[LogInfo], logPath.toString()) + } catch { + case _: NoSuchElementException => + logDebug(s"Log info entry for $logPath not found.") + } + try { + fs.delete(logPath, true) + } catch { + case e: AccessControlException => + logInfo(s"No permission to delete ${attempt.logPath}, ignoring.") + case t: IOException => + logError(s"IOException in cleaning ${attempt.logPath}", t) + } + } - val leftToClean = new mutable.ListBuffer[FsApplicationAttemptInfo] - attemptsToClean.foreach { attempt => - try { - fs.delete(new Path(logDir, attempt.logPath), true) - } catch { - case e: AccessControlException => - logInfo(s"No permission to delete ${attempt.logPath}, ignoring.") - case t: IOException => - logError(s"IOException in cleaning ${attempt.logPath}", t) - leftToClean += attempt + if (remaining.isEmpty) { + listing.delete(app.getClass(), app.id) } } - - attemptsToClean = leftToClean } catch { - case t: Exception => logError("Exception in cleaning logs", t) + case t: Exception => logError("Exception while cleaning logs", t) + } finally { + iterator.foreach(_.close()) } } /** - * Comparison function that defines the sort order for the application listing. - * - * @return Whether `i1` should precede `i2`. - */ - private def compareAppInfo( - i1: FsApplicationHistoryInfo, - i2: FsApplicationHistoryInfo): Boolean = { - val a1 = i1.attempts.head - val a2 = i2.attempts.head - if (a1.endTime != a2.endTime) a1.endTime >= a2.endTime else a1.startTime >= a2.startTime - } - - /** - * Comparison function that defines the sort order for application attempts within the same - * application. Order is: attempts are sorted by descending start time. - * Most recent attempt state matches with current state of the app. - * - * Normally applications should have a single running attempt; but failure to call sc.stop() - * may cause multiple running attempts to show up. - * - * @return Whether `a1` should precede `a2`. - */ - private def compareAttemptInfo( - a1: FsApplicationAttemptInfo, - a2: FsApplicationAttemptInfo): Boolean = { - a1.startTime >= a2.startTime - } - - /** - * Replays the events in the specified log file on the supplied `ReplayListenerBus`. Returns - * an `ApplicationEventListener` instance with event data captured from the replay. - * `ReplayEventsFilter` determines what events are replayed and can therefore limit the - * data captured in the returned `ApplicationEventListener` instance. + * Replays the events in the specified log file on the supplied `ReplayListenerBus`. + * `ReplayEventsFilter` determines what events are replayed. */ private def replay( eventLog: FileStatus, appCompleted: Boolean, bus: ReplayListenerBus, - eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): ApplicationEventListener = { + eventsFilter: ReplayEventsFilter = SELECT_ALL_FILTER): Unit = { val logPath = eventLog.getPath() logInfo(s"Replaying log path: $logPath") // Note that the eventLog may have *increased* in size since when we grabbed the filestatus, @@ -646,10 +671,8 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) // after it's created, so we get a file size that is no bigger than what is actually read. val logInput = EventLoggingListener.openEventLog(logPath, fs) try { - val appListener = new ApplicationEventListener - bus.addListener(appListener) bus.replay(logInput, logPath.toString, !appCompleted, eventsFilter) - appListener + logInfo(s"Finished replaying $logPath") } finally { logInput.close() } @@ -685,58 +708,75 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * @return a summary of the component state */ override def toString: String = { - val header = s""" - | FsHistoryProvider: logdir=$logDir, - | last scan time=$lastScanTime - | Cached application count =${applications.size}} - """.stripMargin - val sb = new StringBuilder(header) - applications.foreach(entry => sb.append(entry._2).append("\n")) - sb.toString + val count = listing.count(classOf[ApplicationInfoWrapper]) + s"""|FsHistoryProvider{logdir=$logDir, + | storedir=$storePath, + | last scan time=$lastScanTime + | application count=$count}""".stripMargin } /** - * Look up an application attempt - * @param appId application ID - * @param attemptId Attempt ID, if set - * @return the matching attempt, if found + * Return the last known size of the given event log, recorded the last time the file + * system scanner detected a change in the file. */ - def lookup(appId: String, attemptId: Option[String]): Option[FsApplicationAttemptInfo] = { - applications.get(appId).flatMap { appInfo => - appInfo.attempts.find(_.attemptId == attemptId) + private def recordedFileSize(log: Path): Long = { + try { + listing.read(classOf[LogInfo], log.toString()).fileSize + } catch { + case _: NoSuchElementException => 0L } } + private def load(appId: String): ApplicationInfoWrapper = { + listing.read(classOf[ApplicationInfoWrapper], appId) + } + /** - * Return true iff a newer version of the UI is available. The check is based on whether the - * fileSize for the currently loaded UI is smaller than the file size the last time - * the logs were loaded. - * - * This is a very cheap operation -- the work of loading the new attempt was already done - * by [[checkForLogs]]. - * @param appId application to probe - * @param attemptId attempt to probe - * @param prevFileSize the file size of the logs for the currently displayed UI + * Write the app's information to the given store. Serialized to avoid the (notedly rare) case + * where two threads are processing separate attempts of the same application. */ - private def updateProbe( - appId: String, - attemptId: Option[String], - prevFileSize: Long)(): Boolean = { - lookup(appId, attemptId) match { - case None => - logDebug(s"Application Attempt $appId/$attemptId not found") - false - case Some(latest) => - prevFileSize < latest.fileSize + private def addListing(app: ApplicationInfoWrapper): Unit = listing.synchronized { + val attempt = app.attempts.head + + val oldApp = try { + load(app.id) + } catch { + case _: NoSuchElementException => + app + } + + def compareAttemptInfo(a1: AttemptInfoWrapper, a2: AttemptInfoWrapper): Boolean = { + a1.info.startTime.getTime() > a2.info.startTime.getTime() } + + val attempts = oldApp.attempts.filter(_.info.attemptId != attempt.info.attemptId) ++ + List(attempt) + + val newAppInfo = new ApplicationInfoWrapper( + app.info, + attempts.sortWith(compareAttemptInfo)) + listing.write(newAppInfo) } -} -private[history] object FsHistoryProvider { - val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + private def createDiskStore(path: File, conf: SparkConf): KVStore = { + val metadata = new AppStatusStoreMetadata(AppStatusStore.CURRENT_VERSION) + KVUtils.open(path, metadata) + } - private val NOT_STARTED = "" + private def getStorePath(path: File, appId: String, attemptId: Option[String]): File = { + val fileName = appId + attemptId.map("_" + _).getOrElse("") + ".ldb" + new File(path, fileName) + } + + /** For testing. Returns internal data about a single attempt. */ + private[history] def getAttempt(appId: String, attemptId: Option[String]): AttemptInfoWrapper = { + load(appId).attempts.find(_.info.attemptId == attemptId).getOrElse( + throw new NoSuchElementException(s"Cannot find attempt $attemptId of $appId.")) + } +} + +private[history] object FsHistoryProvider { private val SPARK_HISTORY_FS_NUM_REPLAY_THREADS = "spark.history.fs.numReplayThreads" private val APPL_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationStart\"" @@ -744,53 +784,157 @@ private[history] object FsHistoryProvider { private val APPL_END_EVENT_PREFIX = "{\"Event\":\"SparkListenerApplicationEnd\"" private val LOG_START_EVENT_PREFIX = "{\"Event\":\"SparkListenerLogStart\"" + + private val ENV_UPDATE_EVENT_PREXFIX = "{\"Event\":\"SparkListenerEnvironmentUpdate\"," + + /** + * Current version of the data written to the listing database. When opening an existing + * db, if the version does not match this value, the FsHistoryProvider will throw away + * all data and re-generate the listing data from the event logs. + */ + private[history] val CURRENT_LISTING_VERSION = 1L } -/** - * Application attempt information. - * - * @param logPath path to the log file, or, for a legacy log, its directory - * @param name application name - * @param appId application ID - * @param attemptId optional attempt ID - * @param startTime start time (from playback) - * @param endTime end time (from playback). -1 if the application is incomplete. - * @param lastUpdated the modification time of the log file when this entry was built by replaying - * the history. - * @param sparkUser user running the application - * @param completed flag to indicate whether or not the application has completed. - * @param fileSize the size of the log file the last time the file was scanned for changes - */ -private class FsApplicationAttemptInfo( +private[history] case class FsHistoryProviderMetadata( + version: Long, + uiVersion: Long, + logDir: String) + +private[history] case class LogInfo( + @KVIndexParam logPath: String, + fileSize: Long) + +private[history] class AttemptInfoWrapper( + val info: v1.ApplicationAttemptInfo, val logPath: String, - val name: String, - val appId: String, - attemptId: Option[String], - startTime: Long, - endTime: Long, - lastUpdated: Long, - sparkUser: String, - completed: Boolean, val fileSize: Long, - appSparkVersion: String) - extends ApplicationAttemptInfo( - attemptId, startTime, endTime, lastUpdated, sparkUser, completed, appSparkVersion) { + val adminAcls: Option[String], + val viewAcls: Option[String], + val adminAclsGroups: Option[String], + val viewAclsGroups: Option[String]) { + + def toAppAttemptInfo(): ApplicationAttemptInfo = { + ApplicationAttemptInfo(info.attemptId, info.startTime.getTime(), + info.endTime.getTime(), info.lastUpdated.getTime(), info.sparkUser, + info.completed, info.appSparkVersion) + } - /** extend the superclass string value with the extra attributes of this class */ - override def toString: String = { - s"FsApplicationAttemptInfo($name, $appId," + - s" ${super.toString}, source=$logPath, size=$fileSize" +} + +private[history] class ApplicationInfoWrapper( + val info: v1.ApplicationInfo, + val attempts: List[AttemptInfoWrapper]) { + + @JsonIgnore @KVIndexParam + def id: String = info.id + + @JsonIgnore @KVIndexParam("endTime") + def endTime(): Long = attempts.head.info.endTime.getTime() + + @JsonIgnore @KVIndexParam("oldestAttempt") + def oldestAttempt(): Long = attempts.map(_.info.lastUpdated.getTime()).min + + def toAppHistoryInfo(): ApplicationHistoryInfo = { + ApplicationHistoryInfo(info.id, info.name, attempts.map(_.toAppAttemptInfo())) } + } -/** - * Application history information - * @param id application ID - * @param name application name - * @param attempts list of attempts, most recent first. - */ -private class FsApplicationHistoryInfo( - id: String, - override val name: String, - override val attempts: List[FsApplicationAttemptInfo]) - extends ApplicationHistoryInfo(id, name, attempts) +private[history] class AppListingListener(log: FileStatus, clock: Clock) extends SparkListener { + + private val app = new MutableApplicationInfo() + private val attempt = new MutableAttemptInfo(log.getPath().getName(), log.getLen()) + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { + app.id = event.appId.orNull + app.name = event.appName + + attempt.attemptId = event.appAttemptId + attempt.startTime = new Date(event.time) + attempt.lastUpdated = new Date(clock.getTimeMillis()) + attempt.sparkUser = event.sparkUser + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { + attempt.endTime = new Date(event.time) + attempt.lastUpdated = new Date(log.getModificationTime()) + attempt.duration = event.time - attempt.startTime.getTime() + attempt.completed = true + } + + override def onEnvironmentUpdate(event: SparkListenerEnvironmentUpdate): Unit = { + val allProperties = event.environmentDetails("Spark Properties").toMap + attempt.viewAcls = allProperties.get("spark.ui.view.acls") + attempt.adminAcls = allProperties.get("spark.admin.acls") + attempt.viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") + attempt.adminAclsGroups = allProperties.get("spark.admin.acls.groups") + } + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(sparkVersion) => + attempt.appSparkVersion = sparkVersion + case _ => + } + + def applicationInfo: Option[ApplicationInfoWrapper] = { + if (app.id != null) { + Some(app.toView()) + } else { + None + } + } + + private class MutableApplicationInfo { + var id: String = null + var name: String = null + var coresGranted: Option[Int] = None + var maxCores: Option[Int] = None + var coresPerExecutor: Option[Int] = None + var memoryPerExecutorMB: Option[Int] = None + + def toView(): ApplicationInfoWrapper = { + val apiInfo = new v1.ApplicationInfo(id, name, coresGranted, maxCores, coresPerExecutor, + memoryPerExecutorMB, Nil) + new ApplicationInfoWrapper(apiInfo, List(attempt.toView())) + } + + } + + private class MutableAttemptInfo(logPath: String, fileSize: Long) { + var attemptId: Option[String] = None + var startTime = new Date(-1) + var endTime = new Date(-1) + var lastUpdated = new Date(-1) + var duration = 0L + var sparkUser: String = null + var completed = false + var appSparkVersion = "" + + var adminAcls: Option[String] = None + var viewAcls: Option[String] = None + var adminAclsGroups: Option[String] = None + var viewAclsGroups: Option[String] = None + + def toView(): AttemptInfoWrapper = { + val apiInfo = new v1.ApplicationAttemptInfo( + attemptId, + startTime, + endTime, + lastUpdated, + duration, + sparkUser, + completed, + appSparkVersion) + new AttemptInfoWrapper( + apiInfo, + logPath, + fileSize, + adminAcls, + viewAcls, + adminAclsGroups, + viewAclsGroups) + } + + } + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala index af14717633409..6399dccc1676a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala @@ -37,7 +37,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") val content =
-
+
    {providerConfig.map { case (k, v) =>
  • {k}: {v}
  • }}
@@ -58,7 +58,7 @@ private[history] class HistoryPage(parent: HistoryServer) extends WebUIPage("") { if (allAppsSize > 0) { ++ - ++ +
++ ++ ++ diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index d9c8fda99ef97..b822a48e98e91 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -106,8 +106,8 @@ class HistoryServer( } } - def getSparkUI(appKey: String): Option[SparkUI] = { - appCache.getSparkUI(appKey) + override def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T = { + appCache.withSparkUI(appId, attemptId)(fn) } initialize() @@ -140,7 +140,6 @@ class HistoryServer( override def stop() { super.stop() provider.stop() - appCache.stop() } /** Attach a reconstructed UI to this server. Only valid after bind(). */ @@ -158,6 +157,7 @@ class HistoryServer( override def detachSparkUI(appId: String, attemptId: Option[String], ui: SparkUI): Unit = { assert(serverInfo.isDefined, "HistoryServer must be bound before detaching SparkUIs") ui.getHandlers.foreach(detachHandler) + provider.onUIDetached(appId, attemptId, ui) } /** @@ -224,15 +224,13 @@ class HistoryServer( */ private def loadAppUi(appId: String, attemptId: Option[String]): Boolean = { try { - appCache.get(appId, attemptId) + appCache.withSparkUI(appId, attemptId) { _ => + // Do nothing, just force the UI to load. + } true } catch { - case NonFatal(e) => e.getCause() match { - case nsee: NoSuchElementException => - false - - case cause: Exception => throw cause - } + case NonFatal(e: NoSuchElementException) => + false } } diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala new file mode 100644 index 0000000000000..52dedc1a2ed41 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.history + +import java.util.concurrent.TimeUnit + +import org.apache.spark.internal.config.ConfigBuilder + +private[spark] object config { + + val DEFAULT_LOG_DIR = "file:/tmp/spark-events" + + val EVENT_LOG_DIR = ConfigBuilder("spark.history.fs.logDirectory") + .stringConf + .createWithDefault(DEFAULT_LOG_DIR) + + val MAX_LOG_AGE_S = ConfigBuilder("spark.history.fs.cleaner.maxAge") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("7d") + + val LOCAL_STORE_DIR = ConfigBuilder("spark.history.store.path") + .doc("Local directory where to cache application history information. By default this is " + + "not set, meaning all history information will be kept in memory.") + .stringConf + .createOptional + +} diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index e030cac60a8e4..2c78c15773af2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -581,7 +581,13 @@ private[deploy] class Master( * The number of cores assigned to each executor is configurable. When this is explicitly set, * multiple executors from the same application may be launched on the same worker if the worker * has enough cores and memory. Otherwise, each executor grabs all the cores available on the - * worker by default, in which case only one executor may be launched on each worker. + * worker by default, in which case only one executor per application may be launched on each + * worker during one single schedule iteration. + * Note that when `spark.executor.cores` is not set, we may still launch multiple executors from + * the same application on the same worker. Consider appA and appB both have one executor running + * on worker1, and appA.coresLeft > 0, then appB is finished and release all its cores on worker1, + * thus for the next schedule iteration, appA launches a new executor that grabs all the free + * cores on worker1, therefore we get multiple executors from appA running on worker1. * * It is important to allocate coresPerExecutor on each worker at a time (instead of 1 core * at a time). Consider the following example: cluster has 4 workers with 16 cores each. diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala index 0164084ab129e..22b65abce611a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/StandaloneRestServer.scala @@ -139,7 +139,9 @@ private[rest] class StandaloneSubmitRequestServlet( val driverExtraLibraryPath = sparkProperties.get("spark.driver.extraLibraryPath") val superviseDriver = sparkProperties.get("spark.driver.supervise") val appArgs = request.appArgs - val environmentVariables = request.environmentVariables + // Filter SPARK_LOCAL_(IP|HOSTNAME) environment variables from being set on the remote system. + val environmentVariables = + request.environmentVariables.filterNot(x => x._1.matches("SPARK_LOCAL_(IP|HOSTNAME)")) // Construct driver description val conf = new SparkConf(false) diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala index 78b0e6b2cbf39..5dcde4ec3a8a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HBaseDelegationTokenProvider.scala @@ -56,7 +56,9 @@ private[security] class HBaseDelegationTokenProvider None } - override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { hbaseConf(hadoopConf).get("hbase.security.authentication") == "kerberos" } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala index c134b7ebe38fa..483d0deec8070 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenManager.scala @@ -115,7 +115,7 @@ private[spark] class HadoopDelegationTokenManager( hadoopConf: Configuration, creds: Credentials): Long = { delegationTokenProviders.values.flatMap { provider => - if (provider.delegationTokensRequired(hadoopConf)) { + if (provider.delegationTokensRequired(sparkConf, hadoopConf)) { provider.obtainDelegationTokens(hadoopConf, sparkConf, creds) } else { logDebug(s"Service ${provider.serviceName} does not require a token." + diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala index 1ba245e84af4b..ed0905088ab25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopDelegationTokenProvider.scala @@ -37,7 +37,7 @@ private[spark] trait HadoopDelegationTokenProvider { * Returns true if delegation tokens are required for this service. By default, it is based on * whether Hadoop security is enabled. */ - def delegationTokensRequired(hadoopConf: Configuration): Boolean + def delegationTokensRequired(sparkConf: SparkConf, hadoopConf: Configuration): Boolean /** * Obtain delegation tokens for this service and get the time of the next renewal. diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala index 300773c58b183..21ca669ea98f0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HadoopFSDelegationTokenProvider.scala @@ -69,7 +69,9 @@ private[deploy] class HadoopFSDelegationTokenProvider(fileSystems: Configuration nextRenewalDate } - def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { UserGroupInformation.isSecurityEnabled } diff --git a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala index b31cc595ed83b..ece5ce79c650d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/security/HiveDelegationTokenProvider.scala @@ -31,7 +31,9 @@ import org.apache.hadoop.security.{Credentials, UserGroupInformation} import org.apache.hadoop.security.token.Token import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.KEYTAB import org.apache.spark.util.Utils private[security] class HiveDelegationTokenProvider @@ -55,9 +57,21 @@ private[security] class HiveDelegationTokenProvider } } - override def delegationTokensRequired(hadoopConf: Configuration): Boolean = { + override def delegationTokensRequired( + sparkConf: SparkConf, + hadoopConf: Configuration): Boolean = { + // Delegation tokens are needed only when: + // - trying to connect to a secure metastore + // - either deploying in cluster mode without a keytab, or impersonating another user + // + // Other modes (such as client with or without keytab, or cluster mode with keytab) do not need + // a delegation token, since there's a valid kerberos TGT for the right user available to the + // driver, which is the only process that connects to the HMS. + val deployMode = sparkConf.get("spark.submit.deployMode", "client") UserGroupInformation.isSecurityEnabled && - hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty + hiveConf(hadoopConf).getTrimmed("hive.metastore.uris", "").nonEmpty && + (SparkHadoopUtil.get.isProxyUser(UserGroupInformation.getCurrentUser()) || + (deployMode == "cluster" && !sparkConf.contains(KEYTAB))) } override def obtainDelegationTokens( @@ -83,7 +97,7 @@ private[security] class HiveDelegationTokenProvider val hive2Token = new Token[DelegationTokenIdentifier]() hive2Token.decodeFromUrlString(tokenStr) - logInfo(s"Get Token from hive metastore: ${hive2Token.toString}") + logDebug(s"Get Token from hive metastore: ${hive2Token.toString}") creds.addToken(new Text("hive.server2.delegation.token"), hive2Token) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index c1671192e0c64..b19c9904d5982 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -23,6 +23,7 @@ import org.apache.commons.lang3.StringUtils import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.{DependencyUtils, SparkHadoopUtil, SparkSubmit} +import org.apache.spark.internal.Logging import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -30,7 +31,7 @@ import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, U * Utility object for launching driver programs such that they share fate with the Worker process. * This is used in standalone cluster mode only. */ -object DriverWrapper { +object DriverWrapper extends Logging { def main(args: Array[String]) { args.toList match { /* @@ -41,8 +42,10 @@ object DriverWrapper { */ case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() - val rpcEnv = RpcEnv.create("Driver", - Utils.localHostName(), 0, conf, new SecurityManager(conf)) + val host: String = Utils.localHostName() + val port: Int = sys.props.getOrElse("spark.driver.port", "0").toInt + val rpcEnv = RpcEnv.create("Driver", host, port, conf, new SecurityManager(conf)) + logInfo(s"Driver address: ${rpcEnv.address}") rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) val currentLoader = Thread.currentThread.getContextClassLoader diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 44a2815b81a73..6f0247b73070d 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -41,6 +41,29 @@ package object config { .bytesConf(ByteUnit.MiB) .createWithDefaultString("1g") + private[spark] val EVENT_LOG_COMPRESS = + ConfigBuilder("spark.eventLog.compress") + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_BLOCK_UPDATES = + ConfigBuilder("spark.eventLog.logBlockUpdates.enabled") + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_TESTING = + ConfigBuilder("spark.eventLog.testing") + .internal() + .booleanConf + .createWithDefault(false) + + private[spark] val EVENT_LOG_OUTPUT_BUFFER_SIZE = ConfigBuilder("spark.eventLog.buffer.kb") + .bytesConf(ByteUnit.KiB) + .createWithDefaultString("100k") + + private[spark] val EVENT_LOG_OVERWRITE = + ConfigBuilder("spark.eventLog.overwrite").booleanConf.createWithDefault(false) + private[spark] val EXECUTOR_CLASS_PATH = ConfigBuilder(SparkLauncher.EXECUTOR_EXTRA_CLASSPATH).stringConf.createOptional @@ -72,6 +95,10 @@ package object config { private[spark] val DYN_ALLOCATION_MAX_EXECUTORS = ConfigBuilder("spark.dynamicAllocation.maxExecutors").intConf.createWithDefault(Int.MaxValue) + private[spark] val LOCALITY_WAIT = ConfigBuilder("spark.locality.wait") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefaultString("3s") + private[spark] val SHUFFLE_SERVICE_ENABLED = ConfigBuilder("spark.shuffle.service.enabled").booleanConf.createWithDefault(false) @@ -199,6 +226,11 @@ package object config { private[spark] val HISTORY_UI_MAX_APPS = ConfigBuilder("spark.history.ui.maxApplications").intConf.createWithDefault(Integer.MAX_VALUE) + private[spark] val UI_SHOW_CONSOLE_PROGRESS = ConfigBuilder("spark.ui.showConsoleProgress") + .doc("When true, show the progress bar in the console.") + .booleanConf + .createWithDefault(false) + private[spark] val IO_ENCRYPTION_ENABLED = ConfigBuilder("spark.io.encryption.enabled") .booleanConf .createWithDefault(false) @@ -261,6 +293,13 @@ package object config { .longConf .createWithDefault(4 * 1024 * 1024) + private[spark] val HADOOP_RDD_IGNORE_EMPTY_SPLITS = + ConfigBuilder("spark.hadoopRDD.ignoreEmptySplits") + .internal() + .doc("When true, HadoopRDD/NewHadoopRDD will not create partitions for empty input splits.") + .booleanConf + .createWithDefault(false) + private[spark] val SECRET_REDACTION_PATTERN = ConfigBuilder("spark.redaction.regex") .doc("Regex to decide which Spark configuration properties and environment variables in " + @@ -341,13 +380,15 @@ package object config { .checkValue(_ > 0, "The max no. of blocks in flight cannot be non-positive.") .createWithDefault(Int.MaxValue) - private[spark] val REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM = - ConfigBuilder("spark.reducer.maxReqSizeShuffleToMem") - .doc("The blocks of a shuffle request will be fetched to disk when size of the request is " + + private[spark] val MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM = + ConfigBuilder("spark.maxRemoteBlockSizeFetchToMem") + .doc("Remote block will be fetched to disk when size of the block is " + "above this threshold. This is to avoid a giant request takes too much memory. We can " + - "enable this config by setting a specific value(e.g. 200m). Note that this config can " + - "be enabled only when the shuffle shuffle service is newer than Spark-2.2 or the shuffle" + - " service is disabled.") + "enable this config by setting a specific value(e.g. 200m). Note this configuration will " + + "affect both shuffle fetch and block manager remote block fetch. For users who " + + "enabled external shuffle service, this feature can only be worked when external shuffle" + + " service is newer than Spark 2.2.") + .withAlternative("spark.reducer.maxReqSizeShuffleToMem") .bytesConf(ByteUnit.BYTE) .createWithDefault(Long.MaxValue) @@ -410,4 +451,28 @@ package object config { .stringConf .toSequence .createWithDefault(Nil) + + private[spark] val UI_X_XSS_PROTECTION = + ConfigBuilder("spark.ui.xXssProtection") + .doc("Value for HTTP X-XSS-Protection response header") + .stringConf + .createWithDefaultString("1; mode=block") + + private[spark] val UI_X_CONTENT_TYPE_OPTIONS = + ConfigBuilder("spark.ui.xContentTypeOptions.enabled") + .doc("Set to 'true' for setting X-Content-Type-Options HTTP response header to 'nosniff'") + .booleanConf + .createWithDefault(true) + + private[spark] val UI_STRICT_TRANSPORT_SECURITY = + ConfigBuilder("spark.ui.strictTransportSecurity") + .doc("Value for HTTP Strict Transport Security Response Header") + .stringConf + .createOptional + + private[spark] val EXTRA_LISTENERS = ConfigBuilder("spark.extraListeners") + .doc("Class names of listeners to add to SparkContext during initialization.") + .stringConf + .toSequence + .createOptional } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index b1d07ab2c9199..95c99d29c3a9c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -20,6 +20,7 @@ package org.apache.spark.internal.io import java.util.{Date, UUID} import scala.collection.mutable +import scala.util.Try import org.apache.hadoop.conf.Configurable import org.apache.hadoop.fs.Path @@ -35,6 +36,9 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * (from the newer mapreduce API, not the old mapred API). * * Unlike Hadoop's OutputCommitter, this implementation is serializable. + * + * @param jobId the job's or stage's id + * @param path the job's output path, or null if committer acts as a noop */ class HadoopMapReduceCommitProtocol(jobId: String, path: String) extends FileCommitProtocol with Serializable with Logging { @@ -44,6 +48,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) /** OutputCommitter from Hadoop is not serializable so marking it transient. */ @transient private var committer: OutputCommitter = _ + /** + * Checks whether there are files to be committed to a valid output location. + * + * As committing and aborting a job occurs on driver, where `addedAbsPathFiles` is always null, + * it is necessary to check whether a valid output path is specified. + * [[HadoopMapReduceCommitProtocol#path]] need not be a valid [[org.apache.hadoop.fs.Path]] for + * committers not writing to distributed file systems. + */ + private val hasValidPath = Try { new Path(path) }.isSuccess + /** * Tracks files staged by this task for absolute output paths. These outputs are not managed by * the Hadoop OutputCommitter, so we must move these to their final locations on job commit. @@ -130,17 +144,21 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) .foldLeft(Map[String, String]())(_ ++ _) logDebug(s"Committing files staged for absolute locations $filesToMove") - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - for ((src, dst) <- filesToMove) { - fs.rename(new Path(src), new Path(dst)) + if (hasValidPath) { + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + for ((src, dst) <- filesToMove) { + fs.rename(new Path(src), new Path(dst)) + } + fs.delete(absPathStagingDir, true) } - fs.delete(absPathStagingDir, true) } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + if (hasValidPath) { + val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(absPathStagingDir, true) + } } override def setupTask(taskContext: TaskAttemptContext): Unit = { diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index fe5fd2da039bb..1d8a266d0079c 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -25,8 +25,8 @@ import scala.concurrent.duration.Duration import scala.reflect.ClassTag import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer, NioManagedBuffer} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.storage.{BlockId, StorageLevel} import org.apache.spark.util.ThreadUtils @@ -68,7 +68,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit + tempFileManager: TempFileManager): Unit /** * Upload a single block to a remote node, available only after [[init]] is invoked. @@ -87,7 +87,12 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo * * It is also only available after [[init]] is invoked. */ - def fetchBlockSync(host: String, port: Int, execId: String, blockId: String): ManagedBuffer = { + def fetchBlockSync( + host: String, + port: Int, + execId: String, + blockId: String, + tempFileManager: TempFileManager): ManagedBuffer = { // A monitor for the thread to wait on. val result = Promise[ManagedBuffer]() fetchBlocks(host, port, execId, Array(blockId), @@ -96,12 +101,17 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo result.failure(exception) } override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): Unit = { - val ret = ByteBuffer.allocate(data.size.toInt) - ret.put(data.nioByteBuffer()) - ret.flip() - result.success(new NioManagedBuffer(ret)) + data match { + case f: FileSegmentManagedBuffer => + result.success(f) + case _ => + val ret = ByteBuffer.allocate(data.size.toInt) + ret.put(data.nioByteBuffer()) + ret.flip() + result.success(new NioManagedBuffer(ret)) + } } - }, tempShuffleFileManager = null) + }, tempFileManager) ThreadUtils.awaitResult(result.future, Duration.Inf) } diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala index ac4d85004bad1..b7d8c35032763 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala @@ -32,7 +32,7 @@ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.client.{RpcResponseCallback, TransportClientBootstrap, TransportClientFactory} import org.apache.spark.network.crypto.{AuthClientBootstrap, AuthServerBootstrap} import org.apache.spark.network.server._ -import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher, RetryingBlockFetcher, TempFileManager} import org.apache.spark.network.shuffle.protocol.UploadBlock import org.apache.spark.network.util.JavaUtils import org.apache.spark.serializer.JavaSerializer @@ -105,14 +105,14 @@ private[spark] class NettyBlockTransferService( execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit = { + tempFileManager: TempFileManager): Unit = { logTrace(s"Fetch blocks from $host:$port (executor id $execId)") try { val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter { override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) { val client = clientFactory.createClient(host, port) new OneForOneBlockFetcher(client, appId, execId, blockIds, listener, - transportConf, tempShuffleFileManager).start() + transportConf, tempFileManager).start() } } @@ -151,7 +151,7 @@ private[spark] class NettyBlockTransferService( // Convert or copy nio buffer into array in order to serialize it. val array = JavaUtils.bufferToArray(blockData.nioByteBuffer()) - client.sendRpc(new UploadBlock(appId, execId, blockId.toString, metadata, array).toByteBuffer, + client.sendRpc(new UploadBlock(appId, execId, blockId.name, metadata, array).toByteBuffer, new RpcResponseCallback { override def onSuccess(response: ByteBuffer): Unit = { logTrace(s"Successfully uploaded block $blockId") diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 76ea8b86c53d2..2480559a41b7a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -35,7 +35,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config._ import org.apache.spark.rdd.HadoopRDD.HadoopMapPartitionsWithSplitRDD import org.apache.spark.scheduler.{HDFSCacheTaskLocation, HostTaskLocation} import org.apache.spark.storage.StorageLevel @@ -134,6 +134,8 @@ class HadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) + // Returns a JobConf that will be used on slaves to obtain input splits for Hadoop reads. protected def getJobConf(): JobConf = { val conf: Configuration = broadcastedConf.value.value @@ -157,20 +159,25 @@ class HadoopRDD[K, V]( if (conf.isInstanceOf[JobConf]) { logDebug("Re-using user-broadcasted JobConf") conf.asInstanceOf[JobConf] - } else if (HadoopRDD.containsCachedMetadata(jobConfCacheKey)) { - logDebug("Re-using cached JobConf") - HadoopRDD.getCachedMetadata(jobConfCacheKey).asInstanceOf[JobConf] } else { - // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in the - // local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). - // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary objects. - // Synchronize to prevent ConcurrentModificationException (SPARK-1097, HADOOP-10456). - HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { - logDebug("Creating new JobConf and caching it for later re-use") - val newJobConf = new JobConf(conf) - initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) - HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) - newJobConf + Option(HadoopRDD.getCachedMetadata(jobConfCacheKey)) + .map { conf => + logDebug("Re-using cached JobConf") + conf.asInstanceOf[JobConf] + } + .getOrElse { + // Create a JobConf that will be cached and used across this RDD's getJobConf() calls in + // the local process. The local cache is accessed through HadoopRDD.putCachedMetadata(). + // The caching helps minimize GC, since a JobConf can contain ~10KB of temporary + // objects. Synchronize to prevent ConcurrentModificationException (SPARK-1097, + // HADOOP-10456). + HadoopRDD.CONFIGURATION_INSTANTIATION_LOCK.synchronized { + logDebug("Creating new JobConf and caching it for later re-use") + val newJobConf = new JobConf(conf) + initLocalJobConfFuncOpt.foreach(f => f(newJobConf)) + HadoopRDD.putCachedMetadata(jobConfCacheKey, newJobConf) + newJobConf + } } } } @@ -190,8 +197,12 @@ class HadoopRDD[K, V]( val jobConf = getJobConf() // add the credentials here as this can be called before SparkContext initialized SparkHadoopUtil.get.addCredentials(jobConf) - val inputFormat = getInputFormat(jobConf) - val inputSplits = inputFormat.getSplits(jobConf, minPartitions) + val allInputSplits = getInputFormat(jobConf).getSplits(jobConf, minPartitions) + val inputSplits = if (ignoreEmptySplits) { + allInputSplits.filter(_.getLength > 0) + } else { + allInputSplits + } val array = new Array[Partition](inputSplits.size) for (i <- 0 until inputSplits.size) { array(i) = new HadoopPartition(id, i, inputSplits(i)) @@ -360,8 +371,6 @@ private[spark] object HadoopRDD extends Logging { */ def getCachedMetadata(key: String): Any = SparkEnv.get.hadoopJobMetadata.get(key) - def containsCachedMetadata(key: String): Boolean = SparkEnv.get.hadoopJobMetadata.containsKey(key) - private def putCachedMetadata(key: String, value: Any): Unit = SparkEnv.get.hadoopJobMetadata.put(key, value) diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 482875e6c1ac5..e4dd1b6a82498 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -21,6 +21,7 @@ import java.io.IOException import java.text.SimpleDateFormat import java.util.{Date, Locale} +import scala.collection.JavaConverters.asScalaBufferConverter import scala.reflect.ClassTag import org.apache.hadoop.conf.{Configurable, Configuration} @@ -34,7 +35,7 @@ import org.apache.spark._ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config._ import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} @@ -89,6 +90,8 @@ class NewHadoopRDD[K, V]( private val ignoreCorruptFiles = sparkContext.conf.get(IGNORE_CORRUPT_FILES) + private val ignoreEmptySplits = sparkContext.conf.get(HADOOP_RDD_IGNORE_EMPTY_SPLITS) + def getConf: Configuration = { val conf: Configuration = confBroadcast.value.value if (shouldCloneJobConf) { @@ -121,8 +124,12 @@ class NewHadoopRDD[K, V]( configurable.setConf(_conf) case _ => } - val jobContext = new JobContextImpl(_conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray + val allRowSplits = inputFormat.getSplits(new JobContextImpl(_conf, jobId)).asScala + val rawSplits = if (ignoreEmptySplits) { + allRowSplits.filter(_.getLength > 0) + } else { + allRowSplits + } val result = new Array[Partition](rawSplits.size) for (i <- 0 until rawSplits.size) { result(i) = new NewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala b/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala deleted file mode 100644 index 6da8865cd10d3..0000000000000 --- a/core/src/main/scala/org/apache/spark/scheduler/ApplicationEventListener.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.scheduler - -/** - * A simple listener for application events. - * - * This listener expects to hear events from a single application only. If events - * from multiple applications are seen, the behavior is unspecified. - */ -private[spark] class ApplicationEventListener extends SparkListener { - var appName: Option[String] = None - var appId: Option[String] = None - var appAttemptId: Option[String] = None - var sparkUser: Option[String] = None - var startTime: Option[Long] = None - var endTime: Option[Long] = None - var viewAcls: Option[String] = None - var adminAcls: Option[String] = None - var viewAclsGroups: Option[String] = None - var adminAclsGroups: Option[String] = None - var appSparkVersion: Option[String] = None - - override def onApplicationStart(applicationStart: SparkListenerApplicationStart) { - appName = Some(applicationStart.appName) - appId = applicationStart.appId - appAttemptId = applicationStart.appAttemptId - startTime = Some(applicationStart.time) - sparkUser = Some(applicationStart.sparkUser) - } - - override def onApplicationEnd(applicationEnd: SparkListenerApplicationEnd) { - endTime = Some(applicationEnd.time) - } - - override def onEnvironmentUpdate(environmentUpdate: SparkListenerEnvironmentUpdate) { - synchronized { - val environmentDetails = environmentUpdate.environmentDetails - val allProperties = environmentDetails("Spark Properties").toMap - viewAcls = allProperties.get("spark.ui.view.acls") - adminAcls = allProperties.get("spark.admin.acls") - viewAclsGroups = allProperties.get("spark.ui.view.acls.groups") - adminAclsGroups = allProperties.get("spark.admin.acls.groups") - } - } - - override def onOtherEvent(event: SparkListenerEvent): Unit = event match { - case SparkListenerLogStart(sparkVersion) => - appSparkVersion = Some(sparkVersion) - case _ => - } -} diff --git a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala index 9dafa0b7646bf..a77adc5ff3545 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala @@ -37,6 +37,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{SPARK_VERSION, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.util.{JsonProtocol, Utils} @@ -45,6 +46,7 @@ import org.apache.spark.util.{JsonProtocol, Utils} * * Event logging is specified by the following configurable parameters: * spark.eventLog.enabled - Whether event logging is enabled. + * spark.eventLog.logBlockUpdates.enabled - Whether to log block updates * spark.eventLog.compress - Whether to compress logged events * spark.eventLog.overwrite - Whether to overwrite any existing files. * spark.eventLog.dir - Path to the directory in which events are logged. @@ -64,10 +66,11 @@ private[spark] class EventLoggingListener( this(appId, appAttemptId, logBaseDir, sparkConf, SparkHadoopUtil.get.newConfiguration(sparkConf)) - private val shouldCompress = sparkConf.getBoolean("spark.eventLog.compress", false) - private val shouldOverwrite = sparkConf.getBoolean("spark.eventLog.overwrite", false) - private val testing = sparkConf.getBoolean("spark.eventLog.testing", false) - private val outputBufferSize = sparkConf.getInt("spark.eventLog.buffer.kb", 100) * 1024 + private val shouldCompress = sparkConf.get(EVENT_LOG_COMPRESS) + private val shouldOverwrite = sparkConf.get(EVENT_LOG_OVERWRITE) + private val shouldLogBlockUpdates = sparkConf.get(EVENT_LOG_BLOCK_UPDATES) + private val testing = sparkConf.get(EVENT_LOG_TESTING) + private val outputBufferSize = sparkConf.get(EVENT_LOG_OUTPUT_BUFFER_SIZE).toInt private val fileSystem = Utils.getHadoopFileSystem(logBaseDir, hadoopConf) private val compressionCodec = if (shouldCompress) { @@ -216,8 +219,11 @@ private[spark] class EventLoggingListener( logEvent(event, flushLogger = true) } - // No-op because logging every update would be overkill - override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = {} + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { + if (shouldLogBlockUpdates) { + logEvent(event, flushLogger = true) + } + } // No-op because logging every update would be overkill override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala index e815b7e0cf6c9..233781f3d9719 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetBlacklist.scala @@ -61,6 +61,16 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, private val blacklistedExecs = new HashSet[String]() private val blacklistedNodes = new HashSet[String]() + private var latestFailureReason: String = null + + /** + * Get the most recent failure reason of this TaskSet. + * @return + */ + def getLatestFailureReason: String = { + latestFailureReason + } + /** * Return true if this executor is blacklisted for the given task. This does *not* * need to return true if the executor is blacklisted for the entire stage, or blacklisted @@ -94,7 +104,9 @@ private[scheduler] class TaskSetBlacklist(val conf: SparkConf, val stageId: Int, private[scheduler] def updateBlacklistForFailedTask( host: String, exec: String, - index: Int): Unit = { + index: Int, + failureReason: String): Unit = { + latestFailureReason = failureReason val execFailures = execToFailures.getOrElseUpdate(exec, new ExecutorFailuresInTaskSet(host)) execFailures.updateWithFailure(index, clock.getTimeMillis()) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3804ea863b4f9..de4711f461df2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -27,7 +27,7 @@ import scala.util.control.NonFatal import org.apache.spark._ import org.apache.spark.TaskState.TaskState -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.scheduler.SchedulingMode._ import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap @@ -83,6 +83,11 @@ private[spark] class TaskSetManager( val successful = new Array[Boolean](numTasks) private val numFailures = new Array[Int](numTasks) + // Set the coresponding index of Boolean var when the task killed by other attempt tasks, + // this happened while we set the `spark.speculation` to true. The task killed by others + // should not resubmit while executor lost. + private val killedByOtherAttempt: Array[Boolean] = new Array[Boolean](numTasks) + val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil) private[scheduler] var tasksSuccessful = 0 @@ -670,9 +675,14 @@ private[spark] class TaskSetManager( } if (blacklistedEverywhere) { val partition = tasks(indexInTaskSet).partitionId - abort(s"Aborting $taskSet because task $indexInTaskSet (partition $partition) " + - s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior " + - s"can be configured via spark.blacklist.*.") + abort(s""" + |Aborting $taskSet because task $indexInTaskSet (partition $partition) + |cannot run anywhere due to node and executor blacklist. + |Most recent failure: + |${taskSetBlacklist.getLatestFailureReason} + | + |Blacklisting behavior can be configured via spark.blacklist.*. + |""".stripMargin) } } } @@ -724,6 +734,7 @@ private[spark] class TaskSetManager( logInfo(s"Killing attempt ${attemptInfo.attemptNumber} for task ${attemptInfo.id} " + s"in stage ${taskSet.id} (TID ${attemptInfo.taskId}) on ${attemptInfo.host} " + s"as the attempt ${info.attemptNumber} succeeded on ${info.host}") + killedByOtherAttempt(index) = true sched.backend.killTask( attemptInfo.taskId, attemptInfo.executorId, @@ -837,9 +848,9 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskEnded(tasks(index), reason, null, accumUpdates, info) if (!isZombie && reason.countTowardsTaskFailures) { - taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( - info.host, info.executorId, index)) assert (null != failureReason) + taskSetBlacklistHelperOpt.foreach(_.updateBlacklistForFailedTask( + info.host, info.executorId, index, failureReason)) numFailures(index) += 1 if (numFailures(index) >= maxTaskFailures) { logError("Task %d in stage %s failed %d times; aborting job".format( @@ -910,7 +921,7 @@ private[spark] class TaskSetManager( && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index - if (successful(index)) { + if (successful(index) && !killedByOtherAttempt(index)) { successful(index) = false copiesRunning(index) -= 1 tasksSuccessful -= 1 @@ -975,7 +986,7 @@ private[spark] class TaskSetManager( } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = conf.get("spark.locality.wait", "3s") + val defaultWait = conf.get(config.LOCALITY_WAIT) val localityWaitKey = level match { case TaskLocality.PROCESS_LOCAL => "spark.locality.wait.process" case TaskLocality.NODE_LOCAL => "spark.locality.wait.node" @@ -984,7 +995,7 @@ private[spark] class TaskSetManager( } if (localityWaitKey != null) { - conf.getTimeAsMs(localityWaitKey, defaultWait) + conf.getTimeAsMs(localityWaitKey, defaultWait.toString) } else { 0L } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala index a4e2a74341283..505c342a889ee 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/StandaloneSchedulerBackend.scala @@ -153,7 +153,7 @@ private[spark] class StandaloneSchedulerBackend( override def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) { - logInfo("Granted executor ID %s on hostPort %s with %d cores, %s RAM".format( + logInfo("Granted executor ID %s on hostPort %s with %d core(s), %s RAM".format( fullId, hostPort, cores, Utils.megabytesToString(memory))) } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index c8d1460300934..0562d45ff57c5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -52,7 +52,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024, SparkEnv.get.conf.getInt("spark.reducer.maxReqsInFlight", Int.MaxValue), SparkEnv.get.conf.get(config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS), - SparkEnv.get.conf.get(config.REDUCER_MAX_REQ_SIZE_SHUFFLE_TO_MEM), + SparkEnv.get.conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM), SparkEnv.get.conf.getBoolean("spark.shuffle.detectCorrupt", true)) val serializerInstance = dep.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala new file mode 100644 index 0000000000000..658ae11671c27 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusListener.scala @@ -0,0 +1,566 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.internal.Logging +import org.apache.spark.scheduler._ +import org.apache.spark.status.api.v1 +import org.apache.spark.storage._ +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.kvstore.KVStore + +/** + * A Spark listener that writes application information to a data store. The types written to the + * store are defined in the `storeTypes.scala` file and are based on the public REST API. + */ +private[spark] class AppStatusListener( + kvstore: KVStore, + conf: SparkConf, + live: Boolean) extends SparkListener with Logging { + + import config._ + + private var sparkVersion = SPARK_VERSION + private var appInfo: v1.ApplicationInfo = null + private var coresPerTask: Int = 1 + + // How often to update live entities. -1 means "never update" when replaying applications, + // meaning only the last write will happen. For live applications, this avoids a few + // operations that we can live without when rapidly processing incoming task events. + private val liveUpdatePeriodNs = if (live) conf.get(LIVE_ENTITY_UPDATE_PERIOD) else -1L + + // Keep track of live entities, so that task metrics can be efficiently updated (without + // causing too many writes to the underlying store, and other expensive operations). + private val liveStages = new HashMap[(Int, Int), LiveStage]() + private val liveJobs = new HashMap[Int, LiveJob]() + private val liveExecutors = new HashMap[String, LiveExecutor]() + private val liveTasks = new HashMap[Long, LiveTask]() + private val liveRDDs = new HashMap[Int, LiveRDD]() + + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case SparkListenerLogStart(version) => sparkVersion = version + case _ => + } + + override def onApplicationStart(event: SparkListenerApplicationStart): Unit = { + assert(event.appId.isDefined, "Application without IDs are not supported.") + + val attempt = new v1.ApplicationAttemptInfo( + event.appAttemptId, + new Date(event.time), + new Date(-1), + new Date(event.time), + -1L, + event.sparkUser, + false, + sparkVersion) + + appInfo = new v1.ApplicationInfo( + event.appId.get, + event.appName, + None, + None, + None, + None, + Seq(attempt)) + + kvstore.write(new ApplicationInfoWrapper(appInfo)) + } + + override def onApplicationEnd(event: SparkListenerApplicationEnd): Unit = { + val old = appInfo.attempts.head + val attempt = new v1.ApplicationAttemptInfo( + old.attemptId, + old.startTime, + new Date(event.time), + new Date(event.time), + event.time - old.startTime.getTime(), + old.sparkUser, + true, + old.appSparkVersion) + + appInfo = new v1.ApplicationInfo( + appInfo.id, + appInfo.name, + None, + None, + None, + None, + Seq(attempt)) + kvstore.write(new ApplicationInfoWrapper(appInfo)) + } + + override def onExecutorAdded(event: SparkListenerExecutorAdded): Unit = { + // This needs to be an update in case an executor re-registers after the driver has + // marked it as "dead". + val exec = getOrCreateExecutor(event.executorId) + exec.host = event.executorInfo.executorHost + exec.isActive = true + exec.totalCores = event.executorInfo.totalCores + exec.maxTasks = event.executorInfo.totalCores / coresPerTask + exec.executorLogs = event.executorInfo.logUrlMap + liveUpdate(exec) + } + + override def onExecutorRemoved(event: SparkListenerExecutorRemoved): Unit = { + liveExecutors.remove(event.executorId).foreach { exec => + exec.isActive = false + update(exec) + } + } + + override def onExecutorBlacklisted(event: SparkListenerExecutorBlacklisted): Unit = { + updateBlackListStatus(event.executorId, true) + } + + override def onExecutorUnblacklisted(event: SparkListenerExecutorUnblacklisted): Unit = { + updateBlackListStatus(event.executorId, false) + } + + override def onNodeBlacklisted(event: SparkListenerNodeBlacklisted): Unit = { + updateNodeBlackList(event.hostId, true) + } + + override def onNodeUnblacklisted(event: SparkListenerNodeUnblacklisted): Unit = { + updateNodeBlackList(event.hostId, false) + } + + private def updateBlackListStatus(execId: String, blacklisted: Boolean): Unit = { + liveExecutors.get(execId).foreach { exec => + exec.isBlacklisted = blacklisted + liveUpdate(exec) + } + } + + private def updateNodeBlackList(host: String, blacklisted: Boolean): Unit = { + // Implicitly (un)blacklist every executor associated with the node. + liveExecutors.values.foreach { exec => + if (exec.hostname == host) { + exec.isBlacklisted = blacklisted + liveUpdate(exec) + } + } + } + + override def onJobStart(event: SparkListenerJobStart): Unit = { + // Compute (a potential over-estimate of) the number of tasks that will be run by this job. + // This may be an over-estimate because the job start event references all of the result + // stages' transitive stage dependencies, but some of these stages might be skipped if their + // output is available from earlier runs. + // See https://github.com/apache/spark/pull/3009 for a more extensive discussion. + val numTasks = { + val missingStages = event.stageInfos.filter(_.completionTime.isEmpty) + missingStages.map(_.numTasks).sum + } + + val lastStageInfo = event.stageInfos.lastOption + val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)") + + val jobGroup = Option(event.properties) + .flatMap { p => Option(p.getProperty(SparkContext.SPARK_JOB_GROUP_ID)) } + + val job = new LiveJob( + event.jobId, + lastStageName, + Some(new Date(event.time)), + event.stageIds, + jobGroup, + numTasks) + liveJobs.put(event.jobId, job) + liveUpdate(job) + + event.stageInfos.foreach { stageInfo => + // A new job submission may re-use an existing stage, so this code needs to do an update + // instead of just a write. + val stage = getOrCreateStage(stageInfo) + stage.jobs :+= job + stage.jobIds += event.jobId + liveUpdate(stage) + } + } + + override def onJobEnd(event: SparkListenerJobEnd): Unit = { + liveJobs.remove(event.jobId).foreach { job => + job.status = event.jobResult match { + case JobSucceeded => JobExecutionStatus.SUCCEEDED + case JobFailed(_) => JobExecutionStatus.FAILED + } + + job.completionTime = Some(new Date(event.time)) + update(job) + } + } + + override def onStageSubmitted(event: SparkListenerStageSubmitted): Unit = { + val stage = getOrCreateStage(event.stageInfo) + stage.status = v1.StageStatus.ACTIVE + stage.schedulingPool = Option(event.properties).flatMap { p => + Option(p.getProperty("spark.scheduler.pool")) + }.getOrElse(SparkUI.DEFAULT_POOL_NAME) + + // Look at all active jobs to find the ones that mention this stage. + stage.jobs = liveJobs.values + .filter(_.stageIds.contains(event.stageInfo.stageId)) + .toSeq + stage.jobIds = stage.jobs.map(_.jobId).toSet + + stage.jobs.foreach { job => + job.completedStages = job.completedStages - event.stageInfo.stageId + job.activeStages += 1 + liveUpdate(job) + } + + event.stageInfo.rddInfos.foreach { info => + if (info.storageLevel.isValid) { + liveUpdate(liveRDDs.getOrElseUpdate(info.id, new LiveRDD(info))) + } + } + + liveUpdate(stage) + } + + override def onTaskStart(event: SparkListenerTaskStart): Unit = { + val task = new LiveTask(event.taskInfo, event.stageId, event.stageAttemptId) + liveTasks.put(event.taskInfo.taskId, task) + liveUpdate(task) + + liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + stage.activeTasks += 1 + stage.firstLaunchTime = math.min(stage.firstLaunchTime, event.taskInfo.launchTime) + liveUpdate(stage) + + stage.jobs.foreach { job => + job.activeTasks += 1 + liveUpdate(job) + } + } + + liveExecutors.get(event.taskInfo.executorId).foreach { exec => + exec.activeTasks += 1 + exec.totalTasks += 1 + liveUpdate(exec) + } + } + + override def onTaskGettingResult(event: SparkListenerTaskGettingResult): Unit = { + // Call update on the task so that the "getting result" time is written to the store; the + // value is part of the mutable TaskInfo state that the live entity already references. + liveTasks.get(event.taskInfo.taskId).foreach { task => + maybeUpdate(task) + } + } + + override def onTaskEnd(event: SparkListenerTaskEnd): Unit = { + // TODO: can this really happen? + if (event.taskInfo == null) { + return + } + + val metricsDelta = liveTasks.remove(event.taskInfo.taskId).map { task => + val errorMessage = event.reason match { + case Success => + None + case k: TaskKilled => + Some(k.reason) + case e: ExceptionFailure => // Handle ExceptionFailure because we might have accumUpdates + Some(e.toErrorString) + case e: TaskFailedReason => // All other failure cases + Some(e.toErrorString) + case other => + logInfo(s"Unhandled task end reason: $other") + None + } + task.errorMessage = errorMessage + val delta = task.updateMetrics(event.taskMetrics) + update(task) + delta + }.orNull + + val (completedDelta, failedDelta) = event.reason match { + case Success => + (1, 0) + case _ => + (0, 1) + } + + liveStages.get((event.stageId, event.stageAttemptId)).foreach { stage => + if (metricsDelta != null) { + stage.metrics.update(metricsDelta) + } + stage.activeTasks -= 1 + stage.completedTasks += completedDelta + stage.failedTasks += failedDelta + liveUpdate(stage) + + stage.jobs.foreach { job => + job.activeTasks -= 1 + job.completedTasks += completedDelta + job.failedTasks += failedDelta + liveUpdate(job) + } + + val esummary = stage.executorSummary(event.taskInfo.executorId) + esummary.taskTime += event.taskInfo.duration + esummary.succeededTasks += completedDelta + esummary.failedTasks += failedDelta + if (metricsDelta != null) { + esummary.metrics.update(metricsDelta) + } + liveUpdate(esummary) + } + + liveExecutors.get(event.taskInfo.executorId).foreach { exec => + if (event.taskMetrics != null) { + val readMetrics = event.taskMetrics.shuffleReadMetrics + exec.totalGcTime += event.taskMetrics.jvmGCTime + exec.totalInputBytes += event.taskMetrics.inputMetrics.bytesRead + exec.totalShuffleRead += readMetrics.localBytesRead + readMetrics.remoteBytesRead + exec.totalShuffleWrite += event.taskMetrics.shuffleWriteMetrics.bytesWritten + } + + exec.activeTasks -= 1 + exec.completedTasks += completedDelta + exec.failedTasks += failedDelta + exec.totalDuration += event.taskInfo.duration + liveUpdate(exec) + } + } + + override def onStageCompleted(event: SparkListenerStageCompleted): Unit = { + liveStages.remove((event.stageInfo.stageId, event.stageInfo.attemptId)).foreach { stage => + stage.info = event.stageInfo + + // Because of SPARK-20205, old event logs may contain valid stages without a submission time + // in their start event. In those cases, we can only detect whether a stage was skipped by + // waiting until the completion event, at which point the field would have been set. + stage.status = event.stageInfo.failureReason match { + case Some(_) => v1.StageStatus.FAILED + case _ if event.stageInfo.submissionTime.isDefined => v1.StageStatus.COMPLETE + case _ => v1.StageStatus.SKIPPED + } + + stage.jobs.foreach { job => + stage.status match { + case v1.StageStatus.COMPLETE => + job.completedStages += event.stageInfo.stageId + case v1.StageStatus.SKIPPED => + job.skippedStages += event.stageInfo.stageId + job.skippedTasks += event.stageInfo.numTasks + case _ => + job.failedStages += 1 + } + job.activeStages -= 1 + liveUpdate(job) + } + + stage.executorSummaries.values.foreach(update) + update(stage) + } + } + + override def onBlockManagerAdded(event: SparkListenerBlockManagerAdded): Unit = { + // This needs to set fields that are already set by onExecutorAdded because the driver is + // considered an "executor" in the UI, but does not have a SparkListenerExecutorAdded event. + val exec = getOrCreateExecutor(event.blockManagerId.executorId) + exec.hostPort = event.blockManagerId.hostPort + event.maxOnHeapMem.foreach { _ => + exec.totalOnHeap = event.maxOnHeapMem.get + exec.totalOffHeap = event.maxOffHeapMem.get + } + exec.isActive = true + exec.maxMemory = event.maxMem + liveUpdate(exec) + } + + override def onBlockManagerRemoved(event: SparkListenerBlockManagerRemoved): Unit = { + // Nothing to do here. Covered by onExecutorRemoved. + } + + override def onUnpersistRDD(event: SparkListenerUnpersistRDD): Unit = { + liveRDDs.remove(event.rddId) + kvstore.delete(classOf[RDDStorageInfoWrapper], event.rddId) + } + + override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = { + event.accumUpdates.foreach { case (taskId, sid, sAttempt, accumUpdates) => + liveTasks.get(taskId).foreach { task => + val metrics = TaskMetrics.fromAccumulatorInfos(accumUpdates) + val delta = task.updateMetrics(metrics) + maybeUpdate(task) + + liveStages.get((sid, sAttempt)).foreach { stage => + stage.metrics.update(delta) + maybeUpdate(stage) + + val esummary = stage.executorSummary(event.execId) + esummary.metrics.update(delta) + maybeUpdate(esummary) + } + } + } + } + + override def onBlockUpdated(event: SparkListenerBlockUpdated): Unit = { + event.blockUpdatedInfo.blockId match { + case block: RDDBlockId => updateRDDBlock(event, block) + case _ => // TODO: API only covers RDD storage. + } + } + + /** Flush all live entities' data to the underlying store. */ + def flush(): Unit = { + liveStages.values.foreach(update) + liveJobs.values.foreach(update) + liveExecutors.values.foreach(update) + liveTasks.values.foreach(update) + liveRDDs.values.foreach(update) + } + + private def updateRDDBlock(event: SparkListenerBlockUpdated, block: RDDBlockId): Unit = { + val executorId = event.blockUpdatedInfo.blockManagerId.executorId + + // Whether values are being added to or removed from the existing accounting. + val storageLevel = event.blockUpdatedInfo.storageLevel + val diskDelta = event.blockUpdatedInfo.diskSize * (if (storageLevel.useDisk) 1 else -1) + val memoryDelta = event.blockUpdatedInfo.memSize * (if (storageLevel.useMemory) 1 else -1) + + // Function to apply a delta to a value, but ensure that it doesn't go negative. + def newValue(old: Long, delta: Long): Long = math.max(0, old + delta) + + val updatedStorageLevel = if (storageLevel.isValid) { + Some(storageLevel.description) + } else { + None + } + + // We need information about the executor to update some memory accounting values in the + // RDD info, so read that beforehand. + val maybeExec = liveExecutors.get(executorId) + var rddBlocksDelta = 0 + + // Update the block entry in the RDD info, keeping track of the deltas above so that we + // can update the executor information too. + liveRDDs.get(block.rddId).foreach { rdd => + val partition = rdd.partition(block.name) + + val executors = if (updatedStorageLevel.isDefined) { + if (!partition.executors.contains(executorId)) { + rddBlocksDelta = 1 + } + partition.executors + executorId + } else { + rddBlocksDelta = -1 + partition.executors - executorId + } + + // Only update the partition if it's still stored in some executor, otherwise get rid of it. + if (executors.nonEmpty) { + if (updatedStorageLevel.isDefined) { + partition.storageLevel = updatedStorageLevel.get + } + partition.memoryUsed = newValue(partition.memoryUsed, memoryDelta) + partition.diskUsed = newValue(partition.diskUsed, diskDelta) + partition.executors = executors + } else { + rdd.removePartition(block.name) + } + + maybeExec.foreach { exec => + if (exec.rddBlocks + rddBlocksDelta > 0) { + val dist = rdd.distribution(exec) + dist.memoryRemaining = newValue(dist.memoryRemaining, -memoryDelta) + dist.memoryUsed = newValue(dist.memoryUsed, memoryDelta) + dist.diskUsed = newValue(dist.diskUsed, diskDelta) + + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + dist.offHeapUsed = newValue(dist.offHeapUsed, memoryDelta) + dist.offHeapRemaining = newValue(dist.offHeapRemaining, -memoryDelta) + } else { + dist.onHeapUsed = newValue(dist.onHeapUsed, memoryDelta) + dist.onHeapRemaining = newValue(dist.onHeapRemaining, -memoryDelta) + } + } + } else { + rdd.removeDistribution(exec) + } + } + + if (updatedStorageLevel.isDefined) { + rdd.storageLevel = updatedStorageLevel.get + } + rdd.memoryUsed = newValue(rdd.memoryUsed, memoryDelta) + rdd.diskUsed = newValue(rdd.diskUsed, diskDelta) + update(rdd) + } + + maybeExec.foreach { exec => + if (exec.hasMemoryInfo) { + if (storageLevel.useOffHeap) { + exec.usedOffHeap = newValue(exec.usedOffHeap, memoryDelta) + } else { + exec.usedOnHeap = newValue(exec.usedOnHeap, memoryDelta) + } + } + exec.memoryUsed = newValue(exec.memoryUsed, memoryDelta) + exec.diskUsed = newValue(exec.diskUsed, diskDelta) + exec.rddBlocks += rddBlocksDelta + if (exec.hasMemoryInfo || rddBlocksDelta != 0) { + liveUpdate(exec) + } + } + } + + private def getOrCreateExecutor(executorId: String): LiveExecutor = { + liveExecutors.getOrElseUpdate(executorId, new LiveExecutor(executorId)) + } + + private def getOrCreateStage(info: StageInfo): LiveStage = { + val stage = liveStages.getOrElseUpdate((info.stageId, info.attemptId), new LiveStage()) + stage.info = info + stage + } + + private def update(entity: LiveEntity): Unit = { + entity.write(kvstore) + } + + /** Update a live entity only if it hasn't been updated in the last configured period. */ + private def maybeUpdate(entity: LiveEntity): Unit = { + if (liveUpdatePeriodNs >= 0) { + val now = System.nanoTime() + if (now - entity.lastWriteTime > liveUpdatePeriodNs) { + update(entity) + } + } + } + + /** Update an entity only if in a live app; avoids redundant writes when replaying logs. */ + private def liveUpdate(entity: LiveEntity): Unit = { + if (live) { + update(entity) + } + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala new file mode 100644 index 0000000000000..2927a3227cbef --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/AppStatusStore.scala @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File +import java.util.{Arrays, List => JList} + +import scala.collection.JavaConverters._ + +import org.apache.spark.{JobExecutionStatus, SparkConf} +import org.apache.spark.scheduler.LiveListenerBus +import org.apache.spark.status.api.v1 +import org.apache.spark.util.{Distribution, Utils} +import org.apache.spark.util.kvstore.{InMemoryStore, KVStore} + +/** + * A wrapper around a KVStore that provides methods for accessing the API data stored within. + */ +private[spark] class AppStatusStore(store: KVStore) { + + def jobsList(statuses: JList[JobExecutionStatus]): Seq[v1.JobData] = { + val it = store.view(classOf[JobDataWrapper]).asScala.map(_.info) + if (!statuses.isEmpty()) { + it.filter { job => statuses.contains(job.status) }.toSeq + } else { + it.toSeq + } + } + + def job(jobId: Int): v1.JobData = { + store.read(classOf[JobDataWrapper], jobId).info + } + + def executorList(activeOnly: Boolean): Seq[v1.ExecutorSummary] = { + store.view(classOf[ExecutorSummaryWrapper]).index("active").reverse().first(true) + .last(true).asScala.map(_.info).toSeq + } + + def stageList(statuses: JList[v1.StageStatus]): Seq[v1.StageData] = { + val it = store.view(classOf[StageDataWrapper]).asScala.map(_.info) + if (!statuses.isEmpty) { + it.filter { s => statuses.contains(s.status) }.toSeq + } else { + it.toSeq + } + } + + def stageData(stageId: Int): Seq[v1.StageData] = { + store.view(classOf[StageDataWrapper]).index("stageId").first(stageId).last(stageId) + .asScala.map(_.info).toSeq + } + + def stageAttempt(stageId: Int, stageAttemptId: Int): v1.StageData = { + store.read(classOf[StageDataWrapper], Array(stageId, stageAttemptId)).info + } + + def taskSummary( + stageId: Int, + stageAttemptId: Int, + quantiles: Array[Double]): v1.TaskMetricDistributions = { + + val stage = Array(stageId, stageAttemptId) + + val rawMetrics = store.view(classOf[TaskDataWrapper]) + .index("stage") + .first(stage) + .last(stage) + .asScala + .flatMap(_.info.taskMetrics) + .toList + .view + + def metricQuantiles(f: v1.TaskMetrics => Double): IndexedSeq[Double] = + Distribution(rawMetrics.map { d => f(d) }).get.getQuantiles(quantiles) + + // We need to do a lot of similar munging to nested metrics here. For each one, + // we want (a) extract the values for nested metrics (b) make a distribution for each metric + // (c) shove the distribution into the right field in our return type and (d) only return + // a result if the option is defined for any of the tasks. MetricHelper is a little util + // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just + // implement one "build" method, which just builds the quantiles for each field. + + val inputMetrics = + new MetricHelper[v1.InputMetrics, v1.InputMetricDistributions](rawMetrics, quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.InputMetrics = raw.inputMetrics + + def build: v1.InputMetricDistributions = new v1.InputMetricDistributions( + bytesRead = submetricQuantiles(_.bytesRead), + recordsRead = submetricQuantiles(_.recordsRead) + ) + }.build + + val outputMetrics = + new MetricHelper[v1.OutputMetrics, v1.OutputMetricDistributions](rawMetrics, quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.OutputMetrics = raw.outputMetrics + + def build: v1.OutputMetricDistributions = new v1.OutputMetricDistributions( + bytesWritten = submetricQuantiles(_.bytesWritten), + recordsWritten = submetricQuantiles(_.recordsWritten) + ) + }.build + + val shuffleReadMetrics = + new MetricHelper[v1.ShuffleReadMetrics, v1.ShuffleReadMetricDistributions](rawMetrics, + quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleReadMetrics = + raw.shuffleReadMetrics + + def build: v1.ShuffleReadMetricDistributions = new v1.ShuffleReadMetricDistributions( + readBytes = submetricQuantiles { s => s.localBytesRead + s.remoteBytesRead }, + readRecords = submetricQuantiles(_.recordsRead), + remoteBytesRead = submetricQuantiles(_.remoteBytesRead), + remoteBytesReadToDisk = submetricQuantiles(_.remoteBytesReadToDisk), + remoteBlocksFetched = submetricQuantiles(_.remoteBlocksFetched), + localBlocksFetched = submetricQuantiles(_.localBlocksFetched), + totalBlocksFetched = submetricQuantiles { s => + s.localBlocksFetched + s.remoteBlocksFetched + }, + fetchWaitTime = submetricQuantiles(_.fetchWaitTime) + ) + }.build + + val shuffleWriteMetrics = + new MetricHelper[v1.ShuffleWriteMetrics, v1.ShuffleWriteMetricDistributions](rawMetrics, + quantiles) { + def getSubmetrics(raw: v1.TaskMetrics): v1.ShuffleWriteMetrics = + raw.shuffleWriteMetrics + + def build: v1.ShuffleWriteMetricDistributions = new v1.ShuffleWriteMetricDistributions( + writeBytes = submetricQuantiles(_.bytesWritten), + writeRecords = submetricQuantiles(_.recordsWritten), + writeTime = submetricQuantiles(_.writeTime) + ) + }.build + + new v1.TaskMetricDistributions( + quantiles = quantiles, + executorDeserializeTime = metricQuantiles(_.executorDeserializeTime), + executorDeserializeCpuTime = metricQuantiles(_.executorDeserializeCpuTime), + executorRunTime = metricQuantiles(_.executorRunTime), + executorCpuTime = metricQuantiles(_.executorCpuTime), + resultSize = metricQuantiles(_.resultSize), + jvmGcTime = metricQuantiles(_.jvmGcTime), + resultSerializationTime = metricQuantiles(_.resultSerializationTime), + memoryBytesSpilled = metricQuantiles(_.memoryBytesSpilled), + diskBytesSpilled = metricQuantiles(_.diskBytesSpilled), + inputMetrics = inputMetrics, + outputMetrics = outputMetrics, + shuffleReadMetrics = shuffleReadMetrics, + shuffleWriteMetrics = shuffleWriteMetrics + ) + } + + def taskList( + stageId: Int, + stageAttemptId: Int, + offset: Int, + length: Int, + sortBy: v1.TaskSorting): Seq[v1.TaskData] = { + val stageKey = Array(stageId, stageAttemptId) + val base = store.view(classOf[TaskDataWrapper]) + val indexed = sortBy match { + case v1.TaskSorting.ID => + base.index("stage").first(stageKey).last(stageKey) + case v1.TaskSorting.INCREASING_RUNTIME => + base.index("runtime").first(stageKey ++ Array(-1L)).last(stageKey ++ Array(Long.MaxValue)) + case v1.TaskSorting.DECREASING_RUNTIME => + base.index("runtime").first(stageKey ++ Array(Long.MaxValue)).last(stageKey ++ Array(-1L)) + .reverse() + } + indexed.skip(offset).max(length).asScala.map(_.info).toSeq + } + + def rddList(): Seq[v1.RDDStorageInfo] = { + store.view(classOf[RDDStorageInfoWrapper]).asScala.map(_.info).toSeq + } + + def rdd(rddId: Int): v1.RDDStorageInfo = { + store.read(classOf[RDDStorageInfoWrapper], rddId).info + } + + def close(): Unit = { + store.close() + } + +} + +private[spark] object AppStatusStore { + + val CURRENT_VERSION = 1L + + /** + * Create an in-memory store for a live application. + * + * @param conf Configuration. + * @param bus Where to attach the listener to populate the store. + */ + def createLiveStore(conf: SparkConf, bus: LiveListenerBus): AppStatusStore = { + val store = new InMemoryStore() + val stateStore = new AppStatusStore(store) + bus.addToStatusQueue(new AppStatusListener(store, conf, true)) + stateStore + } + +} + +/** + * Helper for getting distributions from nested metric types. + */ +private abstract class MetricHelper[I, O]( + rawMetrics: Seq[v1.TaskMetrics], + quantiles: Array[Double]) { + + def getSubmetrics(raw: v1.TaskMetrics): I + + def build: O + + val data: Seq[I] = rawMetrics.map(getSubmetrics) + + /** applies the given function to all input metrics, and returns the quantiles */ + def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { + Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) + } +} diff --git a/core/src/main/scala/org/apache/spark/status/KVUtils.scala b/core/src/main/scala/org/apache/spark/status/KVUtils.scala new file mode 100644 index 0000000000000..4638511944c61 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/KVUtils.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File + +import scala.annotation.meta.getter +import scala.language.implicitConversions +import scala.reflect.{classTag, ClassTag} + +import com.fasterxml.jackson.annotation.JsonInclude +import com.fasterxml.jackson.module.scala.DefaultScalaModule + +import org.apache.spark.internal.Logging +import org.apache.spark.util.kvstore._ + +private[spark] object KVUtils extends Logging { + + /** Use this to annotate constructor params to be used as KVStore indices. */ + type KVIndexParam = KVIndex @getter + + /** + * A KVStoreSerializer that provides Scala types serialization too, and uses the same options as + * the API serializer. + */ + private[spark] class KVStoreScalaSerializer extends KVStoreSerializer { + + mapper.registerModule(DefaultScalaModule) + mapper.setSerializationInclusion(JsonInclude.Include.NON_NULL) + + } + + /** + * Open or create a LevelDB store. + * + * @param path Location of the store. + * @param metadata Metadata value to compare to the data in the store. If the store does not + * contain any metadata (e.g. it's a new store), this value is written as + * the store's metadata. + */ + def open[M: ClassTag](path: File, metadata: M): LevelDB = { + require(metadata != null, "Metadata is required.") + + val db = new LevelDB(path, new KVStoreScalaSerializer()) + val dbMeta = db.getMetadata(classTag[M].runtimeClass) + if (dbMeta == null) { + db.setMetadata(metadata) + } else if (dbMeta != metadata) { + db.close() + throw new MetadataMismatchException() + } + + db + } + + private[spark] class MetadataMismatchException extends Exception + +} diff --git a/core/src/main/scala/org/apache/spark/status/LiveEntity.scala b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala new file mode 100644 index 0000000000000..337ef0b3e6c2b --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/LiveEntity.scala @@ -0,0 +1,529 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.Date + +import scala.collection.mutable.HashMap + +import org.apache.spark.JobExecutionStatus +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler.{AccumulableInfo, StageInfo, TaskInfo} +import org.apache.spark.status.api.v1 +import org.apache.spark.storage.RDDInfo +import org.apache.spark.ui.SparkUI +import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.kvstore.KVStore + +/** + * A mutable representation of a live entity in Spark (jobs, stages, tasks, et al). Every live + * entity uses one of these instances to keep track of their evolving state, and periodically + * flush an immutable view of the entity to the app state store. + */ +private[spark] abstract class LiveEntity { + + var lastWriteTime = 0L + + def write(store: KVStore): Unit = { + store.write(doUpdate()) + lastWriteTime = System.nanoTime() + } + + /** + * Returns an updated view of entity data, to be stored in the status store, reflecting the + * latest information collected by the listener. + */ + protected def doUpdate(): Any + +} + +private class LiveJob( + val jobId: Int, + name: String, + submissionTime: Option[Date], + val stageIds: Seq[Int], + jobGroup: Option[String], + numTasks: Int) extends LiveEntity { + + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + + var skippedTasks = 0 + var skippedStages = Set[Int]() + + var status = JobExecutionStatus.RUNNING + var completionTime: Option[Date] = None + + var completedStages: Set[Int] = Set() + var activeStages = 0 + var failedStages = 0 + + override protected def doUpdate(): Any = { + val info = new v1.JobData( + jobId, + name, + None, // description is always None? + submissionTime, + completionTime, + stageIds, + jobGroup, + status, + numTasks, + activeTasks, + completedTasks, + skippedTasks, + failedTasks, + activeStages, + completedStages.size, + skippedStages.size, + failedStages) + new JobDataWrapper(info, skippedStages) + } + +} + +private class LiveTask( + info: TaskInfo, + stageId: Int, + stageAttemptId: Int) extends LiveEntity { + + import LiveEntityHelpers._ + + private var recordedMetrics: v1.TaskMetrics = null + + var errorMessage: Option[String] = None + + /** + * Update the metrics for the task and return the difference between the previous and new + * values. + */ + def updateMetrics(metrics: TaskMetrics): v1.TaskMetrics = { + if (metrics != null) { + val old = recordedMetrics + recordedMetrics = new v1.TaskMetrics( + metrics.executorDeserializeTime, + metrics.executorDeserializeCpuTime, + metrics.executorRunTime, + metrics.executorCpuTime, + metrics.resultSize, + metrics.jvmGCTime, + metrics.resultSerializationTime, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + new v1.InputMetrics( + metrics.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead), + new v1.OutputMetrics( + metrics.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten), + new v1.ShuffleReadMetrics( + metrics.shuffleReadMetrics.remoteBlocksFetched, + metrics.shuffleReadMetrics.localBlocksFetched, + metrics.shuffleReadMetrics.fetchWaitTime, + metrics.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead), + new v1.ShuffleWriteMetrics( + metrics.shuffleWriteMetrics.bytesWritten, + metrics.shuffleWriteMetrics.writeTime, + metrics.shuffleWriteMetrics.recordsWritten)) + if (old != null) calculateMetricsDelta(recordedMetrics, old) else recordedMetrics + } else { + null + } + } + + /** + * Return a new TaskMetrics object containing the delta of the various fields of the given + * metrics objects. This is currently targeted at updating stage data, so it does not + * necessarily calculate deltas for all the fields. + */ + private def calculateMetricsDelta( + metrics: v1.TaskMetrics, + old: v1.TaskMetrics): v1.TaskMetrics = { + val shuffleWriteDelta = new v1.ShuffleWriteMetrics( + metrics.shuffleWriteMetrics.bytesWritten - old.shuffleWriteMetrics.bytesWritten, + 0L, + metrics.shuffleWriteMetrics.recordsWritten - old.shuffleWriteMetrics.recordsWritten) + + val shuffleReadDelta = new v1.ShuffleReadMetrics( + 0L, 0L, 0L, + metrics.shuffleReadMetrics.remoteBytesRead - old.shuffleReadMetrics.remoteBytesRead, + metrics.shuffleReadMetrics.remoteBytesReadToDisk - + old.shuffleReadMetrics.remoteBytesReadToDisk, + metrics.shuffleReadMetrics.localBytesRead - old.shuffleReadMetrics.localBytesRead, + metrics.shuffleReadMetrics.recordsRead - old.shuffleReadMetrics.recordsRead) + + val inputDelta = new v1.InputMetrics( + metrics.inputMetrics.bytesRead - old.inputMetrics.bytesRead, + metrics.inputMetrics.recordsRead - old.inputMetrics.recordsRead) + + val outputDelta = new v1.OutputMetrics( + metrics.outputMetrics.bytesWritten - old.outputMetrics.bytesWritten, + metrics.outputMetrics.recordsWritten - old.outputMetrics.recordsWritten) + + new v1.TaskMetrics( + 0L, 0L, + metrics.executorRunTime - old.executorRunTime, + metrics.executorCpuTime - old.executorCpuTime, + 0L, 0L, 0L, + metrics.memoryBytesSpilled - old.memoryBytesSpilled, + metrics.diskBytesSpilled - old.diskBytesSpilled, + inputDelta, + outputDelta, + shuffleReadDelta, + shuffleWriteDelta) + } + + override protected def doUpdate(): Any = { + val task = new v1.TaskData( + info.taskId, + info.index, + info.attemptNumber, + new Date(info.launchTime), + if (info.finished) Some(info.duration) else None, + info.executorId, + info.host, + info.status, + info.taskLocality.toString(), + info.speculative, + newAccumulatorInfos(info.accumulables), + errorMessage, + Option(recordedMetrics)) + new TaskDataWrapper(task, stageId, stageAttemptId) + } + +} + +private class LiveExecutor(val executorId: String) extends LiveEntity { + + var hostPort: String = null + var host: String = null + var isActive = true + var totalCores = 0 + + var rddBlocks = 0 + var memoryUsed = 0L + var diskUsed = 0L + var maxTasks = 0 + var maxMemory = 0L + + var totalTasks = 0 + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + var totalDuration = 0L + var totalGcTime = 0L + var totalInputBytes = 0L + var totalShuffleRead = 0L + var totalShuffleWrite = 0L + var isBlacklisted = false + + var executorLogs = Map[String, String]() + + // Memory metrics. They may not be recorded (e.g. old event logs) so if totalOnHeap is not + // initialized, the store will not contain this information. + var totalOnHeap = -1L + var totalOffHeap = 0L + var usedOnHeap = 0L + var usedOffHeap = 0L + + def hasMemoryInfo: Boolean = totalOnHeap >= 0L + + def hostname: String = if (host != null) host else hostPort.split(":")(0) + + override protected def doUpdate(): Any = { + val memoryMetrics = if (totalOnHeap >= 0) { + Some(new v1.MemoryMetrics(usedOnHeap, usedOffHeap, totalOnHeap, totalOffHeap)) + } else { + None + } + + val info = new v1.ExecutorSummary( + executorId, + if (hostPort != null) hostPort else host, + isActive, + rddBlocks, + memoryUsed, + diskUsed, + totalCores, + maxTasks, + activeTasks, + failedTasks, + completedTasks, + totalTasks, + totalDuration, + totalGcTime, + totalInputBytes, + totalShuffleRead, + totalShuffleWrite, + isBlacklisted, + maxMemory, + executorLogs, + memoryMetrics) + new ExecutorSummaryWrapper(info) + } + +} + +/** Metrics tracked per stage (both total and per executor). */ +private class MetricsTracker { + var executorRunTime = 0L + var executorCpuTime = 0L + var inputBytes = 0L + var inputRecords = 0L + var outputBytes = 0L + var outputRecords = 0L + var shuffleReadBytes = 0L + var shuffleReadRecords = 0L + var shuffleWriteBytes = 0L + var shuffleWriteRecords = 0L + var memoryBytesSpilled = 0L + var diskBytesSpilled = 0L + + def update(delta: v1.TaskMetrics): Unit = { + executorRunTime += delta.executorRunTime + executorCpuTime += delta.executorCpuTime + inputBytes += delta.inputMetrics.bytesRead + inputRecords += delta.inputMetrics.recordsRead + outputBytes += delta.outputMetrics.bytesWritten + outputRecords += delta.outputMetrics.recordsWritten + shuffleReadBytes += delta.shuffleReadMetrics.localBytesRead + + delta.shuffleReadMetrics.remoteBytesRead + shuffleReadRecords += delta.shuffleReadMetrics.recordsRead + shuffleWriteBytes += delta.shuffleWriteMetrics.bytesWritten + shuffleWriteRecords += delta.shuffleWriteMetrics.recordsWritten + memoryBytesSpilled += delta.memoryBytesSpilled + diskBytesSpilled += delta.diskBytesSpilled + } + +} + +private class LiveExecutorStageSummary( + stageId: Int, + attemptId: Int, + executorId: String) extends LiveEntity { + + var taskTime = 0L + var succeededTasks = 0 + var failedTasks = 0 + var killedTasks = 0 + + val metrics = new MetricsTracker() + + override protected def doUpdate(): Any = { + val info = new v1.ExecutorStageSummary( + taskTime, + failedTasks, + succeededTasks, + metrics.inputBytes, + metrics.outputBytes, + metrics.shuffleReadBytes, + metrics.shuffleWriteBytes, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled) + new ExecutorStageSummaryWrapper(stageId, attemptId, executorId, info) + } + +} + +private class LiveStage extends LiveEntity { + + import LiveEntityHelpers._ + + var jobs = Seq[LiveJob]() + var jobIds = Set[Int]() + + var info: StageInfo = null + var status = v1.StageStatus.PENDING + + var schedulingPool: String = SparkUI.DEFAULT_POOL_NAME + + var activeTasks = 0 + var completedTasks = 0 + var failedTasks = 0 + + var firstLaunchTime = Long.MaxValue + + val metrics = new MetricsTracker() + + val executorSummaries = new HashMap[String, LiveExecutorStageSummary]() + + def executorSummary(executorId: String): LiveExecutorStageSummary = { + executorSummaries.getOrElseUpdate(executorId, + new LiveExecutorStageSummary(info.stageId, info.attemptId, executorId)) + } + + override protected def doUpdate(): Any = { + val update = new v1.StageData( + status, + info.stageId, + info.attemptId, + + activeTasks, + completedTasks, + failedTasks, + + metrics.executorRunTime, + metrics.executorCpuTime, + info.submissionTime.map(new Date(_)), + if (firstLaunchTime < Long.MaxValue) Some(new Date(firstLaunchTime)) else None, + info.completionTime.map(new Date(_)), + + metrics.inputBytes, + metrics.inputRecords, + metrics.outputBytes, + metrics.outputRecords, + metrics.shuffleReadBytes, + metrics.shuffleReadRecords, + metrics.shuffleWriteBytes, + metrics.shuffleWriteRecords, + metrics.memoryBytesSpilled, + metrics.diskBytesSpilled, + + info.name, + info.details, + schedulingPool, + + newAccumulatorInfos(info.accumulables.values), + None, + None) + + new StageDataWrapper(update, jobIds) + } + +} + +private class LiveRDDPartition(val blockName: String) { + + var executors = Set[String]() + var storageLevel: String = null + var memoryUsed = 0L + var diskUsed = 0L + + def toApi(): v1.RDDPartitionInfo = { + new v1.RDDPartitionInfo( + blockName, + storageLevel, + memoryUsed, + diskUsed, + executors.toSeq.sorted) + } + +} + +private class LiveRDDDistribution(val exec: LiveExecutor) { + + var memoryRemaining = exec.maxMemory + var memoryUsed = 0L + var diskUsed = 0L + + var onHeapUsed = 0L + var offHeapUsed = 0L + var onHeapRemaining = 0L + var offHeapRemaining = 0L + + def toApi(): v1.RDDDataDistribution = { + new v1.RDDDataDistribution( + exec.hostPort, + memoryUsed, + memoryRemaining, + diskUsed, + if (exec.hasMemoryInfo) Some(onHeapUsed) else None, + if (exec.hasMemoryInfo) Some(offHeapUsed) else None, + if (exec.hasMemoryInfo) Some(onHeapRemaining) else None, + if (exec.hasMemoryInfo) Some(offHeapRemaining) else None) + } + +} + +private class LiveRDD(info: RDDInfo) extends LiveEntity { + + var storageLevel: String = info.storageLevel.description + var memoryUsed = 0L + var diskUsed = 0L + + private val partitions = new HashMap[String, LiveRDDPartition]() + private val distributions = new HashMap[String, LiveRDDDistribution]() + + def partition(blockName: String): LiveRDDPartition = { + partitions.getOrElseUpdate(blockName, new LiveRDDPartition(blockName)) + } + + def removePartition(blockName: String): Unit = partitions.remove(blockName) + + def distribution(exec: LiveExecutor): LiveRDDDistribution = { + distributions.getOrElseUpdate(exec.hostPort, new LiveRDDDistribution(exec)) + } + + def removeDistribution(exec: LiveExecutor): Unit = { + distributions.remove(exec.hostPort) + } + + override protected def doUpdate(): Any = { + val parts = if (partitions.nonEmpty) { + Some(partitions.values.toList.sortBy(_.blockName).map(_.toApi())) + } else { + None + } + + val dists = if (distributions.nonEmpty) { + Some(distributions.values.toList.sortBy(_.exec.executorId).map(_.toApi())) + } else { + None + } + + val rdd = new v1.RDDStorageInfo( + info.id, + info.name, + info.numPartitions, + partitions.size, + storageLevel, + memoryUsed, + diskUsed, + dists, + parts) + + new RDDStorageInfoWrapper(rdd) + } + +} + +private object LiveEntityHelpers { + + def newAccumulatorInfos(accums: Iterable[AccumulableInfo]): Seq[v1.AccumulableInfo] = { + accums + .filter { acc => + // We don't need to store internal or SQL accumulables as their values will be shown in + // other places, so drop them to reduce the memory usage. + !acc.internal && (!acc.metadata.isDefined || + acc.metadata.get != Some(AccumulatorContext.SQL_ACCUM_IDENTIFIER)) + } + .map { acc => + new v1.AccumulableInfo( + acc.id, + acc.name.map(_.intern()).orNull, + acc.update.map(_.toString()), + acc.value.map(_.toString()).orNull) + } + .toSeq + } + +} diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index 4a4ed954d689e..5f69949c618fd 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -71,7 +71,7 @@ private[v1] object AllStagesResource { val taskData = if (includeDetails) { Some(stageUiData.taskData.map { case (k, v) => - k -> convertTaskData(v, stageUiData.lastUpdateTime) }) + k -> convertTaskData(v, stageUiData.lastUpdateTime) }.toMap) } else { None } @@ -88,7 +88,7 @@ private[v1] object AllStagesResource { memoryBytesSpilled = summary.memoryBytesSpilled, diskBytesSpilled = summary.diskBytesSpilled ) - }) + }.toMap) } else { None } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala index f17b637754826..9d3833086172f 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/ApiRootResource.scala @@ -248,7 +248,13 @@ private[spark] object ApiRootResource { * interface needed for them all to expose application info as json. */ private[spark] trait UIRoot { - def getSparkUI(appKey: String): Option[SparkUI] + /** + * Runs some code with the current SparkUI instance for the app / attempt. + * + * @throws NoSuchElementException If the app / attempt pair does not exist. + */ + def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T + def getApplicationInfoList: Iterator[ApplicationInfo] def getApplicationInfo(appId: String): Option[ApplicationInfo] @@ -293,15 +299,18 @@ private[v1] trait ApiRequestContext { * to it. If there is no such app, throw an appropriate exception */ def withSparkUI[T](appId: String, attemptId: Option[String])(f: SparkUI => T): T = { - val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) - uiRoot.getSparkUI(appKey) match { - case Some(ui) => + try { + uiRoot.withSparkUI(appId, attemptId) { ui => val user = httpRequest.getRemoteUser() if (!ui.securityManager.checkUIViewPermissions(user)) { throw new ForbiddenException(raw"""user "$user" is not authorized""") } f(ui) - case None => throw new NotFoundException("no such app: " + appId) + } + } catch { + case _: NoSuchElementException => + val appKey = attemptId.map(appId + "/" + _).getOrElse(appId) + throw new NotFoundException(s"no such app: $appKey") } } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index 05948f2661056..bff6f90823f40 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -16,9 +16,11 @@ */ package org.apache.spark.status.api.v1 +import java.lang.{Long => JLong} import java.util.Date -import scala.collection.Map +import com.fasterxml.jackson.annotation.JsonIgnoreProperties +import com.fasterxml.jackson.databind.annotation.JsonDeserialize import org.apache.spark.JobExecutionStatus @@ -31,6 +33,9 @@ class ApplicationInfo private[spark]( val memoryPerExecutorMB: Option[Int], val attempts: Seq[ApplicationAttemptInfo]) +@JsonIgnoreProperties( + value = Array("startTimeEpoch", "endTimeEpoch", "lastUpdatedEpoch"), + allowGetters = true) class ApplicationAttemptInfo private[spark]( val attemptId: Option[String], val startTime: Date, @@ -40,9 +45,13 @@ class ApplicationAttemptInfo private[spark]( val sparkUser: String, val completed: Boolean = false, val appSparkVersion: String) { - def getStartTimeEpoch: Long = startTime.getTime - def getEndTimeEpoch: Long = endTime.getTime - def getLastUpdatedEpoch: Long = lastUpdated.getTime + + def getStartTimeEpoch: Long = startTime.getTime + + def getEndTimeEpoch: Long = endTime.getTime + + def getLastUpdatedEpoch: Long = lastUpdated.getTime + } class ExecutorStageSummary private[spark]( @@ -120,9 +129,13 @@ class RDDDataDistribution private[spark]( val memoryUsed: Long, val memoryRemaining: Long, val diskUsed: Long, + @JsonDeserialize(contentAs = classOf[JLong]) val onHeapMemoryUsed: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val offHeapMemoryUsed: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val onHeapMemoryRemaining: Option[Long], + @JsonDeserialize(contentAs = classOf[JLong]) val offHeapMemoryRemaining: Option[Long]) class RDDPartitionInfo private[spark]( @@ -170,7 +183,8 @@ class TaskData private[spark]( val index: Int, val attempt: Int, val launchTime: Date, - val duration: Option[Long] = None, + @JsonDeserialize(contentAs = classOf[JLong]) + val duration: Option[Long], val executorId: String, val host: String, val status: String, diff --git a/core/src/main/scala/org/apache/spark/status/config.scala b/core/src/main/scala/org/apache/spark/status/config.scala new file mode 100644 index 0000000000000..49144fc883e69 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/config.scala @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.util.concurrent.TimeUnit + +import org.apache.spark.internal.config._ + +private[spark] object config { + + val LIVE_ENTITY_UPDATE_PERIOD = ConfigBuilder("spark.ui.liveUpdate.period") + .timeConf(TimeUnit.NANOSECONDS) + .createWithDefaultString("100ms") + +} diff --git a/core/src/main/scala/org/apache/spark/status/storeTypes.scala b/core/src/main/scala/org/apache/spark/status/storeTypes.scala new file mode 100644 index 0000000000000..340d5994a0012 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/status/storeTypes.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.lang.{Integer => JInteger, Long => JLong} + +import com.fasterxml.jackson.annotation.JsonIgnore + +import org.apache.spark.status.KVUtils._ +import org.apache.spark.status.api.v1._ +import org.apache.spark.util.kvstore.KVIndex + +private[spark] case class AppStatusStoreMetadata( + val version: Long) + +private[spark] class ApplicationInfoWrapper(val info: ApplicationInfo) { + + @JsonIgnore @KVIndex + def id: String = info.id + +} + +private[spark] class ExecutorSummaryWrapper(val info: ExecutorSummary) { + + @JsonIgnore @KVIndex + private[this] val id: String = info.id + + @JsonIgnore @KVIndex("active") + private[this] val active: Boolean = info.isActive + + @JsonIgnore @KVIndex("host") + val host: String = info.hostPort.split(":")(0) + +} + +/** + * Keep track of the existing stages when the job was submitted, and those that were + * completed during the job's execution. This allows a more accurate acounting of how + * many tasks were skipped for the job. + */ +private[spark] class JobDataWrapper( + val info: JobData, + val skippedStages: Set[Int]) { + + @JsonIgnore @KVIndex + private[this] val id: Int = info.jobId + +} + +private[spark] class StageDataWrapper( + val info: StageData, + val jobIds: Set[Int]) { + + @JsonIgnore @KVIndex + def id: Array[Int] = Array(info.stageId, info.attemptId) + + @JsonIgnore @KVIndex("stageId") + def stageId: Int = info.stageId + +} + +/** + * The task information is always indexed with the stage ID, since that is how the UI and API + * consume it. That means every indexed value has the stage ID and attempt ID included, aside + * from the actual data being indexed. + */ +private[spark] class TaskDataWrapper( + val info: TaskData, + val stageId: Int, + val stageAttemptId: Int) { + + @JsonIgnore @KVIndex + def id: Long = info.taskId + + @JsonIgnore @KVIndex("stage") + def stage: Array[Int] = Array(stageId, stageAttemptId) + + @JsonIgnore @KVIndex("runtime") + def runtime: Array[AnyRef] = { + val _runtime = info.taskMetrics.map(_.executorRunTime).getOrElse(-1L) + Array(stageId: JInteger, stageAttemptId: JInteger, _runtime: JLong) + } + +} + +private[spark] class RDDStorageInfoWrapper(val info: RDDStorageInfo) { + + @JsonIgnore @KVIndex + def id: Int = info.id + + @JsonIgnore @KVIndex("cached") + def cached: Boolean = info.numCachedPartitions > 0 + +} + +private[spark] class ExecutorStageSummaryWrapper( + val stageId: Int, + val stageAttemptId: Int, + val executorId: String, + val info: ExecutorStageSummary) { + + @JsonIgnore @KVIndex + val id: Array[Any] = Array(stageId, stageAttemptId, executorId) + + @JsonIgnore @KVIndex("stage") + private[this] val stage: Array[Int] = Array(stageId, stageAttemptId) + +} diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala b/core/src/main/scala/org/apache/spark/storage/BlockId.scala index 524f6970992a5..7ac2c71c18eb3 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala @@ -19,6 +19,7 @@ package org.apache.spark.storage import java.util.UUID +import org.apache.spark.SparkException import org.apache.spark.annotation.DeveloperApi /** @@ -41,11 +42,6 @@ sealed abstract class BlockId { def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId] override def toString: String = name - override def hashCode: Int = name.hashCode - override def equals(other: Any): Boolean = other match { - case o: BlockId => getClass == o.getClass && name.equals(o.name) - case _ => false - } } @DeveloperApi @@ -100,6 +96,10 @@ private[spark] case class TestBlockId(id: String) extends BlockId { override def name: String = "test_" + id } +@DeveloperApi +class UnrecognizedBlockId(name: String) + extends SparkException(s"Failed to parse $name into a block ID") + @DeveloperApi object BlockId { val RDD = "rdd_([0-9]+)_([0-9]+)".r @@ -109,10 +109,11 @@ object BlockId { val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r val TASKRESULT = "taskresult_([0-9]+)".r val STREAM = "input-([0-9]+)-([0-9]+)".r + val TEMP_LOCAL = "temp_local_([-A-Fa-f0-9]+)".r + val TEMP_SHUFFLE = "temp_shuffle_([-A-Fa-f0-9]+)".r val TEST = "test_(.*)".r - /** Converts a BlockId "name" String back into a BlockId. */ - def apply(id: String): BlockId = id match { + def apply(name: String): BlockId = name match { case RDD(rddId, splitIndex) => RDDBlockId(rddId.toInt, splitIndex.toInt) case SHUFFLE(shuffleId, mapId, reduceId) => @@ -127,9 +128,13 @@ object BlockId { TaskResultBlockId(taskId.toLong) case STREAM(streamId, uniqueId) => StreamBlockId(streamId.toInt, uniqueId.toLong) + case TEMP_LOCAL(uuid) => + TempLocalBlockId(UUID.fromString(uuid)) + case TEMP_SHUFFLE(uuid) => + TempShuffleBlockId(UUID.fromString(uuid)) case TEST(value) => TestBlockId(value) case _ => - throw new IllegalStateException("Unrecognized BlockId: " + id) + throw new UnrecognizedBlockId(name) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index a98083df5bd84..e0276a4dc4224 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -18,8 +18,11 @@ package org.apache.spark.storage import java.io._ +import java.lang.ref.{ReferenceQueue => JReferenceQueue, WeakReference} import java.nio.ByteBuffer import java.nio.channels.Channels +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable import scala.collection.mutable.HashMap @@ -39,7 +42,7 @@ import org.apache.spark.metrics.source.Source import org.apache.spark.network._ import org.apache.spark.network.buffer.ManagedBuffer import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.network.shuffle.ExternalShuffleClient +import org.apache.spark.network.shuffle.{ExternalShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.{SerializerInstance, SerializerManager} @@ -203,6 +206,13 @@ private[spark] class BlockManager( private var blockReplicationPolicy: BlockReplicationPolicy = _ + // A TempFileManager used to track all the files of remote blocks which above the + // specified memory threshold. Files will be deleted automatically based on weak reference. + // Exposed for test + private[storage] val remoteBlockTempFileManager = + new BlockManager.RemoteBlockTempFileManager(this) + private val maxRemoteBlockToMem = conf.get(config.MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM) + /** * Initializes the BlockManager with the given appId. This is not performed in the constructor as * the appId may not be known at BlockManager instantiation time (in particular for the driver, @@ -632,8 +642,8 @@ private[spark] class BlockManager( * Return a list of locations for the given block, prioritizing the local machine since * multiple block managers can share the same host, followed by hosts on the same rack. */ - private def getLocations(blockId: BlockId): Seq[BlockManagerId] = { - val locs = Random.shuffle(master.getLocations(blockId)) + private def sortLocations(locations: Seq[BlockManagerId]): Seq[BlockManagerId] = { + val locs = Random.shuffle(locations) val (preferredLocs, otherLocs) = locs.partition { loc => blockManagerId.host == loc.host } blockManagerId.topologyInfo match { case None => preferredLocs ++ otherLocs @@ -653,7 +663,25 @@ private[spark] class BlockManager( require(blockId != null, "BlockId is null") var runningFailureCount = 0 var totalFailureCount = 0 - val locations = getLocations(blockId) + + // Because all the remote blocks are registered in driver, it is not necessary to ask + // all the slave executors to get block status. + val locationsAndStatus = master.getLocationsAndStatus(blockId) + val blockSize = locationsAndStatus.map { b => + b.status.diskSize.max(b.status.memSize) + }.getOrElse(0L) + val blockLocations = locationsAndStatus.map(_.locations).getOrElse(Seq.empty) + + // If the block size is above the threshold, we should pass our FileManger to + // BlockTransferService, which will leverage it to spill the block; if not, then passed-in + // null value means the block will be persisted in memory. + val tempFileManager = if (blockSize > maxRemoteBlockToMem) { + remoteBlockTempFileManager + } else { + null + } + + val locations = sortLocations(blockLocations) val maxFetchFailures = locations.size var locationIterator = locations.iterator while (locationIterator.hasNext) { @@ -661,7 +689,7 @@ private[spark] class BlockManager( logDebug(s"Getting remote block $blockId from $loc") val data = try { blockTransferService.fetchBlockSync( - loc.host, loc.port, loc.executorId, blockId.toString).nioByteBuffer() + loc.host, loc.port, loc.executorId, blockId.toString, tempFileManager).nioByteBuffer() } catch { case NonFatal(e) => runningFailureCount += 1 @@ -684,7 +712,7 @@ private[spark] class BlockManager( // take a significant amount of time. To get rid of these stale entries // we refresh the block locations after a certain number of fetch failures if (runningFailureCount >= maxFailuresBeforeLocationRefresh) { - locationIterator = getLocations(blockId).iterator + locationIterator = sortLocations(master.getLocations(blockId)).iterator logDebug(s"Refreshed locations from the driver " + s"after ${runningFailureCount} fetch failures.") runningFailureCount = 0 @@ -1512,6 +1540,7 @@ private[spark] class BlockManager( // Closing should be idempotent, but maybe not for the NioBlockTransferService. shuffleClient.close() } + remoteBlockTempFileManager.stop() diskBlockManager.stop() rpcEnv.stop(slaveEndpoint) blockInfoManager.clear() @@ -1552,4 +1581,65 @@ private[spark] object BlockManager { override val metricRegistry = new MetricRegistry metricRegistry.registerAll(metricSet) } + + class RemoteBlockTempFileManager(blockManager: BlockManager) + extends TempFileManager with Logging { + + private class ReferenceWithCleanup(file: File, referenceQueue: JReferenceQueue[File]) + extends WeakReference[File](file, referenceQueue) { + private val filePath = file.getAbsolutePath + + def cleanUp(): Unit = { + logDebug(s"Clean up file $filePath") + + if (!new File(filePath).delete()) { + logDebug(s"Fail to delete file $filePath") + } + } + } + + private val referenceQueue = new JReferenceQueue[File] + private val referenceBuffer = Collections.newSetFromMap[ReferenceWithCleanup]( + new ConcurrentHashMap) + + private val POLL_TIMEOUT = 1000 + @volatile private var stopped = false + + private val cleaningThread = new Thread() { override def run() { keepCleaning() } } + cleaningThread.setDaemon(true) + cleaningThread.setName("RemoteBlock-temp-file-clean-thread") + cleaningThread.start() + + override def createTempFile(): File = { + blockManager.diskBlockManager.createTempLocalBlock()._2 + } + + override def registerTempFileToClean(file: File): Boolean = { + referenceBuffer.add(new ReferenceWithCleanup(file, referenceQueue)) + } + + def stop(): Unit = { + stopped = true + cleaningThread.interrupt() + cleaningThread.join() + } + + private def keepCleaning(): Unit = { + while (!stopped) { + try { + Option(referenceQueue.remove(POLL_TIMEOUT)) + .map(_.asInstanceOf[ReferenceWithCleanup]) + .foreach { ref => + referenceBuffer.remove(ref) + ref.cleanUp() + } + } catch { + case _: InterruptedException => + // no-op + case NonFatal(e) => + logError("Error in cleaning thread", e) + } + } + } + } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 8b1dc0ba6356a..d24421b962774 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -84,6 +84,12 @@ class BlockManagerMaster( driverEndpoint.askSync[Seq[BlockManagerId]](GetLocations(blockId)) } + /** Get locations as well as status of the blockId from the driver */ + def getLocationsAndStatus(blockId: BlockId): Option[BlockLocationsAndStatus] = { + driverEndpoint.askSync[Option[BlockLocationsAndStatus]]( + GetLocationsAndStatus(blockId)) + } + /** Get locations of multiple blockIds from the driver */ def getLocations(blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { driverEndpoint.askSync[IndexedSeq[Seq[BlockManagerId]]]( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index df0a5f5e229fb..56d0266b8edad 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -82,6 +82,9 @@ class BlockManagerMasterEndpoint( case GetLocations(blockId) => context.reply(getLocations(blockId)) + case GetLocationsAndStatus(blockId) => + context.reply(getLocationsAndStatus(blockId)) + case GetLocationsMultipleBlockIds(blockIds) => context.reply(getLocationsMultipleBlockIds(blockIds)) @@ -422,6 +425,17 @@ class BlockManagerMasterEndpoint( if (blockLocations.containsKey(blockId)) blockLocations.get(blockId).toSeq else Seq.empty } + private def getLocationsAndStatus(blockId: BlockId): Option[BlockLocationsAndStatus] = { + val locations = Option(blockLocations.get(blockId)).map(_.toSeq).getOrElse(Seq.empty) + val status = locations.headOption.flatMap { bmId => blockManagerInfo(bmId).getStatus(blockId) } + + if (locations.nonEmpty && status.isDefined) { + Some(BlockLocationsAndStatus(locations, status.get)) + } else { + None + } + } + private def getLocationsMultipleBlockIds( blockIds: Array[BlockId]): IndexedSeq[Seq[BlockManagerId]] = { blockIds.map(blockId => getLocations(blockId)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 0c0ff144596ac..1bbe7a5b39509 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -93,6 +93,13 @@ private[spark] object BlockManagerMessages { case class GetLocations(blockId: BlockId) extends ToBlockManagerMaster + case class GetLocationsAndStatus(blockId: BlockId) extends ToBlockManagerMaster + + // The response message of `GetLocationsAndStatus` request. + case class BlockLocationsAndStatus(locations: Seq[BlockManagerId], status: BlockStatus) { + assert(locations.nonEmpty) + } + case class GetLocationsMultipleBlockIds(blockIds: Array[BlockId]) extends ToBlockManagerMaster case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 3d43e3c367aac..a69bcc9259995 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -100,7 +100,16 @@ private[spark] class DiskBlockManager(conf: SparkConf, deleteFilesOnStop: Boolea /** List all the blocks currently stored on disk by the disk manager. */ def getAllBlocks(): Seq[BlockId] = { - getAllFiles().map(f => BlockId(f.getName)) + getAllFiles().flatMap { f => + try { + Some(BlockId(f.getName)) + } catch { + case _: UnrecognizedBlockId => + // Skip files which do not correspond to blocks, for example temporary + // files created by [[SortShuffleWriter]]. + None + } + } } /** Produces a unique block id and File suitable for storing local intermediate results. */ diff --git a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala index 3579acf8d83d9..97abd92d4b70f 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskStore.scala @@ -47,9 +47,9 @@ private[spark] class DiskStore( private val minMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapThreshold", "2m") private val maxMemoryMapBytes = conf.getSizeAsBytes("spark.storage.memoryMapLimitForTests", Int.MaxValue.toString) - private val blockSizes = new ConcurrentHashMap[String, Long]() + private val blockSizes = new ConcurrentHashMap[BlockId, Long]() - def getSize(blockId: BlockId): Long = blockSizes.get(blockId.name) + def getSize(blockId: BlockId): Long = blockSizes.get(blockId) /** * Invokes the provided callback function to write the specific block. @@ -67,7 +67,7 @@ private[spark] class DiskStore( var threwException: Boolean = true try { writeFunc(out) - blockSizes.put(blockId.name, out.getCount) + blockSizes.put(blockId, out.getCount) threwException = false } finally { try { @@ -113,7 +113,7 @@ private[spark] class DiskStore( } def remove(blockId: BlockId): Boolean = { - blockSizes.remove(blockId.name) + blockSizes.remove(blockId) val file = diskManager.getFile(blockId.name) if (file.exists()) { val ret = file.delete() diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 2d176b62f8b36..98b5a735a4529 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -28,7 +28,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue} import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils import org.apache.spark.util.io.ChunkedByteBufferOutputStream @@ -69,7 +69,7 @@ final class ShuffleBlockFetcherIterator( maxBlocksInFlightPerAddress: Int, maxReqSizeShuffleToMem: Long, detectCorrupt: Boolean) - extends Iterator[(BlockId, InputStream)] with TempShuffleFileManager with Logging { + extends Iterator[(BlockId, InputStream)] with TempFileManager with Logging { import ShuffleBlockFetcherIterator._ @@ -162,11 +162,11 @@ final class ShuffleBlockFetcherIterator( currentResult = null } - override def createTempShuffleFile(): File = { + override def createTempFile(): File = { blockManager.diskBlockManager.createTempLocalBlock()._2 } - override def registerTempShuffleFileToClean(file: File): Boolean = synchronized { + override def registerTempFileToClean(file: File): Boolean = synchronized { if (isZombie) { false } else { diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index 651e9c7b2ab61..17f7a69ad6ba1 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -388,7 +388,13 @@ private[spark] class MemoryStore( // perform one final call to attempt to allocate additional memory if necessary. if (keepUnrolling) { serializationStream.close() - reserveAdditionalMemoryIfNecessary() + if (bbos.size > unrollMemoryUsedByThisBlock) { + val amountToRequest = bbos.size - unrollMemoryUsedByThisBlock + keepUnrolling = reserveUnrollMemoryForThisTask(blockId, amountToRequest, memoryMode) + if (keepUnrolling) { + unrollMemoryUsedByThisBlock += amountToRequest + } + } } if (keepUnrolling) { diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 5ee04dad6ed4d..0adeb4058b6e4 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -39,6 +39,7 @@ import org.json4s.jackson.JsonMethods.{pretty, render} import org.apache.spark.{SecurityManager, SparkConf, SSLOptions} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.util.Utils /** @@ -89,6 +90,14 @@ private[spark] object JettyUtils extends Logging { val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") response.setHeader("X-Frame-Options", xFrameOptionsValue) + response.setHeader("X-XSS-Protection", conf.get(UI_X_XSS_PROTECTION)) + if (conf.get(UI_X_CONTENT_TYPE_OPTIONS)) { + response.setHeader("X-Content-Type-Options", "nosniff") + } + if (request.getScheme == "https") { + conf.get(UI_STRICT_TRANSPORT_SECURITY).foreach( + response.setHeader("Strict-Transport-Security", _)) + } response.getWriter.print(servletParams.extractFn(result)) } else { response.setStatus(HttpServletResponse.SC_FORBIDDEN) diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala index 6e94073238a56..ee645f6bf8a7a 100644 --- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.{SecurityManager, SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.scheduler._ +import org.apache.spark.status.AppStatusStore import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationAttemptInfo, ApplicationInfo, UIRoot} import org.apache.spark.storage.StorageStatusListener @@ -39,6 +40,7 @@ import org.apache.spark.util.Utils * Top level user interface for a Spark application. */ private[spark] class SparkUI private ( + val store: AppStatusStore, val sc: Option[SparkContext], val conf: SparkConf, securityManager: SecurityManager, @@ -51,7 +53,8 @@ private[spark] class SparkUI private ( var appName: String, val basePath: String, val lastUpdateTime: Option[Long] = None, - val startTime: Long) + val startTime: Long, + val appSparkVersion: String) extends WebUI(securityManager, securityManager.getSSLOptions("ui"), SparkUI.getUIPort(conf), conf, basePath, "SparkUI") with Logging @@ -61,8 +64,6 @@ private[spark] class SparkUI private ( var appId: String = _ - var appSparkVersion = org.apache.spark.SPARK_VERSION - private var streamingJobProgressListener: Option[SparkListener] = None /** Initialize all components of the server. */ @@ -104,8 +105,12 @@ private[spark] class SparkUI private ( logInfo(s"Stopped Spark web UI at $webUrl") } - def getSparkUI(appId: String): Option[SparkUI] = { - if (appId == this.appId) Some(this) else None + override def withSparkUI[T](appId: String, attemptId: Option[String])(fn: SparkUI => T): T = { + if (appId == this.appId) { + fn(this) + } else { + throw new NoSuchElementException() + } } def getApplicationInfoList: Iterator[ApplicationInfo] = { @@ -159,63 +164,26 @@ private[spark] object SparkUI { conf.getInt("spark.ui.port", SparkUI.DEFAULT_PORT) } - def createLiveUI( - sc: SparkContext, - conf: SparkConf, - jobProgressListener: JobProgressListener, - securityManager: SecurityManager, - appName: String, - startTime: Long): SparkUI = { - create(Some(sc), conf, - sc.listenerBus.addToStatusQueue, - securityManager, appName, jobProgressListener = Some(jobProgressListener), - startTime = startTime) - } - - def createHistoryUI( - conf: SparkConf, - listenerBus: SparkListenerBus, - securityManager: SecurityManager, - appName: String, - basePath: String, - lastUpdateTime: Option[Long], - startTime: Long): SparkUI = { - val sparkUI = create(None, conf, listenerBus.addListener, securityManager, appName, basePath, - lastUpdateTime = lastUpdateTime, startTime = startTime) - - val listenerFactories = ServiceLoader.load(classOf[SparkHistoryListenerFactory], - Utils.getContextOrSparkClassLoader).asScala - listenerFactories.foreach { listenerFactory => - val listeners = listenerFactory.createListeners(conf, sparkUI) - listeners.foreach(listenerBus.addListener) - } - sparkUI - } - /** - * Create a new Spark UI. - * - * @param sc optional SparkContext; this can be None when reconstituting a UI from event logs. - * @param jobProgressListener if supplied, this JobProgressListener will be used; otherwise, the - * web UI will create and register its own JobProgressListener. + * Create a new UI backed by an AppStatusStore. */ - private def create( + def create( sc: Option[SparkContext], + store: AppStatusStore, conf: SparkConf, addListenerFn: SparkListenerInterface => Unit, securityManager: SecurityManager, appName: String, - basePath: String = "", - jobProgressListener: Option[JobProgressListener] = None, + basePath: String, + startTime: Long, lastUpdateTime: Option[Long] = None, - startTime: Long): SparkUI = { + appSparkVersion: String = org.apache.spark.SPARK_VERSION): SparkUI = { - val _jobProgressListener: JobProgressListener = jobProgressListener.getOrElse { + val jobProgressListener = sc.map(_.jobProgressListener).getOrElse { val listener = new JobProgressListener(conf) addListenerFn(listener) listener } - val environmentListener = new EnvironmentListener val storageStatusListener = new StorageStatusListener(conf) val executorsListener = new ExecutorsListener(storageStatusListener, conf) @@ -228,8 +196,9 @@ private[spark] object SparkUI { addListenerFn(storageListener) addListenerFn(operationGraphListener) - new SparkUI(sc, conf, securityManager, environmentListener, storageStatusListener, - executorsListener, _jobProgressListener, storageListener, operationGraphListener, - appName, basePath, lastUpdateTime, startTime) + new SparkUI(store, sc, conf, securityManager, environmentListener, storageStatusListener, + executorsListener, jobProgressListener, storageListener, operationGraphListener, + appName, basePath, lastUpdateTime, startTime, appSparkVersion) } + } diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala index d63381c78bc3b..7b2767f0be3cd 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsPage.scala @@ -82,7 +82,7 @@ private[ui] class ExecutorsPage(
++ - ++ +
++ ++ ++ diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 8406826a228db..5e60218c5740b 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -98,8 +98,8 @@ private[spark] object JsonProtocol { logStartToJson(logStart) case metricsUpdate: SparkListenerExecutorMetricsUpdate => executorMetricsUpdateToJson(metricsUpdate) - case blockUpdated: SparkListenerBlockUpdated => - throw new MatchError(blockUpdated) // TODO(ekl) implement this + case blockUpdate: SparkListenerBlockUpdated => + blockUpdateToJson(blockUpdate) case _ => parse(mapper.writeValueAsString(event)) } } @@ -246,6 +246,12 @@ private[spark] object JsonProtocol { }) } + def blockUpdateToJson(blockUpdate: SparkListenerBlockUpdated): JValue = { + val blockUpdatedInfo = blockUpdatedInfoToJson(blockUpdate.blockUpdatedInfo) + ("Event" -> SPARK_LISTENER_EVENT_FORMATTED_CLASS_NAMES.blockUpdate) ~ + ("Block Updated Info" -> blockUpdatedInfo) + } + /** ------------------------------------------------------------------- * * JSON serialization methods for classes SparkListenerEvents depend on | * -------------------------------------------------------------------- */ @@ -458,6 +464,14 @@ private[spark] object JsonProtocol { ("Log Urls" -> mapToJson(executorInfo.logUrlMap)) } + def blockUpdatedInfoToJson(blockUpdatedInfo: BlockUpdatedInfo): JValue = { + ("Block Manager ID" -> blockManagerIdToJson(blockUpdatedInfo.blockManagerId)) ~ + ("Block ID" -> blockUpdatedInfo.blockId.toString) ~ + ("Storage Level" -> storageLevelToJson(blockUpdatedInfo.storageLevel)) ~ + ("Memory Size" -> blockUpdatedInfo.memSize) ~ + ("Disk Size" -> blockUpdatedInfo.diskSize) + } + /** ------------------------------ * * Util JSON serialization methods | * ------------------------------- */ @@ -515,6 +529,7 @@ private[spark] object JsonProtocol { val executorRemoved = Utils.getFormattedClassName(SparkListenerExecutorRemoved) val logStart = Utils.getFormattedClassName(SparkListenerLogStart) val metricsUpdate = Utils.getFormattedClassName(SparkListenerExecutorMetricsUpdate) + val blockUpdate = Utils.getFormattedClassName(SparkListenerBlockUpdated) } def sparkEventFromJson(json: JValue): SparkListenerEvent = { @@ -538,6 +553,7 @@ private[spark] object JsonProtocol { case `executorRemoved` => executorRemovedFromJson(json) case `logStart` => logStartFromJson(json) case `metricsUpdate` => executorMetricsUpdateFromJson(json) + case `blockUpdate` => blockUpdateFromJson(json) case other => mapper.readValue(compact(render(json)), Utils.classForName(other)) .asInstanceOf[SparkListenerEvent] } @@ -676,6 +692,11 @@ private[spark] object JsonProtocol { SparkListenerExecutorMetricsUpdate(execInfo, accumUpdates) } + def blockUpdateFromJson(json: JValue): SparkListenerBlockUpdated = { + val blockUpdatedInfo = blockUpdatedInfoFromJson(json \ "Block Updated Info") + SparkListenerBlockUpdated(blockUpdatedInfo) + } + /** --------------------------------------------------------------------- * * JSON deserialization methods for classes SparkListenerEvents depend on | * ---------------------------------------------------------------------- */ @@ -989,6 +1010,15 @@ private[spark] object JsonProtocol { new ExecutorInfo(executorHost, totalCores, logUrls) } + def blockUpdatedInfoFromJson(json: JValue): BlockUpdatedInfo = { + val blockManagerId = blockManagerIdFromJson(json \ "Block Manager ID") + val blockId = BlockId((json \ "Block ID").extract[String]) + val storageLevel = storageLevelFromJson(json \ "Storage Level") + val memorySize = (json \ "Memory Size").extract[Long] + val diskSize = (json \ "Disk Size").extract[Long] + BlockUpdatedInfo(blockManagerId, blockId, storageLevel, memorySize, diskSize) + } + /** -------------------------------- * * Util JSON deserialization methods | * --------------------------------- */ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 836e33c36d9a1..930e09d90c2f5 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -19,6 +19,7 @@ package org.apache.spark.util import java.io._ import java.lang.management.{LockInfo, ManagementFactory, MonitorInfo, ThreadInfo} +import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer @@ -37,7 +38,7 @@ import scala.collection.Map import scala.collection.mutable.ArrayBuffer import scala.io.Source import scala.reflect.ClassTag -import scala.util.Try +import scala.util.{Failure, Success, Try} import scala.util.control.{ControlThrowable, NonFatal} import scala.util.matching.Regex @@ -2687,6 +2688,60 @@ private[spark] object Utils extends Logging { def stringToSeq(str: String): Seq[String] = { str.split(",").map(_.trim()).filter(_.nonEmpty) } + + /** + * Create instances of extension classes. + * + * The classes in the given list must: + * - Be sub-classes of the given base class. + * - Provide either a no-arg constructor, or a 1-arg constructor that takes a SparkConf. + * + * The constructors are allowed to throw "UnsupportedOperationException" if the extension does not + * want to be registered; this allows the implementations to check the Spark configuration (or + * other state) and decide they do not need to be added. A log message is printed in that case. + * Other exceptions are bubbled up. + */ + def loadExtensions[T](extClass: Class[T], classes: Seq[String], conf: SparkConf): Seq[T] = { + classes.flatMap { name => + try { + val klass = classForName(name) + require(extClass.isAssignableFrom(klass), + s"$name is not a subclass of ${extClass.getName()}.") + + val ext = Try(klass.getConstructor(classOf[SparkConf])) match { + case Success(ctor) => + ctor.newInstance(conf) + + case Failure(_) => + klass.getConstructor().newInstance() + } + + Some(ext.asInstanceOf[T]) + } catch { + case _: NoSuchMethodException => + throw new SparkException( + s"$name did not have a zero-argument constructor or a" + + " single-argument constructor that accepts SparkConf. Note: if the class is" + + " defined inside of another Scala class, then its constructors may accept an" + + " implicit parameter that references the enclosing class; in this case, you must" + + " define the class as a top-level class in order to prevent this extra" + + " parameter from breaking Spark's ability to find a valid constructor.") + + case e: InvocationTargetException => + e.getCause() match { + case uoe: UnsupportedOperationException => + logDebug(s"Extension $name not being initialized.", uoe) + logInfo(s"Extension $name not being initialized.") + None + + case null => throw e + + case cause => throw cause + } + } + } + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala index b755e5da51684..e17a9de97e335 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PartitionedPairBuffer.scala @@ -19,6 +19,8 @@ package org.apache.spark.util.collection import java.util.Comparator +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.collection.WritablePartitionedPairCollection._ /** @@ -96,7 +98,5 @@ private[spark] class PartitionedPairBuffer[K, V](initialCapacity: Int = 64) } private object PartitionedPairBuffer { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - val MAXIMUM_CAPACITY: Int = (Int.MaxValue - 8) / 2 + val MAXIMUM_CAPACITY: Int = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH / 2 } diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index 5330a688e63e3..d0d0334add0bf 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -23,6 +23,7 @@ import java.util.LinkedList; import java.util.UUID; +import org.hamcrest.Matchers; import scala.Tuple2$; import org.junit.After; @@ -503,6 +504,41 @@ public void testGetIterator() throws Exception { verifyIntIterator(sorter.getIterator(279), 279, 300); } + @Test + public void testOOMDuringSpill() throws Exception { + final UnsafeExternalSorter sorter = newSorter(); + // we assume that given default configuration, + // the size of the data we insert to the sorter (ints) + // and assuming we shouldn't spill before pointers array is exhausted + // (memory manager is not configured to throw at this point) + // - so this loop runs a reasonable number of iterations (<2000). + // test indeed completed within <30ms (on a quad i7 laptop). + for (int i = 0; sorter.hasSpaceForAnotherRecord(); ++i) { + insertNumber(sorter, i); + } + // we expect the next insert to attempt growing the pointerssArray first + // allocation is expected to fail, then a spill is triggered which + // attempts another allocation which also fails and we expect to see this + // OOM here. the original code messed with a released array within the + // spill code and ended up with a failed assertion. we also expect the + // location of the OOM to be + // org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset + memoryManager.markconsequentOOM(2); + try { + insertNumber(sorter, 1024); + fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); + } + // we expect an OutOfMemoryError here, anything else (i.e the original NPE is a failure) + catch (OutOfMemoryError oom){ + String oomStackTrace = Utils.exceptionString(oom); + assertThat("expected OutOfMemoryError in " + + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset", + oomStackTrace, + Matchers.containsString( + "org.apache.spark.util.collection.unsafe.sort.UnsafeInMemorySorter.reset")); + } + } + private void verifyIntIterator(UnsafeSorterIterator iter, int start, int end) throws IOException { for (int i = start; i < end; i++) { diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java index bd89085aa9a14..594f07dd780f9 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeInMemorySorterSuite.java @@ -35,6 +35,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.isIn; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; public class UnsafeInMemorySorterSuite { @@ -139,4 +140,50 @@ public int compare( } assertEquals(dataToSort.length, iterLength); } + + @Test + public void freeAfterOOM() { + final SparkConf sparkConf = new SparkConf(); + sparkConf.set("spark.memory.offHeap.enabled", "false"); + + final TestMemoryManager testMemoryManager = + new TestMemoryManager(sparkConf); + final TaskMemoryManager memoryManager = new TaskMemoryManager( + testMemoryManager, 0); + final TestMemoryConsumer consumer = new TestMemoryConsumer(memoryManager); + final MemoryBlock dataPage = memoryManager.allocatePage(2048, consumer); + final Object baseObject = dataPage.getBaseObject(); + // Write the records into the data page: + long position = dataPage.getBaseOffset(); + + final HashPartitioner hashPartitioner = new HashPartitioner(4); + // Use integer comparison for comparing prefixes (which are partition ids, in this case) + final PrefixComparator prefixComparator = PrefixComparators.LONG; + final RecordComparator recordComparator = new RecordComparator() { + @Override + public int compare( + Object leftBaseObject, + long leftBaseOffset, + Object rightBaseObject, + long rightBaseOffset) { + return 0; + } + }; + UnsafeInMemorySorter sorter = new UnsafeInMemorySorter(consumer, memoryManager, + recordComparator, prefixComparator, 100, shouldUseRadixSort()); + + testMemoryManager.markExecutionAsOutOfMemoryOnce(); + try { + sorter.reset(); + fail("expected OutOfMmoryError but it seems operation surprisingly succeeded"); + } catch (OutOfMemoryError oom) { + // as expected + } + // [SPARK-21907] this failed on NPE at + // org.apache.spark.memory.MemoryConsumer.freeArray(MemoryConsumer.java:108) + sorter.free(); + // simulate a 'back to back' free. + sorter.free(); + } + } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index bea67b71a5a12..f8005610f7e4f 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -171,7 +171,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex val serializerManager = SparkEnv.get.serializerManager blockManager.master.getLocations(blockId).foreach { cmId => val bytes = blockTransfer.fetchBlockSync(cmId.host, cmId.port, cmId.executorId, - blockId.toString) + blockId.toString, null) val deserialized = serializerManager.dataDeserializeStream(blockId, new ChunkedByteBuffer(bytes.nioByteBuffer()).toInputStream())(data.elementClassTag).toList assert(deserialized === (1 to 100).toList) diff --git a/core/src/test/scala/org/apache/spark/FileSuite.scala b/core/src/test/scala/org/apache/spark/FileSuite.scala index 02728180ac82d..e9539dc73f6fa 100644 --- a/core/src/test/scala/org/apache/spark/FileSuite.scala +++ b/core/src/test/scala/org/apache/spark/FileSuite.scala @@ -31,7 +31,7 @@ import org.apache.hadoop.mapreduce.Job import org.apache.hadoop.mapreduce.lib.input.{FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} -import org.apache.spark.internal.config.IGNORE_CORRUPT_FILES +import org.apache.spark.internal.config._ import org.apache.spark.rdd.{HadoopRDD, NewHadoopRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -347,10 +347,10 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } - test ("allow user to disable the output directory existence checking (old Hadoop API") { - val sf = new SparkConf() - sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") - sc = new SparkContext(sf) + test ("allow user to disable the output directory existence checking (old Hadoop API)") { + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") + sc = new SparkContext(conf) val randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1) randomRDD.saveAsTextFile(tempDir.getPath + "/output") assert(new File(tempDir.getPath + "/output/part-00000").exists() === true) @@ -380,9 +380,9 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } test ("allow user to disable the output directory existence checking (new Hadoop API") { - val sf = new SparkConf() - sf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") - sc = new SparkContext(sf) + val conf = new SparkConf() + conf.setAppName("test").setMaster("local").set("spark.hadoop.validateOutputSpecs", "false") + sc = new SparkContext(conf) val randomRDD = sc.parallelize( Array(("key1", "a"), ("key2", "a"), ("key3", "b"), ("key4", "c")), 1) randomRDD.saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]]( @@ -510,4 +510,87 @@ class FileSuite extends SparkFunSuite with LocalSparkContext { } } + test("spark.hadoopRDD.ignoreEmptySplits work correctly (old Hadoop API)") { + val conf = new SparkConf() + .setAppName("test") + .setMaster("local") + .set(HADOOP_RDD_IGNORE_EMPTY_SPLITS, true) + sc = new SparkContext(conf) + + def testIgnoreEmptySplits( + data: Array[Tuple2[String, String]], + actualPartitionNum: Int, + expectedPartitionNum: Int): Unit = { + val output = new File(tempDir, "output") + sc.parallelize(data, actualPartitionNum) + .saveAsHadoopFile[TextOutputFormat[String, String]](output.getPath) + for (i <- 0 until actualPartitionNum) { + assert(new File(output, s"part-0000$i").exists() === true) + } + val hadoopRDD = sc.textFile(new File(output, "part-*").getPath) + assert(hadoopRDD.partitions.length === expectedPartitionNum) + Utils.deleteRecursively(output) + } + + // Ensure that if all of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array.empty[Tuple2[String, String]], + actualPartitionNum = 1, + expectedPartitionNum = 0) + + // Ensure that if no split is empty, we don't lose any splits + testIgnoreEmptySplits( + data = Array(("key1", "a"), ("key2", "a"), ("key3", "b")), + actualPartitionNum = 2, + expectedPartitionNum = 2) + + // Ensure that if part of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array(("key1", "a"), ("key2", "a")), + actualPartitionNum = 5, + expectedPartitionNum = 2) + } + + test("spark.hadoopRDD.ignoreEmptySplits work correctly (new Hadoop API)") { + val conf = new SparkConf() + .setAppName("test") + .setMaster("local") + .set(HADOOP_RDD_IGNORE_EMPTY_SPLITS, true) + sc = new SparkContext(conf) + + def testIgnoreEmptySplits( + data: Array[Tuple2[String, String]], + actualPartitionNum: Int, + expectedPartitionNum: Int): Unit = { + val output = new File(tempDir, "output") + sc.parallelize(data, actualPartitionNum) + .saveAsNewAPIHadoopFile[NewTextOutputFormat[String, String]](output.getPath) + for (i <- 0 until actualPartitionNum) { + assert(new File(output, s"part-r-0000$i").exists() === true) + } + val hadoopRDD = sc.newAPIHadoopFile(new File(output, "part-r-*").getPath, + classOf[NewTextInputFormat], classOf[LongWritable], classOf[Text]) + .asInstanceOf[NewHadoopRDD[_, _]] + assert(hadoopRDD.partitions.length === expectedPartitionNum) + Utils.deleteRecursively(output) + } + + // Ensure that if all of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array.empty[Tuple2[String, String]], + actualPartitionNum = 1, + expectedPartitionNum = 0) + + // Ensure that if no split is empty, we don't lose any splits + testIgnoreEmptySplits( + data = Array(("1", "a"), ("2", "a"), ("3", "b")), + actualPartitionNum = 2, + expectedPartitionNum = 2) + + // Ensure that if part of the splits are empty, we remove the splits correctly + testIgnoreEmptySplits( + data = Array(("1", "a"), ("2", "b")), + actualPartitionNum = 5, + expectedPartitionNum = 2) + } } diff --git a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala index 6aedcb1271ff6..1aa1c421d792e 100644 --- a/core/src/test/scala/org/apache/spark/SharedSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/SharedSparkContext.scala @@ -29,10 +29,23 @@ trait SharedSparkContext extends BeforeAndAfterAll with BeforeAndAfterEach { sel var conf = new SparkConf(false) + /** + * Initialize the [[SparkContext]]. Generally, this is just called from beforeAll; however, in + * test using styles other than FunSuite, there is often code that relies on the session between + * test group constructs and the actual tests, which may need this session. It is purely a + * semantic difference, but semantically, it makes more sense to call 'initializeContext' between + * a 'describe' and an 'it' call than it does to call 'beforeAll'. + */ + protected def initializeContext(): Unit = { + if (null == _sc) { + _sc = new SparkContext( + "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + } + } + override def beforeAll() { super.beforeAll() - _sc = new SparkContext( - "local[4]", "test", conf.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)) + initializeContext() } override def afterAll() { diff --git a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala index ad801bf8519a6..cfbf56fb8c369 100644 --- a/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/SparkSubmitSuite.scala @@ -100,6 +100,8 @@ class SparkSubmitSuite with TimeLimits with TestPrematureExit { + import SparkSubmitSuite._ + override def beforeEach() { super.beforeEach() System.setProperty("spark.testing", "true") @@ -174,10 +176,10 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, _) = prepareSubmitEnvironment(appArgs) + val (_, _, conf, _) = prepareSubmitEnvironment(appArgs) appArgs.deployMode should be ("client") - sysProps("spark.submit.deployMode") should be ("client") + conf.get("spark.submit.deployMode") should be ("client") // Both cmd line and configuration are specified, cmdline option takes the priority val clArgs1 = Seq( @@ -188,10 +190,10 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs1 = new SparkSubmitArguments(clArgs1) - val (_, _, sysProps1, _) = prepareSubmitEnvironment(appArgs1) + val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) appArgs1.deployMode should be ("cluster") - sysProps1("spark.submit.deployMode") should be ("cluster") + conf1.get("spark.submit.deployMode") should be ("cluster") // Neither cmdline nor configuration are specified, client mode is the default choice val clArgs2 = Seq( @@ -202,9 +204,9 @@ class SparkSubmitSuite val appArgs2 = new SparkSubmitArguments(clArgs2) appArgs2.deployMode should be (null) - val (_, _, sysProps2, _) = prepareSubmitEnvironment(appArgs2) + val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) appArgs2.deployMode should be ("client") - sysProps2("spark.submit.deployMode") should be ("client") + conf2.get("spark.submit.deployMode") should be ("client") } test("handles YARN cluster mode") { @@ -225,7 +227,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") childArgsStr should include ("--class org.SomeClass") childArgsStr should include ("--arg arg1 --arg arg2") @@ -238,16 +240,16 @@ class SparkSubmitSuite classpath(2) should endWith ("two.jar") classpath(3) should endWith ("three.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.driver.memory") should be ("4g") - sysProps("spark.executor.cores") should be ("5") - sysProps("spark.yarn.queue") should be ("thequeue") - sysProps("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar") - sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") - sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") - sysProps("spark.app.name") should be ("beauty") - sysProps("spark.ui.enabled") should be ("false") - sysProps("SPARK_SUBMIT") should be ("true") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.driver.memory") should be ("4g") + conf.get("spark.executor.cores") should be ("5") + conf.get("spark.yarn.queue") should be ("thequeue") + conf.get("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar") + conf.get("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") + conf.get("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") + conf.get("spark.app.name") should be ("beauty") + conf.get("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") } test("handles YARN client mode") { @@ -268,7 +270,7 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (4) @@ -276,17 +278,17 @@ class SparkSubmitSuite classpath(1) should endWith ("one.jar") classpath(2) should endWith ("two.jar") classpath(3) should endWith ("three.jar") - sysProps("spark.app.name") should be ("trill") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.executor.cores") should be ("5") - sysProps("spark.yarn.queue") should be ("thequeue") - sysProps("spark.executor.instances") should be ("6") - sysProps("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") - sysProps("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") - sysProps("spark.yarn.dist.jars") should include + conf.get("spark.app.name") should be ("trill") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.executor.cores") should be ("5") + conf.get("spark.yarn.queue") should be ("thequeue") + conf.get("spark.executor.instances") should be ("6") + conf.get("spark.yarn.dist.files") should include regex (".*file1.txt,.*file2.txt") + conf.get("spark.yarn.dist.archives") should include regex (".*archive1.txt,.*archive2.txt") + conf.get("spark.yarn.dist.jars") should include regex (".*one.jar,.*two.jar,.*three.jar,.*thejar.jar") - sysProps("SPARK_SUBMIT") should be ("true") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") } test("handles standalone cluster mode") { @@ -314,7 +316,7 @@ class SparkSubmitSuite "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) appArgs.useRest = useRest - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) val childArgsStr = childArgs.mkString(" ") if (useRest) { childArgsStr should endWith ("thejar.jar org.SomeClass arg1 arg2") @@ -325,17 +327,18 @@ class SparkSubmitSuite mainClass should be ("org.apache.spark.deploy.Client") } classpath should have size 0 - sysProps should have size 9 - sysProps.keys should contain ("SPARK_SUBMIT") - sysProps.keys should contain ("spark.master") - sysProps.keys should contain ("spark.app.name") - sysProps.keys should contain ("spark.jars") - sysProps.keys should contain ("spark.driver.memory") - sysProps.keys should contain ("spark.driver.cores") - sysProps.keys should contain ("spark.driver.supervise") - sysProps.keys should contain ("spark.ui.enabled") - sysProps.keys should contain ("spark.submit.deployMode") - sysProps("spark.ui.enabled") should be ("false") + sys.props("SPARK_SUBMIT") should be ("true") + + val confMap = conf.getAll.toMap + confMap.keys should contain ("spark.master") + confMap.keys should contain ("spark.app.name") + confMap.keys should contain ("spark.jars") + confMap.keys should contain ("spark.driver.memory") + confMap.keys should contain ("spark.driver.cores") + confMap.keys should contain ("spark.driver.supervise") + confMap.keys should contain ("spark.ui.enabled") + confMap.keys should contain ("spark.submit.deployMode") + conf.get("spark.ui.enabled") should be ("false") } test("handles standalone client mode") { @@ -350,14 +353,14 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) classpath(0) should endWith ("thejar.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.cores.max") should be ("5") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.cores.max") should be ("5") + conf.get("spark.ui.enabled") should be ("false") } test("handles mesos client mode") { @@ -372,14 +375,14 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (childArgs, classpath, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) + val (childArgs, classpath, conf, mainClass) = prepareSubmitEnvironment(appArgs) childArgs.mkString(" ") should be ("arg1 arg2") mainClass should be ("org.SomeClass") classpath should have length (1) classpath(0) should endWith ("thejar.jar") - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.cores.max") should be ("5") - sysProps("spark.ui.enabled") should be ("false") + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.cores.max") should be ("5") + conf.get("spark.ui.enabled") should be ("false") } test("handles confs with flag equivalents") { @@ -392,13 +395,28 @@ class SparkSubmitSuite "thejar.jar", "arg1", "arg2") val appArgs = new SparkSubmitArguments(clArgs) - val (_, _, sysProps, mainClass) = prepareSubmitEnvironment(appArgs) - sysProps("spark.executor.memory") should be ("5g") - sysProps("spark.master") should be ("yarn") - sysProps("spark.submit.deployMode") should be ("cluster") + val (_, _, conf, mainClass) = prepareSubmitEnvironment(appArgs) + conf.get("spark.executor.memory") should be ("5g") + conf.get("spark.master") should be ("yarn") + conf.get("spark.submit.deployMode") should be ("cluster") mainClass should be ("org.apache.spark.deploy.yarn.Client") } + test("SPARK-21568 ConsoleProgressBar should be enabled only in shells") { + // Unset from system properties since this config is defined in the root pom's test config. + sys.props -= UI_SHOW_CONSOLE_PROGRESS.key + + val clArgs1 = Seq("--class", "org.apache.spark.repl.Main", "spark-shell") + val appArgs1 = new SparkSubmitArguments(clArgs1) + val (_, _, conf1, _) = prepareSubmitEnvironment(appArgs1) + conf1.get(UI_SHOW_CONSOLE_PROGRESS) should be (true) + + val clArgs2 = Seq("--class", "org.SomeClass", "thejar.jar") + val appArgs2 = new SparkSubmitArguments(clArgs2) + val (_, _, conf2, _) = prepareSubmitEnvironment(appArgs2) + assert(!conf2.contains(UI_SHOW_CONSOLE_PROGRESS)) + } + test("launch simple application with spark-submit") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val args = Seq( @@ -571,11 +589,11 @@ class SparkSubmitSuite "--files", files, "thejar.jar") val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) appArgs.jars should be (Utils.resolveURIs(jars)) appArgs.files should be (Utils.resolveURIs(files)) - sysProps("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) - sysProps("spark.files") should be (Utils.resolveURIs(files)) + conf.get("spark.jars") should be (Utils.resolveURIs(jars + ",thejar.jar")) + conf.get("spark.files") should be (Utils.resolveURIs(files)) // Test files and archives (Yarn) val clArgs2 = Seq( @@ -586,11 +604,11 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 + val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) appArgs2.files should be (Utils.resolveURIs(files)) appArgs2.archives should be (Utils.resolveURIs(archives)) - sysProps2("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) - sysProps2("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) + conf2.get("spark.yarn.dist.files") should be (Utils.resolveURIs(files)) + conf2.get("spark.yarn.dist.archives") should be (Utils.resolveURIs(archives)) // Test python files val clArgs3 = Seq( @@ -601,12 +619,12 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 + val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) appArgs3.pyFiles should be (Utils.resolveURIs(pyFiles)) - sysProps3("spark.submit.pyFiles") should be ( + conf3.get("spark.submit.pyFiles") should be ( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) - sysProps3(PYSPARK_DRIVER_PYTHON.key) should be ("python3.4") - sysProps3(PYSPARK_PYTHON.key) should be ("python3.5") + conf3.get(PYSPARK_DRIVER_PYTHON.key) should be ("python3.4") + conf3.get(PYSPARK_PYTHON.key) should be ("python3.5") } test("resolves config paths correctly") { @@ -630,9 +648,9 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs = new SparkSubmitArguments(clArgs) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 - sysProps("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) - sysProps("spark.files") should be(Utils.resolveURIs(files)) + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + conf.get("spark.jars") should be(Utils.resolveURIs(jars + ",thejar.jar")) + conf.get("spark.files") should be(Utils.resolveURIs(files)) // Test files and archives (Yarn) val f2 = File.createTempFile("test-submit-files-archives", "", tmpDir) @@ -647,9 +665,9 @@ class SparkSubmitSuite "thejar.jar" ) val appArgs2 = new SparkSubmitArguments(clArgs2) - val sysProps2 = SparkSubmit.prepareSubmitEnvironment(appArgs2)._3 - sysProps2("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) - sysProps2("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) + val (_, _, conf2, _) = SparkSubmit.prepareSubmitEnvironment(appArgs2) + conf2.get("spark.yarn.dist.files") should be(Utils.resolveURIs(files)) + conf2.get("spark.yarn.dist.archives") should be(Utils.resolveURIs(archives)) // Test python files val f3 = File.createTempFile("test-submit-python-files", "", tmpDir) @@ -662,8 +680,8 @@ class SparkSubmitSuite "mister.py" ) val appArgs3 = new SparkSubmitArguments(clArgs3) - val sysProps3 = SparkSubmit.prepareSubmitEnvironment(appArgs3)._3 - sysProps3("spark.submit.pyFiles") should be( + val (_, _, conf3, _) = SparkSubmit.prepareSubmitEnvironment(appArgs3) + conf3.get("spark.submit.pyFiles") should be( PythonRunner.formatPaths(Utils.resolveURIs(pyFiles)).mkString(",")) // Test remote python files @@ -679,11 +697,9 @@ class SparkSubmitSuite "hdfs:///tmp/mister.py" ) val appArgs4 = new SparkSubmitArguments(clArgs4) - val sysProps4 = SparkSubmit.prepareSubmitEnvironment(appArgs4)._3 + val (_, _, conf4, _) = SparkSubmit.prepareSubmitEnvironment(appArgs4) // Should not format python path for yarn cluster mode - sysProps4("spark.submit.pyFiles") should be( - Utils.resolveURIs(remotePyFiles) - ) + conf4.get("spark.submit.pyFiles") should be(Utils.resolveURIs(remotePyFiles)) } test("user classpath first in driver") { @@ -757,14 +773,14 @@ class SparkSubmitSuite jar2.toString) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs)._3 - sysProps("spark.yarn.dist.jars").split(",").toSet should be + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs) + conf.get("spark.yarn.dist.jars").split(",").toSet should be (Set(jar1.toURI.toString, jar2.toURI.toString)) - sysProps("spark.yarn.dist.files").split(",").toSet should be + conf.get("spark.yarn.dist.files").split(",").toSet should be (Set(file1.toURI.toString, file2.toURI.toString)) - sysProps("spark.yarn.dist.pyFiles").split(",").toSet should be + conf.get("spark.yarn.dist.pyFiles").split(",").toSet should be (Set(pyFile1.getAbsolutePath, pyFile2.getAbsolutePath)) - sysProps("spark.yarn.dist.archives").split(",").toSet should be + conf.get("spark.yarn.dist.archives").split(",").toSet should be (Set(archive1.toURI.toString, archive2.toURI.toString)) } @@ -883,18 +899,18 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) // All the resources should still be remote paths, so that YARN client will not upload again. - sysProps("spark.yarn.dist.jars") should be (tmpJarPath) - sysProps("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") - sysProps("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") + conf.get("spark.yarn.dist.jars") should be (tmpJarPath) + conf.get("spark.yarn.dist.files") should be (s"s3a://${file.getAbsolutePath}") + conf.get("spark.yarn.dist.pyFiles") should be (s"s3a://${pyFile.getAbsolutePath}") // Local repl jars should be a local path. - sysProps("spark.repl.local.jars") should (startWith("file:")) + conf.get("spark.repl.local.jars") should (startWith("file:")) // local py files should not be a URI format. - sysProps("spark.submit.pyFiles") should (startWith("/")) + conf.get("spark.submit.pyFiles") should (startWith("/")) } test("download remote resource if it is not supported by yarn service") { @@ -941,9 +957,9 @@ class SparkSubmitSuite ) val appArgs = new SparkSubmitArguments(args) - val sysProps = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf))._3 + val (_, _, conf, _) = SparkSubmit.prepareSubmitEnvironment(appArgs, Some(hadoopConf)) - val jars = sysProps("spark.yarn.dist.jars").split(",").toSet + val jars = conf.get("spark.yarn.dist.jars").split(",").toSet // The URI of remote S3 resource should still be remote. assert(jars.contains(tmpS3JarPath)) @@ -962,30 +978,6 @@ class SparkSubmitSuite } } - // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. - private def runSparkSubmit(args: Seq[String]): Unit = { - val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) - val sparkSubmitFile = if (Utils.isWindows) { - new File("..\\bin\\spark-submit.cmd") - } else { - new File("../bin/spark-submit") - } - val process = Utils.executeCommand( - Seq(sparkSubmitFile.getCanonicalPath) ++ args, - new File(sparkHome), - Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) - - try { - val exitCode = failAfter(60 seconds) { process.waitFor() } - if (exitCode != 0) { - fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") - } - } finally { - // Ensure we still kill the process in case it timed out - process.destroy() - } - } - private def forConfDir(defaults: Map[String, String]) (f: String => Unit) = { val tmpDir = Utils.createTempDir() @@ -1006,6 +998,47 @@ class SparkSubmitSuite conf.set("fs.s3a.impl", classOf[TestFileSystem].getCanonicalName) conf.set("fs.s3a.impl.disable.cache", "true") } + + test("start SparkApplication without modifying system properties") { + val args = Array( + "--class", classOf[TestSparkApplication].getName(), + "--master", "local", + "--conf", "spark.test.hello=world", + "spark-internal", + "hello") + + val exception = intercept[SparkException] { + SparkSubmit.main(args) + } + + assert(exception.getMessage() === "hello") + } +} + +object SparkSubmitSuite extends SparkFunSuite with TimeLimits { + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. + def runSparkSubmit(args: Seq[String], root: String = ".."): Unit = { + val sparkHome = sys.props.getOrElse("spark.test.home", fail("spark.test.home is not set!")) + val sparkSubmitFile = if (Utils.isWindows) { + new File(s"$root\\bin\\spark-submit.cmd") + } else { + new File(s"$root/bin/spark-submit") + } + val process = Utils.executeCommand( + Seq(sparkSubmitFile.getCanonicalPath) ++ args, + new File(sparkHome), + Map("SPARK_TESTING" -> "1", "SPARK_HOME" -> sparkHome)) + + try { + val exitCode = failAfter(60 seconds) { process.waitFor() } + if (exitCode != 0) { + fail(s"Process returned with exit code $exitCode. See the log4j logs for more detail.") + } + } finally { + // Ensure we still kill the process in case it timed out + process.destroy() + } + } } object JarCreationTest extends Logging { @@ -1099,3 +1132,17 @@ class TestFileSystem extends org.apache.hadoop.fs.LocalFileSystem { override def open(path: Path): FSDataInputStream = super.open(local(path)) } + +class TestSparkApplication extends SparkApplication with Matchers { + + override def start(args: Array[String], conf: SparkConf): Unit = { + assert(args.size === 1) + assert(args(0) === "hello") + assert(conf.get("spark.test.hello") === "world") + assert(sys.props.get("spark.test.hello") === None) + + // This is how the test verifies the application was actually run. + throw new SparkException(args(0)) + } + +} diff --git a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala index 6e50e84549047..44f9c566a380d 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/ApplicationCacheSuite.scala @@ -18,15 +18,11 @@ package org.apache.spark.deploy.history import java.util.{Date, NoSuchElementException} -import javax.servlet.Filter import javax.servlet.http.{HttpServletRequest, HttpServletResponse} import scala.collection.mutable -import scala.collection.mutable.ListBuffer import com.codahale.metrics.Counter -import com.google.common.cache.LoadingCache -import com.google.common.util.concurrent.UncheckedExecutionException import org.eclipse.jetty.servlet.ServletContextHandler import org.mockito.Matchers._ import org.mockito.Mockito._ @@ -39,23 +35,10 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.status.api.v1.{ApplicationAttemptInfo => AttemptInfo, ApplicationInfo} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{Clock, ManualClock, Utils} +import org.apache.spark.util.ManualClock class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar with Matchers { - /** - * subclass with access to the cache internals - * @param retainedApplications number of retained applications - */ - class TestApplicationCache( - operations: ApplicationCacheOperations = new StubCacheOperations(), - retainedApplications: Int, - clock: Clock = new ManualClock(0)) - extends ApplicationCache(operations, retainedApplications, clock) { - - def cache(): LoadingCache[CacheKey, CacheEntry] = appCache - } - /** * Stub cache operations. * The state is kept in a map of [[CacheKey]] to [[CacheEntry]], @@ -77,8 +60,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar override def getAppUI(appId: String, attemptId: Option[String]): Option[LoadedAppUI] = { logDebug(s"getAppUI($appId, $attemptId)") getAppUICount += 1 - instances.get(CacheKey(appId, attemptId)).map( e => - LoadedAppUI(e.ui, () => updateProbe(appId, attemptId, e.probeTime))) + instances.get(CacheKey(appId, attemptId)).map { e => e.loadedUI } } override def attachSparkUI( @@ -96,10 +78,9 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar attemptId: Option[String], completed: Boolean, started: Long, - ended: Long, - timestamp: Long): SparkUI = { - val ui = putAppUI(appId, attemptId, completed, started, ended, timestamp) - attachSparkUI(appId, attemptId, ui, completed) + ended: Long): LoadedAppUI = { + val ui = putAppUI(appId, attemptId, completed, started, ended) + attachSparkUI(appId, attemptId, ui.ui, completed) ui } @@ -108,23 +89,12 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar attemptId: Option[String], completed: Boolean, started: Long, - ended: Long, - timestamp: Long): SparkUI = { - val ui = newUI(appId, attemptId, completed, started, ended) - putInstance(appId, attemptId, ui, completed, timestamp) + ended: Long): LoadedAppUI = { + val ui = LoadedAppUI(newUI(appId, attemptId, completed, started, ended)) + instances(CacheKey(appId, attemptId)) = new CacheEntry(ui, completed) ui } - def putInstance( - appId: String, - attemptId: Option[String], - ui: SparkUI, - completed: Boolean, - timestamp: Long): Unit = { - instances += (CacheKey(appId, attemptId) -> - new CacheEntry(ui, completed, () => updateProbe(appId, attemptId, timestamp), timestamp)) - } - /** * Detach a reconstructed UI * @@ -146,23 +116,6 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar attached.get(CacheKey(appId, attemptId)) } - /** - * The update probe. - * @param appId application to probe - * @param attemptId attempt to probe - * @param updateTime timestamp of this UI load - */ - private[history] def updateProbe( - appId: String, - attemptId: Option[String], - updateTime: Long)(): Boolean = { - updateProbeCount += 1 - logDebug(s"isUpdated($appId, $attemptId, ${updateTime})") - val entry = instances.get(CacheKey(appId, attemptId)).get - val updated = entry.probeTime > updateTime - logDebug(s"entry = $entry; updated = $updated") - updated - } } /** @@ -210,15 +163,13 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val now = clock.getTimeMillis() // add the entry - operations.putAppUI(app1, None, true, now, now, now) + operations.putAppUI(app1, None, true, now, now) // make sure its local operations.getAppUI(app1, None).get operations.getAppUICount = 0 // now expect it to be found - val cacheEntry = cache.lookupCacheEntry(app1, None) - assert(1 === cacheEntry.probeTime) - assert(cacheEntry.completed) + cache.withSparkUI(app1, None) { _ => } // assert about queries made of the operations assert(1 === operations.getAppUICount, "getAppUICount") assert(1 === operations.attachCount, "attachCount") @@ -236,8 +187,8 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar assert(0 === operations.detachCount, "attachCount") // evict the entry - operations.putAndAttach("2", None, true, time2, time2, time2) - operations.putAndAttach("3", None, true, time2, time2, time2) + operations.putAndAttach("2", None, true, time2, time2) + operations.putAndAttach("3", None, true, time2, time2) cache.get("2") cache.get("3") @@ -248,7 +199,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val appId = "app1" val attemptId = Some("_01") val time3 = clock.getTimeMillis() - operations.putAppUI(appId, attemptId, false, time3, 0, time3) + operations.putAppUI(appId, attemptId, false, time3, 0) // expect an error here assertNotFound(appId, None) } @@ -256,10 +207,11 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar test("Test that if an attempt ID is set, it must be used in lookups") { val operations = new StubCacheOperations() val clock = new ManualClock(1) - implicit val cache = new ApplicationCache(operations, retainedApplications = 10, clock = clock) + implicit val cache = new ApplicationCache(operations, retainedApplications = 10, + clock = clock) val appId = "app1" val attemptId = Some("_01") - operations.putAppUI(appId, attemptId, false, clock.getTimeMillis(), 0, 0) + operations.putAppUI(appId, attemptId, false, clock.getTimeMillis(), 0) assertNotFound(appId, None) } @@ -271,50 +223,29 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar test("Incomplete apps refreshed") { val operations = new StubCacheOperations() val clock = new ManualClock(50) - val window = 500 - implicit val cache = new ApplicationCache(operations, retainedApplications = 5, clock = clock) + implicit val cache = new ApplicationCache(operations, 5, clock) val metrics = cache.metrics // add the incomplete app // add the entry val started = clock.getTimeMillis() val appId = "app1" val attemptId = Some("001") - operations.putAppUI(appId, attemptId, false, started, 0, started) - val firstEntry = cache.lookupCacheEntry(appId, attemptId) - assert(started === firstEntry.probeTime, s"timestamp in $firstEntry") - assert(!firstEntry.completed, s"entry is complete: $firstEntry") - assertMetric("lookupCount", metrics.lookupCount, 1) + val initialUI = operations.putAndAttach(appId, attemptId, false, started, 0) + val firstUI = cache.withSparkUI(appId, attemptId) { ui => ui } + assertMetric("lookupCount", metrics.lookupCount, 1) assert(0 === operations.updateProbeCount, "expected no update probe on that first get") - val checkTime = window * 2 - clock.setTime(checkTime) - val entry3 = cache.lookupCacheEntry(appId, attemptId) - assert(firstEntry !== entry3, s"updated entry test from $cache") + // Invalidate the first entry to trigger a re-load. + initialUI.invalidate() + + // Update the UI in the stub so that a new one is provided to the cache. + operations.putAppUI(appId, attemptId, true, started, started + 10) + + val updatedUI = cache.withSparkUI(appId, attemptId) { ui => ui } + assert(firstUI !== updatedUI, s"expected updated UI") assertMetric("lookupCount", metrics.lookupCount, 2) - assertMetric("updateProbeCount", metrics.updateProbeCount, 1) - assertMetric("updateTriggeredCount", metrics.updateTriggeredCount, 0) - assert(1 === operations.updateProbeCount, s"refresh count in $cache") - assert(0 === operations.detachCount, s"detach count") - assert(entry3.probeTime === checkTime) - - val updateTime = window * 3 - // update the cached value - val updatedApp = operations.putAppUI(appId, attemptId, true, started, updateTime, updateTime) - val endTime = window * 10 - clock.setTime(endTime) - logDebug(s"Before operation = $cache") - val entry5 = cache.lookupCacheEntry(appId, attemptId) - assertMetric("lookupCount", metrics.lookupCount, 3) - assertMetric("updateProbeCount", metrics.updateProbeCount, 2) - // the update was triggered - assertMetric("updateTriggeredCount", metrics.updateTriggeredCount, 1) - assert(updatedApp === entry5.ui, s"UI {$updatedApp} did not match entry {$entry5} in $cache") - - // at which point, the refreshes stop - clock.setTime(window * 20) - assertCacheEntryEquals(appId, attemptId, entry5) - assertMetric("updateProbeCount", metrics.updateProbeCount, 2) + assert(1 === operations.detachCount, s"detach count") } /** @@ -337,27 +268,6 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar } } - /** - * Look up the cache entry and assert that it matches in the expected value. - * This assertion works if the two CacheEntries are different -it looks at the fields. - * UI are compared on object equality; the timestamp and completed flags directly. - * @param appId application ID - * @param attemptId attempt ID - * @param expected expected value - * @param cache app cache - */ - def assertCacheEntryEquals( - appId: String, - attemptId: Option[String], - expected: CacheEntry) - (implicit cache: ApplicationCache): Unit = { - val actual = cache.lookupCacheEntry(appId, attemptId) - val errorText = s"Expected get($appId, $attemptId) -> $expected, but got $actual from $cache" - assert(expected.ui === actual.ui, errorText + " SparkUI reference") - assert(expected.completed === actual.completed, errorText + " -completed flag") - assert(expected.probeTime === actual.probeTime, errorText + " -timestamp") - } - /** * Assert that a key wasn't found in cache or loaded. * @@ -370,14 +280,9 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar appId: String, attemptId: Option[String]) (implicit cache: ApplicationCache): Unit = { - val ex = intercept[UncheckedExecutionException] { + val ex = intercept[NoSuchElementException] { cache.get(appId, attemptId) } - var cause = ex.getCause - assert(cause !== null) - if (!cause.isInstanceOf[NoSuchElementException]) { - throw cause - } } test("Large Scale Application Eviction") { @@ -385,12 +290,12 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar val clock = new ManualClock(0) val size = 5 // only two entries are retained, so we expect evictions to occur on lookups - implicit val cache: ApplicationCache = new TestApplicationCache(operations, - retainedApplications = size, clock = clock) + implicit val cache = new ApplicationCache(operations, retainedApplications = size, + clock = clock) val attempt1 = Some("01") - val ids = new ListBuffer[String]() + val ids = new mutable.ListBuffer[String]() // build a list of applications val count = 100 for (i <- 1 to count ) { @@ -398,7 +303,7 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar ids += appId clock.advance(10) val t = clock.getTimeMillis() - operations.putAppUI(appId, attempt1, true, t, t, t) + operations.putAppUI(appId, attempt1, true, t, t) } // now go through them in sequence reading them, expect evictions ids.foreach { id => @@ -413,20 +318,19 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar test("Attempts are Evicted") { val operations = new StubCacheOperations() - implicit val cache: ApplicationCache = new TestApplicationCache(operations, - retainedApplications = 4) + implicit val cache = new ApplicationCache(operations, 4, new ManualClock()) val metrics = cache.metrics val appId = "app1" val attempt1 = Some("01") val attempt2 = Some("02") val attempt3 = Some("03") - operations.putAppUI(appId, attempt1, true, 100, 110, 110) - operations.putAppUI(appId, attempt2, true, 200, 210, 210) - operations.putAppUI(appId, attempt3, true, 300, 310, 310) + operations.putAppUI(appId, attempt1, true, 100, 110) + operations.putAppUI(appId, attempt2, true, 200, 210) + operations.putAppUI(appId, attempt3, true, 300, 310) val attempt4 = Some("04") - operations.putAppUI(appId, attempt4, true, 400, 410, 410) + operations.putAppUI(appId, attempt4, true, 400, 410) val attempt5 = Some("05") - operations.putAppUI(appId, attempt5, true, 500, 510, 510) + operations.putAppUI(appId, attempt5, true, 500, 510) def expectLoadAndEvictionCounts(expectedLoad: Int, expectedEvictionCount: Int): Unit = { assertMetric("loadCount", metrics.loadCount, expectedLoad) @@ -457,20 +361,14 @@ class ApplicationCacheSuite extends SparkFunSuite with Logging with MockitoSugar } - test("Instantiate Filter") { - // this is a regression test on the filter being constructable - val clazz = Utils.classForName(ApplicationCacheCheckFilterRelay.FILTER_NAME) - val instance = clazz.newInstance() - instance shouldBe a [Filter] - } - test("redirect includes query params") { - val clazz = Utils.classForName(ApplicationCacheCheckFilterRelay.FILTER_NAME) - val filter = clazz.newInstance().asInstanceOf[ApplicationCacheCheckFilter] - filter.appId = "local-123" + val operations = new StubCacheOperations() + val ui = operations.putAndAttach("foo", None, true, 0, 10) val cache = mock[ApplicationCache] - when(cache.checkForUpdates(any(), any())).thenReturn(true) - ApplicationCacheCheckFilterRelay.setApplicationCache(cache) + when(cache.operations).thenReturn(operations) + val filter = new ApplicationCacheCheckFilter(new CacheKey("foo", None), ui, cache) + ui.invalidate() + val request = mock[HttpServletRequest] when(request.getMethod()).thenReturn("GET") when(request.getRequestURI()).thenReturn("http://localhost:18080/history/local-123/jobs/job/") diff --git a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala index 7109146ece371..86c8cdf43258c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/FsHistoryProviderSuite.scala @@ -36,10 +36,12 @@ import org.scalatest.Matchers import org.scalatest.concurrent.Eventually._ import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.history.config._ import org.apache.spark.internal.Logging import org.apache.spark.io._ import org.apache.spark.scheduler._ import org.apache.spark.security.GroupMappingServiceProvider +import org.apache.spark.status.AppStatusStore import org.apache.spark.util.{Clock, JsonProtocol, ManualClock, Utils} class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matchers with Logging { @@ -66,9 +68,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new File(logPath) } - test("Parse application logs") { + Seq(true, false).foreach { inMemory => + test(s"Parse application logs (inMemory = $inMemory)") { + testAppLogParsing(inMemory) + } + } + + private def testAppLogParsing(inMemory: Boolean) { val clock = new ManualClock(12345678) - val provider = new FsHistoryProvider(createTestConf(), clock) + val provider = new FsHistoryProvider(createTestConf(inMemory = inMemory), clock) // Write a new-style application log. val newAppComplete = newLogFile("new1", None, inProgress = false) @@ -172,20 +180,18 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc ) updateAndCheck(provider) { list => list.size should be (1) - list.head.attempts.head.asInstanceOf[FsApplicationAttemptInfo].logPath should - endWith(EventLoggingListener.IN_PROGRESS) + provider.getAttempt("app1", None).logPath should endWith(EventLoggingListener.IN_PROGRESS) } logFile1.renameTo(newLogFile("app1", None, inProgress = false)) updateAndCheck(provider) { list => list.size should be (1) - list.head.attempts.head.asInstanceOf[FsApplicationAttemptInfo].logPath should not - endWith(EventLoggingListener.IN_PROGRESS) + provider.getAttempt("app1", None).logPath should not endWith(EventLoggingListener.IN_PROGRESS) } } test("Parse logs that application is not started") { - val provider = new FsHistoryProvider((createTestConf())) + val provider = new FsHistoryProvider(createTestConf()) val logFile1 = newLogFile("app1", None, inProgress = true) writeFile(logFile1, true, None, @@ -342,17 +348,23 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc provider.checkForLogs() // This should not trigger any cleanup - updateAndCheck(provider)(list => list.size should be(2)) + updateAndCheck(provider) { list => + list.size should be(2) + } // Should trigger cleanup for first file but not second one clock.setTime(firstFileModifiedTime + maxAge + 1) - updateAndCheck(provider)(list => list.size should be(1)) + updateAndCheck(provider) { list => + list.size should be(1) + } assert(!log1.exists()) assert(log2.exists()) // Should cleanup the second file as well. clock.setTime(secondFileModifiedTime + maxAge + 1) - updateAndCheck(provider)(list => list.size should be(0)) + updateAndCheck(provider) { list => + list.size should be(0) + } assert(!log1.exists()) assert(!log2.exists()) } @@ -580,7 +592,71 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc securityManager.checkUIViewPermissions("user4") should be (false) securityManager.checkUIViewPermissions("user5") should be (false) } - } + } + + test("mismatched version discards old listing") { + val conf = createTestConf() + val oldProvider = new FsHistoryProvider(conf) + + val logFile1 = newLogFile("app1", None, inProgress = false) + writeFile(logFile1, true, None, + SparkListenerLogStart("2.3"), + SparkListenerApplicationStart("test", Some("test"), 1L, "test", None), + SparkListenerApplicationEnd(5L) + ) + + updateAndCheck(oldProvider) { list => + list.size should be (1) + } + assert(oldProvider.listing.count(classOf[ApplicationInfoWrapper]) === 1) + + // Manually overwrite the version in the listing db; this should cause the new provider to + // discard all data because the versions don't match. + val meta = new FsHistoryProviderMetadata(FsHistoryProvider.CURRENT_LISTING_VERSION + 1, + AppStatusStore.CURRENT_VERSION, conf.get(LOCAL_STORE_DIR).get) + oldProvider.listing.setMetadata(meta) + oldProvider.stop() + + val mistatchedVersionProvider = new FsHistoryProvider(conf) + assert(mistatchedVersionProvider.listing.count(classOf[ApplicationInfoWrapper]) === 0) + } + + test("invalidate cached UI") { + val provider = new FsHistoryProvider(createTestConf()) + val appId = "new1" + + // Write an incomplete app log. + val appLog = newLogFile(appId, None, inProgress = true) + writeFile(appLog, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None) + ) + provider.checkForLogs() + + // Load the app UI. + val oldUI = provider.getAppUI(appId, None) + assert(oldUI.isDefined) + intercept[NoSuchElementException] { + oldUI.get.ui.store.job(0) + } + + // Add more info to the app log, and trigger the provider to update things. + writeFile(appLog, true, None, + SparkListenerApplicationStart(appId, Some(appId), 1L, "test", None), + SparkListenerJobStart(0, 1L, Nil, null), + SparkListenerApplicationEnd(5L) + ) + provider.checkForLogs() + + // Manually detach the old UI; ApplicationCache would do this automatically in a real SHS + // when the app's UI was requested. + provider.onUIDetached(appId, None, oldUI.get.ui) + + // Load the UI again and make sure we can get the new info added to the logs. + val freshUI = provider.getAppUI(appId, None) + assert(freshUI.isDefined) + assert(freshUI != oldUI) + freshUI.get.ui.store.job(0) + } /** * Asks the provider to check for logs and calls a function to perform checks on the updated @@ -623,8 +699,15 @@ class FsHistoryProviderSuite extends SparkFunSuite with BeforeAndAfter with Matc new FileOutputStream(file).close() } - private def createTestConf(): SparkConf = { - new SparkConf().set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + private def createTestConf(inMemory: Boolean = false): SparkConf = { + val conf = new SparkConf() + .set("spark.history.fs.logDirectory", testDir.getAbsolutePath()) + + if (!inMemory) { + conf.set(LOCAL_STORE_DIR, Utils.createTempDir().getAbsolutePath()) + } + + conf } private class SafeModeTestProvider(conf: SparkConf, clock: Clock) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 18da8c18939ed..010a8dd004d4f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -43,6 +43,7 @@ import org.scalatest.mockito.MockitoSugar import org.scalatest.selenium.WebBrowser import org.apache.spark._ +import org.apache.spark.deploy.history.config._ import org.apache.spark.ui.SparkUI import org.apache.spark.ui.jobs.UIData.JobUIData import org.apache.spark.util.{ResetSystemProperties, Utils} @@ -64,16 +65,20 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers private val logDir = getTestResourcePath("spark-events") private val expRoot = getTestResourceFile("HistoryServerExpectations") + private val storeDir = Utils.createTempDir(namePrefix = "history") private var provider: FsHistoryProvider = null private var server: HistoryServer = null private var port: Int = -1 def init(extraConf: (String, String)*): Unit = { + Utils.deleteRecursively(storeDir) + assert(storeDir.mkdir()) val conf = new SparkConf() .set("spark.history.fs.logDirectory", logDir) .set("spark.history.fs.update.interval", "0") .set("spark.testing", "true") + .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) conf.setAll(extraConf) provider = new FsHistoryProvider(conf) provider.checkForLogs() @@ -87,14 +92,13 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers def stop(): Unit = { server.stop() + server = null } before { - init() - } - - after{ - stop() + if (server == null) { + init() + } } val cases = Seq( @@ -290,20 +294,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val uiRoot = "/testwebproxybase" System.setProperty("spark.ui.proxyBase", uiRoot) - server.stop() - - val conf = new SparkConf() - .set("spark.history.fs.logDirectory", logDir) - .set("spark.history.fs.update.interval", "0") - .set("spark.testing", "true") - - provider = new FsHistoryProvider(conf) - provider.checkForLogs() - val securityManager = HistoryServer.createSecurityManager(conf) - - server = new HistoryServer(conf, provider, securityManager, 18080) - server.initialize() - server.bind() + stop() + init() val port = server.boundPort @@ -372,7 +364,6 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } test("incomplete apps get refreshed") { - implicit val webDriver: WebDriver = new HtmlUnitDriver implicit val formats = org.json4s.DefaultFormats @@ -382,12 +373,14 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers // a new conf is used with the background thread set and running at its fastest // allowed refresh rate (1Hz) + stop() val myConf = new SparkConf() .set("spark.history.fs.logDirectory", logDir.getAbsolutePath) .set("spark.eventLog.dir", logDir.getAbsolutePath) .set("spark.history.fs.update.interval", "1s") .set("spark.eventLog.enabled", "true") .set("spark.history.cache.window", "250ms") + .set(LOCAL_STORE_DIR, storeDir.getAbsolutePath()) .remove("spark.testing") val provider = new FsHistoryProvider(myConf) val securityManager = HistoryServer.createSecurityManager(myConf) @@ -413,9 +406,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers } } - // stop the server with the old config, and start the new one - server.stop() - server = new HistoryServer(myConf, provider, securityManager, 18080) + server = new HistoryServer(myConf, provider, securityManager, 0) server.initialize() server.bind() val port = server.boundPort @@ -461,7 +452,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers rootAppPage should not be empty def getAppUI: SparkUI = { - provider.getAppUI(appId, None).get.ui + server.withSparkUI(appId, None) { ui => ui } } // selenium isn't that useful on failures...add our own reporting @@ -516,7 +507,7 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers getNumJobs("") should be (1) getNumJobs("/jobs") should be (1) getNumJobsRestful() should be (1) - assert(metrics.lookupCount.getCount > 1, s"lookup count too low in $metrics") + assert(metrics.lookupCount.getCount > 0, s"lookup count too low in $metrics") // dump state before the next bit of test, which is where update // checking really gets stressed diff --git a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala index 70887dc5dd97a..490baf040491f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/rest/StandaloneRestSubmitSuite.scala @@ -445,9 +445,9 @@ class StandaloneRestSubmitSuite extends SparkFunSuite with BeforeAndAfterEach { "--class", mainClass, mainJar) ++ appArgs val args = new SparkSubmitArguments(commandLineArgs) - val (_, _, sparkProperties, _) = SparkSubmit.prepareSubmitEnvironment(args) + val (_, _, sparkConf, _) = SparkSubmit.prepareSubmitEnvironment(args) new RestSubmissionClient("spark://host:port").constructSubmitRequest( - mainJar, mainClass, appArgs, sparkProperties.toMap, Map.empty) + mainJar, mainClass, appArgs, sparkConf.getAll.toMap, Map.empty) } /** Return the response as a submit response, or fail with error otherwise. */ diff --git a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala index 5f699df8211de..c26945fa5fa31 100644 --- a/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala +++ b/core/src/test/scala/org/apache/spark/memory/TestMemoryManager.scala @@ -27,8 +27,8 @@ class TestMemoryManager(conf: SparkConf) numBytes: Long, taskAttemptId: Long, memoryMode: MemoryMode): Long = { - if (oomOnce) { - oomOnce = false + if (consequentOOM > 0) { + consequentOOM -= 1 0 } else if (available >= numBytes) { available -= numBytes @@ -58,11 +58,15 @@ class TestMemoryManager(conf: SparkConf) override def maxOffHeapStorageMemory: Long = 0L - private var oomOnce = false + private var consequentOOM = 0 private var available = Long.MaxValue def markExecutionAsOutOfMemoryOnce(): Unit = { - oomOnce = true + markconsequentOOM(1) + } + + def markconsequentOOM(n : Int) : Unit = { + consequentOOM += n } def limit(avail: Long): Unit = { diff --git a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala index 44dd955ce8690..0a248b6064ee8 100644 --- a/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/PairRDDFunctionsSuite.scala @@ -26,7 +26,7 @@ import org.apache.commons.math3.distribution.{BinomialDistribution, PoissonDistr import org.apache.hadoop.conf.{Configurable, Configuration} import org.apache.hadoop.fs.FileSystem import org.apache.hadoop.mapred._ -import org.apache.hadoop.mapreduce.{JobContext => NewJobContext, +import org.apache.hadoop.mapreduce.{Job => NewJob, JobContext => NewJobContext, OutputCommitter => NewOutputCommitter, OutputFormat => NewOutputFormat, RecordWriter => NewRecordWriter, TaskAttemptContext => NewTaskAttempContext} import org.apache.hadoop.util.Progressable @@ -568,6 +568,50 @@ class PairRDDFunctionsSuite extends SparkFunSuite with SharedSparkContext { assert(FakeWriterWithCallback.exception.getMessage contains "failed to write") } + test("saveAsNewAPIHadoopDataset should support invalid output paths when " + + "there are no files to be committed to an absolute output location") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + def saveRddWithPath(path: String): Unit = { + val job = NewJob.getInstance(new Configuration(sc.hadoopConfiguration)) + job.setOutputKeyClass(classOf[Integer]) + job.setOutputValueClass(classOf[Integer]) + job.setOutputFormatClass(classOf[NewFakeFormat]) + if (null != path) { + job.getConfiguration.set("mapred.output.dir", path) + } else { + job.getConfiguration.unset("mapred.output.dir") + } + val jobConfiguration = job.getConfiguration + + // just test that the job does not fail with java.lang.IllegalArgumentException. + pairs.saveAsNewAPIHadoopDataset(jobConfiguration) + } + + saveRddWithPath(null) + saveRddWithPath("") + saveRddWithPath("::invalid::") + } + + // In spark 2.1, only null was supported - not other invalid paths. + // org.apache.hadoop.mapred.FileOutputFormat.getOutputPath fails with IllegalArgumentException + // for non-null invalid paths. + test("saveAsHadoopDataset should respect empty output directory when " + + "there are no files to be committed to an absolute output location") { + val pairs = sc.parallelize(Array((new Integer(1), new Integer(2))), 1) + + val conf = new JobConf() + conf.setOutputKeyClass(classOf[Integer]) + conf.setOutputValueClass(classOf[Integer]) + conf.setOutputFormat(classOf[FakeOutputFormat]) + conf.setOutputCommitter(classOf[FakeOutputCommitter]) + + FakeOutputCommitter.ran = false + pairs.saveAsHadoopDataset(conf) + + assert(FakeOutputCommitter.ran, "OutputCommitter was never called") + } + test("lookup") { val pairs = sc.parallelize(Array((1, 2), (3, 4), (5, 6), (5, 7))) diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala index f6015cd51c2bd..d3bbfd11d406d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistIntegrationSuite.scala @@ -115,8 +115,9 @@ class BlacklistIntegrationSuite extends SchedulerIntegrationSuite[MultiExecutorM withBackend(runBackend _) { val jobFuture = submit(new MockRDD(sc, 10, Nil), (0 until 10).toArray) awaitJobTermination(jobFuture, duration) - val pattern = ("Aborting TaskSet 0.0 because task .* " + - "cannot run anywhere due to node and executor blacklist").r + val pattern = ( + s"""|Aborting TaskSet 0.0 because task .* + |cannot run anywhere due to node and executor blacklist""".stripMargin).r assert(pattern.findFirstIn(failure.getMessage).isDefined, s"Couldn't find $pattern in ${failure.getMessage()}") } diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala index a136d69b36d6c..cd1b7a9e5ab18 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala @@ -110,7 +110,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val taskSetBlacklist = createTaskSetBlacklist(stageId) if (stageId % 2 == 0) { // fail one task in every other taskset - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") failuresSoFar += 1 } blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) @@ -132,7 +133,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // for many different stages, executor 1 fails a task, and then the taskSet fails. (0 until failuresUntilBlacklisted * 10).foreach { stage => val taskSetBlacklist = createTaskSetBlacklist(stage) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") } assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) } @@ -147,7 +149,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val numFailures = math.max(conf.get(config.MAX_FAILURES_PER_EXEC), conf.get(config.MAX_FAILURES_PER_EXEC_STAGE)) (0 until numFailures).foreach { index => - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = index) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = index, failureReason = "testing") } assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set()) @@ -170,7 +173,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) assert(blacklist.nodeBlacklist() === Set()) @@ -183,7 +187,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) assert(blacklist.nodeBlacklist() === Set("hostA")) @@ -207,7 +212,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail one more task, but executor isn't put back into blacklist since the count of failures // on that executor should have been reset to 0. val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(2, 0, taskSetBlacklist2.execToFailures) assert(blacklist.nodeBlacklist() === Set()) assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) @@ -221,7 +227,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Lets say that executor 1 dies completely. We get some task failures, but // the taskset then finishes successfully (elsewhere). (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.handleRemovedExecutor("1") blacklist.updateBlacklistForSuccessfulTaskSet( @@ -236,7 +243,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Now another executor gets spun up on that host, but it also dies. val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.handleRemovedExecutor("2") blacklist.updateBlacklistForSuccessfulTaskSet( @@ -279,7 +287,7 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M def failOneTaskInTaskSet(exec: String): Unit = { val taskSetBlacklist = createTaskSetBlacklist(stageId = stageId) - taskSetBlacklist.updateBlacklistForFailedTask("host-" + exec, exec, 0) + taskSetBlacklist.updateBlacklistForFailedTask("host-" + exec, exec, 0, "testing") blacklist.updateBlacklistForSuccessfulTaskSet(stageId, 0, taskSetBlacklist.execToFailures) stageId += 1 } @@ -354,12 +362,12 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 2) // Taskset1 has one failure immediately - taskSetBlacklist1.updateBlacklistForFailedTask("host-1", "1", 0) + taskSetBlacklist1.updateBlacklistForFailedTask("host-1", "1", 0, "testing") // Then we have a *long* delay, much longer than the timeout, before any other failures or // taskset completion clock.advance(blacklist.BLACKLIST_TIMEOUT_MILLIS * 5) // After the long delay, we have one failure on taskset 2, on the same executor - taskSetBlacklist2.updateBlacklistForFailedTask("host-1", "1", 0) + taskSetBlacklist2.updateBlacklistForFailedTask("host-1", "1", 0, "testing") // Finally, we complete both tasksets. Its important here to complete taskset2 *first*. We // want to make sure that when taskset 1 finishes, even though we've now got two task failures, // we realize that the task failure we just added was well before the timeout. @@ -377,16 +385,20 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // we blacklist executors on two different hosts -- make sure that doesn't lead to any // node blacklisting val taskSetBlacklist0 = createTaskSetBlacklist(stageId = 0) - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "1", 2)) assertEquivalentToSet(blacklist.isNodeBlacklisted(_), Set()) val taskSetBlacklist1 = createTaskSetBlacklist(stageId = 1) - taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) - taskSetBlacklist1.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 0, failureReason = "testing") + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(1, 0, taskSetBlacklist1.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "2", 2)) @@ -395,8 +407,10 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Finally, blacklist another executor on the same node as the original blacklisted executor, // and make sure this time we *do* blacklist the node. val taskSetBlacklist2 = createTaskSetBlacklist(stageId = 0) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 0) - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "3", index = 1) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 0, failureReason = "testing") + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 1, failureReason = "testing") blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) assertEquivalentToSet(blacklist.isExecutorBlacklisted(_), Set("1", "2", "3")) verify(listenerBusMock).post(SparkListenerExecutorBlacklisted(0, "3", 2)) @@ -486,7 +500,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist0.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist0.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist0.execToFailures) @@ -497,7 +512,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist1.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist1.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist1.execToFailures) @@ -512,7 +528,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // Fail 4 tasks in one task set on executor 1, so that executor gets blacklisted for the whole // application. (0 until 4).foreach { partition => - taskSetBlacklist2.updateBlacklistForFailedTask("hostA", exec = "1", index = partition) + taskSetBlacklist2.updateBlacklistForFailedTask( + "hostA", exec = "1", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist2.execToFailures) @@ -523,7 +540,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M // application. Since that's the second executor that is blacklisted on the same node, we also // blacklist that node. (0 until 4).foreach { partition => - taskSetBlacklist3.updateBlacklistForFailedTask("hostA", exec = "2", index = partition) + taskSetBlacklist3.updateBlacklistForFailedTask( + "hostA", exec = "2", index = partition, failureReason = "testing") } blacklist.updateBlacklistForSuccessfulTaskSet(0, 0, taskSetBlacklist3.execToFailures) diff --git a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala index 6b42775ccb0f6..a9e92fa07b9dd 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/EventLoggingListenerSuite.scala @@ -228,6 +228,7 @@ class EventLoggingListenerSuite extends SparkFunSuite with LocalSparkContext wit SparkListenerStageCompleted, SparkListenerTaskStart, SparkListenerTaskEnd, + SparkListenerBlockUpdated, SparkListenerApplicationEnd).map(Utils.getFormattedClassName) Utils.tryWithSafeFinally { val logStart = SparkListenerLogStart(SPARK_VERSION) @@ -291,6 +292,7 @@ object EventLoggingListenerSuite { def getLoggingConf(logDir: Path, compressionCodec: Option[String] = None): SparkConf = { val conf = new SparkConf conf.set("spark.eventLog.enabled", "true") + conf.set("spark.eventLog.logBlockUpdates.enabled", "true") conf.set("spark.eventLog.testing", "true") conf.set("spark.eventLog.dir", logDir.toString) compressionCodec.foreach { codec => diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index fe6de2bd98850..109d4a0a870b8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -19,8 +19,7 @@ package org.apache.spark.scheduler import java.util.Properties -import org.apache.spark.SparkEnv -import org.apache.spark.TaskContext +import org.apache.spark.{Partition, SparkEnv, TaskContext} import org.apache.spark.executor.TaskMetrics class FakeTask( @@ -58,4 +57,21 @@ object FakeTask { } new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) } + + def createShuffleMapTaskSet( + numTasks: Int, + stageId: Int, + stageAttemptId: Int, + prefLocs: Seq[TaskLocation]*): TaskSet = { + if (prefLocs.size != 0 && prefLocs.size != numTasks) { + throw new IllegalArgumentException("Wrong number of task locations") + } + val tasks = Array.tabulate[Task[_]](numTasks) { i => + new ShuffleMapTask(stageId, stageAttemptId, null, new Partition { + override def index: Int = i + }, prefLocs(i), new Properties, + SparkEnv.get.closureSerializer.newInstance().serialize(TaskMetrics.registered).array()) + } + new TaskSet(tasks, stageId, stageAttemptId, priority = 0, null) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index d061c7845f4a6..1beb36afa95f0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.Matchers import org.apache.spark._ import org.apache.spark.executor.TaskMetrics -import org.apache.spark.internal.config.LISTENER_BUS_EVENT_QUEUE_CAPACITY +import org.apache.spark.internal.config._ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{ResetSystemProperties, RpcUtils} @@ -446,13 +446,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match classOf[FirehoseListenerThatAcceptsSparkConf], classOf[BasicJobCounter]) val conf = new SparkConf().setMaster("local").setAppName("test") - .set("spark.extraListeners", listeners.map(_.getName).mkString(",")) + .set(EXTRA_LISTENERS, listeners.map(_.getName)) sc = new SparkContext(conf) sc.listenerBus.listeners.asScala.count(_.isInstanceOf[BasicJobCounter]) should be (1) sc.listenerBus.listeners.asScala .count(_.isInstanceOf[ListenerThatAcceptsSparkConf]) should be (1) sc.listenerBus.listeners.asScala - .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) + .count(_.isInstanceOf[FirehoseListenerThatAcceptsSparkConf]) should be (1) } test("add and remove listeners to/from LiveListenerBus queues") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index b8626bf777598..6003899bb7bef 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -660,9 +660,14 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B assert(tsm.isZombie) assert(failedTaskSet) val idx = failedTask.index - assert(failedTaskSetReason === s"Aborting TaskSet 0.0 because task $idx (partition $idx) " + - s"cannot run anywhere due to node and executor blacklist. Blacklisting behavior can be " + - s"configured via spark.blacklist.*.") + assert(failedTaskSetReason === s""" + |Aborting $taskSet because task $idx (partition $idx) + |cannot run anywhere due to node and executor blacklist. + |Most recent failure: + |${tsm.taskSetBlacklistHelperOpt.get.getLatestFailureReason} + | + |Blacklisting behavior can be configured via spark.blacklist.*. + |""".stripMargin) } test("don't abort if there is an executor available, though it hasn't had scheduled tasks yet") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala index f1392e9db6bfd..18981d5be2f94 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetBlacklistSuite.scala @@ -37,7 +37,8 @@ class TaskSetBlacklistSuite extends SparkFunSuite { // First, mark task 0 as failed on exec1. // task 0 should be blacklisted on exec1, and nowhere else - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec1", index = 0, failureReason = "testing") for { executor <- (1 to 4).map(_.toString) index <- 0 until 10 @@ -49,17 +50,20 @@ class TaskSetBlacklistSuite extends SparkFunSuite { assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark task 1 failed on exec1 -- this pushes the executor into the blacklist - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark one task as failed on exec2 -- not enough for any further blacklisting yet. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Mark another task as failed on exec2 -- now we blacklist exec2, which also leads to // blacklisting the entire node. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "exec2", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "exec2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec1")) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("exec2")) assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) @@ -108,34 +112,41 @@ class TaskSetBlacklistSuite extends SparkFunSuite { .set(config.MAX_FAILED_EXEC_PER_NODE_STAGE, 3) val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) // Fail a task twice on hostA, exec:1 - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTask("1", 0)) assert(!taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail the same task once more on hostA, exec:2 - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 0) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "2", index = 0, failureReason = "testing") assert(taskSetBlacklist.isNodeBlacklistedForTask("hostA", 0)) assert(!taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail another task on hostA, exec:1. Now that executor has failures on two different tasks, // so its blacklisted - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail a third task on hostA, exec:2, so that exec is blacklisted for the whole task set - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "2", index = 2) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "2", index = 2, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) // Fail a fourth & fifth task on hostA, exec:3. Now we've got three executors that are // blacklisted for the taskset, so blacklist the whole node. - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 3) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "3", index = 4) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 3, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "3", index = 4, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("3")) assert(taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) } @@ -147,13 +158,17 @@ class TaskSetBlacklistSuite extends SparkFunSuite { val conf = new SparkConf().setAppName("test").setMaster("local") .set(config.BLACKLIST_ENABLED.key, "true") val taskSetBlacklist = new TaskSetBlacklist(conf, stageId = 0, new SystemClock()) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostA", exec = "1", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostA", exec = "1", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) - taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 0) - taskSetBlacklist.updateBlacklistForFailedTask("hostB", exec = "2", index = 1) + taskSetBlacklist.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 0, failureReason = "testing") + taskSetBlacklist.updateBlacklistForFailedTask( + "hostB", exec = "2", index = 1, failureReason = "testing") assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("1")) assert(taskSetBlacklist.isExecutorBlacklistedForTaskSet("2")) assert(!taskSetBlacklist.isNodeBlacklistedForTaskSet("hostA")) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ae43f4cadc037..2ce81ae27daf6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -744,6 +744,113 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(resubmittedTasks === 0) } + + test("[SPARK-22074] Task killed by other attempt task should not be resubmitted") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + // Set the speculation multiplier to be 0 so speculative tasks are launched immediately + sc.conf.set("spark.speculation.multiplier", "0.0") + sc.conf.set("spark.speculation.quantile", "0.5") + sc.conf.set("spark.speculation", "true") + + var killTaskCalled = false + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), + ("exec2", "host2"), ("exec3", "host3")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask( + taskId: Long, + executorId: String, + interruptThread: Boolean, + reason: String): Unit = { + // Check the only one killTask event in this case, which triggered by + // task 2.1 completed. + assert(taskId === 2) + assert(executorId === "exec3") + assert(interruptThread) + assert(reason === "another attempt succeeded") + killTaskCalled = true + } + }) + + // Keep track of the number of tasks that are resubmitted, + // so that the test can check that no tasks were resubmitted. + var resubmittedTasks = 0 + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += 1 + case _ => + } + } + } + sched.setDAGScheduler(dagScheduler) + + val taskSet = FakeTask.createShuffleMapTaskSet(4, 0, 0, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host3", "exec3")), + Seq(TaskLocation("host2", "exec2"))) + + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task => + task.metrics.internalAccums + } + // Offer resources for 4 tasks to start + for ((exec, host) <- Seq( + "exec1" -> "host1", + "exec1" -> "host1", + "exec3" -> "host3", + "exec2" -> "host2")) { + val taskOption = manager.resourceOffer(exec, host, NO_PREF) + assert(taskOption.isDefined) + val task = taskOption.get + assert(task.executorId === exec) + // Add an extra assert to make sure task 2.0 is running on exec3 + if (task.index == 2) { + assert(task.attemptNumber === 0) + assert(task.executorId === "exec3") + } + } + assert(sched.startedTasks.toSet === Set(0, 1, 2, 3)) + clock.advance(1) + // Complete the 2 tasks and leave 2 task in running + for (id <- Set(0, 1)) { + manager.handleSuccessfulTask(id, createTaskResult(id, accumUpdatesByTask(id))) + assert(sched.endedTasks(id) === Success) + } + + // checkSpeculatableTasks checks that the task runtime is greater than the threshold for + // speculating. Since we use a threshold of 0 for speculation, tasks need to be running for + // > 0ms, so advance the clock by 1ms here. + clock.advance(1) + assert(manager.checkSpeculatableTasks(0)) + assert(sched.speculativeTasks.toSet === Set(2, 3)) + + // Offer resource to start the speculative attempt for the running task 2.0 + val taskOption = manager.resourceOffer("exec2", "host2", ANY) + assert(taskOption.isDefined) + val task4 = taskOption.get + assert(task4.index === 2) + assert(task4.taskId === 4) + assert(task4.executorId === "exec2") + assert(task4.attemptNumber === 1) + // Complete the speculative attempt for the running task + manager.handleSuccessfulTask(4, createTaskResult(2, accumUpdatesByTask(2))) + // Make sure schedBackend.killTask(2, "exec3", true, "another attempt succeeded") gets called + assert(killTaskCalled) + // Host 3 Losts, there's only task 2.0 on it, which killed by task 2.1 + manager.executorLost("exec3", "host3", SlaveLost()) + // Check the resubmittedTasks + assert(resubmittedTasks === 0) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler( @@ -1146,7 +1253,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // Make sure that the blacklist ignored all of the task failures above, since they aren't // the fault of the executor where the task was running. verify(blacklist, never()) - .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) + .updateBlacklistForFailedTask(anyString(), anyString(), anyInt(), anyString()) } test("update application blacklist for shuffle-fetch") { diff --git a/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala new file mode 100644 index 0000000000000..7ac1ce19f8ddf --- /dev/null +++ b/core/src/test/scala/org/apache/spark/status/AppStatusListenerSuite.scala @@ -0,0 +1,703 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.status + +import java.io.File +import java.lang.{Integer => JInteger, Long => JLong} +import java.util.{Arrays, Date, Properties} + +import scala.collection.JavaConverters._ +import scala.reflect.{classTag, ClassTag} + +import org.scalatest.BeforeAndAfter + +import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.scheduler._ +import org.apache.spark.scheduler.cluster._ +import org.apache.spark.status.api.v1 +import org.apache.spark.storage._ +import org.apache.spark.util.Utils +import org.apache.spark.util.kvstore._ + +class AppStatusListenerSuite extends SparkFunSuite with BeforeAndAfter { + + import config._ + + private val conf = new SparkConf().set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + + private var time: Long = _ + private var testDir: File = _ + private var store: KVStore = _ + + before { + time = 0L + testDir = Utils.createTempDir() + store = KVUtils.open(testDir, getClass().getName()) + } + + after { + store.close() + Utils.deleteRecursively(testDir) + } + + test("scheduler events") { + val listener = new AppStatusListener(store, conf, true) + + // Start the application. + time += 1 + listener.onApplicationStart(SparkListenerApplicationStart( + "name", + Some("id"), + time, + "user", + Some("attempt"), + None)) + + check[ApplicationInfoWrapper]("id") { app => + assert(app.info.name === "name") + assert(app.info.id === "id") + assert(app.info.attempts.size === 1) + + val attempt = app.info.attempts.head + assert(attempt.attemptId === Some("attempt")) + assert(attempt.startTime === new Date(time)) + assert(attempt.lastUpdated === new Date(time)) + assert(attempt.endTime.getTime() === -1L) + assert(attempt.sparkUser === "user") + assert(!attempt.completed) + } + + // Start a couple of executors. + time += 1 + val execIds = Array("1", "2") + + execIds.foreach { id => + listener.onExecutorAdded(SparkListenerExecutorAdded(time, id, + new ExecutorInfo(s"$id.example.com", 1, Map()))) + } + + execIds.foreach { id => + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + assert(exec.info.hostPort === s"$id.example.com") + assert(exec.info.isActive) + } + } + + // Start a job with 2 stages / 4 tasks each + time += 1 + val stages = Seq( + new StageInfo(1, 0, "stage1", 4, Nil, Nil, "details1"), + new StageInfo(2, 0, "stage2", 4, Nil, Seq(1), "details2")) + + val jobProps = new Properties() + jobProps.setProperty(SparkContext.SPARK_JOB_GROUP_ID, "jobGroup") + jobProps.setProperty("spark.scheduler.pool", "schedPool") + + listener.onJobStart(SparkListenerJobStart(1, time, stages, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.jobId === 1) + assert(job.info.name === stages.last.name) + assert(job.info.description === None) + assert(job.info.status === JobExecutionStatus.RUNNING) + assert(job.info.submissionTime === Some(new Date(time))) + assert(job.info.jobGroup === Some("jobGroup")) + } + + stages.foreach { info => + check[StageDataWrapper](key(info)) { stage => + assert(stage.info.status === v1.StageStatus.PENDING) + assert(stage.jobIds === Set(1)) + } + } + + // Submit stage 1 + time += 1 + stages.head.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stages.head, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 1) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.status === v1.StageStatus.ACTIVE) + assert(stage.info.submissionTime === Some(new Date(stages.head.submissionTime.get))) + assert(stage.info.schedulingPool === "schedPool") + } + + // Start tasks from stage 1 + time += 1 + var _taskIdTracker = -1L + def nextTaskId(): Long = { + _taskIdTracker += 1 + _taskIdTracker + } + + def createTasks(count: Int, time: Long): Seq[TaskInfo] = { + (1 to count).map { id => + val exec = execIds(id.toInt % execIds.length) + val taskId = nextTaskId() + new TaskInfo(taskId, taskId.toInt, 1, time, exec, s"$exec.example.com", + TaskLocality.PROCESS_LOCAL, id % 2 == 0) + } + } + + val s1Tasks = createTasks(4, time) + s1Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, task)) + } + + assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveTasks === s1Tasks.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numActiveTasks === s1Tasks.size) + assert(stage.info.firstTaskLaunchedTime === Some(new Date(s1Tasks.head.launchTime))) + } + + s1Tasks.foreach { task => + check[TaskDataWrapper](task.taskId) { wrapper => + assert(wrapper.info.taskId === task.taskId) + assert(wrapper.stageId === stages.head.stageId) + assert(wrapper.stageAttemptId === stages.head.attemptId) + assert(Arrays.equals(wrapper.stage, Array(stages.head.stageId, stages.head.attemptId))) + + val runtime = Array[AnyRef](stages.head.stageId: JInteger, stages.head.attemptId: JInteger, + -1L: JLong) + assert(Arrays.equals(wrapper.runtime, runtime)) + + assert(wrapper.info.index === task.index) + assert(wrapper.info.attempt === task.attemptNumber) + assert(wrapper.info.launchTime === new Date(task.launchTime)) + assert(wrapper.info.executorId === task.executorId) + assert(wrapper.info.host === task.host) + assert(wrapper.info.status === task.status) + assert(wrapper.info.taskLocality === task.taskLocality.toString()) + assert(wrapper.info.speculative === task.speculative) + } + } + + // Send executor metrics update. Only update one metric to avoid a lot of boilerplate code. + s1Tasks.foreach { task => + val accum = new AccumulableInfo(1L, Some(InternalAccumulator.MEMORY_BYTES_SPILLED), + Some(1L), None, true, false, None) + listener.onExecutorMetricsUpdate(SparkListenerExecutorMetricsUpdate( + task.executorId, + Seq((task.taskId, stages.head.stageId, stages.head.attemptId, Seq(accum))))) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.memoryBytesSpilled === s1Tasks.size) + } + + val execs = store.view(classOf[ExecutorStageSummaryWrapper]).index("stage") + .first(key(stages.head)).last(key(stages.head)).asScala.toSeq + assert(execs.size > 0) + execs.foreach { exec => + assert(exec.info.memoryBytesSpilled === s1Tasks.size / 2) + } + + // Fail one of the tasks, re-start it. + time += 1 + s1Tasks.head.markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", TaskResultLost, s1Tasks.head, null)) + + time += 1 + val reattempt = { + val orig = s1Tasks.head + // Task reattempts have a different ID, but the same index as the original. + new TaskInfo(nextTaskId(), orig.index, orig.attemptNumber + 1, time, orig.executorId, + s"${orig.executorId}.example.com", TaskLocality.PROCESS_LOCAL, orig.speculative) + } + listener.onTaskStart(SparkListenerTaskStart(stages.head.stageId, stages.head.attemptId, + reattempt)) + + assert(store.count(classOf[TaskDataWrapper]) === s1Tasks.size + 1) + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1) + assert(job.info.numActiveTasks === s1Tasks.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === s1Tasks.size) + } + + check[TaskDataWrapper](s1Tasks.head.taskId) { task => + assert(task.info.status === s1Tasks.head.status) + assert(task.info.duration === Some(s1Tasks.head.duration)) + assert(task.info.errorMessage == Some(TaskResultLost.toErrorString)) + } + + check[TaskDataWrapper](reattempt.taskId) { task => + assert(task.info.index === s1Tasks.head.index) + assert(task.info.attempt === reattempt.attemptNumber) + } + + // Succeed all tasks in stage 1. + val pending = s1Tasks.drop(1) ++ Seq(reattempt) + + val s1Metrics = TaskMetrics.empty + s1Metrics.setExecutorCpuTime(2L) + s1Metrics.setExecutorRunTime(4L) + + time += 1 + pending.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.head.stageId, stages.head.attemptId, + "taskType", Success, task, s1Metrics)) + } + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1) + assert(job.info.numActiveTasks === 0) + assert(job.info.numCompletedTasks === pending.size) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === pending.size) + } + + pending.foreach { task => + check[TaskDataWrapper](task.taskId) { wrapper => + assert(wrapper.info.errorMessage === None) + assert(wrapper.info.taskMetrics.get.executorCpuTime === 2L) + assert(wrapper.info.taskMetrics.get.executorRunTime === 4L) + } + } + + assert(store.count(classOf[TaskDataWrapper]) === pending.size + 1) + + // End stage 1. + time += 1 + stages.head.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(stages.head)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 0) + assert(job.info.numCompletedStages === 1) + } + + check[StageDataWrapper](key(stages.head)) { stage => + assert(stage.info.status === v1.StageStatus.COMPLETE) + assert(stage.info.numFailedTasks === 1) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === pending.size) + } + + // Submit stage 2. + time += 1 + stages.last.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(stages.last, jobProps)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 1) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.status === v1.StageStatus.ACTIVE) + assert(stage.info.submissionTime === Some(new Date(stages.last.submissionTime.get))) + } + + // Start and fail all tasks of stage 2. + time += 1 + val s2Tasks = createTasks(4, time) + s2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(stages.last.stageId, stages.last.attemptId, task)) + } + + time += 1 + s2Tasks.foreach { task => + task.markFinished(TaskState.FAILED, time) + listener.onTaskEnd(SparkListenerTaskEnd(stages.last.stageId, stages.last.attemptId, + "taskType", TaskResultLost, task, null)) + } + + check[JobDataWrapper](1) { job => + assert(job.info.numFailedTasks === 1 + s2Tasks.size) + assert(job.info.numActiveTasks === 0) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.numFailedTasks === s2Tasks.size) + assert(stage.info.numActiveTasks === 0) + } + + // Fail stage 2. + time += 1 + stages.last.completionTime = Some(time) + stages.last.failureReason = Some("uh oh") + listener.onStageCompleted(SparkListenerStageCompleted(stages.last)) + + check[JobDataWrapper](1) { job => + assert(job.info.numCompletedStages === 1) + assert(job.info.numFailedStages === 1) + } + + check[StageDataWrapper](key(stages.last)) { stage => + assert(stage.info.status === v1.StageStatus.FAILED) + assert(stage.info.numFailedTasks === s2Tasks.size) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === 0) + } + + // - Re-submit stage 2, all tasks, and succeed them and the stage. + val oldS2 = stages.last + val newS2 = new StageInfo(oldS2.stageId, oldS2.attemptId + 1, oldS2.name, oldS2.numTasks, + oldS2.rddInfos, oldS2.parentIds, oldS2.details, oldS2.taskMetrics) + + time += 1 + newS2.submissionTime = Some(time) + listener.onStageSubmitted(SparkListenerStageSubmitted(newS2, jobProps)) + assert(store.count(classOf[StageDataWrapper]) === 3) + + val newS2Tasks = createTasks(4, time) + + newS2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(newS2.stageId, newS2.attemptId, task)) + } + + time += 1 + newS2Tasks.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(newS2.stageId, newS2.attemptId, "taskType", Success, + task, null)) + } + + time += 1 + newS2.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(newS2)) + + check[JobDataWrapper](1) { job => + assert(job.info.numActiveStages === 0) + assert(job.info.numFailedStages === 1) + assert(job.info.numCompletedStages === 2) + } + + check[StageDataWrapper](key(newS2)) { stage => + assert(stage.info.status === v1.StageStatus.COMPLETE) + assert(stage.info.numActiveTasks === 0) + assert(stage.info.numCompleteTasks === newS2Tasks.size) + } + + // End job. + time += 1 + listener.onJobEnd(SparkListenerJobEnd(1, time, JobSucceeded)) + + check[JobDataWrapper](1) { job => + assert(job.info.status === JobExecutionStatus.SUCCEEDED) + } + + // Submit a second job that re-uses stage 1 and stage 2. Stage 1 won't be re-run, but + // stage 2 will. In any case, the DAGScheduler creates new info structures that are copies + // of the old stages, so mimic that behavior here. The "new" stage 1 is submitted without + // a submission time, which means it is "skipped", and the stage 2 re-execution should not + // change the stats of the already finished job. + time += 1 + val j2Stages = Seq( + new StageInfo(3, 0, "stage1", 4, Nil, Nil, "details1"), + new StageInfo(4, 0, "stage2", 4, Nil, Seq(3), "details2")) + j2Stages.last.submissionTime = Some(time) + listener.onJobStart(SparkListenerJobStart(2, time, j2Stages, null)) + assert(store.count(classOf[JobDataWrapper]) === 2) + + listener.onStageSubmitted(SparkListenerStageSubmitted(j2Stages.head, jobProps)) + listener.onStageCompleted(SparkListenerStageCompleted(j2Stages.head)) + listener.onStageSubmitted(SparkListenerStageSubmitted(j2Stages.last, jobProps)) + assert(store.count(classOf[StageDataWrapper]) === 5) + + time += 1 + val j2s2Tasks = createTasks(4, time) + + j2s2Tasks.foreach { task => + listener.onTaskStart(SparkListenerTaskStart(j2Stages.last.stageId, j2Stages.last.attemptId, + task)) + } + + time += 1 + j2s2Tasks.foreach { task => + task.markFinished(TaskState.FINISHED, time) + listener.onTaskEnd(SparkListenerTaskEnd(j2Stages.last.stageId, j2Stages.last.attemptId, + "taskType", Success, task, null)) + } + + time += 1 + j2Stages.last.completionTime = Some(time) + listener.onStageCompleted(SparkListenerStageCompleted(j2Stages.last)) + + time += 1 + listener.onJobEnd(SparkListenerJobEnd(2, time, JobSucceeded)) + + check[JobDataWrapper](1) { job => + assert(job.info.numCompletedStages === 2) + assert(job.info.numCompletedTasks === s1Tasks.size + s2Tasks.size) + } + + check[JobDataWrapper](2) { job => + assert(job.info.status === JobExecutionStatus.SUCCEEDED) + assert(job.info.numCompletedStages === 1) + assert(job.info.numCompletedTasks === j2s2Tasks.size) + assert(job.info.numSkippedStages === 1) + assert(job.info.numSkippedTasks === s1Tasks.size) + } + + // Blacklist an executor. + time += 1 + listener.onExecutorBlacklisted(SparkListenerExecutorBlacklisted(time, "1", 42)) + check[ExecutorSummaryWrapper]("1") { exec => + assert(exec.info.isBlacklisted) + } + + time += 1 + listener.onExecutorUnblacklisted(SparkListenerExecutorUnblacklisted(time, "1")) + check[ExecutorSummaryWrapper]("1") { exec => + assert(!exec.info.isBlacklisted) + } + + // Blacklist a node. + time += 1 + listener.onNodeBlacklisted(SparkListenerNodeBlacklisted(time, "1.example.com", 2)) + check[ExecutorSummaryWrapper]("1") { exec => + assert(exec.info.isBlacklisted) + } + + time += 1 + listener.onNodeUnblacklisted(SparkListenerNodeUnblacklisted(time, "1.example.com")) + check[ExecutorSummaryWrapper]("1") { exec => + assert(!exec.info.isBlacklisted) + } + + // Stop executors. + listener.onExecutorRemoved(SparkListenerExecutorRemoved(41L, "1", "Test")) + listener.onExecutorRemoved(SparkListenerExecutorRemoved(41L, "2", "Test")) + + Seq("1", "2").foreach { id => + check[ExecutorSummaryWrapper](id) { exec => + assert(exec.info.id === id) + assert(!exec.info.isActive) + } + } + + // End the application. + listener.onApplicationEnd(SparkListenerApplicationEnd(42L)) + + check[ApplicationInfoWrapper]("id") { app => + assert(app.info.name === "name") + assert(app.info.id === "id") + assert(app.info.attempts.size === 1) + + val attempt = app.info.attempts.head + assert(attempt.attemptId === Some("attempt")) + assert(attempt.startTime === new Date(1L)) + assert(attempt.lastUpdated === new Date(42L)) + assert(attempt.endTime === new Date(42L)) + assert(attempt.duration === 41L) + assert(attempt.sparkUser === "user") + assert(attempt.completed) + } + } + + test("storage events") { + val listener = new AppStatusListener(store, conf, true) + val maxMemory = 42L + + // Register a couple of block managers. + val bm1 = BlockManagerId("1", "1.example.com", 42) + val bm2 = BlockManagerId("2", "2.example.com", 84) + Seq(bm1, bm2).foreach { bm => + listener.onExecutorAdded(SparkListenerExecutorAdded(1L, bm.executorId, + new ExecutorInfo(bm.host, 1, Map()))) + listener.onBlockManagerAdded(SparkListenerBlockManagerAdded(1L, bm, maxMemory)) + check[ExecutorSummaryWrapper](bm.executorId) { exec => + assert(exec.info.maxMemory === maxMemory) + } + } + + val rdd1b1 = RDDBlockId(1, 1) + val level = StorageLevel.MEMORY_AND_DISK + + // Submit a stage and make sure the RDD is recorded. + val rddInfo = new RDDInfo(rdd1b1.rddId, "rdd1", 2, level, Nil) + val stage = new StageInfo(1, 0, "stage1", 4, Seq(rddInfo), Nil, "details1") + listener.onStageSubmitted(SparkListenerStageSubmitted(stage, new Properties())) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.name === rddInfo.name) + assert(wrapper.info.numPartitions === rddInfo.numPartitions) + assert(wrapper.info.storageLevel === rddInfo.storageLevel.description) + } + + // Add partition 1 replicated on two block managers. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, level, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 1L) + assert(wrapper.info.diskUsed === 1L) + + assert(wrapper.info.dataDistribution.isDefined) + assert(wrapper.info.dataDistribution.get.size === 1) + + val dist = wrapper.info.dataDistribution.get.head + assert(dist.address === bm1.hostPort) + assert(dist.memoryUsed === 1L) + assert(dist.diskUsed === 1L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + assert(wrapper.info.partitions.isDefined) + assert(wrapper.info.partitions.get.size === 1) + + val part = wrapper.info.partitions.get.head + assert(part.blockName === rdd1b1.name) + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 1L) + assert(part.diskUsed === 1L) + assert(part.executors === Seq(bm1.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 1L) + assert(exec.info.diskUsed === 1L) + } + + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, level, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 2L) + assert(wrapper.info.diskUsed === 2L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 1L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm2.hostPort).get + assert(dist.memoryUsed === 1L) + assert(dist.diskUsed === 1L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get(0) + assert(part.memoryUsed === 2L) + assert(part.diskUsed === 2L) + assert(part.executors === Seq(bm1.executorId, bm2.executorId)) + } + + check[ExecutorSummaryWrapper](bm2.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 1L) + assert(exec.info.diskUsed === 1L) + } + + // Add a second partition only to bm 1. + val rdd1b2 = RDDBlockId(1, 2) + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b2, level, + 3L, 3L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 5L) + assert(wrapper.info.diskUsed === 5L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 2L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get + assert(dist.memoryUsed === 4L) + assert(dist.diskUsed === 4L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b2.name).get + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 3L) + assert(part.diskUsed === 3L) + assert(part.executors === Seq(bm1.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 2L) + assert(exec.info.memoryUsed === 4L) + assert(exec.info.diskUsed === 4L) + } + + // Remove block 1 from bm 1. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm1, rdd1b1, + StorageLevel.NONE, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 4L) + assert(wrapper.info.diskUsed === 4L) + assert(wrapper.info.dataDistribution.get.size === 2L) + assert(wrapper.info.partitions.get.size === 2L) + + val dist = wrapper.info.dataDistribution.get.find(_.address == bm1.hostPort).get + assert(dist.memoryUsed === 3L) + assert(dist.diskUsed === 3L) + assert(dist.memoryRemaining === maxMemory - dist.memoryUsed) + + val part = wrapper.info.partitions.get.find(_.blockName === rdd1b1.name).get + assert(part.storageLevel === level.description) + assert(part.memoryUsed === 1L) + assert(part.diskUsed === 1L) + assert(part.executors === Seq(bm2.executorId)) + } + + check[ExecutorSummaryWrapper](bm1.executorId) { exec => + assert(exec.info.rddBlocks === 1L) + assert(exec.info.memoryUsed === 3L) + assert(exec.info.diskUsed === 3L) + } + + // Remove block 2 from bm 2. This should leave only block 2 info in the store. + listener.onBlockUpdated(SparkListenerBlockUpdated(BlockUpdatedInfo(bm2, rdd1b1, + StorageLevel.NONE, 1L, 1L))) + + check[RDDStorageInfoWrapper](rdd1b1.rddId) { wrapper => + assert(wrapper.info.memoryUsed === 3L) + assert(wrapper.info.diskUsed === 3L) + assert(wrapper.info.dataDistribution.get.size === 1L) + assert(wrapper.info.partitions.get.size === 1L) + assert(wrapper.info.partitions.get(0).blockName === rdd1b2.name) + } + + check[ExecutorSummaryWrapper](bm2.executorId) { exec => + assert(exec.info.rddBlocks === 0L) + assert(exec.info.memoryUsed === 0L) + assert(exec.info.diskUsed === 0L) + } + + // Unpersist RDD1. + listener.onUnpersistRDD(SparkListenerUnpersistRDD(rdd1b1.rddId)) + intercept[NoSuchElementException] { + check[RDDStorageInfoWrapper](rdd1b1.rddId) { _ => () } + } + + } + + private def key(stage: StageInfo): Array[Int] = Array(stage.stageId, stage.attemptId) + + private def check[T: ClassTag](key: Any)(fn: T => Unit): Unit = { + val value = store.read(classTag[T].runtimeClass, key).asInstanceOf[T] + fn(value) + } + +} diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index f0c521b00b583..ff4755833a916 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -35,13 +35,8 @@ class BlockIdSuite extends SparkFunSuite { } test("test-bad-deserialization") { - try { - // Try to deserialize an invalid block id. + intercept[UnrecognizedBlockId] { BlockId("myblock") - fail() - } catch { - case e: IllegalStateException => // OK - case _: Throwable => fail() } } @@ -139,6 +134,7 @@ class BlockIdSuite extends SparkFunSuite { assert(id.id.getMostSignificantBits() === 5) assert(id.id.getLeastSignificantBits() === 2) assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) } test("temp shuffle") { @@ -151,6 +147,7 @@ class BlockIdSuite extends SparkFunSuite { assert(id.id.getMostSignificantBits() === 1) assert(id.id.getLeastSignificantBits() === 2) assert(!id.isShuffle) + assertSame(id, BlockId(id.toString)) } test("test") { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index dd61dcd11bcda..c2101ba828553 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -198,55 +198,6 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite } } - test("block replication - deterministic node selection") { - val blockSize = 1000 - val storeSize = 10000 - val stores = (1 to 5).map { - i => makeBlockManager(storeSize, s"store$i") - } - val storageLevel2x = StorageLevel.MEMORY_AND_DISK_2 - val storageLevel3x = StorageLevel(true, true, false, true, 3) - val storageLevel4x = StorageLevel(true, true, false, true, 4) - - def putBlockAndGetLocations(blockId: String, level: StorageLevel): Set[BlockManagerId] = { - stores.head.putSingle(blockId, new Array[Byte](blockSize), level) - val locations = master.getLocations(blockId).sortBy { _.executorId }.toSet - stores.foreach { _.removeBlock(blockId) } - master.removeBlock(blockId) - locations - } - - // Test if two attempts to 2x replication returns same set of locations - val a1Locs = putBlockAndGetLocations("a1", storageLevel2x) - assert(putBlockAndGetLocations("a1", storageLevel2x) === a1Locs, - "Inserting a 2x replicated block second time gave different locations from the first") - - // Test if two attempts to 3x replication returns same set of locations - val a2Locs3x = putBlockAndGetLocations("a2", storageLevel3x) - assert(putBlockAndGetLocations("a2", storageLevel3x) === a2Locs3x, - "Inserting a 3x replicated block second time gave different locations from the first") - - // Test if 2x replication of a2 returns a strict subset of the locations of 3x replication - val a2Locs2x = putBlockAndGetLocations("a2", storageLevel2x) - assert( - a2Locs2x.subsetOf(a2Locs3x), - "Inserting a with 2x replication gave locations that are not a subset of locations" + - s" with 3x replication [3x: ${a2Locs3x.mkString(",")}; 2x: ${a2Locs2x.mkString(",")}" - ) - - // Test if 4x replication of a2 returns a strict superset of the locations of 3x replication - val a2Locs4x = putBlockAndGetLocations("a2", storageLevel4x) - assert( - a2Locs3x.subsetOf(a2Locs4x), - "Inserting a with 4x replication gave locations that are not a superset of locations " + - s"with 3x replication [3x: ${a2Locs3x.mkString(",")}; 4x: ${a2Locs4x.mkString(",")}" - ) - - // Test if 3x replication of two different blocks gives two different sets of locations - val a3Locs3x = putBlockAndGetLocations("a3", storageLevel3x) - assert(a3Locs3x !== a2Locs3x, "Two blocks gave same locations with 3x replication") - } - test("block replication - replication failures") { /* Create a system of three block managers / stores. One of them (say, failableStore) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index cfe89fde63f88..d45c194d31adc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark.storage -import java.io.File import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -45,14 +44,14 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.client.{RpcResponseCallback, TransportClient} import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} import org.apache.spark.network.server.{NoOpRpcHandler, TransportServer, TransportServerBootstrap} -import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient, TempFileManager} import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, RegisterExecutor} import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.security.{CryptoStreamUtils, EncryptionFunSuite} import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} import org.apache.spark.shuffle.sort.SortShuffleManager -import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat +import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer @@ -512,8 +511,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE when(bmMaster.getLocations(mc.any[BlockId])).thenReturn(Seq(bmId1, bmId2, bmId3)) val blockManager = makeBlockManager(128, "exec", bmMaster) - val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) - val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + val sortLocations = PrivateMethod[Seq[BlockManagerId]]('sortLocations) + val locations = blockManager invokePrivate sortLocations(bmMaster.getLocations("test")) assert(locations.map(_.host) === Seq(localHost, localHost, otherHost)) } @@ -535,8 +534,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE val blockManager = makeBlockManager(128, "exec", bmMaster) blockManager.blockManagerId = BlockManagerId(SparkContext.DRIVER_IDENTIFIER, localHost, 1, Some(localRack)) - val getLocations = PrivateMethod[Seq[BlockManagerId]]('getLocations) - val locations = blockManager invokePrivate getLocations(BroadcastBlockId(0)) + val sortLocations = PrivateMethod[Seq[BlockManagerId]]('sortLocations) + val locations = blockManager invokePrivate sortLocations(bmMaster.getLocations("test")) assert(locations.map(_.host) === Seq(localHost, localHost, otherHost, otherHost, otherHost)) assert(locations.flatMap(_.topologyInfo) === Seq(localRack, localRack, localRack, otherRack, otherRack)) @@ -1274,13 +1273,18 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE // so that we have a chance to do location refresh val blockManagerIds = (0 to maxFailuresBeforeLocationRefresh) .map { i => BlockManagerId(s"id-$i", s"host-$i", i + 1) } - when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockManagerIds) + when(mockBlockManagerMaster.getLocationsAndStatus(mc.any[BlockId])).thenReturn( + Option(BlockLocationsAndStatus(blockManagerIds, BlockStatus.empty))) + when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn( + blockManagerIds) + store = makeBlockManager(8000, "executor1", mockBlockManagerMaster, transferService = Option(mockBlockTransferService)) val block = store.getRemoteBytes("item") .asInstanceOf[Option[ByteBuffer]] assert(block.isDefined) - verify(mockBlockManagerMaster, times(2)).getLocations("item") + verify(mockBlockManagerMaster, times(1)).getLocationsAndStatus("item") + verify(mockBlockManagerMaster, times(1)).getLocations("item") } test("SPARK-17484: block status is properly updated following an exception in put()") { @@ -1371,8 +1375,32 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE server.close() } + test("fetch remote block to local disk if block size is larger than threshold") { + conf.set(MAX_REMOTE_BLOCK_SIZE_FETCH_TO_MEM, 1000L) + + val mockBlockManagerMaster = mock(classOf[BlockManagerMaster]) + val mockBlockTransferService = new MockBlockTransferService(0) + val blockLocations = Seq(BlockManagerId("id-0", "host-0", 1)) + val blockStatus = BlockStatus(StorageLevel.DISK_ONLY, 0L, 2000L) + + when(mockBlockManagerMaster.getLocationsAndStatus(mc.any[BlockId])).thenReturn( + Option(BlockLocationsAndStatus(blockLocations, blockStatus))) + when(mockBlockManagerMaster.getLocations(mc.any[BlockId])).thenReturn(blockLocations) + + store = makeBlockManager(8000, "executor1", mockBlockManagerMaster, + transferService = Option(mockBlockTransferService)) + val block = store.getRemoteBytes("item") + .asInstanceOf[Option[ByteBuffer]] + + assert(block.isDefined) + assert(mockBlockTransferService.numCalls === 1) + // assert FileManager is not null if the block size is larger than threshold. + assert(mockBlockTransferService.tempFileManager === store.remoteBlockTempFileManager) + } + class MockBlockTransferService(val maxFailures: Int) extends BlockTransferService { var numCalls = 0 + var tempFileManager: TempFileManager = null override def init(blockDataManager: BlockDataManager): Unit = {} @@ -1382,7 +1410,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE execId: String, blockIds: Array[String], listener: BlockFetchingListener, - tempShuffleFileManager: TempShuffleFileManager): Unit = { + tempFileManager: TempFileManager): Unit = { listener.onBlockFetchSuccess("mockBlockId", new NioManagedBuffer(ByteBuffer.allocate(1))) } @@ -1394,7 +1422,8 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE override def uploadBlock( hostname: String, - port: Int, execId: String, + port: Int, + execId: String, blockId: BlockId, blockData: ManagedBuffer, level: StorageLevel, @@ -1407,12 +1436,14 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE host: String, port: Int, execId: String, - blockId: String): ManagedBuffer = { + blockId: String, + tempFileManager: TempFileManager): ManagedBuffer = { numCalls += 1 + this.tempFileManager = tempFileManager if (numCalls <= maxFailures) { throw new RuntimeException("Failing block fetch in the mock block transfer service") } - super.fetchBlockSync(host, port, execId, blockId) + super.fetchBlockSync(host, port, execId, blockId, tempFileManager) } } } diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 7859b0bba2b48..0c4f3c48ef802 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.storage import java.io.{File, FileWriter} +import java.util.UUID import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach} @@ -79,6 +80,12 @@ class DiskBlockManagerSuite extends SparkFunSuite with BeforeAndAfterEach with B assert(diskBlockManager.getAllBlocks.toSet === ids.toSet) } + test("SPARK-22227: non-block files are skipped") { + val file = diskBlockManager.getFile("unmanaged_file") + writeToFile(file, 10) + assert(diskBlockManager.getAllBlocks().isEmpty) + } + def writeToFile(file: File, numBytes: Int) { val writer = new FileWriter(file, true) for (i <- 0 until numBytes) writer.write(i) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index c371cbcf8dff5..5bfe9905ff17b 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.shuffle.{BlockFetchingListener, TempShuffleFileManager} +import org.apache.spark.network.shuffle.{BlockFetchingListener, TempFileManager} import org.apache.spark.network.util.LimitedInputStream import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.Utils @@ -437,12 +437,12 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) val transfer = mock(classOf[BlockTransferService]) - var tempShuffleFileManager: TempShuffleFileManager = null + var tempFileManager: TempFileManager = null when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) .thenAnswer(new Answer[Unit] { override def answer(invocation: InvocationOnMock): Unit = { val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempShuffleFileManager = invocation.getArguments()(5).asInstanceOf[TempShuffleFileManager] + tempFileManager = invocation.getArguments()(5).asInstanceOf[TempFileManager] Future { listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) @@ -472,13 +472,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT fetchShuffleBlock(blocksByAddress1) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. - assert(tempShuffleFileManager == null) + assert(tempFileManager == null) val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) fetchShuffleBlock(blocksByAddress2) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. - assert(tempShuffleFileManager != null) + assert(tempFileManager != null) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index a1a858765a7d4..4abbb8e7894f5 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -96,6 +96,9 @@ class JsonProtocolSuite extends SparkFunSuite { .zipWithIndex.map { case (a, i) => a.copy(id = i) } SparkListenerExecutorMetricsUpdate("exec3", Seq((1L, 2, 3, accumUpdates))) } + val blockUpdated = + SparkListenerBlockUpdated(BlockUpdatedInfo(BlockManagerId("Stars", + "In your multitude...", 300), RDDBlockId(0, 0), StorageLevel.MEMORY_ONLY, 100L, 0L)) testEvent(stageSubmitted, stageSubmittedJsonString) testEvent(stageCompleted, stageCompletedJsonString) @@ -120,6 +123,7 @@ class JsonProtocolSuite extends SparkFunSuite { testEvent(nodeBlacklisted, nodeBlacklistedJsonString) testEvent(nodeUnblacklisted, nodeUnblacklistedJsonString) testEvent(executorMetricsUpdate, executorMetricsUpdateJsonString) + testEvent(blockUpdated, blockUpdatedJsonString) } test("Dependent Classes") { @@ -2007,6 +2011,29 @@ private[spark] object JsonProtocolSuite extends Assertions { |} """.stripMargin + private val blockUpdatedJsonString = + """ + |{ + | "Event": "SparkListenerBlockUpdated", + | "Block Updated Info": { + | "Block Manager ID": { + | "Executor ID": "Stars", + | "Host": "In your multitude...", + | "Port": 300 + | }, + | "Block ID": "rdd_0_0", + | "Storage Level": { + | "Use Disk": false, + | "Use Memory": true, + | "Deserialized": true, + | "Replication": 1 + | }, + | "Memory Size": 100, + | "Disk Size": 0 + | } + |} + """.stripMargin + private val executorBlacklistedJsonString = s""" |{ diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 2b16cc4852ba8..4d3adeb968e84 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -38,9 +38,10 @@ import org.apache.commons.math3.stat.inference.ChiSquareTest import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path -import org.apache.spark.{SparkConf, SparkFunSuite, TaskContext} +import org.apache.spark.{SparkConf, SparkException, SparkFunSuite, TaskContext} import org.apache.spark.internal.Logging import org.apache.spark.network.util.ByteUnit +import org.apache.spark.scheduler.SparkListener class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { @@ -1110,4 +1111,57 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with Logging { Utils.tryWithSafeFinallyAndFailureCallbacks {}(catchBlock = {}, finallyBlock = {}) TaskContext.unset } + + test("load extensions") { + val extensions = Seq( + classOf[SimpleExtension], + classOf[ExtensionWithConf], + classOf[UnregisterableExtension]).map(_.getName()) + + val conf = new SparkConf(false) + val instances = Utils.loadExtensions(classOf[Object], extensions, conf) + assert(instances.size === 2) + assert(instances.count(_.isInstanceOf[SimpleExtension]) === 1) + + val extWithConf = instances.find(_.isInstanceOf[ExtensionWithConf]) + .map(_.asInstanceOf[ExtensionWithConf]) + .get + assert(extWithConf.conf eq conf) + + class NestedExtension { } + + val invalid = Seq(classOf[NestedExtension].getName()) + intercept[SparkException] { + Utils.loadExtensions(classOf[Object], invalid, conf) + } + + val error = Seq(classOf[ExtensionWithError].getName()) + intercept[IllegalArgumentException] { + Utils.loadExtensions(classOf[Object], error, conf) + } + + val wrongType = Seq(classOf[ListenerImpl].getName()) + intercept[IllegalArgumentException] { + Utils.loadExtensions(classOf[Seq[_]], wrongType, conf) + } + } + +} + +private class SimpleExtension + +private class ExtensionWithConf(val conf: SparkConf) + +private class UnregisterableExtension { + + throw new UnsupportedOperationException() + +} + +private class ExtensionWithError { + + throw new IllegalArgumentException() + } + +private class ListenerImpl extends SparkListener diff --git a/dev/checkstyle-suppressions.xml b/dev/checkstyle-suppressions.xml index 6e15f6955984e..bbda824dd13b4 100644 --- a/dev/checkstyle-suppressions.xml +++ b/dev/checkstyle-suppressions.xml @@ -40,8 +40,6 @@ files="src/main/java/org/apache/hive/service/*"/> - - diff --git a/dev/create-release/release-build.sh b/dev/create-release/release-build.sh index 8de1d6a37dc25..7e8d5c7075195 100755 --- a/dev/create-release/release-build.sh +++ b/dev/create-release/release-build.sh @@ -74,7 +74,7 @@ GIT_REF=${GIT_REF:-master} # Destination directory parent on remote server REMOTE_PARENT_DIR=${REMOTE_PARENT_DIR:-/home/$ASF_USERNAME/public_html} -GPG="gpg --no-tty --batch" +GPG="gpg -u $GPG_KEY --no-tty --batch" NEXUS_ROOT=https://repository.apache.org/service/local/staging NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads BASE_DIR=$(pwd) @@ -84,9 +84,9 @@ MVN="build/mvn --force" # Hive-specific profiles for some builds HIVE_PROFILES="-Phive -Phive-thriftserver" # Profiles for publishing snapshots and release to Maven Central -PUBLISH_PROFILES="-Pmesos -Pyarn $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" +PUBLISH_PROFILES="-Pmesos -Pyarn -Pflume $HIVE_PROFILES -Pspark-ganglia-lgpl -Pkinesis-asl" # Profiles for building binary releases -BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Psparkr" +BASE_RELEASE_PROFILES="-Pmesos -Pyarn -Pflume -Psparkr" # Scala 2.11 only profiles for some builds SCALA_2_11_PROFILES="-Pkafka-0-8" # Scala 2.12 only profiles for some builds @@ -125,7 +125,7 @@ else echo "Please set JAVA_HOME correctly." exit 1 else - JAVA_HOME="$JAVA_7_HOME" + export JAVA_HOME="$JAVA_7_HOME" fi fi fi @@ -140,7 +140,7 @@ DEST_DIR_NAME="spark-$SPARK_PACKAGE_VERSION" function LFTP { SSH="ssh -o ConnectTimeout=300 -o StrictHostKeyChecking=no -i $ASF_RSA_KEY" COMMANDS=$(cat <= (2, 7): - subprocess_check_output = subprocess.check_output - subprocess_check_call = subprocess.check_call -else: - # SPARK-8763 - # backported from subprocess module in Python 2.7 - def subprocess_check_output(*popenargs, **kwargs): - if 'stdout' in kwargs: - raise ValueError('stdout argument not allowed, it will be overridden.') - process = subprocess.Popen(stdout=subprocess.PIPE, *popenargs, **kwargs) - output, unused_err = process.communicate() - retcode = process.poll() - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise subprocess.CalledProcessError(retcode, cmd, output=output) - return output - - # backported from subprocess module in Python 2.7 - def subprocess_check_call(*popenargs, **kwargs): - retcode = call(*popenargs, **kwargs) - if retcode: - cmd = kwargs.get("args") - if cmd is None: - cmd = popenargs[0] - raise CalledProcessError(retcode, cmd) - return 0 +subprocess_check_output = subprocess.check_output +subprocess_check_call = subprocess.check_call def exit_from_command_with_retcode(cmd, retcode): diff --git a/dev/test-dependencies.sh b/dev/test-dependencies.sh index c7714578bd005..58b295d4f6e00 100755 --- a/dev/test-dependencies.sh +++ b/dev/test-dependencies.sh @@ -29,7 +29,7 @@ export LC_ALL=C # TODO: This would be much nicer to do in SBT, once SBT supports Maven-style resolution. # NOTE: These should match those in the release publishing script -HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Phive" +HADOOP2_MODULE_PROFILES="-Phive-thriftserver -Pmesos -Pkafka-0-8 -Pyarn -Pflume -Phive" MVN="build/mvn" HADOOP_PROFILES=( hadoop-2.6 diff --git a/docs/building-spark.md b/docs/building-spark.md index 57baa503259c1..98f7df155456f 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -100,6 +100,13 @@ Note: Kafka 0.8 support is deprecated as of Spark 2.3.0. Kafka 0.10 support is still automatically built. +## Building with Flume support + +Apache Flume support must be explicitly enabled with the `flume` profile. +Note: Flume support is deprecated as of Spark 2.3.0. + + ./build/mvn -Pflume -DskipTests clean package + ## Building submodules individually It's possible to build Spark sub-modules using the `mvn -pl` option. diff --git a/docs/cluster-overview.md b/docs/cluster-overview.md index a2ad958959a50..c42bb4bb8377e 100644 --- a/docs/cluster-overview.md +++ b/docs/cluster-overview.md @@ -58,6 +58,9 @@ for providing container-centric infrastructure. Kubernetes support is being acti developed in an [apache-spark-on-k8s](https://github.com/apache-spark-on-k8s/) Github organization. For documentation, refer to that project's README. +A third-party project (not supported by the Spark project) exists to add support for +[Nomad](https://github.com/hashicorp/nomad-spark) as a cluster manager. + # Submitting Applications Applications can be submitted to a cluster of any type using the `spark-submit` script. diff --git a/docs/configuration.md b/docs/configuration.md index 6e9fe591b70a3..d3c358bb74173 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -547,13 +547,14 @@ Apart from these, the following properties are also available, and may be useful - spark.reducer.maxReqSizeShuffleToMem + spark.maxRemoteBlockSizeFetchToMem Long.MaxValue - The blocks of a shuffle request will be fetched to disk when size of the request is above - this threshold. This is to avoid a giant request takes too much memory. We can enable this - config by setting a specific value(e.g. 200m). Note that this config can be enabled only when - the shuffle shuffle service is newer than Spark-2.2 or the shuffle service is disabled. + The remote block will be fetched to disk when size of the block is above this threshold. + This is to avoid a giant request takes too much memory. We can enable this config by setting + a specific value(e.g. 200m). Note this configuration will affect both shuffle fetch + and block manager remote block fetch. For users who enabled external shuffle service, + this feature can only be worked when external shuffle service is newer than Spark 2.2. @@ -713,6 +714,14 @@ Apart from these, the following properties are also available, and may be useful + + + + + @@ -739,6 +748,20 @@ Apart from these, the following properties are also available, and may be useful finished. + + + + + + + + + + @@ -1015,7 +1038,7 @@ Apart from these, the following properties are also available, and may be useful diff --git a/docs/monitoring.md b/docs/monitoring.md index 51084a25983ea..1ae43185d22f8 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -224,6 +224,15 @@ The history server can be configured as follows: Number of threads that will be used by history server to process event logs. + + + + +
Property NameDefaultMeaning
spark.eventLog.logBlockUpdates.enabledfalse + Whether to log events for every block update, if spark.eventLog.enabled is true. + *Warning*: This will increase the size of the event log considerably. +
spark.eventLog.compress false
spark.eventLog.overwritefalse + Whether to overwrite any existing files. +
spark.eventLog.buffer.kb100k + Buffer size in KB to use when writing to output streams. +
spark.ui.enabled true0.5 Amount of storage memory immune to eviction, expressed as a fraction of the size of the - region set aside by s​park.memory.fraction. The higher this is, the less + region set aside by spark.memory.fraction. The higher this is, the less working memory may be available to execution and tasks may spill to disk more often. Leaving this at the default value is recommended. For more detail, see this description. @@ -1041,7 +1064,7 @@ Apart from these, the following properties are also available, and may be useful spark.memory.useLegacyMode false - ​Whether to enable the legacy memory management mode used in Spark 1.5 and before. + Whether to enable the legacy memory management mode used in Spark 1.5 and before. The legacy mode rigidly partitions the heap space into fixed-size regions, potentially leading to excessive spilling if the application was not tuned. The following deprecated memory fraction configurations are not read unless this is enabled: @@ -1115,11 +1138,8 @@ Apart from these, the following properties are also available, and may be useful The number of cores to use on each executor. - In standalone and Mesos coarse-grained modes, setting this - parameter allows an application to run multiple executors on the - same worker, provided that there are enough cores on that - worker. Otherwise, only one executor per application will run on - each worker. + In standalone and Mesos coarse-grained modes, for more detail, see + this description.
spark.history.store.path(none) + Local directory where to cache application history data. If set, the history + server will store application data on disk instead of keeping it in memory. The data + written to disk will be re-used in the event of a history server restart. +
Note that in all of these UIs, the tables are sortable by clicking their headers, diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 432639588cc2b..9599d40c545b2 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -401,6 +401,15 @@ To use a custom metrics.properties for the application master and executors, upd Principal to be used to login to KDC, while running on secure HDFS. (Works also with the "local" master) + + spark.yarn.kerberos.relogin.period + 1m + + How often to check whether the kerberos TGT should be renewed. This should be set to a value + that is shorter than the TGT renewal period (or the TGT lifetime if TGT renewal is not enabled). + The default value should be enough for most deployments. + + spark.yarn.config.gatewayPath (none) diff --git a/docs/security.md b/docs/security.md index 1d004003f9a32..15aadf07cf873 100644 --- a/docs/security.md +++ b/docs/security.md @@ -186,7 +186,54 @@ configure those ports. +### HTTP Security Headers + +Apache Spark can be configured to include HTTP Headers which aids in preventing Cross +Site Scripting (XSS), Cross-Frame Scripting (XFS), MIME-Sniffing and also enforces HTTP +Strict Transport Security. + + + + + + + + + + + + + + + + + + +
Property NameDefaultMeaning
spark.ui.xXssProtection1; mode=block + Value for HTTP X-XSS-Protection response header. You can choose appropriate value + from below: +
    +
  • 0 (Disables XSS filtering)
  • +
  • 1 (Enables XSS filtering. If a cross-site scripting attack is detected, + the browser will sanitize the page.)
  • +
  • 1; mode=block (Enables XSS filtering. The browser will prevent rendering + of the page if an attack is detected.)
  • +
+
spark.ui.xContentTypeOptions.enabledtrue + When value is set to "true", X-Content-Type-Options HTTP response header will be set + to "nosniff". Set "false" to disable. +
spark.ui.strictTransportSecurityNone + Value for HTTP Strict Transport Security (HSTS) Response Header. You can choose appropriate + value from below and set expire-time accordingly, when Spark is SSL/TLS enabled. +
    +
  • max-age=<expire-time>
  • +
  • max-age=<expire-time>; includeSubDomains
  • +
  • max-age=<expire-time>; preload
  • +
+
+ See the [configuration page](configuration.html) for more details on the security configuration parameters, and org.apache.spark.SecurityManager for implementation details about security. + diff --git a/docs/spark-standalone.md b/docs/spark-standalone.md index 1095386c31ab8..f51c5cc38f4de 100644 --- a/docs/spark-standalone.md +++ b/docs/spark-standalone.md @@ -328,6 +328,14 @@ export SPARK_MASTER_OPTS="-Dspark.deploy.defaultCores=" This is useful on shared clusters where users might not have configured a maximum number of cores individually. +# Executors Scheduling + +The number of cores assigned to each executor is configurable. When `spark.executor.cores` is +explicitly set, multiple executors from the same application may be launched on the same worker +if the worker has enough cores and memory. Otherwise, each executor grabs all the cores available +on the worker by default, in which case only one executor per application may be launched on each +worker during one single schedule iteration. + # Monitoring and Logging Spark's standalone mode offers a web-based user interface to monitor the cluster. The master and each worker has its own web UI that shows cluster and job statistics. By default you can access the web UI for the master at port 8080. The port can be changed either in the configuration file or via command-line options. diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index a095263bfa619..639a8ea7bb8ad 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -461,6 +461,8 @@ name (i.e., `org.apache.spark.sql.parquet`), but for built-in sources you can al names (`json`, `parquet`, `jdbc`, `orc`, `libsvm`, `csv`, `text`). DataFrames loaded from any data source type can be converted into other types using this syntax. +To load a JSON file you can use: +
{% include_example manual_load_options scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} @@ -479,6 +481,26 @@ source type can be converted into other types using this syntax.
+To load a CSV file you can use: + +
+
+{% include_example manual_load_options_csv scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %} +
+ +
+{% include_example manual_load_options_csv java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %} +
+ +
+{% include_example manual_load_options_csv python/sql/datasource.py %} +
+ +
+{% include_example manual_load_options_csv r/RSparkSQLExample.R %} + +
+
### Run SQL on files directly Instead of using read API to load a file into DataFrame and query it, you can also query that @@ -573,7 +595,7 @@ Note that partition information is not gathered by default when creating externa ### Bucketing, Sorting and Partitioning -For file-based data source, it is also possible to bucket and sort or partition the output. +For file-based data source, it is also possible to bucket and sort or partition the output. Bucketing and sorting are applicable only to persistent tables:
@@ -598,7 +620,7 @@ CREATE TABLE users_bucketed_by_name( name STRING, favorite_color STRING, favorite_numbers array -) USING parquet +) USING parquet CLUSTERED BY(name) INTO 42 BUCKETS; {% endhighlight %} @@ -629,7 +651,7 @@ while partitioning can be used with both `save` and `saveAsTable` when using the {% highlight sql %} CREATE TABLE users_by_favorite_color( - name STRING, + name STRING, favorite_color STRING, favorite_numbers array ) USING csv PARTITIONED BY(favorite_color); @@ -664,7 +686,7 @@ CREATE TABLE users_bucketed_and_partitioned( name STRING, favorite_color STRING, favorite_numbers array -) USING parquet +) USING parquet PARTITIONED BY (favorite_color) CLUSTERED BY(name) SORTED BY (favorite_numbers) INTO 42 BUCKETS; @@ -675,7 +697,7 @@ CLUSTERED BY(name) SORTED BY (favorite_numbers) INTO 42 BUCKETS;
`partitionBy` creates a directory structure as described in the [Partition Discovery](#partition-discovery) section. -Thus, it has limited applicability to columns with high cardinality. In contrast +Thus, it has limited applicability to columns with high cardinality. In contrast `bucketBy` distributes data across a fixed number of buckets and can be used when a number of unique values is unbounded. diff --git a/docs/streaming-flume-integration.md b/docs/streaming-flume-integration.md index a5d36da5b6de9..257a4f7d4f3ca 100644 --- a/docs/streaming-flume-integration.md +++ b/docs/streaming-flume-integration.md @@ -5,6 +5,8 @@ title: Spark Streaming + Flume Integration Guide [Apache Flume](https://flume.apache.org/) is a distributed, reliable, and available service for efficiently collecting, aggregating, and moving large amounts of log data. Here we explain how to configure Flume and Spark Streaming to receive data from Flume. There are two approaches to this. +**Note: Flume support is deprecated as of Spark 2.3.0.** + ## Approach 1: Flume-style Push-based Approach Flume is designed to push data between Flume agents. In this approach, Spark Streaming essentially sets up a receiver that acts an Avro agent for Flume, to which Flume can push the data. Here are the configuration steps. @@ -44,8 +46,7 @@ configuring Flume agents. val flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) - See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala). + See the [API docs](api/scala/index.html#org.apache.spark.streaming.flume.FlumeUtils$).
import org.apache.spark.streaming.flume.*; @@ -53,8 +54,7 @@ configuring Flume agents. JavaReceiverInputDStream flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]); - See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html) - and the [example]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java). + See the [API docs](api/java/index.html?org/apache/spark/streaming/flume/FlumeUtils.html).
from pyspark.streaming.flume import FlumeUtils @@ -62,8 +62,7 @@ configuring Flume agents. flumeStream = FlumeUtils.createStream(streamingContext, [chosen machine's hostname], [chosen port]) By default, the Python API will decode Flume event body as UTF8 encoded strings. You can specify your custom decoding function to decode the body byte arrays in Flume events to any arbitrary data type. - See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils) - and the [example]({{site.SPARK_GITHUB_URL}}/blob/v{{site.SPARK_VERSION_SHORT}}/examples/src/main/python/streaming/flume_wordcount.py). + See the [API docs](api/python/pyspark.streaming.html#pyspark.streaming.flume.FlumeUtils).
@@ -162,8 +161,6 @@ configuring Flume agents. - See the Scala example [FlumePollingEventCount]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala). - Note that each input DStream can be configured to receive data from multiple sinks. 3. **Deploying:** This is same as the first approach. diff --git a/examples/pom.xml b/examples/pom.xml index 52a6764ae26a5..1791dbaad775e 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -34,7 +34,6 @@ examples none package - provided provided provided provided @@ -78,12 +77,6 @@ ${project.version} provided
- - org.apache.spark - spark-streaming-flume_${scala.binary.version} - ${project.version} - provided - org.apache.spark spark-streaming-kafka-0-10_${scala.binary.version} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index fe4d6bc83f04a..27052be87b82e 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -118,9 +118,18 @@ public static void main(String[] args) { Dataset userRecs = model.recommendForAllUsers(10); // Generate top 10 user recommendations for each movie Dataset movieRecs = model.recommendForAllItems(10); + + // Generate top 10 movie recommendations for a specified set of users + Dataset users = ratings.select(als.getUserCol()).distinct().limit(3); + Dataset userSubsetRecs = model.recommendForUserSubset(users, 10); + // Generate top 10 user recommendations for a specified set of movies + Dataset movies = ratings.select(als.getItemCol()).distinct().limit(3); + Dataset movieSubSetRecs = model.recommendForItemSubset(movies, 10); // $example off$ userRecs.show(); movieRecs.show(); + userSubsetRecs.show(); + movieSubSetRecs.show(); spark.stop(); } diff --git a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java index 95859c52c2aeb..ef3c904775697 100644 --- a/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java +++ b/examples/src/main/java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java @@ -116,6 +116,13 @@ private static void runBasicDataSourceExample(SparkSession spark) { spark.read().format("json").load("examples/src/main/resources/people.json"); peopleDF.select("name", "age").write().format("parquet").save("namesAndAges.parquet"); // $example off:manual_load_options$ + // $example on:manual_load_options_csv$ + Dataset peopleDFCsv = spark.read().format("csv") + .option("sep", ";") + .option("inferSchema", "true") + .option("header", "true") + .load("examples/src/main/resources/people.csv"); + // $example off:manual_load_options_csv$ // $example on:direct_sql$ Dataset sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`"); diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index 1672d552eb1d5..8b7ec9c439f9f 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -60,8 +60,17 @@ userRecs = model.recommendForAllUsers(10) # Generate top 10 user recommendations for each movie movieRecs = model.recommendForAllItems(10) + + # Generate top 10 movie recommendations for a specified set of users + users = ratings.select(als.getUserCol()).distinct().limit(3) + userSubsetRecs = model.recommendForUserSubset(users, 10) + # Generate top 10 user recommendations for a specified set of movies + movies = ratings.select(als.getItemCol()).distinct().limit(3) + movieSubSetRecs = model.recommendForItemSubset(movies, 10) # $example off$ userRecs.show() movieRecs.show() + userSubsetRecs.show() + movieSubSetRecs.show() spark.stop() diff --git a/examples/src/main/python/sql/datasource.py b/examples/src/main/python/sql/datasource.py index f86012ea382e8..b375fa775de39 100644 --- a/examples/src/main/python/sql/datasource.py +++ b/examples/src/main/python/sql/datasource.py @@ -53,6 +53,11 @@ def basic_datasource_example(spark): df.select("name", "age").write.save("namesAndAges.parquet", format="parquet") # $example off:manual_load_options$ + # $example on:manual_load_options_csv$ + df = spark.read.load("examples/src/main/resources/people.csv", + format="csv", sep=":", inferSchema="true", header="true") + # $example off:manual_load_options_csv$ + # $example on:write_sorting_and_bucketing$ df.write.bucketBy(42, "name").sortBy("age").saveAsTable("people_bucketed") # $example off:write_sorting_and_bucketing$ diff --git a/examples/src/main/r/RSparkSQLExample.R b/examples/src/main/r/RSparkSQLExample.R index 3734568d872d0..a5ed723da47ca 100644 --- a/examples/src/main/r/RSparkSQLExample.R +++ b/examples/src/main/r/RSparkSQLExample.R @@ -113,6 +113,12 @@ write.df(namesAndAges, "namesAndAges.parquet", "parquet") # $example off:manual_load_options$ +# $example on:manual_load_options_csv$ +df <- read.df("examples/src/main/resources/people.csv", "csv") +namesAndAges <- select(df, "name", "age") +# $example off:manual_load_options_csv$ + + # $example on:direct_sql$ df <- sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") # $example off:direct_sql$ diff --git a/examples/src/main/resources/people.csv b/examples/src/main/resources/people.csv new file mode 100644 index 0000000000000..7fe5adba93d77 --- /dev/null +++ b/examples/src/main/resources/people.csv @@ -0,0 +1,3 @@ +name;age;job +Jorge;30;Developer +Bob;32;Developer diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index 07b15dfa178f7..8091838a2301e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -80,9 +80,18 @@ object ALSExample { val userRecs = model.recommendForAllUsers(10) // Generate top 10 user recommendations for each movie val movieRecs = model.recommendForAllItems(10) + + // Generate top 10 movie recommendations for a specified set of users + val users = ratings.select(als.getUserCol).distinct().limit(3) + val userSubsetRecs = model.recommendForUserSubset(users, 10) + // Generate top 10 user recommendations for a specified set of movies + val movies = ratings.select(als.getItemCol).distinct().limit(3) + val movieSubSetRecs = model.recommendForItemSubset(movies, 10) // $example off$ userRecs.show() movieRecs.show() + userSubsetRecs.show() + movieSubSetRecs.show() spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala index 86b3dc4a84f58..f9477969a4bb5 100644 --- a/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala @@ -49,6 +49,14 @@ object SQLDataSourceExample { val peopleDF = spark.read.format("json").load("examples/src/main/resources/people.json") peopleDF.select("name", "age").write.format("parquet").save("namesAndAges.parquet") // $example off:manual_load_options$ + // $example on:manual_load_options_csv$ + val peopleDFCsv = spark.read.format("csv") + .option("sep", ";") + .option("inferSchema", "true") + .option("header", "true") + .load("examples/src/main/resources/people.csv") + // $example off:manual_load_options_csv$ + // $example on:direct_sql$ val sqlDF = spark.sql("SELECT * FROM parquet.`examples/src/main/resources/users.parquet`") // $example off:direct_sql$ diff --git a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala index 7680ae3835132..90343182712ed 100644 --- a/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala +++ b/external/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleIntegrationSuite.scala @@ -21,7 +21,7 @@ import java.sql.{Connection, Date, Timestamp} import java.util.Properties import java.math.BigDecimal -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Row, SaveMode} import org.apache.spark.sql.execution.{WholeStageCodegenExec, RowDataSourceScanExec} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -52,7 +52,7 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo import testImplicits._ override val db = new DatabaseOnDocker { - override val imageName = "wnameless/oracle-xe-11g:14.04.4" + override val imageName = "wnameless/oracle-xe-11g:16.04" override val env = Map( "ORACLE_ROOT_PASSWORD" -> "oracle" ) @@ -104,15 +104,18 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo """.stripMargin.replaceAll("\n", " ")) - conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate(); + conn.prepareStatement("CREATE TABLE numerics (b DECIMAL(1), f DECIMAL(3, 2), i DECIMAL(10))").executeUpdate() conn.prepareStatement( - "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate(); - conn.commit(); + "INSERT INTO numerics VALUES (4, 1.23, 9999999999)").executeUpdate() + conn.commit() + + conn.prepareStatement("CREATE TABLE oracle_types (d BINARY_DOUBLE, f BINARY_FLOAT)").executeUpdate() + conn.commit() } test("SPARK-16625 : Importing Oracle numeric types") { - val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties); + val df = sqlContext.read.jdbc(jdbcUrl, "numerics", new Properties) val rows = df.collect() assert(rows.size == 1) val row = rows(0) @@ -307,4 +310,32 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationSuite with SharedSQLCo assert(values.getInt(1).equals(1)) assert(values.getBoolean(2).equals(false)) } + + test("SPARK-22303: handle BINARY_DOUBLE and BINARY_FLOAT as DoubleType and FloatType") { + val tableName = "oracle_types" + val schema = StructType(Seq( + StructField("d", DoubleType, true), + StructField("f", FloatType, true))) + val props = new Properties() + + // write it back to the table (append mode) + val data = spark.sparkContext.parallelize(Seq(Row(1.1, 2.2f))) + val dfWrite = spark.createDataFrame(data, schema) + dfWrite.write.mode(SaveMode.Append).jdbc(jdbcUrl, tableName, props) + + // read records from oracle_types + val dfRead = sqlContext.read.jdbc(jdbcUrl, tableName, new Properties) + val rows = dfRead.collect() + assert(rows.size == 1) + + // check data types + val types = dfRead.schema.map(field => field.dataType) + assert(types(0).equals(DoubleType)) + assert(types(1).equals(FloatType)) + + // check values + val values = rows(0) + assert(values.getDouble(0) === 1.1) + assert(values.getFloat(1) === 2.2f) + } } diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java similarity index 98% rename from examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java rename to external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java index 0c651049d0ffa..4e3420d9c3b06 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ b/external/flume/src/main/java/org/apache/spark/examples/JavaFlumeEventCount.java @@ -48,8 +48,6 @@ public static void main(String[] args) throws Exception { System.exit(1); } - StreamingExamples.setStreamingLogLevels(); - String host = args[0]; int port = Integer.parseInt(args[1]); diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala b/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala similarity index 98% rename from examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala rename to external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala index 91e52e4eff5a7..f877f79391b37 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumeEventCount.scala +++ b/external/flume/src/main/scala/org/apache/spark/examples/FlumeEventCount.scala @@ -47,8 +47,6 @@ object FlumeEventCount { System.exit(1) } - StreamingExamples.setStreamingLogLevels() - val Array(host, IntParam(port)) = args val batchInterval = Milliseconds(2000) diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala b/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala similarity index 98% rename from examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala rename to external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala index dd725d72c23ef..79a4027ca5bde 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/FlumePollingEventCount.scala +++ b/external/flume/src/main/scala/org/apache/spark/examples/FlumePollingEventCount.scala @@ -44,8 +44,6 @@ object FlumePollingEventCount { System.exit(1) } - StreamingExamples.setStreamingLogLevels() - val Array(host, IntParam(port)) = args val batchInterval = Milliseconds(2000) diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala index 3e3ed712f0dbf..707193a957700 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeUtils.scala @@ -30,6 +30,7 @@ import org.apache.spark.streaming.StreamingContext import org.apache.spark.streaming.api.java.{JavaPairDStream, JavaReceiverInputDStream, JavaStreamingContext} import org.apache.spark.streaming.dstream.ReceiverInputDStream +@deprecated("Deprecated without replacement", "2.3.0") object FlumeUtils { private val DEFAULT_POLLING_PARALLELISM = 5 private val DEFAULT_POLLING_BATCH_SIZE = 1000 diff --git a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java index 865d4926da6a9..4353e3f263c51 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java +++ b/launcher/src/main/java/org/apache/spark/launcher/LauncherServer.java @@ -232,20 +232,20 @@ public void run() { }; ServerConnection clientConnection = new ServerConnection(client, timeout); Thread clientThread = factory.newThread(clientConnection); - synchronized (timeout) { - clientThread.start(); - synchronized (clients) { - clients.add(clientConnection); - } - long timeoutMs = getConnectionTimeout(); - // 0 is used for testing to avoid issues with clock resolution / thread scheduling, - // and force an immediate timeout. - if (timeoutMs > 0) { - timeoutTimer.schedule(timeout, getConnectionTimeout()); - } else { - timeout.run(); - } + synchronized (clients) { + clients.add(clientConnection); } + + long timeoutMs = getConnectionTimeout(); + // 0 is used for testing to avoid issues with clock resolution / thread scheduling, + // and force an immediate timeout. + if (timeoutMs > 0) { + timeoutTimer.schedule(timeout, timeoutMs); + } else { + timeout.run(); + } + + clientThread.start(); } } catch (IOException ioe) { if (running) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index 718a368a8e731..75b8ef5ca5ef4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -625,7 +625,7 @@ private static class ArgumentValidator extends SparkSubmitOptionParser { @Override protected boolean handle(String opt, String value) { if (value == null && hasValue) { - throw new IllegalArgumentException(String.format("'%s' does not expect a value.", opt)); + throw new IllegalArgumentException(String.format("'%s' expects a value.", opt)); } return true; } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index ef08134809915..730fcab333e11 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -230,21 +230,23 @@ private[ml] object ProbabilisticClassificationModel { * Normalize a vector of raw predictions to be a multinomial probability vector, in place. * * The input raw predictions should be nonnegative. - * The output vector sums to 1, unless the input vector is all-0 (in which case the output is - * all-0 too). + * The output vector sums to 1. * * NOTE: This is NOT applicable to all models, only ones which effectively use class * instance counts for raw predictions. + * + * @throws IllegalArgumentException if the input vector is all-0 or including negative values */ def normalizeToProbabilitiesInPlace(v: DenseVector): Unit = { + v.values.foreach(value => require(value >= 0, + "The input raw predictions should be nonnegative.")) val sum = v.values.sum - if (sum != 0) { - var i = 0 - val size = v.size - while (i < size) { - v.values(i) /= sum - i += 1 - } + require(sum > 0, "Can't normalize the 0-vector.") + var i = 0 + val size = v.size + while (i < size) { + v.values(i) /= sum + i += 1 } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index 3da29b1c816b1..4bab670cc159f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -458,7 +458,7 @@ abstract class LDAModel private[ml] ( if ($(topicDistributionCol).nonEmpty) { // TODO: Make the transformer natively in ml framework to avoid extra conversion. - val transformer = oldLocalModel.getTopicDistributionMethod(sparkSession.sparkContext) + val transformer = oldLocalModel.getTopicDistributionMethod val t = udf { (v: Vector) => transformer(OldVectors.fromML(v)).asML } dataset.withColumn($(topicDistributionCol), t(col($(featuresCol)))).toDF() diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala index 1f36eced3d08f..4663f16b5f5dc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Imputer.scala @@ -223,20 +223,18 @@ class ImputerModel private[ml] ( override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) - var outputDF = dataset val surrogates = surrogateDF.select($(inputCols).map(col): _*).head().toSeq - $(inputCols).zip($(outputCols)).zip(surrogates).foreach { + val newCols = $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), surrogate) => val inputType = dataset.schema(inputCol).dataType val ic = col(inputCol) - outputDF = outputDF.withColumn(outputCol, - when(ic.isNull, surrogate) + when(ic.isNull, surrogate) .when(ic === $(missingValue), surrogate) .otherwise(ic) - .cast(inputType)) + .cast(inputType) } - outputDF.toDF() + dataset.withColumns($(outputCols), newCols).toDF() } override def transformSchema(schema: StructType): StructType = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 3d5fd1794de23..a8843661c873b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -344,6 +344,21 @@ class ALSModel private[ml] ( recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) } + /** + * Returns top `numItems` items recommended for each user id in the input data set. Note that if + * there are duplicate ids in the input dataset, only one set of recommendations per unique id + * will be returned. + * @param dataset a Dataset containing a column of user ids. The column name must match `userCol`. + * @param numItems max number of recommendations for each user. + * @return a DataFrame of (userCol: Int, recommendations), where recommendations are + * stored as an array of (itemCol: Int, rating: Float) Rows. + */ + @Since("2.3.0") + def recommendForUserSubset(dataset: Dataset[_], numItems: Int): DataFrame = { + val srcFactorSubset = getSourceFactorSubset(dataset, userFactors, $(userCol)) + recommendForAll(srcFactorSubset, itemFactors, $(userCol), $(itemCol), numItems) + } + /** * Returns top `numUsers` users recommended for each item, for all items. * @param numUsers max number of recommendations for each item @@ -355,6 +370,39 @@ class ALSModel private[ml] ( recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) } + /** + * Returns top `numUsers` users recommended for each item id in the input data set. Note that if + * there are duplicate ids in the input dataset, only one set of recommendations per unique id + * will be returned. + * @param dataset a Dataset containing a column of item ids. The column name must match `itemCol`. + * @param numUsers max number of recommendations for each item. + * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are + * stored as an array of (userCol: Int, rating: Float) Rows. + */ + @Since("2.3.0") + def recommendForItemSubset(dataset: Dataset[_], numUsers: Int): DataFrame = { + val srcFactorSubset = getSourceFactorSubset(dataset, itemFactors, $(itemCol)) + recommendForAll(srcFactorSubset, userFactors, $(itemCol), $(userCol), numUsers) + } + + /** + * Returns a subset of a factor DataFrame limited to only those unique ids contained + * in the input dataset. + * @param dataset input Dataset containing id column to user to filter factors. + * @param factors factor DataFrame to filter. + * @param column column name containing the ids in the input dataset. + * @return DataFrame containing factors only for those ids present in both the input dataset and + * the factor DataFrame. + */ + private def getSourceFactorSubset( + dataset: Dataset[_], + factors: DataFrame, + column: String): DataFrame = { + factors + .join(dataset.select(column), factors("id") === dataset(column), joinType = "left_semi") + .select(factors("id"), factors("features")) + } + /** * Makes recommendations for all users (or items). * diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala index 4ab420058f33d..b8a6e94248421 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAModel.scala @@ -371,7 +371,7 @@ class LocalLDAModel private[spark] ( /** * Get a method usable as a UDF for `topicDistributions()` */ - private[spark] def getTopicDistributionMethod(sc: SparkContext): Vector => Vector = { + private[spark] def getTopicDistributionMethod: Vector => Vector = { val expElogbeta = exp(LDAUtils.dirichletExpectation(topicsMatrix.asBreeze.toDenseMatrix.t).t) val docConcentrationBrz = this.docConcentration.asBreeze val gammaShape = this.gammaShape diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala index d633893e55f55..693a2a31f026b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala @@ -26,6 +26,7 @@ import breeze.stats.distributions.{Gamma, RandBasis} import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.graphx._ import org.apache.spark.graphx.util.PeriodicGraphCheckpointer +import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors} import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel @@ -259,7 +260,7 @@ final class EMLDAOptimizer extends LDAOptimizer { */ @Since("1.4.0") @DeveloperApi -final class OnlineLDAOptimizer extends LDAOptimizer { +final class OnlineLDAOptimizer extends LDAOptimizer with Logging { // LDA common parameters private var k: Int = 0 @@ -462,31 +463,61 @@ final class OnlineLDAOptimizer extends LDAOptimizer { val expElogbetaBc = batch.sparkContext.broadcast(expElogbeta) val alpha = this.alpha.asBreeze val gammaShape = this.gammaShape - - val stats: RDD[(BDM[Double], List[BDV[Double]])] = batch.mapPartitions { docs => + val optimizeDocConcentration = this.optimizeDocConcentration + // If and only if optimizeDocConcentration is set true, + // we calculate logphat in the same pass as other statistics. + // No calculation of loghat happens otherwise. + val logphatPartOptionBase = () => if (optimizeDocConcentration) { + Some(BDV.zeros[Double](k)) + } else { + None + } + + val stats: RDD[(BDM[Double], Option[BDV[Double]], Long)] = batch.mapPartitions { docs => val nonEmptyDocs = docs.filter(_._2.numNonzeros > 0) val stat = BDM.zeros[Double](k, vocabSize) - var gammaPart = List[BDV[Double]]() + val logphatPartOption = logphatPartOptionBase() + var nonEmptyDocCount: Long = 0L nonEmptyDocs.foreach { case (_, termCounts: Vector) => + nonEmptyDocCount += 1 val (gammad, sstats, ids) = OnlineLDAOptimizer.variationalTopicInference( termCounts, expElogbetaBc.value, alpha, gammaShape, k) - stat(::, ids) := stat(::, ids).toDenseMatrix + sstats - gammaPart = gammad :: gammaPart + stat(::, ids) := stat(::, ids) + sstats + logphatPartOption.foreach(_ += LDAUtils.dirichletExpectation(gammad)) } - Iterator((stat, gammaPart)) - }.persist(StorageLevel.MEMORY_AND_DISK) - val statsSum: BDM[Double] = stats.map(_._1).treeAggregate(BDM.zeros[Double](k, vocabSize))( - _ += _, _ += _) - val gammat: BDM[Double] = breeze.linalg.DenseMatrix.vertcat( - stats.map(_._2).flatMap(list => list).collect().map(_.toDenseMatrix): _*) - stats.unpersist() + Iterator((stat, logphatPartOption, nonEmptyDocCount)) + } + + val elementWiseSum = ( + u: (BDM[Double], Option[BDV[Double]], Long), + v: (BDM[Double], Option[BDV[Double]], Long)) => { + u._1 += v._1 + u._2.foreach(_ += v._2.get) + (u._1, u._2, u._3 + v._3) + } + + val (statsSum: BDM[Double], logphatOption: Option[BDV[Double]], nonEmptyDocsN: Long) = stats + .treeAggregate((BDM.zeros[Double](k, vocabSize), logphatPartOptionBase(), 0L))( + elementWiseSum, elementWiseSum + ) + expElogbetaBc.destroy(false) - val batchResult = statsSum *:* expElogbeta.t + if (nonEmptyDocsN == 0) { + logWarning("No non-empty documents were submitted in the batch.") + // Therefore, there is no need to update any of the model parameters + return this + } + + val batchResult = statsSum *:* expElogbeta.t // Note that this is an optimization to avoid batch.count - updateLambda(batchResult, (miniBatchFraction * corpusSize).ceil.toInt) - if (optimizeDocConcentration) updateAlpha(gammat) + val batchSize = (miniBatchFraction * corpusSize).ceil.toInt + updateLambda(batchResult, batchSize) + + logphatOption.foreach(_ /= nonEmptyDocsN.toDouble) + logphatOption.foreach(updateAlpha(_, nonEmptyDocsN)) + this } @@ -503,21 +534,22 @@ final class OnlineLDAOptimizer extends LDAOptimizer { } /** - * Update alpha based on `gammat`, the inferred topic distributions for documents in the - * current mini-batch. Uses Newton-Rhapson method. + * Update alpha based on `logphat`. + * Uses Newton-Rhapson method. * @see Section 3.3, Huang: Maximum Likelihood Estimation of Dirichlet Distribution Parameters * (http://jonathan-huang.org/research/dirichlet/dirichlet.pdf) + * @param logphat Expectation of estimated log-posterior distribution of + * topics in a document averaged over the batch. + * @param nonEmptyDocsN number of non-empty documents */ - private def updateAlpha(gammat: BDM[Double]): Unit = { + private def updateAlpha(logphat: BDV[Double], nonEmptyDocsN: Double): Unit = { val weight = rho() - val N = gammat.rows.toDouble val alpha = this.alpha.asBreeze.toDenseVector - val logphat: BDV[Double] = - sum(LDAUtils.dirichletExpectation(gammat)(::, breeze.linalg.*)).t / N - val gradf = N * (-LDAUtils.dirichletExpectation(alpha) + logphat) - val c = N * trigamma(sum(alpha)) - val q = -N * trigamma(alpha) + val gradf = nonEmptyDocsN * (-LDAUtils.dirichletExpectation(alpha) + logphat) + + val c = nonEmptyDocsN * trigamma(sum(alpha)) + val q = -nonEmptyDocsN * trigamma(alpha) val b = sum(gradf / q) / (1D / c + sum(1D / q)) val dalpha = -(gradf - b) / q diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 6f96813497b62..b8c306d86bace 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -353,11 +353,14 @@ class Word2Vec extends Serializable with Logging { val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) + val totalWordsCounts = numIterations * trainWordsCount + 1 var alpha = learningRate for (k <- 1 to numIterations) { val bcSyn0Global = sc.broadcast(syn0Global) val bcSyn1Global = sc.broadcast(syn1Global) + val numWordsProcessedInPreviousIterations = (k - 1) * trainWordsCount + val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) @@ -368,11 +371,12 @@ class Word2Vec extends Serializable with Logging { var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount - // TODO: discount by iteration? - alpha = - learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) + alpha = learningRate * + (1 - (numPartitions * wordCount.toDouble + numWordsProcessedInPreviousIterations) / + totalWordsCounts) if (alpha < learningRate * 0.0001) alpha = learningRate * 0.0001 - logInfo("wordCount = " + wordCount + ", alpha = " + alpha) + logInfo(s"wordCount = ${wordCount + numWordsProcessedInPreviousIterations}, " + + s"alpha = $alpha") } wc += sentence.length var pos = 0 diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 9730dd68a3b27..0d3adf993383f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.classification import scala.util.Random import breeze.linalg.{DenseVector => BDV, Vector => BV} -import breeze.stats.distributions.{Multinomial => BrzMultinomial} +import breeze.stats.distributions.{Multinomial => BrzMultinomial, RandBasis => BrzRandBasis} import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial} @@ -335,6 +335,7 @@ object NaiveBayesSuite { val _pi = pi.map(math.exp) val _theta = theta.map(row => row.map(math.exp)) + implicit val rngForBrzMultinomial = BrzRandBasis.withSeed(seed) for (i <- 0 until nPoints) yield { val y = calcLabel(rnd.nextDouble(), _pi) val xi = modelType match { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala index 4ecd5a05365eb..d649ceac949c4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/ProbabilisticClassifierSuite.scala @@ -80,6 +80,24 @@ class ProbabilisticClassifierSuite extends SparkFunSuite { new TestProbabilisticClassificationModel("myuid", 2, 2).setThresholds(Array(-0.1, 0.1)) } } + + test("normalizeToProbabilitiesInPlace") { + val vec1 = Vectors.dense(1.0, 2.0, 3.0).toDense + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec1) + assert(vec1 ~== Vectors.dense(1.0 / 6, 2.0 / 6, 3.0 / 6) relTol 1e-3) + + // all-0 input test + val vec2 = Vectors.dense(0.0, 0.0, 0.0).toDense + intercept[IllegalArgumentException] { + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec2) + } + + // negative input test + val vec3 = Vectors.dense(1.0, -1.0, 2.0).toDense + intercept[IllegalArgumentException] { + ProbabilisticClassificationModel.normalizeToProbabilitiesInPlace(vec3) + } + } } object ProbabilisticClassifierSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala index ee2ba73fa96d5..c08b35b419266 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ImputerSuite.scala @@ -43,7 +43,7 @@ class ImputerSuite extends SparkFunSuite with MLlibTestSparkContext with Default (0, 1.0, 1.0, 1.0), (1, 3.0, 3.0, 3.0), (2, Double.NaN, Double.NaN, Double.NaN), - (3, -1.0, 2.0, 3.0) + (3, -1.0, 2.0, 1.0) )).toDF("id", "value", "expected_mean_value", "expected_median_value") val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) .setMissingValue(-1.0) diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index ac7319110159b..addcd21d50aac 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -723,9 +723,9 @@ class ALSSuite val numUsers = model.userFactors.count val numItems = model.itemFactors.count val expected = Map( - 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), - 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), - 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + 0 -> Seq((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Seq((3, 39f), (5, 33f), (4, 26f), (6, 16f)), + 2 -> Seq((3, 51f), (5, 45f), (4, 30f), (6, 18f)) ) Seq(2, 4, 6).foreach { k => @@ -743,10 +743,10 @@ class ALSSuite val numUsers = model.userFactors.count val numItems = model.itemFactors.count val expected = Map( - 3 -> Array((0, 54f), (2, 51f), (1, 39f)), - 4 -> Array((0, 44f), (2, 30f), (1, 26f)), - 5 -> Array((2, 45f), (0, 42f), (1, 33f)), - 6 -> Array((0, 28f), (2, 18f), (1, 16f)) + 3 -> Seq((0, 54f), (2, 51f), (1, 39f)), + 4 -> Seq((0, 44f), (2, 30f), (1, 26f)), + 5 -> Seq((2, 45f), (0, 42f), (1, 33f)), + 6 -> Seq((0, 28f), (2, 18f), (1, 16f)) ) Seq(2, 3, 4).foreach { k => @@ -759,9 +759,93 @@ class ALSSuite } } + test("recommendForUserSubset with k <, = and > num_items") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val numItems = model.itemFactors.count + val expected = Map( + 0 -> Seq((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 2 -> Seq((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + ) + val userSubset = expected.keys.toSeq.toDF("user") + val numUsersSubset = userSubset.count + + Seq(2, 4, 6).foreach { k => + val n = math.min(k, numItems).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topItems = model.recommendForUserSubset(userSubset, k) + assert(topItems.count() == numUsersSubset) + assert(topItems.columns.contains("user")) + checkRecommendations(topItems, expectedUpToN, "item") + } + } + + test("recommendForItemSubset with k <, = and > num_users") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val numUsers = model.userFactors.count + val expected = Map( + 3 -> Seq((0, 54f), (2, 51f), (1, 39f)), + 6 -> Seq((0, 28f), (2, 18f), (1, 16f)) + ) + val itemSubset = expected.keys.toSeq.toDF("item") + val numItemsSubset = itemSubset.count + + Seq(2, 3, 4).foreach { k => + val n = math.min(k, numUsers).toInt + val expectedUpToN = expected.mapValues(_.slice(0, n)) + val topUsers = model.recommendForItemSubset(itemSubset, k) + assert(topUsers.count() == numItemsSubset) + assert(topUsers.columns.contains("item")) + checkRecommendations(topUsers, expectedUpToN, "user") + } + } + + test("subset recommendations eliminate duplicate ids, returns same results as unique ids") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val k = 2 + + val users = Seq(0, 1).toDF("user") + val dupUsers = Seq(0, 1, 0, 1).toDF("user") + val singleUserRecs = model.recommendForUserSubset(users, k) + val dupUserRecs = model.recommendForUserSubset(dupUsers, k) + .as[(Int, Seq[(Int, Float)])].collect().toMap + assert(singleUserRecs.count == dupUserRecs.size) + checkRecommendations(singleUserRecs, dupUserRecs, "item") + + val items = Seq(3, 4, 5).toDF("item") + val dupItems = Seq(3, 4, 5, 4, 5).toDF("item") + val singleItemRecs = model.recommendForItemSubset(items, k) + val dupItemRecs = model.recommendForItemSubset(dupItems, k) + .as[(Int, Seq[(Int, Float)])].collect().toMap + assert(singleItemRecs.count == dupItemRecs.size) + checkRecommendations(singleItemRecs, dupItemRecs, "user") + } + + test("subset recommendations on full input dataset equivalent to recommendForAll") { + val spark = this.spark + import spark.implicits._ + val model = getALSModel + val k = 2 + + val userSubset = model.userFactors.withColumnRenamed("id", "user").drop("features") + val userSubsetRecs = model.recommendForUserSubset(userSubset, k) + val allUserRecs = model.recommendForAllUsers(k).as[(Int, Seq[(Int, Float)])].collect().toMap + checkRecommendations(userSubsetRecs, allUserRecs, "item") + + val itemSubset = model.itemFactors.withColumnRenamed("id", "item").drop("features") + val itemSubsetRecs = model.recommendForItemSubset(itemSubset, k) + val allItemRecs = model.recommendForAllItems(k).as[(Int, Seq[(Int, Float)])].collect().toMap + checkRecommendations(itemSubsetRecs, allItemRecs, "user") + } + private def checkRecommendations( topK: DataFrame, - expected: Map[Int, Array[(Int, Float)]], + expected: Map[Int, Seq[(Int, Float)]], dstColName: String): Unit = { val spark = this.spark import spark.implicits._ diff --git a/pom.xml b/pom.xml index b0408ecca0f66..2d59f06811a82 100644 --- a/pom.xml +++ b/pom.xml @@ -98,15 +98,13 @@ sql/core sql/hive assembly - external/flume - external/flume-sink - external/flume-assembly examples repl launcher external/kafka-0-10 external/kafka-0-10-assembly external/kafka-0-10-sql + @@ -130,7 +128,7 @@ 1.2.1 10.12.1.1 1.8.2 - 1.4.0 + 1.4.1 nohive 1.6.0 9.3.20.v20170531 @@ -179,7 +177,10 @@ 4.7 1.1 2.52.0 - 2.6 + + 2.8 1.8 1.0.0 0.4.0 @@ -637,11 +638,6 @@ - - com.fasterxml.jackson.module - jackson-module-paranamer - ${fasterxml.jackson.version} - com.fasterxml.jackson.module jackson-module-jaxb-annotations @@ -1716,6 +1712,10 @@ org.apache.hive hive-storage-api + + io.airlift + slice + @@ -2488,6 +2488,7 @@ maven-checkstyle-plugin 2.17 + false true ${basedir}/src/main/java,${basedir}/src/main/scala ${basedir}/src/test/java @@ -2584,6 +2585,15 @@ + + flume + + external/flume + external/flume-sink + external/flume-assembly + + + spark-ganglia-lgpl @@ -2682,7 +2692,7 @@ scala-2.12 - 2.12.3 + 2.12.4 2.12 diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index dd299e074535e..99cac34c85ebc 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,10 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // SPARK-18085: Better History Server scalability for many / large applications + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.ExecutorSummary.executorLogs"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.deploy.history.HistoryServer.getSparkUI"), + // [SPARK-20495][SQL] Add StorageLevel to cacheTable API ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.cacheTable"), diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index a568d264cb2db..9501eed1e906b 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -43,11 +43,8 @@ object BuildCommons { "catalyst", "sql", "hive", "hive-thriftserver", "sql-kafka-0-10" ).map(ProjectRef(buildLocation, _)) - val streamingProjects@Seq( - streaming, streamingFlumeSink, streamingFlume, streamingKafka010 - ) = Seq( - "streaming", "streaming-flume-sink", "streaming-flume", "streaming-kafka-0-10" - ).map(ProjectRef(buildLocation, _)) + val streamingProjects@Seq(streaming, streamingKafka010) = + Seq("streaming", "streaming-kafka-0-10").map(ProjectRef(buildLocation, _)) val allProjects@Seq( core, graphx, mllib, mllibLocal, repl, networkCommon, networkShuffle, launcher, unsafe, tags, sketch, kvstore, _* @@ -56,9 +53,13 @@ object BuildCommons { "tags", "sketch", "kvstore" ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects - val optionallyEnabledProjects@Seq(mesos, yarn, streamingKafka, sparkGangliaLgpl, - streamingKinesisAsl, dockerIntegrationTests, hadoopCloud) = - Seq("mesos", "yarn", "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", + val optionallyEnabledProjects@Seq(mesos, yarn, + streamingFlumeSink, streamingFlume, + streamingKafka, sparkGangliaLgpl, streamingKinesisAsl, + dockerIntegrationTests, hadoopCloud) = + Seq("mesos", "yarn", + "streaming-flume-sink", "streaming-flume", + "streaming-kafka-0-8", "ganglia-lgpl", "streaming-kinesis-asl", "docker-integration-tests", "hadoop-cloud").map(ProjectRef(buildLocation, _)) val assemblyProjects@Seq(networkYarn, streamingFlumeAssembly, streamingKafkaAssembly, streamingKafka010Assembly, streamingKinesisAslAssembly) = diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index bcfb36880eb02..e8bcbe4cd34cb 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -90,6 +90,14 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha >>> item_recs.where(item_recs.item == 2)\ .select("recommendations.user", "recommendations.rating").collect() [Row(user=[2, 1, 0], rating=[4.901..., 3.981..., -0.138...])] + >>> user_subset = df.where(df.user == 2) + >>> user_subset_recs = model.recommendForUserSubset(user_subset, 3) + >>> user_subset_recs.select("recommendations.item", "recommendations.rating").first() + Row(item=[2, 1, 0], rating=[4.901..., 1.056..., -1.501...]) + >>> item_subset = df.where(df.item == 0) + >>> item_subset_recs = model.recommendForItemSubset(item_subset, 3) + >>> item_subset_recs.select("recommendations.user", "recommendations.rating").first() + Row(user=[0, 1, 2], rating=[3.910..., 2.625..., -1.501...]) >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) @@ -414,6 +422,36 @@ def recommendForAllItems(self, numUsers): """ return self._call_java("recommendForAllItems", numUsers) + @since("2.3.0") + def recommendForUserSubset(self, dataset, numItems): + """ + Returns top `numItems` items recommended for each user id in the input data set. Note that + if there are duplicate ids in the input dataset, only one set of recommendations per unique + id will be returned. + + :param dataset: a Dataset containing a column of user ids. The column name must match + `userCol`. + :param numItems: max number of recommendations for each user + :return: a DataFrame of (userCol, recommendations), where recommendations are + stored as an array of (itemCol, rating) Rows. + """ + return self._call_java("recommendForUserSubset", dataset, numItems) + + @since("2.3.0") + def recommendForItemSubset(self, dataset, numUsers): + """ + Returns top `numUsers` users recommended for each item id in the input data set. Note that + if there are duplicate ids in the input dataset, only one set of recommendations per unique + id will be returned. + + :param dataset: a Dataset containing a column of item ids. The column name must match + `itemCol`. + :param numUsers: max number of recommendations for each item + :return: a DataFrame of (itemCol, recommendations), where recommendations are + stored as an array of (userCol, rating) Rows. + """ + return self._call_java("recommendForItemSubset", dataset, numUsers) + if __name__ == "__main__": import doctest diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 67772910c0d38..c3c47bd79459a 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -175,7 +175,9 @@ def context(self, sqlContext): .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ - warnings.warn("Deprecated in 2.1 and will be removed in 3.0, use session instead.") + warnings.warn( + "Deprecated in 2.1 and will be removed in 3.0, use session instead.", + DeprecationWarning) self._jwrite.context(sqlContext._ssql_ctx) return self @@ -256,7 +258,9 @@ def context(self, sqlContext): .. note:: Deprecated in 2.1 and will be removed in 3.0, use session instead. """ - warnings.warn("Deprecated in 2.1 and will be removed in 3.0, use session instead.") + warnings.warn( + "Deprecated in 2.1 and will be removed in 3.0, use session instead.", + DeprecationWarning) self._jread.context(sqlContext._ssql_ctx) return self diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py index e04eeb2b60d71..cce703d432b5a 100644 --- a/python/pyspark/mllib/classification.py +++ b/python/pyspark/mllib/classification.py @@ -311,7 +311,7 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, """ warnings.warn( "Deprecated in 2.0.0. Use ml.classification.LogisticRegression or " - "LogisticRegressionWithLBFGS.") + "LogisticRegressionWithLBFGS.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index fc2a0b3b5038a..2cd1da3fbf9aa 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -234,7 +234,7 @@ def precision(self, label=None): """ if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("precision") else: return self.call("precision", float(label)) @@ -246,7 +246,7 @@ def recall(self, label=None): """ if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("recall") else: return self.call("recall", float(label)) @@ -259,7 +259,7 @@ def fMeasure(self, label=None, beta=None): if beta is None: if label is None: # note:: Deprecated in 2.0.0. Use accuracy. - warnings.warn("Deprecated in 2.0.0. Use accuracy.") + warnings.warn("Deprecated in 2.0.0. Use accuracy.", DeprecationWarning) return self.call("fMeasure") else: return self.call("fMeasure", label) diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py index 1b66f5b51044b..ea107d400621d 100644 --- a/python/pyspark/mllib/regression.py +++ b/python/pyspark/mllib/regression.py @@ -278,7 +278,8 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, A condition which decides iteration termination. (default: 0.001) """ - warnings.warn("Deprecated in 2.0.0. Use ml.regression.LinearRegression.") + warnings.warn( + "Deprecated in 2.0.0. Use ml.regression.LinearRegression.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations), @@ -421,7 +422,8 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, """ warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 1.0. " - "Note the default regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.") + "Note the default regParam is 0.01 for LassoWithSGD, but is 0.0 for LinearRegression.", + DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step), @@ -566,7 +568,7 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, warnings.warn( "Deprecated in 2.0.0. Use ml.regression.LinearRegression with elasticNetParam = 0.0. " "Note the default regParam is 0.01 for RidgeRegressionWithSGD, but is 0.0 for " - "LinearRegression.") + "LinearRegression.", DeprecationWarning) def train(rdd, i): return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step), diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 7c1fbadcb82be..a0adeed994456 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -79,12 +79,14 @@ class SpecialLengths(object): TIMING_DATA = -3 END_OF_STREAM = -4 NULL = -5 + START_ARROW_STREAM = -6 class PythonEvalType(object): NON_UDF = 0 SQL_BATCHED_UDF = 1 SQL_PANDAS_UDF = 2 + SQL_PANDAS_GROUPED_UDF = 3 class Serializer(object): @@ -190,7 +192,7 @@ def loads(self, obj): class ArrowSerializer(FramedSerializer): """ - Serializes an Arrow stream. + Serializes bytes as Arrow data with the Arrow file format. """ def dumps(self, batch): @@ -211,44 +213,61 @@ def __repr__(self): return "ArrowSerializer" -class ArrowPandasSerializer(ArrowSerializer): +def _create_batch(series): + import pyarrow as pa + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or \ + (len(series) == 2 and isinstance(series[1], pa.DataType)): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + # If a nullable integer series has been promoted to floating point with NaNs, need to cast + # NOTE: this is not necessary with Arrow >= 0.7 + def cast_series(s, t): + if t is None or s.dtype == t.to_pandas_dtype(): + return s + else: + return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) + + arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) + + +class ArrowStreamPandasSerializer(Serializer): """ - Serializes Pandas.Series as Arrow data. + Serializes Pandas.Series as Arrow data with Arrow streaming format. """ - def dumps(self, series): + def dump_stream(self, iterator, stream): """ - Make an ArrowRecordBatch from a Pandas Series and serialize. Input is a single series or + Make ArrowRecordBatches from Pandas Series and serialize. Input is a single series or a list of series accompanied by an optional pyarrow type to coerce the data to. """ import pyarrow as pa - # Make input conform to [(series1, type1), (series2, type2), ...] - if not isinstance(series, (list, tuple)) or \ - (len(series) == 2 and isinstance(series[1], pa.DataType)): - series = [series] - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - # If a nullable integer series has been promoted to floating point with NaNs, need to cast - # NOTE: this is not necessary with Arrow >= 0.7 - def cast_series(s, t): - if t is None or s.dtype == t.to_pandas_dtype(): - return s - else: - return s.fillna(0).astype(t.to_pandas_dtype(), copy=False) - - arrs = [pa.Array.from_pandas(cast_series(s, t), mask=s.isnull(), type=t) for s, t in series] - batch = pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in xrange(len(arrs))]) - return super(ArrowPandasSerializer, self).dumps(batch) + writer = None + try: + for series in iterator: + batch = _create_batch(series) + if writer is None: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + finally: + if writer is not None: + writer.close() - def loads(self, obj): + def load_stream(self, stream): """ - Deserialize an ArrowRecordBatch to an Arrow table and return as a list of pandas.Series. + Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ - table = super(ArrowPandasSerializer, self).loads(obj) - return [c.to_pandas() for c in table.itercolumns()] + import pyarrow as pa + reader = pa.open_stream(stream) + for batch in reader: + table = pa.Table.from_batches([batch]) + yield [c.to_pandas() for c in table.itercolumns()] def __repr__(self): - return "ArrowPandasSerializer" + return "ArrowStreamPandasSerializer" class BatchedSerializer(Serializer): diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py index b7ce9a83a616d..c0b574e2b93a1 100644 --- a/python/pyspark/sql/dataframe.py +++ b/python/pyspark/sql/dataframe.py @@ -130,6 +130,8 @@ def registerTempTable(self, name): .. note:: Deprecated in 2.0, use createOrReplaceTempView instead. """ + warnings.warn( + "Deprecated in 2.0, use createOrReplaceTempView instead.", DeprecationWarning) self._jdf.createOrReplaceTempView(name) @since(2.0) @@ -1038,8 +1040,8 @@ def summary(self, *statistics): | mean| 3.5| null| | stddev|2.1213203435596424| null| | min| 2|Alice| - | 25%| 5| null| - | 50%| 5| null| + | 25%| 2| null| + | 50%| 2| null| | 75%| 5| null| | max| 5| Bob| +-------+------------------+-----+ @@ -1050,7 +1052,7 @@ def summary(self, *statistics): +-------+---+-----+ | count| 2| 2| | min| 2|Alice| - | 25%| 5| null| + | 25%| 2| null| | 75%| 5| null| | max| 5| Bob| +-------+---+-----+ @@ -1227,7 +1229,7 @@ def groupBy(self, *cols): """ jgd = self._jdf.groupBy(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.4) def rollup(self, *cols): @@ -1248,7 +1250,7 @@ def rollup(self, *cols): """ jgd = self._jdf.rollup(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.4) def cube(self, *cols): @@ -1271,7 +1273,7 @@ def cube(self, *cols): """ jgd = self._jdf.cube(self._jcols(*cols)) from pyspark.sql.group import GroupedData - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self) @since(1.3) def agg(self, *exprs): @@ -1308,6 +1310,7 @@ def unionAll(self, other): .. note:: Deprecated in 2.0, use :func:`union` instead. """ + warnings.warn("Deprecated in 2.0, use union instead.", DeprecationWarning) return self.union(other) @since(2.3) @@ -1878,7 +1881,7 @@ def toPandas(self): 1 5 Bob """ import pandas as pd - if self.sql_ctx.getConf("spark.sql.execution.arrow.enable", "false").lower() == "true": + if self.sql_ctx.getConf("spark.sql.execution.arrow.enabled", "false").lower() == "true": try: import pyarrow tables = self._collectAsArrow() @@ -1889,7 +1892,7 @@ def toPandas(self): return pd.DataFrame.from_records([], columns=self.columns) except ImportError as e: msg = "note: pyarrow must be installed and available on calling Python process " \ - "if using spark.sql.execution.arrow.enable=true" + "if using spark.sql.execution.arrow.enabled=true" raise ImportError("%s\n%s" % (e.message, msg)) else: pdf = pd.DataFrame.from_records(self.collect(), columns=self.columns) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 63e9a830bbc9e..0d40368c9cd6e 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -21,6 +21,7 @@ import math import sys import functools +import warnings if sys.version < "3": from itertools import imap as map @@ -44,6 +45,14 @@ def _(col): return _ +def _wrap_deprecated_function(func, message): + """ Wrap the deprecated function to print out deprecation warnings""" + def _(col): + warnings.warn(message, DeprecationWarning) + return func(col) + return functools.wraps(func)(_) + + def _create_binary_mathfunction(name, doc=""): """ Create a binary mathfunction by name""" def _(col1, col2): @@ -207,6 +216,12 @@ def _(): """returns the relative rank (i.e. percentile) of rows within a window partition.""", } +# Wraps deprecated functions (keys) with the messages (values). +_functions_deprecated = { + 'toDegrees': 'Deprecated in 2.1, use degrees instead.', + 'toRadians': 'Deprecated in 2.1, use radians instead.', +} + for _name, _doc in _functions.items(): globals()[_name] = since(1.3)(_create_function(_name, _doc)) for _name, _doc in _functions_1_4.items(): @@ -219,6 +234,8 @@ def _(): globals()[_name] = since(1.6)(_create_function(_name, _doc)) for _name, _doc in _functions_2_1.items(): globals()[_name] = since(2.1)(_create_function(_name, _doc)) +for _name, _message in _functions_deprecated.items(): + globals()[_name] = _wrap_deprecated_function(globals()[_name], _message) del _name, _doc @@ -227,6 +244,7 @@ def approxCountDistinct(col, rsd=None): """ .. note:: Deprecated in 2.1, use :func:`approx_count_distinct` instead. """ + warnings.warn("Deprecated in 2.1, use approx_count_distinct instead.", DeprecationWarning) return approx_count_distinct(col, rsd) @@ -2038,13 +2056,22 @@ def _wrap_function(sc, func, returnType): sc.pythonVer, broadcast_vars, sc._javaAccumulator) +class PythonUdfType(object): + # row-at-a-time UDFs + NORMAL_UDF = 0 + # scalar vectorized UDFs + PANDAS_UDF = 1 + # grouped vectorized UDFs + PANDAS_GROUPED_UDF = 2 + + class UserDefinedFunction(object): """ User defined function in Python .. versionadded:: 1.3 """ - def __init__(self, func, returnType, name=None, vectorized=False): + def __init__(self, func, returnType, name=None, pythonUdfType=PythonUdfType.NORMAL_UDF): if not callable(func): raise TypeError( "Not a function or callable (__call__ is not defined): " @@ -2058,7 +2085,7 @@ def __init__(self, func, returnType, name=None, vectorized=False): self._name = name or ( func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) - self._vectorized = vectorized + self.pythonUdfType = pythonUdfType @property def returnType(self): @@ -2090,7 +2117,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self._vectorized) + self._name, wrapped_func, jdt, self.pythonUdfType) return judf def __call__(self, *cols): @@ -2118,20 +2145,26 @@ def wrapper(*args): wrapper.__name__ = self._name wrapper.__module__ = (self.func.__module__ if hasattr(self.func, '__module__') else self.func.__class__.__module__) + wrapper.func = self.func wrapper.returnType = self.returnType + wrapper.pythonUdfType = self.pythonUdfType return wrapper -def _create_udf(f, returnType, vectorized): +def _create_udf(f, returnType, pythonUdfType): - def _udf(f, returnType=StringType(), vectorized=vectorized): - if vectorized: + def _udf(f, returnType=StringType(), pythonUdfType=pythonUdfType): + if pythonUdfType == PythonUdfType.PANDAS_UDF: import inspect - if len(inspect.getargspec(f).args) == 0: - raise NotImplementedError("0-parameter pandas_udfs are not currently supported") - udf_obj = UserDefinedFunction(f, returnType, vectorized=vectorized) + argspec = inspect.getargspec(f) + if len(argspec.args) == 0 and argspec.varargs is None: + raise ValueError( + "0-arg pandas_udfs are not supported. " + "Instead, create a 1-arg pandas_udf and ignore the arg in your function." + ) + udf_obj = UserDefinedFunction(f, returnType, pythonUdfType=pythonUdfType) return udf_obj._wrapped() # decorator @udf, @udf(), @udf(dataType()), or similar with @pandas_udf @@ -2139,14 +2172,14 @@ def _udf(f, returnType=StringType(), vectorized=vectorized): # If DataType has been passed as a positional argument # for decorator use it as a returnType return_type = f or returnType - return functools.partial(_udf, returnType=return_type, vectorized=vectorized) + return functools.partial(_udf, returnType=return_type, pythonUdfType=pythonUdfType) else: - return _udf(f=f, returnType=returnType, vectorized=vectorized) + return _udf(f=f, returnType=returnType, pythonUdfType=pythonUdfType) @since(1.3) def udf(f=None, returnType=StringType()): - """Creates a :class:`Column` expression representing a user defined function (UDF). + """Creates a user defined function (UDF). .. note:: The user-defined functions must be deterministic. Due to optimization, duplicate invocations may be eliminated or the function may even be invoked more times than @@ -2175,40 +2208,78 @@ def udf(f=None, returnType=StringType()): | 8| JOHN DOE| 22| +----------+--------------+------------+ """ - return _create_udf(f, returnType=returnType, vectorized=False) + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.NORMAL_UDF) @since(2.3) def pandas_udf(f=None, returnType=StringType()): """ - Creates a :class:`Column` expression representing a user defined function (UDF) that accepts - `Pandas.Series` as input arguments and outputs a `Pandas.Series` of the same length. + Creates a vectorized user defined function (UDF). - :param f: python function if used as a standalone function + :param f: user-defined function. A python function if used as a standalone function :param returnType: a :class:`pyspark.sql.types.DataType` object - >>> from pyspark.sql.types import IntegerType, StringType - >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) - >>> @pandas_udf(returnType=StringType()) - ... def to_upper(s): - ... return s.str.upper() - ... - >>> @pandas_udf(returnType="integer") - ... def add_one(x): - ... return x + 1 - ... - >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) - >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ - ... .show() # doctest: +SKIP - +----------+--------------+------------+ - |slen(name)|to_upper(name)|add_one(age)| - +----------+--------------+------------+ - | 8| JOHN DOE| 22| - +----------+--------------+------------+ - """ - wrapped_udf = _create_udf(f, returnType=returnType, vectorized=True) - - return wrapped_udf + The user-defined function can define one of the following transformations: + + 1. One or more `pandas.Series` -> A `pandas.Series` + + This udf is used with :meth:`pyspark.sql.DataFrame.withColumn` and + :meth:`pyspark.sql.DataFrame.select`. + The returnType should be a primitive data type, e.g., `DoubleType()`. + The length of the returned `pandas.Series` must be of the same as the input `pandas.Series`. + + >>> from pyspark.sql.types import IntegerType, StringType + >>> slen = pandas_udf(lambda s: s.str.len(), IntegerType()) + >>> @pandas_udf(returnType=StringType()) + ... def to_upper(s): + ... return s.str.upper() + ... + >>> @pandas_udf(returnType="integer") + ... def add_one(x): + ... return x + 1 + ... + >>> df = spark.createDataFrame([(1, "John Doe", 21)], ("id", "name", "age")) + >>> df.select(slen("name").alias("slen(name)"), to_upper("name"), add_one("age")) \\ + ... .show() # doctest: +SKIP + +----------+--------------+------------+ + |slen(name)|to_upper(name)|add_one(age)| + +----------+--------------+------------+ + | 8| JOHN DOE| 22| + +----------+--------------+------------+ + + 2. A `pandas.DataFrame` -> A `pandas.DataFrame` + + This udf is only used with :meth:`pyspark.sql.GroupedData.apply`. + The returnType should be a :class:`StructType` describing the schema of the returned + `pandas.DataFrame`. + + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf(returnType=df.schema) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. note:: This type of udf cannot be used with functions such as `withColumn` or `select` + because it defines a `DataFrame` transformation rather than a `Column` + transformation. + + .. seealso:: :meth:`pyspark.sql.GroupedData.apply` + + .. note:: The user-defined function must be deterministic. + """ + return _create_udf(f, returnType=returnType, pythonUdfType=PythonUdfType.PANDAS_UDF) blacklist = ['map', 'since', 'ignore_unicode_prefix'] diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py index f2092f9c63054..e11388d604312 100644 --- a/python/pyspark/sql/group.py +++ b/python/pyspark/sql/group.py @@ -19,6 +19,7 @@ from pyspark.rdd import ignore_unicode_prefix from pyspark.sql.column import Column, _to_seq, _to_java_column, _create_column_from_literal from pyspark.sql.dataframe import DataFrame +from pyspark.sql.functions import PythonUdfType, UserDefinedFunction from pyspark.sql.types import * __all__ = ["GroupedData"] @@ -54,9 +55,10 @@ class GroupedData(object): .. versionadded:: 1.3 """ - def __init__(self, jgd, sql_ctx): + def __init__(self, jgd, df): self._jgd = jgd - self.sql_ctx = sql_ctx + self._df = df + self.sql_ctx = df.sql_ctx @ignore_unicode_prefix @since(1.3) @@ -170,7 +172,7 @@ def sum(self, *cols): @since(1.6) def pivot(self, pivot_col, values=None): """ - Pivots a column of the current [[DataFrame]] and perform the specified aggregation. + Pivots a column of the current :class:`DataFrame` and perform the specified aggregation. There are two versions of pivot function: one that requires the caller to specify the list of distinct values to pivot on, and one that does not. The latter is more concise but less efficient, because Spark needs to first compute the list of distinct values internally. @@ -192,7 +194,88 @@ def pivot(self, pivot_col, values=None): jgd = self._jgd.pivot(pivot_col) else: jgd = self._jgd.pivot(pivot_col, values) - return GroupedData(jgd, self.sql_ctx) + return GroupedData(jgd, self._df) + + @since(2.3) + def apply(self, udf): + """ + Maps each group of the current :class:`DataFrame` using a pandas udf and returns the result + as a `DataFrame`. + + The user-defined function should take a `pandas.DataFrame` and return another + `pandas.DataFrame`. For each group, all columns are passed together as a `pandas.DataFrame` + to the user-function and the returned `pandas.DataFrame`s are combined as a + :class:`DataFrame`. + The returned `pandas.DataFrame` can be of arbitrary length and its schema must match the + returnType of the pandas udf. + + This function does not support partial aggregation, and requires shuffling all the data in + the :class:`DataFrame`. + + :param udf: A function object returned by :meth:`pyspark.sql.functions.pandas_udf` + + >>> from pyspark.sql.functions import pandas_udf + >>> df = spark.createDataFrame( + ... [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)], + ... ("id", "v")) + >>> @pandas_udf(returnType=df.schema) + ... def normalize(pdf): + ... v = pdf.v + ... return pdf.assign(v=(v - v.mean()) / v.std()) + >>> df.groupby('id').apply(normalize).show() # doctest: +SKIP + +---+-------------------+ + | id| v| + +---+-------------------+ + | 1|-0.7071067811865475| + | 1| 0.7071067811865475| + | 2|-0.8320502943378437| + | 2|-0.2773500981126146| + | 2| 1.1094003924504583| + +---+-------------------+ + + .. seealso:: :meth:`pyspark.sql.functions.pandas_udf` + + """ + import inspect + + # Columns are special because hasattr always return True + if isinstance(udf, Column) or not hasattr(udf, 'func') \ + or udf.pythonUdfType != PythonUdfType.PANDAS_UDF \ + or len(inspect.getargspec(udf.func).args) != 1: + raise ValueError("The argument to apply must be a 1-arg pandas_udf") + if not isinstance(udf.returnType, StructType): + raise ValueError("The returnType of the pandas_udf must be a StructType") + + df = self._df + func = udf.func + returnType = udf.returnType + + # The python executors expects the function to use pd.Series as input and output + # So we to create a wrapper function that turns that to a pd.DataFrame before passing + # down to the user function, then turn the result pd.DataFrame back into pd.Series + columns = df.columns + + def wrapped(*cols): + from pyspark.sql.types import to_arrow_type + import pandas as pd + result = func(pd.concat(cols, axis=1, keys=columns)) + if not isinstance(result, pd.DataFrame): + raise TypeError("Return type of the user-defined function should be " + "Pandas.DataFrame, but is {}".format(type(result))) + if not len(result.columns) == len(returnType): + raise RuntimeError( + "Number of columns of the returned Pandas.DataFrame " + "doesn't match specified schema. " + "Expected: {} Actual: {}".format(len(returnType), len(result.columns))) + arrow_return_types = (to_arrow_type(field.dataType) for field in returnType) + return [(result[result.columns[i]], arrow_type) + for i, arrow_type in enumerate(arrow_return_types)] + + udf_obj = UserDefinedFunction( + wrapped, returnType, name=udf.__name__, pythonUdfType=PythonUdfType.PANDAS_GROUPED_UDF) + udf_column = udf_obj(*[df[col] for col in df.columns]) + jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr()) + return DataFrame(jdf, self.sql_ctx) def _test(): @@ -206,6 +289,7 @@ def _test(): .getOrCreate() sc = spark.sparkContext globs['sc'] = sc + globs['spark'] = spark globs['df'] = sc.parallelize([(2, 'Alice'), (5, 'Bob')]) \ .toDF(StructType([StructField('age', IntegerType()), StructField('name', StringType())])) diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index cb847a0420311..f3092918abb54 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -335,7 +335,8 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non ``inferSchema`` is enabled. To avoid going through the entire data once, disable ``inferSchema`` option or specify the schema explicitly using ``schema``. - :param path: string, or list of strings, for input path(s). + :param path: string, or list of strings, for input path(s), + or RDD of Strings storing CSV rows. :param schema: an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). :param sep: sets the single character as a separator for each field and value. @@ -408,6 +409,10 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non >>> df = spark.read.csv('python/test_support/sql/ages.csv') >>> df.dtypes [('_c0', 'string'), ('_c1', 'string')] + >>> rdd = sc.textFile('python/test_support/sql/ages.csv') + >>> df2 = spark.read.csv(rdd) + >>> df2.dtypes + [('_c0', 'string'), ('_c1', 'string')] """ self._set_opts( schema=schema, sep=sep, encoding=encoding, quote=quote, escape=escape, comment=comment, @@ -420,7 +425,29 @@ def csv(self, path, schema=None, sep=None, encoding=None, quote=None, escape=Non columnNameOfCorruptRecord=columnNameOfCorruptRecord, multiLine=multiLine) if isinstance(path, basestring): path = [path] - return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) + if type(path) == list: + return self._df(self._jreader.csv(self._spark._sc._jvm.PythonUtils.toSeq(path))) + elif isinstance(path, RDD): + def func(iterator): + for x in iterator: + if not isinstance(x, basestring): + x = unicode(x) + if isinstance(x, unicode): + x = x.encode("utf-8") + yield x + keyed = path.mapPartitions(func) + keyed._bypass_serializer = True + jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) + # see SPARK-22112 + # There aren't any jvm api for creating a dataframe from rdd storing csv. + # We can do it through creating a jvm dataset firstly and using the jvm api + # for creating a dataframe from dataset storing csv. + jdataset = self._spark._ssql_ctx.createDataset( + jrdd.rdd(), + self._spark._jvm.Encoders.STRING()) + return self._df(self._jreader.csv(jdataset)) + else: + raise TypeError("path can be only string, list or RDD") @since(1.5) def orc(self, path): diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 1b3af42c47ad2..685eebcafefba 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -3088,7 +3088,7 @@ class ArrowTests(ReusedPySparkTestCase): def setUpClass(cls): ReusedPySparkTestCase.setUpClass() cls.spark = SparkSession(cls.sc) - cls.spark.conf.set("spark.sql.execution.arrow.enable", "true") + cls.spark.conf.set("spark.sql.execution.arrow.enabled", "true") cls.schema = StructType([ StructField("1_str_t", StringType(), True), StructField("2_int_t", IntegerType(), True), @@ -3120,9 +3120,9 @@ def test_null_conversion(self): def test_toPandas_arrow_toggle(self): df = self.spark.createDataFrame(self.data, schema=self.schema) - self.spark.conf.set("spark.sql.execution.arrow.enable", "false") + self.spark.conf.set("spark.sql.execution.arrow.enabled", "false") pdf = df.toPandas() - self.spark.conf.set("spark.sql.execution.arrow.enable", "true") + self.spark.conf.set("spark.sql.execution.arrow.enabled", "true") pdf_arrow = df.toPandas() self.assertFramesEqual(pdf_arrow, pdf) @@ -3256,17 +3256,17 @@ def test_vectorized_udf_null_string(self): def test_vectorized_udf_zero_parameter(self): from pyspark.sql.functions import pandas_udf - error_str = '0-parameter pandas_udfs.*not.*supported' + error_str = '0-arg pandas_udfs.*not.*supported' with QuietTest(self.sc): - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): pandas_udf(lambda: 1, LongType()) - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): @pandas_udf def zero_no_type(): return 1 - with self.assertRaisesRegexp(NotImplementedError, error_str): + with self.assertRaisesRegexp(ValueError, error_str): @pandas_udf(LongType()) def zero_with_type(): return 1 @@ -3348,7 +3348,7 @@ def test_vectorized_udf_wrong_return_type(self): df = self.spark.range(10) f = pandas_udf(lambda x: x * 1.0, StringType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Invalid.*type.*string'): + with self.assertRaisesRegexp(Exception, 'Invalid.*type'): df.select(f(col('id'))).collect() def test_vectorized_udf_return_scalar(self): @@ -3356,7 +3356,7 @@ def test_vectorized_udf_return_scalar(self): df = self.spark.range(10) f = pandas_udf(lambda x: 1.0, DoubleType()) with QuietTest(self.sc): - with self.assertRaisesRegexp(Exception, 'Return.*type.*pandas_udf.*Series'): + with self.assertRaisesRegexp(Exception, 'Return.*type.*Series'): df.select(f(col('id'))).collect() def test_vectorized_udf_decorator(self): @@ -3376,6 +3376,188 @@ def test_vectorized_udf_empty_partition(self): res = df.select(f(col('id'))) self.assertEquals(df.collect(), res.collect()) + def test_vectorized_udf_varargs(self): + from pyspark.sql.functions import pandas_udf, col + df = self.spark.createDataFrame(self.sc.parallelize([Row(id=1)], 2)) + f = pandas_udf(lambda *v: v[0], LongType()) + res = df.select(f(col('id'))) + self.assertEquals(df.collect(), res.collect()) + + def test_vectorized_udf_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType([StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, DateType()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.select(f(col('dt'))).collect() + + +@unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") +class GroupbyApplyTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.spark = SparkSession(cls.sc) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + cls.spark.stop() + + def assertFramesEqual(self, expected, result): + msg = ("DataFrames are not equal: " + + ("\n\nExpected:\n%s\n%s" % (expected, expected.dtypes)) + + ("\n\nResult:\n%s\n%s" % (result, result.dtypes))) + self.assertTrue(expected.equals(result), msg=msg) + + @property + def data(self): + from pyspark.sql.functions import array, explode, col, lit + return self.spark.range(10).toDF('id') \ + .withColumn("vs", array([lit(i) for i in range(20, 30)])) \ + .withColumn("v", explode(col('vs'))).drop('vs') + + def test_simple(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo_udf = pandas_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('v1', DoubleType()), + StructField('v2', LongType())])) + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + + def test_decorator(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('v1', DoubleType()), + StructField('v2', LongType())])) + def foo(pdf): + return pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id) + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + + def test_coerce(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo = pandas_udf( + lambda pdf: pdf, + StructType([StructField('id', LongType()), StructField('v', DoubleType())])) + + result = df.groupby('id').apply(foo).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo.func).reset_index(drop=True) + expected = expected.assign(v=expected.v.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_complex_groupby(self): + from pyspark.sql.functions import pandas_udf, col + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('norm', DoubleType())])) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby(col('id') % 2 == 0).apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = pdf.groupby(pdf['id'] % 2 == 0).apply(normalize.func) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_empty_groupby(self): + from pyspark.sql.functions import pandas_udf, col + df = self.data + + @pandas_udf(StructType( + [StructField('id', LongType()), + StructField('v', IntegerType()), + StructField('norm', DoubleType())])) + def normalize(pdf): + v = pdf.v + return pdf.assign(norm=(v - v.mean()) / v.std()) + + result = df.groupby().apply(normalize).sort('id', 'v').toPandas() + pdf = df.toPandas() + expected = normalize.func(pdf) + expected = expected.sort_values(['id', 'v']).reset_index(drop=True) + expected = expected.assign(norm=expected.norm.astype('float64')) + self.assertFramesEqual(expected, result) + + def test_datatype_string(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo_udf = pandas_udf( + lambda pdf: pdf.assign(v1=pdf.v * pdf.id * 1.0, v2=pdf.v + pdf.id), + "id long, v int, v1 double, v2 long") + + result = df.groupby('id').apply(foo_udf).sort('id').toPandas() + expected = df.toPandas().groupby('id').apply(foo_udf.func).reset_index(drop=True) + self.assertFramesEqual(expected, result) + + def test_wrong_return_type(self): + from pyspark.sql.functions import pandas_udf + df = self.data + + foo = pandas_udf( + lambda pdf: pdf, + StructType([StructField('id', LongType()), StructField('v', StringType())])) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Invalid.*type'): + df.groupby('id').apply(foo).sort('id').toPandas() + + def test_wrong_args(self): + from pyspark.sql.functions import udf, pandas_udf, sum + df = self.data + + with QuietTest(self.sc): + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(lambda x: x) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(udf(lambda x: x, DoubleType())) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(sum(df.v)) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply(df.v + 1) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply( + pandas_udf(lambda: 1, StructType([StructField("d", DoubleType())]))) + with self.assertRaisesRegexp(ValueError, 'pandas_udf'): + df.groupby('id').apply( + pandas_udf(lambda x, y: x, StructType([StructField("d", DoubleType())]))) + with self.assertRaisesRegexp(ValueError, 'returnType'): + df.groupby('id').apply(pandas_udf(lambda x: x, DoubleType())) + + def test_unsupported_types(self): + from pyspark.sql.functions import pandas_udf, col + schema = StructType( + [StructField("id", LongType(), True), StructField("dt", DateType(), True)]) + df = self.spark.createDataFrame([(1, datetime.date(1970, 1, 1),)], schema=schema) + f = pandas_udf(lambda x: x, df.schema) + with QuietTest(self.sc): + with self.assertRaisesRegexp(Exception, 'Unsupported data type'): + df.groupby('id').apply(f).collect() + + if __name__ == "__main__": from pyspark.sql.tests import * if xmlrunner: diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index ebdc11c3b744a..f65273d5f0b6c 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1597,7 +1597,7 @@ def convert(self, obj, gateway_client): register_input_converter(DateConverter()) -def toArrowType(dt): +def to_arrow_type(dt): """ Convert Spark data type to pyarrow type """ import pyarrow as pa diff --git a/python/pyspark/streaming/flume.py b/python/pyspark/streaming/flume.py index cd30483fc636a..5a975d050b0d8 100644 --- a/python/pyspark/streaming/flume.py +++ b/python/pyspark/streaming/flume.py @@ -53,7 +53,14 @@ def createStream(ssc, hostname, port, :param enableDecompression: Should netty server decompress input stream :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object + + .. note:: Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. + See SPARK-22142. """ + warnings.warn( + "Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. " + "See SPARK-22142.", + DeprecationWarning) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) helper = FlumeUtils._get_helper(ssc._sc) jstream = helper.createStream(ssc._jssc, hostname, port, jlevel, enableDecompression) @@ -79,7 +86,14 @@ def createPollingStream(ssc, addresses, will result in this stream using more threads :param bodyDecoder: A function used to decode body (default is utf8_decoder) :return: A DStream object + + .. note:: Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. + See SPARK-22142. """ + warnings.warn( + "Deprecated in 2.3.0. Flume support is deprecated as of Spark 2.3.0. " + "See SPARK-22142.", + DeprecationWarning) jlevel = ssc._sc._getJavaStorageLevel(storageLevel) hosts = [] ports = [] diff --git a/python/pyspark/streaming/kafka.py b/python/pyspark/streaming/kafka.py index 4af4135c81958..fdb9308604489 100644 --- a/python/pyspark/streaming/kafka.py +++ b/python/pyspark/streaming/kafka.py @@ -15,6 +15,8 @@ # limitations under the License. # +import warnings + from py4j.protocol import Py4JJavaError from pyspark.rdd import RDD @@ -56,8 +58,13 @@ def createStream(ssc, zkQuorum, groupId, topics, kafkaParams=None, :param valueDecoder: A function used to decode value (default is utf8_decoder) :return: A DStream object - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if kafkaParams is None: kafkaParams = dict() kafkaParams.update({ @@ -105,8 +112,13 @@ def createDirectStream(ssc, topics, kafkaParams, fromOffsets=None, :return: A DStream object .. note:: Experimental - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if fromOffsets is None: fromOffsets = dict() if not isinstance(topics, list): @@ -159,8 +171,13 @@ def createRDD(sc, kafkaParams, offsetRanges, leaders=None, :return: An RDD object .. note:: Experimental - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) if leaders is None: leaders = dict() if not isinstance(kafkaParams, dict): @@ -229,7 +246,8 @@ class OffsetRange(object): """ Represents a range of offsets from a single Kafka TopicAndPartition. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition, fromOffset, untilOffset): @@ -240,6 +258,10 @@ def __init__(self, topic, partition, fromOffset, untilOffset): :param fromOffset: Inclusive starting offset. :param untilOffset: Exclusive ending offset. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self.topic = topic self.partition = partition self.fromOffset = fromOffset @@ -270,7 +292,8 @@ class TopicAndPartition(object): """ Represents a specific topic and partition for Kafka. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition): @@ -279,6 +302,10 @@ def __init__(self, topic, partition): :param topic: Kafka topic name. :param partition: Kafka partition id. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self._topic = topic self._partition = partition @@ -303,7 +330,8 @@ class Broker(object): """ Represent the host and port info for a Kafka broker. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, host, port): @@ -312,6 +340,10 @@ def __init__(self, host, port): :param host: Broker's hostname. :param port: Broker's port. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self._host = host self._port = port @@ -323,10 +355,15 @@ class KafkaRDD(RDD): """ A Python wrapper of KafkaRDD, to provide additional information on normal RDD. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, jrdd, ctx, jrdd_deserializer): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) RDD.__init__(self, jrdd, ctx, jrdd_deserializer) def offsetRanges(self): @@ -345,10 +382,15 @@ class KafkaDStream(DStream): """ A Python wrapper of KafkaDStream - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, jdstream, ssc, jrdd_deserializer): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) DStream.__init__(self, jdstream, ssc, jrdd_deserializer) def foreachRDD(self, func): @@ -383,10 +425,15 @@ class KafkaTransformedDStream(TransformedDStream): """ Kafka specific wrapper of TransformedDStream to transform on Kafka RDD. - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, prev, func): + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) TransformedDStream.__init__(self, prev, func) @property @@ -405,7 +452,8 @@ class KafkaMessageAndMetadata(object): """ Kafka message and metadata information. Including topic, partition, offset and message - .. note:: Deprecated in 2.3.0 + .. note:: Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. + See SPARK-21893. """ def __init__(self, topic, partition, offset, key, message): @@ -419,6 +467,10 @@ def __init__(self, topic, partition, offset, key, message): :param message: actual message payload of this Kafka message, the return data is undecoded bytearray. """ + warnings.warn( + "Deprecated in 2.3.0. Kafka 0.8 support is deprecated as of Spark 2.3.0. " + "See SPARK-21893.", + DeprecationWarning) self.topic = topic self.partition = partition self.offset = offset diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index 229cf53e47359..5b86c1cb2c390 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -1478,7 +1478,7 @@ def search_kafka_assembly_jar(): ("Failed to find Spark Streaming kafka assembly jar in %s. " % kafka_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/package streaming-kafka-0-8-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pkafka-0-8 package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Kafka assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1495,7 +1495,7 @@ def search_flume_assembly_jar(): ("Failed to find Spark Streaming Flume assembly jar in %s. " % flume_assembly_dir) + "You need to build Spark with " "'build/sbt assembly/assembly streaming-flume-assembly/assembly' or " - "'build/mvn package' before running this test.") + "'build/mvn -Pflume package' before running this test.") elif len(jars) > 1: raise Exception(("Found multiple Spark Streaming Flume assembly JARs: %s; please " "remove all but one") % (", ".join(jars))) @@ -1516,6 +1516,9 @@ def search_kinesis_asl_assembly_jar(): return jars[0] +# Must be same as the variable and condition defined in modules.py +flume_test_environ_var = "ENABLE_FLUME_TESTS" +are_flume_tests_enabled = os.environ.get(flume_test_environ_var) == '1' # Must be same as the variable and condition defined in modules.py kafka_test_environ_var = "ENABLE_KAFKA_0_8_TESTS" are_kafka_tests_enabled = os.environ.get(kafka_test_environ_var) == '1' @@ -1538,9 +1541,16 @@ def search_kinesis_asl_assembly_jar(): os.environ["PYSPARK_SUBMIT_ARGS"] = "--jars %s pyspark-shell" % jars testcases = [BasicOperationTests, WindowFunctionTests, StreamingContextTests, CheckpointTests, - FlumeStreamTests, FlumePollingStreamTests, StreamingListenerTests] + if are_flume_tests_enabled: + testcases.append(FlumeStreamTests) + testcases.append(FlumePollingStreamTests) + else: + sys.stderr.write( + "Skipped test_flume_stream (enable by setting environment variable %s=1" + % flume_test_environ_var) + if are_kafka_tests_enabled: testcases.append(KafkaStreamTests) else: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index fd917c400c872..5e100e0a9a95d 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,8 +31,8 @@ from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, PythonEvalType, UTF8Deserializer, PickleSerializer, \ - BatchedSerializer, ArrowPandasSerializer -from pyspark.sql.types import toArrowType + BatchedSerializer, ArrowStreamPandasSerializer +from pyspark.sql.types import to_arrow_type from pyspark import shuffle pickleSer = PickleSerializer() @@ -74,16 +74,18 @@ def wrap_udf(f, return_type): def wrap_pandas_udf(f, return_type): - arrow_return_type = toArrowType(return_type) + arrow_return_type = to_arrow_type(return_type) def verify_result_length(*a): result = f(*a) if not hasattr(result, "__len__"): - raise TypeError("Return type of pandas_udf should be a Pandas.Series") + raise TypeError("Return type of the user-defined functon should be " + "Pandas.Series, but is {}".format(type(result))) if len(result) != len(a[0]): raise RuntimeError("Result vector from pandas_udf was not the required length: " "expected %d, got %d" % (len(a[0]), len(result))) return result + return lambda *a: (verify_result_length(*a), arrow_return_type) @@ -100,6 +102,9 @@ def read_single_udf(pickleSer, infile, eval_type): # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_PANDAS_UDF: return arg_offsets, wrap_pandas_udf(row_func, return_type) + elif eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: + # a groupby apply udf has already been wrapped under apply() + return arg_offsets, row_func else: return arg_offsets, wrap_udf(row_func, return_type) @@ -122,8 +127,9 @@ def read_udfs(pickleSer, infile, eval_type): func = lambda _, it: map(mapper, it) - if eval_type == PythonEvalType.SQL_PANDAS_UDF: - ser = ArrowPandasSerializer() + if eval_type == PythonEvalType.SQL_PANDAS_UDF \ + or eval_type == PythonEvalType.SQL_PANDAS_GROUPED_UDF: + ser = ArrowStreamPandasSerializer() else: ser = BatchedSerializer(PickleSerializer(), 100) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala index 1fe94974c8e36..76aded4edb431 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSource.scala @@ -23,8 +23,9 @@ import org.apache.spark.metrics.source.Source private[mesos] class MesosClusterSchedulerSource(scheduler: MesosClusterScheduler) extends Source { - override def sourceName: String = "mesos_cluster" - override def metricRegistry: MetricRegistry = new MetricRegistry() + + override val sourceName: String = "mesos_cluster" + override val metricRegistry: MetricRegistry = new MetricRegistry() metricRegistry.register(MetricRegistry.name("waitingDrivers"), new Gauge[Int] { override def getValue: Int = scheduler.getQueuedDriversSize diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index 26699873145b4..603c980cb268d 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -32,6 +32,7 @@ import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskStat import org.apache.spark.deploy.mesos.config._ import org.apache.spark.deploy.security.HadoopDelegationTokenManager import org.apache.spark.internal.config +import org.apache.spark.launcher.{LauncherBackend, SparkAppHandle} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.mesos.MesosExternalShuffleClient import org.apache.spark.rpc.RpcEndpointAddress @@ -89,6 +90,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Synchronization protected by stateLock private[this] var stopCalled: Boolean = false + private val launcherBackend = new LauncherBackend() { + override protected def onStopRequest(): Unit = { + stopSchedulerBackend() + setState(SparkAppHandle.State.KILLED) + } + } + // If shuffle service is enabled, the Spark driver will register with the shuffle service. // This is for cleaning up shuffle files reliably. private val shuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false) @@ -99,6 +107,14 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( private var totalCoresAcquired = 0 private var totalGpusAcquired = 0 + // The amount of time to wait for locality scheduling + private val localityWait = conf.get(config.LOCALITY_WAIT) + // The start of the waiting, for data local scheduling + private var localityWaitStartTime = System.currentTimeMillis() + // If true, the scheduler is in the process of launching executors to reach the requested + // executor limit + private var launchingExecutors = false + // SlaveID -> Slave // This map accumulates entries for the duration of the job. Slaves are never deleted, because // we need to maintain e.g. failure state and connection state. @@ -174,6 +190,9 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( override def start() { super.start() + if (sc.deployMode == "client") { + launcherBackend.connect() + } val startedBefore = IdHelper.startedBefore.getAndSet(true) val suffix = if (startedBefore) { @@ -194,6 +213,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( sc.conf.getOption("spark.mesos.driver.frameworkId").map(_ + suffix) ) + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) startScheduler(driver) } @@ -287,15 +307,21 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( this.mesosExternalShuffleClient.foreach(_.init(appId)) this.schedulerDriver = driver markRegistered() + launcherBackend.setAppId(appId) + launcherBackend.setState(SparkAppHandle.State.RUNNING) } override def sufficientResourcesRegistered(): Boolean = { totalCoreCount.get >= maxCoresOption.getOrElse(0) * minRegisteredRatio } - override def disconnected(d: org.apache.mesos.SchedulerDriver) {} + override def disconnected(d: org.apache.mesos.SchedulerDriver) { + launcherBackend.setState(SparkAppHandle.State.SUBMITTED) + } - override def reregistered(d: org.apache.mesos.SchedulerDriver, masterInfo: MasterInfo) {} + override def reregistered(d: org.apache.mesos.SchedulerDriver, masterInfo: MasterInfo) { + launcherBackend.setState(SparkAppHandle.State.RUNNING) + } /** * Method called by Mesos to offer resources on slaves. We respond by launching an executor, @@ -311,6 +337,19 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( return } + if (numExecutors >= executorLimit) { + logDebug("Executor limit reached. numExecutors: " + numExecutors + + " executorLimit: " + executorLimit) + offers.asScala.map(_.getId).foreach(d.declineOffer) + launchingExecutors = false + return + } else { + if (!launchingExecutors) { + launchingExecutors = true + localityWaitStartTime = System.currentTimeMillis() + } + } + logDebug(s"Received ${offers.size} resource offers.") val (matchedOffers, unmatchedOffers) = offers.asScala.partition { offer => @@ -413,7 +452,7 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( val offerId = offer.getId.getValue val resources = remainingResources(offerId) - if (canLaunchTask(slaveId, resources)) { + if (canLaunchTask(slaveId, offer.getHostname, resources)) { // Create a task launchTasks = true val taskId = newMesosTaskId() @@ -477,7 +516,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpuResourcesToUse ++ memResourcesToUse ++ portResourcesToUse ++ gpuResourcesToUse) } - private def canLaunchTask(slaveId: String, resources: JList[Resource]): Boolean = { + private def canLaunchTask(slaveId: String, offerHostname: String, + resources: JList[Resource]): Boolean = { val offerMem = getResource(resources, "mem") val offerCPUs = getResource(resources, "cpus").toInt val cpus = executorCores(offerCPUs) @@ -489,9 +529,10 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( cpus <= offerCPUs && cpus + totalCoresAcquired <= maxCores && mem <= offerMem && - numExecutors() < executorLimit && + numExecutors < executorLimit && slaves.get(slaveId).map(_.taskFailures).getOrElse(0) < MAX_SLAVE_FAILURES && - meetsPortRequirements + meetsPortRequirements && + satisfiesLocality(offerHostname) } private def executorCores(offerCPUs: Int): Int = { @@ -500,6 +541,25 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( ) } + private def satisfiesLocality(offerHostname: String): Boolean = { + if (!Utils.isDynamicAllocationEnabled(conf) || hostToLocalTaskCount.isEmpty) { + return true + } + + // Check the locality information + val currentHosts = slaves.values.filter(_.taskIDs.nonEmpty).map(_.hostname).toSet + val allDesiredHosts = hostToLocalTaskCount.keys.toSet + // Try to match locality for hosts which do not have executors yet, to potentially + // increase coverage. + val remainingHosts = allDesiredHosts -- currentHosts + if (!remainingHosts.contains(offerHostname) && + (System.currentTimeMillis() - localityWaitStartTime <= localityWait)) { + logDebug("Skipping host and waiting for locality. host: " + offerHostname) + return false + } + return true + } + override def statusUpdate(d: org.apache.mesos.SchedulerDriver, status: TaskStatus) { val taskId = status.getTaskId.getValue val slaveId = status.getSlaveId.getValue @@ -569,6 +629,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def stop() { + stopSchedulerBackend() + launcherBackend.setState(SparkAppHandle.State.FINISHED) + launcherBackend.close() + } + + private def stopSchedulerBackend() { // Make sure we're not launching tasks during shutdown stateLock.synchronized { if (stopCalled) { @@ -646,6 +712,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // since at coarse grain it depends on the amount of slaves available. logInfo("Capping the total amount of executors to " + requestedTotal) executorLimitOption = Some(requestedTotal) + // Update the locality wait start time to continue trying for locality. + localityWaitStartTime = System.currentTimeMillis() true } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index f6bae01c3af59..6c40792112f49 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -604,6 +604,55 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite assert(backend.isReady) } + test("supports data locality with dynamic allocation") { + setBackend(Map( + "spark.dynamicAllocation.enabled" -> "true", + "spark.dynamicAllocation.testing" -> "true", + "spark.locality.wait" -> "1s")) + + assert(backend.getExecutorIds().isEmpty) + + backend.requestTotalExecutors(2, 2, Map("hosts10" -> 1, "hosts11" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(1, false) + offerResourcesAndVerify(2, false) + + // Offer local resource + offerResourcesAndVerify(10, true) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Offer non-local resource, which should be accepted + offerResourcesAndVerify(1, true) + + // Update total executors + backend.requestTotalExecutors(3, 3, Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(3, false) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Update total executors + backend.requestTotalExecutors(4, 4, Map("hosts10" -> 1, "hosts11" -> 1, "hosts12" -> 1, + "hosts13" -> 1)) + + // Offer non-local resources, which should be rejected + offerResourcesAndVerify(3, false) + + // Offer local resource + offerResourcesAndVerify(13, true) + + // Wait longer than spark.locality.wait + Thread.sleep(2000) + + // Offer non-local resource, which should be accepted + offerResourcesAndVerify(2, true) + } + private case class Resources(mem: Int, cpus: Int, gpus: Int = 0) private def registerMockExecutor(executorId: String, slaveId: String, cores: Integer) = { @@ -631,6 +680,19 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite backend.resourceOffers(driver, mesosOffers.asJava) } + private def offerResourcesAndVerify(id: Int, expectAccept: Boolean): Unit = { + offerResources(List(Resources(backend.executorMemory(sc), 1)), id) + if (expectAccept) { + val numExecutors = backend.getExecutorIds().size + val launchedTasks = verifyTaskLaunched(driver, s"o$id") + assert(s"s$id" == launchedTasks.head.getSlaveId.getValue) + registerMockExecutor(launchedTasks.head.getTaskId.getValue, s"s$id", 1) + assert(backend.getExecutorIds().size == numExecutors + 1) + } else { + verifyTaskNotLaunched(driver, s"o$id") + } + } + private def createTaskStatus(taskId: String, slaveId: String, state: TaskState): TaskStatus = { TaskStatus.newBuilder() .setTaskId(TaskID.newBuilder().setValue(taskId).build()) diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 2a67cbc913ffe..833db0c1ff334 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -84,6 +84,12 @@ object Utils { captor.getValue.asScala.toList } + def verifyTaskNotLaunched(driver: SchedulerDriver, offerId: String): Unit = { + verify(driver, times(0)).launchTasks( + Matchers.eq(Collections.singleton(createOfferId(offerId))), + Matchers.any(classOf[java.util.Collection[TaskInfo]])) + } + def createOfferId(offerId: String): OfferID = { OfferID.newBuilder().setValue(offerId).build() } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index e227bff88f71d..244d912b9f3aa 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -20,6 +20,7 @@ package org.apache.spark.deploy.yarn import java.io.{File, IOException} import java.lang.reflect.InvocationTargetException import java.net.{Socket, URI, URL} +import java.security.PrivilegedExceptionAction import java.util.concurrent.{TimeoutException, TimeUnit} import scala.collection.mutable.HashMap @@ -28,6 +29,7 @@ import scala.concurrent.duration.Duration import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.conf.YarnConfiguration @@ -49,10 +51,7 @@ import org.apache.spark.util._ /** * Common application master functionality for Spark on Yarn. */ -private[spark] class ApplicationMaster( - args: ApplicationMasterArguments, - client: YarnRMClient) - extends Logging { +private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends Logging { // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. @@ -62,6 +61,44 @@ private[spark] class ApplicationMaster( .asInstanceOf[YarnConfiguration] private val isClusterMode = args.userClass != null + private val ugi = { + val original = UserGroupInformation.getCurrentUser() + + // If a principal and keytab were provided, log in to kerberos, and set up a thread to + // renew the kerberos ticket when needed. Because the UGI API does not expose the TTL + // of the TGT, use a configuration to define how often to check that a relogin is necessary. + // checkTGTAndReloginFromKeytab() is a no-op if the relogin is not yet needed. + val principal = sparkConf.get(PRINCIPAL).orNull + val keytab = sparkConf.get(KEYTAB).orNull + if (principal != null && keytab != null) { + UserGroupInformation.loginUserFromKeytab(principal, keytab) + + val renewer = new Thread() { + override def run(): Unit = Utils.tryLogNonFatalError { + while (true) { + TimeUnit.SECONDS.sleep(sparkConf.get(KERBEROS_RELOGIN_PERIOD)) + UserGroupInformation.getCurrentUser().checkTGTAndReloginFromKeytab() + } + } + } + renewer.setName("am-kerberos-renewer") + renewer.setDaemon(true) + renewer.start() + + // Transfer the original user's tokens to the new user, since that's needed to connect to + // YARN. It also copies over any delegation tokens that might have been created by the + // client, which will then be transferred over when starting executors (until new ones + // are created by the periodic task). + val newUser = UserGroupInformation.getCurrentUser() + SparkHadoopUtil.get.transferCredentials(original, newUser) + newUser + } else { + SparkHadoopUtil.get.createSparkUser() + } + } + + private val client = doAsUser { new YarnRMClient() } + // Default to twice the number of executors (twice the maximum number of executors if dynamic // allocation is enabled), with a minimum of 3. @@ -139,7 +176,7 @@ private[spark] class ApplicationMaster( // Load the list of localized files set by the client. This is used when launching executors, // and is loaded here so that these configs don't pollute the Web UI's environment page in // cluster mode. - private val localResources = { + private val localResources = doAsUser { logInfo("Preparing Local resources") val resources = HashMap[String, LocalResource]() @@ -201,6 +238,13 @@ private[spark] class ApplicationMaster( } final def run(): Int = { + doAsUser { + runImpl() + } + exitCode + } + + private def runImpl(): Unit = { try { val appAttemptId = client.getAttemptId() @@ -254,11 +298,6 @@ private[spark] class ApplicationMaster( } } - // Call this to force generation of secret so it gets populated into the - // Hadoop UGI. This has to happen before the startUserApplication which does a - // doAs in order for the credentials to be passed on to the executor containers. - val securityMgr = new SecurityManager(sparkConf) - // If the credentials file config is present, we must periodically renew tokens. So create // a new AMDelegationTokenRenewer if (sparkConf.contains(CREDENTIALS_FILE_PATH)) { @@ -284,6 +323,9 @@ private[spark] class ApplicationMaster( credentialRenewerThread.join() } + // Call this to force generation of secret so it gets populated into the Hadoop UGI. + val securityMgr = new SecurityManager(sparkConf) + if (isClusterMode) { runDriver(securityMgr) } else { @@ -297,7 +339,6 @@ private[spark] class ApplicationMaster( ApplicationMaster.EXIT_UNCAUGHT_EXCEPTION, "Uncaught exception: " + e) } - exitCode } /** @@ -747,6 +788,12 @@ private[spark] class ApplicationMaster( } } + private def doAsUser[T](fn: => T): T = { + ugi.doAs(new PrivilegedExceptionAction[T]() { + override def run: T = fn + }) + } + } object ApplicationMaster extends Logging { @@ -775,10 +822,8 @@ object ApplicationMaster extends Logging { sys.props(k) = v } } - SparkHadoopUtil.get.runAsSparkUser { () => - master = new ApplicationMaster(amArgs, new YarnRMClient) - System.exit(master.run()) - } + master = new ApplicationMaster(amArgs) + System.exit(master.run()) } private[spark] def sparkContextInitialized(sc: SparkContext): Unit = { diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 64b2b4d4db549..1fe25c4ddaabf 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -394,7 +394,10 @@ private[spark] class Client( if (credentials != null) { // Add credentials to current user's UGI, so that following operations don't need to use the // Kerberos tgt to get delegations again in the client side. - UserGroupInformation.getCurrentUser.addCredentials(credentials) + val currentUser = UserGroupInformation.getCurrentUser() + if (SparkHadoopUtil.get.isProxyUser(currentUser)) { + currentUser.addCredentials(credentials) + } logDebug(YarnSparkHadoopUtil.get.dumpTokens(credentials).mkString("\n")) } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala index 187803cc6050b..e1af8ba087d6e 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config.scala @@ -347,6 +347,10 @@ package object config { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(Long.MaxValue) + private[spark] val KERBEROS_RELOGIN_PERIOD = ConfigBuilder("spark.yarn.kerberos.relogin.period") + .timeConf(TimeUnit.SECONDS) + .createWithDefaultString("1m") + // The list of cache-related config entries. This is used by Client and the AM to clean // up the environment so that these settings do not appear on the web UI. private[yarn] val CACHE_CONFIGS = Seq( diff --git a/scalastyle-config.xml b/scalastyle-config.xml index bd7f462b722cd..7bdd3fac773a3 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -86,7 +86,7 @@ This file is divided into 3 sections: - + diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d0a54288780ea..17c8404f8a79c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -25,7 +25,7 @@ grammar SqlBase; * For char stream "2.3", "2." is not a valid decimal token, because it is followed by digit '3'. * For char stream "2.3_", "2.3" is not a valid decimal token, because it is followed by '_'. * For char stream "2.3W", "2.3" is not a valid decimal token, because it is followed by 'W'. - * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is folllowed + * For char stream "12.0D 34.E2+0.12 " 12.0D is a valid decimal token because it is followed * by a space. 34.E2 is a valid decimal token because it is followed by symbol '+' * which is not a digit or letter or underscore. */ @@ -40,10 +40,6 @@ grammar SqlBase; } } -tokens { - DELIMITER -} - singleStatement : statement EOF ; @@ -447,12 +443,15 @@ joinCriteria ; sample - : TABLESAMPLE '(' - ( (negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) - | (expression sampleType=ROWS) - | sampleType=BYTELENGTH_LITERAL - | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON (identifier | qualifiedName '(' ')'))?)) - ')' + : TABLESAMPLE '(' sampleMethod? ')' + ; + +sampleMethod + : negativeSign=MINUS? percentage=(INTEGER_VALUE | DECIMAL_VALUE) PERCENTLIT #sampleByPercentile + | expression ROWS #sampleByRows + | sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE + (ON (identifier | qualifiedName '(' ')'))? #sampleByBucket + | bytes=expression #sampleByBytes ; identifierList @@ -1004,10 +1003,6 @@ TINYINT_LITERAL : DIGIT+ 'Y' ; -BYTELENGTH_LITERAL - : DIGIT+ ('B' | 'K' | 'M' | 'G') - ; - INTEGER_VALUE : DIGIT+ ; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java index 971d19973f067..259976118c12f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolder.java @@ -19,6 +19,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow; import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; /** * A helper class to manage the data buffer for an unsafe row. The data buffer can grow and @@ -36,9 +37,7 @@ */ public class BufferHolder { - // Some JVMs can't allocate arrays of length Integer.MAX_VALUE; actual max is somewhat - // smaller. Be conservative and lower the cap a little. - private static final int ARRAY_MAX = Integer.MAX_VALUE - 8; + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; public byte[] buffer; public int cursor = Platform.BYTE_ARRAY_OFFSET; @@ -51,7 +50,7 @@ public BufferHolder(UnsafeRow row) { public BufferHolder(UnsafeRow row, int initialSize) { int bitsetWidthInBytes = UnsafeRow.calculateBitSetWidthInBytes(row.numFields()); - if (row.numFields() > (Integer.MAX_VALUE - initialSize - bitsetWidthInBytes) / 8) { + if (row.numFields() > (ARRAY_MAX - initialSize - bitsetWidthInBytes) / 8) { throw new UnsupportedOperationException( "Cannot create BufferHolder for input UnsafeRow because there are " + "too many fields (number of fields: " + row.numFields() + ")"); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8edf575db7969..d6a962a14dc9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.SubExprUtils._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.objects.{LambdaVariable, MapObjects, NewInstance, UnresolvedMapObjects} import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, _} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.catalyst.trees.TreeNodeRef import org.apache.spark.sql.catalyst.util.toPrettySQL @@ -293,12 +293,6 @@ class Analyzer( Seq(Seq.empty) } - private def hasGroupingAttribute(expr: Expression): Boolean = { - expr.collectFirst { - case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u - }.isDefined - } - private[analysis] def hasGroupingFunction(e: Expression): Boolean = { e.collectFirst { case g: Grouping => g @@ -452,9 +446,6 @@ class Analyzer( // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. - case p if p.expressions.exists(hasGroupingAttribute) => - failAnalysis( - s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") // Ensure group by expressions and aggregate expressions have been resolved. case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) @@ -1174,6 +1165,10 @@ class Analyzer( case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. + case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => + withPosition(u) { + Alias(GroupingID(Nil), VirtualColumn.hiveGroupingIdName)() + } case u @ UnresolvedGenerator(name, children) => withPosition(u) { catalog.lookupFunction(name, children) match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d9906bb6e6ede..b5e8bdd79869e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -272,10 +272,23 @@ trait CheckAnalysis extends PredicateHelper { case o if o.children.nonEmpty && o.missingInput.nonEmpty => val missingAttributes = o.missingInput.mkString(",") val input = o.inputSet.mkString(",") + val msgForMissingAttributes = s"Resolved attribute(s) $missingAttributes missing " + + s"from $input in operator ${operator.simpleString}." - failAnalysis( - s"resolved attribute(s) $missingAttributes missing from $input " + - s"in operator ${operator.simpleString}") + val resolver = plan.conf.resolver + val attrsWithSameName = o.missingInput.filter { missing => + o.inputSet.exists(input => resolver(missing.name, input.name)) + } + + val msg = if (attrsWithSameName.nonEmpty) { + val sameNames = attrsWithSameName.map(_.name).mkString(",") + s"$msgForMissingAttributes Attribute(s) with the same name appear in the " + + s"operation: $sameNames. Please check if the right attribute(s) are used." + } else { + msgForMissingAttributes + } + + failAnalysis(msg) case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => failAnalysis( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala new file mode 100644 index 0000000000000..072dc954879ca --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelper.scala @@ -0,0 +1,286 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import scala.util.control.NonFatal + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.expressions.{Add, AttributeReference, AttributeSet, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + + +/** + * Helper object for stream joins. See [[StreamingSymmetricHashJoinExec]] in SQL for more details. + */ +object StreamingJoinHelper extends PredicateHelper with Logging { + + /** + * Check the provided logical plan to see if its join keys contain a watermark attribute. + * + * Will return false if the plan is not an equijoin. + * @param plan the logical plan to check + */ + def isWatermarkInJoinKeys(plan: LogicalPlan): Boolean = { + plan match { + case ExtractEquiJoinKeys(_, leftKeys, rightKeys, _, _, _) => + (leftKeys ++ rightKeys).exists { + case a: AttributeReference => a.metadata.contains(EventTimeWatermark.delayKey) + case _ => false + } + case _ => false + } + } + + /** + * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) + * given the join condition and the event time watermark. This is how it works. + * - The condition is split into conjunctive predicates, and we find the predicates of the + * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). + * - We canoncalize the predicate and solve it with the event time watermark value to find the + * value of the state watermark. + * This function is supposed to make best-effort attempt to get the state watermark. If there is + * any error, it will return None. + * + * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark + * is to be calculated + * @param attributesWithEventWatermark attributes of the other side which has a watermark column + * @param joinCondition join condition + * @param eventWatermark watermark defined on the input event data + * @return state value watermark in milliseconds, is possible. + */ + def getStateValueWatermark( + attributesToFindStateWatermarkFor: AttributeSet, + attributesWithEventWatermark: AttributeSet, + joinCondition: Option[Expression], + eventWatermark: Option[Long]): Option[Long] = { + + // If condition or event time watermark is not provided, then cannot calculate state watermark + if (joinCondition.isEmpty || eventWatermark.isEmpty) return None + + // If there is not watermark attribute, then cannot define state watermark + if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None + + def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { + try { + getStateWatermarkFromLessThenPredicate( + l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark) + } catch { + case NonFatal(e) => + logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) + None + } + } + + val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => + + // The generated the state watermark cleanup expression is inclusive of the state watermark. + // If state watermark is W, all state where timestamp <= W will be cleaned up. + // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean + // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below. + val stateWatermark = predicate match { + case LessThan(l, r) => getStateWatermarkSafely(l, r) + case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) + case GreaterThan(l, r) => getStateWatermarkSafely(r, l) + case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) + case _ => None + } + if (stateWatermark.nonEmpty) { + logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") + } + stateWatermark + } + allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) + } + + /** + * Extract the state value watermark (milliseconds) from the condition + * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for + * leftTime using the watermark on the rightTime. Example: + * + * Input: rightTime-with-watermark + c1 < leftTime + c2 + * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 + * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime + * With watermark value: watermark-value + c1 + (-c2) < leftTime + */ + private def getStateWatermarkFromLessThenPredicate( + leftExpr: Expression, + rightExpr: Expression, + attributesToFindStateWatermarkFor: AttributeSet, + attributesWithEventWatermark: AttributeSet, + eventWatermark: Option[Long]): Option[Long] = { + + val attributesInCondition = AttributeSet( + leftExpr.collect { case a: AttributeReference => a } ++ + rightExpr.collect { case a: AttributeReference => a } + ) + if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 || + attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) { + // If more than attributes present in condition from one side, then it cannot be solved + return None + } + + def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { + e.collectLeaves().collectFirst { + case a @ AttributeReference(_, _, _, _) + if attributesToFindStateWatermarkFor.contains(a) => a + }.nonEmpty + } + + // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 + val allOnLeftExpr = Subtract(leftExpr, rightExpr) + logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") + + // Canonicalization step 2: extract commutative terms + // rightTime-with-watermark, c1, -leftTime, -c2 + val terms = ExpressionSet(collectTerms(allOnLeftExpr)) + logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) + + // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor + val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) + + // Verify there is only one correct constraint term and of the correct type + if (constraintTerms.size > 1) { + logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + if (constraintTerms.isEmpty) { + logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + + terms.mkString("\n\t")) + return None + } + val constraintTerm = constraintTerms.head + if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { + // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` + // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. + // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark + // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about + // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, + // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. + return None + } + + // Replace watermark attribute with watermark value, and generate the resolved expression + // from the other terms. That is, + // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) + logDebug(s"Constraint term from join condition:\t$constraintTerm") + val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => + term.transform { + case a @ AttributeReference(_, _, _, metadata) + if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) => + Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0)) + } + }.reduceLeft(Add) + + // Calculate the constraint value + logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") + val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] + Some((Double2double(constraintValue) / 1000.0).toLong) + } + + /** + * Collect all the terms present in an expression after converting it into the form + * a + b + c + d where each term be either an attribute or a literal casted to long, + * optionally wrapped in a unary minus. + */ + private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { + var invalid = false + + /** Wrap a term with UnaryMinus if its needs to be negated. */ + def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { + if (minus) UnaryMinus(expr) else expr + } + + /** + * Recursively split the expression into its leaf terms contains attributes or literals. + * Returns terms only of the forms: + * Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)), + * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) + * Multiply(Literal), UnaryMinus(Multiply(Literal)) + * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) + * + * Note: + * - If term needs to be negated for making it a commutative term, + * then it will be wrapped in UnaryMinus(...) + * - Each terms will be representing timestamp value or time interval in microseconds, + * typed as doubles. + */ + def collect(expr: Expression, negate: Boolean): Seq[Expression] = { + expr match { + case Add(left, right) => + collect(left, negate) ++ collect(right, negate) + case Subtract(left, right) => + collect(left, negate) ++ collect(right, !negate) + case TimeAdd(left, right, _) => + collect(left, negate) ++ collect(right, negate) + case TimeSub(left, right, _) => + collect(left, negate) ++ collect(right, !negate) + case UnaryMinus(child) => + collect(child, !negate) + case CheckOverflow(child, _) => + collect(child, negate) + case Cast(child, dataType, _) => + dataType match { + case _: NumericType | _: TimestampType => collect(child, negate) + case _ => + invalid = true + Seq.empty + } + case a: AttributeReference => + val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a + Seq(negateIfNeeded(castedRef, negate)) + case lit: Literal => + // If literal of type calendar interval, then explicitly convert to millis + // Convert other number like literal to doubles representing millis (by x1000) + val castedLit = lit.dataType match { + case CalendarIntervalType => + val calendarInterval = lit.value.asInstanceOf[CalendarInterval] + if (calendarInterval.months > 0) { + invalid = true + logWarning( + s"Failed to extract state value watermark from condition $exprToCollectFrom " + + s"as imprecise intervals like months and years cannot be used for" + + s"watermark calculation. Use interval in terms of day instead.") + Literal(0.0) + } else { + Literal(calendarInterval.microseconds.toDouble) + } + case DoubleType => + Multiply(lit, Literal(1000000.0)) + case _: NumericType => + Multiply(Cast(lit, DoubleType), Literal(1000000.0)) + case _: TimestampType => + Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0)) + } + Seq(negateIfNeeded(castedLit, negate)) + case a @ _ => + logWarning( + s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a") + invalid = true + Seq.empty + } + } + + val terms = collect(exprToCollectFrom, negate = false) + if (!invalid) terms else Seq.empty + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index 9ffe646b5e4ec..532d22dbf2321 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -100,6 +100,16 @@ object TypeCoercion { case (_: TimestampType, _: DateType) | (_: DateType, _: TimestampType) => Some(TimestampType) + case (t1 @ StructType(fields1), t2 @ StructType(fields2)) if t1.sameType(t2) => + Some(StructType(fields1.zip(fields2).map { case (f1, f2) => + // Since `t1.sameType(t2)` is true, two StructTypes have the same DataType + // except `name` (in case of `spark.sql.caseSensitive=false`) and `nullable`. + // - Different names: use f1.name + // - Different nullabilities: `nullable` is true iff one of them is nullable. + val dataType = findTightestCommonType(f1.dataType, f2.dataType).get + StructField(f1.name, dataType, nullable = f1.nullable || f2.nullable) + })) + case _ => None } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala index d1d705691b076..04502d04d9509 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet, MonotonicallyIncreasingID} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes @@ -128,6 +129,16 @@ object UnsupportedOperationChecker { !subplan.isStreaming || (aggs.nonEmpty && outputMode == InternalOutputModes.Complete) } + def checkUnsupportedExpressions(implicit operator: LogicalPlan): Unit = { + val unsupportedExprs = operator.expressions.flatMap(_.collect { + case m: MonotonicallyIncreasingID => m + }).distinct + if (unsupportedExprs.nonEmpty) { + throwError("Expression(s): " + unsupportedExprs.map(_.sql).mkString(", ") + + " is not supported with streaming DataFrames/Datasets") + } + } + plan.foreachUp { implicit subPlan => // Operations that cannot exists anywhere in a streaming plan @@ -217,7 +228,7 @@ object UnsupportedOperationChecker { throwError("dropDuplicates is not supported after aggregation on a " + "streaming DataFrame/Dataset") - case Join(left, right, joinType, _) => + case Join(left, right, joinType, condition) => joinType match { @@ -233,16 +244,52 @@ object UnsupportedOperationChecker { throwError("Full outer joins with streaming DataFrames/Datasets are not supported") } - case LeftOuter | LeftSemi | LeftAnti => + case LeftSemi | LeftAnti => if (right.isStreaming) { - throwError("Left outer/semi/anti joins with a streaming DataFrame/Dataset " + - "on the right is not supported") + throwError("Left semi/anti joins with a streaming DataFrame/Dataset " + + "on the right are not supported") } + // We support streaming left outer joins with static on the right always, and with + // stream on both sides under the appropriate conditions. + case LeftOuter => + if (!left.isStreaming && right.isStreaming) { + throwError("Left outer join with a streaming DataFrame/Dataset " + + "on the right and a static DataFrame/Dataset on the left is not supported") + } else if (left.isStreaming && right.isStreaming) { + val watermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) + + val hasValidWatermarkRange = + StreamingJoinHelper.getStateValueWatermark( + left.outputSet, right.outputSet, condition, Some(1000000)).isDefined + + if (!watermarkInJoinKeys && !hasValidWatermarkRange) { + throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + "is not supported without a watermark in the join keys, or a watermark on " + + "the nullable side and an appropriate range condition") + } + } + + // We support streaming right outer joins with static on the left always, and with + // stream on both sides under the appropriate conditions. case RightOuter => - if (left.isStreaming) { - throwError("Right outer join with a streaming DataFrame/Dataset on the left is " + - "not supported") + if (left.isStreaming && !right.isStreaming) { + throwError("Right outer join with a streaming DataFrame/Dataset on the left and " + + "a static DataFrame/DataSet on the right not supported") + } else if (left.isStreaming && right.isStreaming) { + val isWatermarkInJoinKeys = StreamingJoinHelper.isWatermarkInJoinKeys(subPlan) + + // Check if the nullable side has a watermark, and there's a range condition which + // implies a state value watermark on the first side. + val hasValidWatermarkRange = + StreamingJoinHelper.getStateValueWatermark( + right.outputSet, left.outputSet, condition, Some(1000000)).isDefined + + if (!isWatermarkInJoinKeys && !hasValidWatermarkRange) { + throwError("Stream-stream outer join between two streaming DataFrame/Datasets " + + "is not supported without a watermark in the join keys, or a watermark on " + + "the nullable side and an appropriate range condition") + } } case NaturalJoin(_) | UsingJoin(_, _) => @@ -286,6 +333,9 @@ object UnsupportedOperationChecker { case _ => } + + // Check if there are unsupported expressions in streaming query plan. + checkUnsupportedExpressions(subPlan) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 9407b727bca4c..95bc3d674b4f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.catalog -import java.lang.reflect.InvocationTargetException import java.net.URI import java.util.Locale import java.util.concurrent.Callable @@ -25,7 +24,6 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable import scala.util.{Failure, Success, Try} -import scala.util.control.NonFatal import com.google.common.cache.{Cache, CacheBuilder} import org.apache.hadoop.conf.Configuration @@ -41,7 +39,6 @@ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParserInterface} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias, View} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -52,7 +49,7 @@ object SessionCatalog { /** * An internal catalog that is used by a Spark Session. This internal catalog serves as a * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary - * tables and functions of the Spark Session that it belongs to. + * views and functions of the Spark Session that it belongs to. * * This class must be thread-safe. */ @@ -90,19 +87,19 @@ class SessionCatalog( new SQLConf().copy(SQLConf.CASE_SENSITIVE -> true)) } - /** List of temporary tables, mapping from table name to their logical plan. */ + /** List of temporary views, mapping from table name to their logical plan. */ @GuardedBy("this") - protected val tempTables = new mutable.HashMap[String, LogicalPlan] + protected val tempViews = new mutable.HashMap[String, LogicalPlan] // Note: we track current database here because certain operations do not explicitly // specify the database (e.g. DROP TABLE my_table). In these cases we must first - // check whether the temporary table or function exists, then, if not, operate on + // check whether the temporary view or function exists, then, if not, operate on // the corresponding item in the current database. @GuardedBy("this") protected var currentDb: String = formatDatabaseName(DEFAULT_DATABASE) /** - * Checks if the given name conforms the Hive standard ("[a-zA-z_0-9]+"), + * Checks if the given name conforms the Hive standard ("[a-zA-Z_0-9]+"), * i.e. if this name only contains characters, numbers, and _. * * This method is intended to have the same behavior of @@ -272,8 +269,8 @@ class SessionCatalog( // ---------------------------------------------------------------------------- // Tables // ---------------------------------------------------------------------------- - // There are two kinds of tables, temporary tables and metastore tables. - // Temporary tables are isolated across sessions and do not belong to any + // There are two kinds of tables, temporary views and metastore tables. + // Temporary views are isolated across sessions and do not belong to any // particular database. Metastore tables can be used across multiple // sessions as their metadata is persisted in the underlying catalog. // ---------------------------------------------------------------------------- @@ -462,10 +459,10 @@ class SessionCatalog( tableDefinition: LogicalPlan, overrideIfExists: Boolean): Unit = synchronized { val table = formatTableName(name) - if (tempTables.contains(table) && !overrideIfExists) { + if (tempViews.contains(table) && !overrideIfExists) { throw new TempTableAlreadyExistsException(name) } - tempTables.put(table, tableDefinition) + tempViews.put(table, tableDefinition) } /** @@ -487,7 +484,7 @@ class SessionCatalog( viewDefinition: LogicalPlan): Boolean = synchronized { val viewName = formatTableName(name.table) if (name.database.isEmpty) { - if (tempTables.contains(viewName)) { + if (tempViews.contains(viewName)) { createTempView(viewName, viewDefinition, overrideIfExists = true) true } else { @@ -504,7 +501,7 @@ class SessionCatalog( * Return a local temporary view exactly as it was stored. */ def getTempView(name: String): Option[LogicalPlan] = synchronized { - tempTables.get(formatTableName(name)) + tempViews.get(formatTableName(name)) } /** @@ -520,7 +517,7 @@ class SessionCatalog( * Returns true if this view is dropped successfully, false otherwise. */ def dropTempView(name: String): Boolean = synchronized { - tempTables.remove(formatTableName(name)).isDefined + tempViews.remove(formatTableName(name)).isDefined } /** @@ -572,7 +569,7 @@ class SessionCatalog( * Rename a table. * * If a database is specified in `oldName`, this will rename the table in that database. - * If no database is specified, this will first attempt to rename a temporary table with + * If no database is specified, this will first attempt to rename a temporary view with * the same name, then, if that does not exist, rename the table in the current database. * * This assumes the database specified in `newName` matches the one in `oldName`. @@ -592,7 +589,7 @@ class SessionCatalog( globalTempViewManager.rename(oldTableName, newTableName) } else { requireDbExists(db) - if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { + if (oldName.database.isDefined || !tempViews.contains(oldTableName)) { requireTableExists(TableIdentifier(oldTableName, Some(db))) requireTableNotExists(TableIdentifier(newTableName, Some(db))) validateName(newTableName) @@ -600,16 +597,16 @@ class SessionCatalog( } else { if (newName.database.isDefined) { throw new AnalysisException( - s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': cannot specify database " + + s"RENAME TEMPORARY VIEW from '$oldName' to '$newName': cannot specify database " + s"name '${newName.database.get}' in the destination table") } - if (tempTables.contains(newTableName)) { - throw new AnalysisException(s"RENAME TEMPORARY TABLE from '$oldName' to '$newName': " + + if (tempViews.contains(newTableName)) { + throw new AnalysisException(s"RENAME TEMPORARY VIEW from '$oldName' to '$newName': " + "destination table already exists") } - val table = tempTables(oldTableName) - tempTables.remove(oldTableName) - tempTables.put(newTableName, table) + val table = tempViews(oldTableName) + tempViews.remove(oldTableName) + tempViews.put(newTableName, table) } } } @@ -618,7 +615,7 @@ class SessionCatalog( * Drop a table. * * If a database is specified in `name`, this will drop the table from that database. - * If no database is specified, this will first attempt to drop a temporary table with + * If no database is specified, this will first attempt to drop a temporary view with * the same name, then, if that does not exist, drop the table from the current database. */ def dropTable( @@ -633,7 +630,7 @@ class SessionCatalog( throw new NoSuchTableException(globalTempViewManager.database, table) } } else { - if (name.database.isDefined || !tempTables.contains(table)) { + if (name.database.isDefined || !tempViews.contains(table)) { requireDbExists(db) // When ignoreIfNotExists is false, no exception is issued when the table does not exist. // Instead, log it as an error message. @@ -643,7 +640,7 @@ class SessionCatalog( throw new NoSuchTableException(db = db, table = table) } } else { - tempTables.remove(table) + tempViews.remove(table) } } } @@ -652,7 +649,7 @@ class SessionCatalog( * Return a [[LogicalPlan]] that represents the given table or view. * * If a database is specified in `name`, this will return the table/view from that database. - * If no database is specified, this will first attempt to return a temporary table/view with + * If no database is specified, this will first attempt to return a temporary view with * the same name, then, if that does not exist, return the table/view from the current database. * * Note that, the global temp view database is also valid here, this will return the global temp @@ -671,7 +668,7 @@ class SessionCatalog( globalTempViewManager.get(table).map { viewDef => SubqueryAlias(table, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) - } else if (name.database.isDefined || !tempTables.contains(table)) { + } else if (name.database.isDefined || !tempViews.contains(table)) { val metadata = externalCatalog.getTable(db, table) if (metadata.tableType == CatalogTableType.VIEW) { val viewText = metadata.viewText.getOrElse(sys.error("Invalid view without text.")) @@ -687,21 +684,21 @@ class SessionCatalog( SubqueryAlias(table, UnresolvedCatalogRelation(metadata)) } } else { - SubqueryAlias(table, tempTables(table)) + SubqueryAlias(table, tempViews(table)) } } } /** - * Return whether a table with the specified name is a temporary table. + * Return whether a table with the specified name is a temporary view. * - * Note: The temporary table cache is checked only when database is not + * Note: The temporary view cache is checked only when database is not * explicitly specified. */ def isTemporaryTable(name: TableIdentifier): Boolean = synchronized { val table = formatTableName(name.table) if (name.database.isEmpty) { - tempTables.contains(table) + tempViews.contains(table) } else if (formatDatabaseName(name.database.get) == globalTempViewManager.database) { globalTempViewManager.get(table).isDefined } else { @@ -710,7 +707,7 @@ class SessionCatalog( } /** - * List all tables in the specified database, including local temporary tables. + * List all tables in the specified database, including local temporary views. * * Note that, if the specified database is global temporary view database, we will list global * temporary views. @@ -718,7 +715,7 @@ class SessionCatalog( def listTables(db: String): Seq[TableIdentifier] = listTables(db, "*") /** - * List all matching tables in the specified database, including local temporary tables. + * List all matching tables in the specified database, including local temporary views. * * Note that, if the specified database is global temporary view database, we will list global * temporary views. @@ -736,7 +733,7 @@ class SessionCatalog( } } val localTempViews = synchronized { - StringUtils.filterPattern(tempTables.keys.toSeq, pattern).map { name => + StringUtils.filterPattern(tempViews.keys.toSeq, pattern).map { name => TableIdentifier(name) } } @@ -750,11 +747,11 @@ class SessionCatalog( val dbName = formatDatabaseName(name.database.getOrElse(currentDb)) val tableName = formatTableName(name.table) - // Go through temporary tables and invalidate them. + // Go through temporary views and invalidate them. // If the database is defined, this may be a global temporary view. - // If the database is not defined, there is a good chance this is a temp table. + // If the database is not defined, there is a good chance this is a temp view. if (name.database.isEmpty) { - tempTables.get(tableName).foreach(_.refresh()) + tempViews.get(tableName).foreach(_.refresh()) } else if (dbName == globalTempViewManager.database) { globalTempViewManager.get(tableName).foreach(_.refresh()) } @@ -765,11 +762,11 @@ class SessionCatalog( } /** - * Drop all existing temporary tables. + * Drop all existing temporary views. * For testing only. */ def clearTempTables(): Unit = synchronized { - tempTables.clear() + tempViews.clear() } // ---------------------------------------------------------------------------- @@ -1337,7 +1334,7 @@ class SessionCatalog( */ private[sql] def copyStateTo(target: SessionCatalog): Unit = synchronized { target.currentDb = currentDb - // copy over temporary tables - tempTables.foreach(kv => target.tempTables.put(kv._1, kv._2)) + // copy over temporary views + tempViews.foreach(kv => target.tempViews.put(kv._1, kv._2)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 1965144e81197..1dbae4d37d8f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -22,8 +22,6 @@ import java.util.Date import scala.collection.mutable -import com.google.common.base.Objects - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation @@ -307,7 +305,7 @@ case class CatalogTable( identifier.database.foreach(map.put("Database", _)) map.put("Table", identifier.table) - if (owner.nonEmpty) map.put("Owner", owner) + if (owner != null && owner.nonEmpty) map.put("Owner", owner) map.put("Created Time", new Date(createTime).toString) map.put("Last Access", new Date(lastAccessTime).toString) map.put("Created By", "Spark " + createVersion) @@ -405,6 +403,11 @@ object CatalogTypes { * Specifications of a table partition. Mapping column name to column value. */ type TablePartitionSpec = Map[String, String] + + /** + * Initialize an empty spec. + */ + lazy val emptyTablePartitionSpec: TablePartitionSpec = Map.empty[String, String] } /** @@ -435,15 +438,6 @@ case class HiveTableRelation( def isPartitioned: Boolean = partitionCols.nonEmpty - override def equals(relation: Any): Boolean = relation match { - case other: HiveTableRelation => tableMeta == other.tableMeta && output == other.output - case _ => false - } - - override def hashCode(): Int = { - Objects.hashCode(tableMeta.identifier, output) - } - override lazy val canonicalized: HiveTableRelation = copy( tableMeta = tableMeta.copy( storage = CatalogStorageFormat.empty, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala index cd97304302e48..65bb9a8c642b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CallMethodViaReflection.scala @@ -76,7 +76,7 @@ case class CallMethodViaReflection(children: Seq[Expression]) } } - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def nullable: Boolean = true override val dataType: DataType = StringType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index c058425b4bc36..0e75ac88dc2b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -79,7 +79,7 @@ abstract class Expression extends TreeNode[Expression] { * An example would be `SparkPartitionID` that relies on the partition id returned by TaskContext. * By default leaf expressions are deterministic as Nil.forall(_.deterministic) returns true. */ - def deterministic: Boolean = children.forall(_.deterministic) + lazy val deterministic: Boolean = children.forall(_.deterministic) def nullable: Boolean @@ -265,7 +265,7 @@ trait NonSQLExpression extends Expression { * An expression that is nondeterministic. */ trait Nondeterministic extends Expression { - final override def deterministic: Boolean = false + final override lazy val deterministic: Boolean = false final override def foldable: Boolean = false @transient diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala index 305ac90e245b8..7e8e7b8cd5f18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSet.scala @@ -30,8 +30,9 @@ object ExpressionSet { } /** - * A [[Set]] where membership is determined based on a canonical representation of an [[Expression]] - * (i.e. one that attempts to ignore cosmetic differences). See [[Canonicalize]] for more details. + * A [[Set]] where membership is determined based on determinacy and a canonical representation of + * an [[Expression]] (i.e. one that attempts to ignore cosmetic differences). + * See [[Canonicalize]] for more details. * * Internally this set uses the canonical representation, but keeps also track of the original * expressions to ease debugging. Since different expressions can share the same canonical @@ -46,6 +47,10 @@ object ExpressionSet { * set.contains(1 + a) => true * set.contains(a + 2) => false * }}} + * + * For non-deterministic expressions, they are always considered as not contained in the [[Set]]. + * On adding a non-deterministic expression, simply append it to the original expressions. + * This is consistent with how we define `semanticEquals` between two expressions. */ class ExpressionSet protected( protected val baseSet: mutable.Set[Expression] = new mutable.HashSet, @@ -53,7 +58,9 @@ class ExpressionSet protected( extends Set[Expression] { protected def add(e: Expression): Unit = { - if (!baseSet.contains(e.canonicalized)) { + if (!e.deterministic) { + originals += e + } else if (!baseSet.contains(e.canonicalized) ) { baseSet.add(e.canonicalized) originals += e } @@ -74,9 +81,13 @@ class ExpressionSet protected( } override def -(elem: Expression): ExpressionSet = { - val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) - val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) - new ExpressionSet(newBaseSet, newOriginals) + if (elem.deterministic) { + val newBaseSet = baseSet.clone().filterNot(_ == elem.canonicalized) + val newOriginals = originals.clone().filterNot(_.canonicalized == elem.canonicalized) + new ExpressionSet(newBaseSet, newOriginals) + } else { + new ExpressionSet(baseSet.clone(), originals.clone()) + } } override def iterator: Iterator[Expression] = originals.iterator diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 527f1670c25e1..179853032035e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -49,7 +49,7 @@ case class ScalaUDF( udfDeterministic: Boolean = true) extends Expression with ImplicitCastInputTypes with NonSQLExpression with UserDefinedExpression { - override def deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = udfDeterministic && children.forall(_.deterministic) override def toString: String = s"${udfName.map(name => s"UDF:$name").getOrElse("UDF")}(${children.mkString(", ")})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala index 096d1b35a8620..d4421ca20a9bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervals.scala @@ -22,9 +22,10 @@ import java.util import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, ExpectsInputTypes, Expression} +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, GenericInternalRow} import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, HyperLogLogPlusPlusHelper} import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform /** * This function counts the approximate number of distinct values (ndv) in @@ -46,16 +47,7 @@ case class ApproxCountDistinctForIntervals( relativeSD: Double = 0.05, mutableAggBufferOffset: Int = 0, inputAggBufferOffset: Int = 0) - extends ImperativeAggregate with ExpectsInputTypes { - - def this(child: Expression, endpointsExpression: Expression) = { - this( - child = child, - endpointsExpression = endpointsExpression, - relativeSD = 0.05, - mutableAggBufferOffset = 0, - inputAggBufferOffset = 0) - } + extends TypedImperativeAggregate[Array[Long]] with ExpectsInputTypes { def this(child: Expression, endpointsExpression: Expression, relativeSD: Expression) = { this( @@ -114,29 +106,11 @@ case class ApproxCountDistinctForIntervals( private lazy val totalNumWords = numWordsPerHllpp * hllppArray.length /** Allocate enough words to store all registers. */ - override lazy val aggBufferAttributes: Seq[AttributeReference] = { - Seq.tabulate(totalNumWords) { i => - AttributeReference(s"MS[$i]", LongType)() - } + override def createAggregationBuffer(): Array[Long] = { + Array.fill(totalNumWords)(0L) } - override def aggBufferSchema: StructType = StructType.fromAttributes(aggBufferAttributes) - - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) - - /** Fill all words with zeros. */ - override def initialize(buffer: InternalRow): Unit = { - var word = 0 - while (word < totalNumWords) { - buffer.setLong(mutableAggBufferOffset + word, 0) - word += 1 - } - } - - override def update(buffer: InternalRow, input: InternalRow): Unit = { + override def update(buffer: Array[Long], input: InternalRow): Array[Long] = { val value = child.eval(input) // Ignore empty rows if (value != null) { @@ -153,13 +127,14 @@ case class ApproxCountDistinctForIntervals( // endpoints are sorted into ascending order already if (endpoints.head > doubleValue || endpoints.last < doubleValue) { // ignore if the value is out of the whole range - return + return buffer } val hllppIndex = findHllppIndex(doubleValue) - val offset = mutableAggBufferOffset + hllppIndex * numWordsPerHllpp - hllppArray(hllppIndex).update(buffer, offset, value, child.dataType) + val offset = hllppIndex * numWordsPerHllpp + hllppArray(hllppIndex).update(LongArrayInternalRow(buffer), offset, value, child.dataType) } + buffer } // Find which interval (HyperLogLogPlusPlusHelper) should receive the given value. @@ -196,17 +171,18 @@ case class ApproxCountDistinctForIntervals( } } - override def merge(buffer1: InternalRow, buffer2: InternalRow): Unit = { + override def merge(buffer1: Array[Long], buffer2: Array[Long]): Array[Long] = { for (i <- hllppArray.indices) { hllppArray(i).merge( - buffer1 = buffer1, - buffer2 = buffer2, - offset1 = mutableAggBufferOffset + i * numWordsPerHllpp, - offset2 = inputAggBufferOffset + i * numWordsPerHllpp) + buffer1 = LongArrayInternalRow(buffer1), + buffer2 = LongArrayInternalRow(buffer2), + offset1 = i * numWordsPerHllpp, + offset2 = i * numWordsPerHllpp) } + buffer1 } - override def eval(buffer: InternalRow): Any = { + override def eval(buffer: Array[Long]): Any = { val ndvArray = hllppResults(buffer) // If the endpoints contains multiple elements with the same value, // we set ndv=1 for intervals between these elements. @@ -218,19 +194,23 @@ case class ApproxCountDistinctForIntervals( new GenericArrayData(ndvArray) } - def hllppResults(buffer: InternalRow): Array[Long] = { + def hllppResults(buffer: Array[Long]): Array[Long] = { val ndvArray = new Array[Long](hllppArray.length) for (i <- ndvArray.indices) { - ndvArray(i) = hllppArray(i).query(buffer, mutableAggBufferOffset + i * numWordsPerHllpp) + ndvArray(i) = hllppArray(i).query(LongArrayInternalRow(buffer), i * numWordsPerHllpp) } ndvArray } - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int) + : ApproxCountDistinctForIntervals = { copy(mutableAggBufferOffset = newMutableAggBufferOffset) + } - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int) + : ApproxCountDistinctForIntervals = { copy(inputAggBufferOffset = newInputAggBufferOffset) + } override def children: Seq[Expression] = Seq(child, endpointsExpression) @@ -239,4 +219,31 @@ case class ApproxCountDistinctForIntervals( override def dataType: DataType = ArrayType(LongType) override def prettyName: String = "approx_count_distinct_for_intervals" + + override def serialize(obj: Array[Long]): Array[Byte] = { + val byteArray = new Array[Byte](obj.length * 8) + var i = 0 + while (i < obj.length) { + Platform.putLong(byteArray, Platform.BYTE_ARRAY_OFFSET + i * 8, obj(i)) + i += 1 + } + byteArray + } + + override def deserialize(bytes: Array[Byte]): Array[Long] = { + assert(bytes.length % 8 == 0) + val length = bytes.length / 8 + val longArray = new Array[Long](length) + var i = 0 + while (i < length) { + longArray(i) = Platform.getLong(bytes, Platform.BYTE_ARRAY_OFFSET + i * 8) + i += 1 + } + longArray + } + + private case class LongArrayInternalRow(array: Array[Long]) extends GenericInternalRow { + override def getLong(offset: Int): Long = array(offset) + override def setLong(offset: Int, value: Long): Unit = { array(offset) = value } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index c423e17169e85..708bdbfc36058 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -80,7 +80,8 @@ case class Average(child: Expression) extends DeclarativeAggregate with Implicit case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(sum, dt) / Cast(count, dt), resultType) + Cast(Cast(sum, dt) / Cast(count, DecimalType.bounded(DecimalType.MAX_PRECISION, 0)), + resultType) case _ => Cast(sum, resultType) / Cast(count, resultType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index bfc58c22886cc..4e671e1f3e6eb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -44,7 +44,7 @@ case class First(child: Expression, ignoreNullsExpr: Expression) override def nullable: Boolean = true // First is not a deterministic function. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false // Return data type. override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index 96a6ec08a160a..0ccabb9d98914 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -44,7 +44,7 @@ case class Last(child: Expression, ignoreNullsExpr: Expression) override def nullable: Boolean = true // Last is not a deterministic function. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false // Return data type. override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala index 405c2065680f5..be972f006352e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/collect.scala @@ -44,7 +44,7 @@ abstract class Collect[T <: Growable[Any] with Iterable[Any]] extends TypedImper // Both `CollectList` and `CollectSet` are non-deterministic since their results depend on the // actual order of input rows. - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def update(buffer: T, input: InternalRow): T = { val value = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f3b45799c5688..2cb66599076a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -373,20 +373,6 @@ class CodegenContext { */ private val placeHolderToComments = new mutable.HashMap[String, String] - /** - * It will count the lines of every Java function generated by whole-stage codegen, - * if there is a function of length greater than spark.sql.codegen.maxLinesPerFunction, - * it will return true. - */ - def isTooLongGeneratedFunction: Boolean = { - classFunctions.values.exists { _.values.exists { - code => - val codeWithoutComments = CodeFormatter.stripExtraNewLinesAndComments(code) - codeWithoutComments.count(_ == '\n') > SQLConf.get.maxLinesPerFunction - } - } - } - /** * Returns a term name that is unique within this instance of a `CodegenContext`. */ @@ -786,16 +772,19 @@ class CodegenContext { foldFunctions: Seq[String] => String = _.mkString("", ";\n", ";")): String = { val blocks = new ArrayBuffer[String]() val blockBuilder = new StringBuilder() + var length = 0 for (code <- expressions) { // We can't know how many bytecode will be generated, so use the length of source code // as metric. A method should not go beyond 8K, otherwise it will not be JITted, should // also not be too small, or it will have many function calls (for wide table), see the // results in BenchmarkWideTable. - if (blockBuilder.length > 1024) { + if (length > 1024) { blocks += blockBuilder.toString() blockBuilder.clear() + length = 0 } blockBuilder.append(code) + length += CodeFormatter.stripExtraNewLinesAndComments(code).length } blocks += blockBuilder.toString() @@ -1020,10 +1009,16 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin } object CodeGenerator extends Logging { + + // This is the value of HugeMethodLimit in the OpenJDK JVM settings + val DEFAULT_JVM_HUGE_METHOD_LIMIT = 8000 + /** * Compile the Java source code into a Java class, using Janino. + * + * @return a pair of a generated class and the max bytecode size of generated functions. */ - def compile(code: CodeAndComment): GeneratedClass = try { + def compile(code: CodeAndComment): (GeneratedClass, Int) = try { cache.get(code) } catch { // Cache.get() may wrap the original exception. See the following URL @@ -1036,7 +1031,7 @@ object CodeGenerator extends Logging { /** * Compile the Java source code into a Java class, using Janino. */ - private[this] def doCompile(code: CodeAndComment): GeneratedClass = { + private[this] def doCompile(code: CodeAndComment): (GeneratedClass, Int) = { val evaluator = new ClassBodyEvaluator() // A special classloader used to wrap the actual parent classloader of @@ -1075,9 +1070,9 @@ object CodeGenerator extends Logging { s"\n${CodeFormatter.format(code)}" }) - try { + val maxCodeSize = try { evaluator.cook("generated.java", code.body) - recordCompilationStats(evaluator) + updateAndGetCompilationStats(evaluator) } catch { case e: JaninoRuntimeException => val msg = s"failed to compile: $e" @@ -1092,13 +1087,15 @@ object CodeGenerator extends Logging { logInfo(s"\n${CodeFormatter.format(code, maxLines)}") throw new CompileException(msg, e.getLocation) } - evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] + + (evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass], maxCodeSize) } /** - * Records the generated class and method bytecode sizes by inspecting janino private fields. + * Returns the max bytecode size of the generated functions by inspecting janino private fields. + * Also, this method updates the metrics information. */ - private def recordCompilationStats(evaluator: ClassBodyEvaluator): Unit = { + private def updateAndGetCompilationStats(evaluator: ClassBodyEvaluator): Int = { // First retrieve the generated classes. val classes = { val resultField = classOf[SimpleCompiler].getDeclaredField("result") @@ -1113,23 +1110,26 @@ object CodeGenerator extends Logging { val codeAttr = Utils.classForName("org.codehaus.janino.util.ClassFile$CodeAttribute") val codeAttrField = codeAttr.getDeclaredField("code") codeAttrField.setAccessible(true) - classes.foreach { case (_, classBytes) => + val codeSizes = classes.flatMap { case (_, classBytes) => CodegenMetrics.METRIC_GENERATED_CLASS_BYTECODE_SIZE.update(classBytes.length) try { val cf = new ClassFile(new ByteArrayInputStream(classBytes)) - cf.methodInfos.asScala.foreach { method => - method.getAttributes().foreach { a => - if (a.getClass.getName == codeAttr.getName) { - CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update( - codeAttrField.get(a).asInstanceOf[Array[Byte]].length) - } + val stats = cf.methodInfos.asScala.flatMap { method => + method.getAttributes().filter(_.getClass.getName == codeAttr.getName).map { a => + val byteCodeSize = codeAttrField.get(a).asInstanceOf[Array[Byte]].length + CodegenMetrics.METRIC_GENERATED_METHOD_BYTECODE_SIZE.update(byteCodeSize) + byteCodeSize } } + Some(stats) } catch { case NonFatal(e) => logWarning("Error calculating stats of compiled class.", e) + None } - } + }.flatten + + codeSizes.max } /** @@ -1144,8 +1144,8 @@ object CodeGenerator extends Logging { private val cache = CacheBuilder.newBuilder() .maximumSize(100) .build( - new CacheLoader[CodeAndComment, GeneratedClass]() { - override def load(code: CodeAndComment): GeneratedClass = { + new CacheLoader[CodeAndComment, (GeneratedClass, Int)]() { + override def load(code: CodeAndComment): (GeneratedClass, Int) = { val startTime = System.nanoTime() val result = doCompile(code) val endTime = System.nanoTime() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 3768dcde00a4e..b5429fade53cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -142,7 +142,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableP new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 4e47895985209..1639d1b9dda1f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -185,7 +185,8 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated Ordering by ${ordering.mkString(",")}:\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[BaseOrdering] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[BaseOrdering] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index e35b9dda6c017..e0fabad6d089a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -78,6 +78,7 @@ object GeneratePredicate extends CodeGenerator[Expression, Predicate] { new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated predicate '$predicate':\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(ctx.references.toArray).asInstanceOf[Predicate] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[Predicate] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index 192701a829686..1e4ac3f2afd52 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -189,8 +189,8 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) + val (clazz, _) = CodeGenerator.compile(code) val resultRow = new SpecificInternalRow(expressions.map(_.dataType)) - c.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] + clazz.generate(ctx.references.toArray :+ resultRow).asInstanceOf[Projection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index f2a66efc98e71..4bd50aee05514 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -409,7 +409,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(ctx.references.toArray).asInstanceOf[UnsafeProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala index 4aa5ec82471ec..6bc72a0d75c6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala @@ -196,7 +196,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType), U val code = CodeFormatter.stripOverlappingComments(new CodeAndComment(codeBody, Map.empty)) logDebug(s"SpecificUnsafeRowJoiner($schema1, $schema2):\n${CodeFormatter.format(code)}") - val c = CodeGenerator.compile(code) - c.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(Array.empty).asInstanceOf[UnsafeRowJoiner] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index ef293ff3f18ea..b86e271fe2958 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -119,7 +119,7 @@ case class CurrentDatabase() extends LeafExpression with Unevaluable { // scalastyle:on line.size.limit case class Uuid() extends LeafExpression { - override def deterministic: Boolean = false + override lazy val deterministic: Boolean = false override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a602894efbcae..d829e01441dcc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -124,8 +124,6 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) SimplifyCreateMapOps, CombineConcats) ++ extendedOperatorOptimizationRules: _*) :: - Batch("Check Cartesian Products", Once, - CheckCartesianProducts) :: Batch("Join Reorder", Once, CostBasedJoinReorder) :: Batch("Decimal Optimizations", fixedPoint, @@ -136,6 +134,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog) Batch("LocalRelation", fixedPoint, ConvertToLocalRelation, PropagateEmptyRelation) :: + // The following batch should be executed after batch "Join Reorder" and "LocalRelation". + Batch("Check Cartesian Products", Once, + CheckCartesianProducts) :: Batch("OptimizeCodegen", Once, OptimizeCodegen) :: Batch("RewriteSubquery", Once, @@ -304,13 +305,20 @@ object LimitPushDown extends Rule[LogicalPlan] { } } - private def maybePushLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { - (limitExp, plan.maxRows) match { - case (IntegerLiteral(maxRow), Some(childMaxRows)) if maxRow < childMaxRows => + private def maybePushLocalLimit(limitExp: Expression, plan: LogicalPlan): LogicalPlan = { + (limitExp, plan.maxRowsPerPartition) match { + case (IntegerLiteral(newLimit), Some(childMaxRows)) if newLimit < childMaxRows => + // If the child has a cap on max rows per partition and the cap is larger than + // the new limit, put a new LocalLimit there. LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) + case (_, None) => + // If the child has no cap, put the new LocalLimit. LocalLimit(limitExp, stripGlobalLimitIfPresent(plan)) - case _ => plan + + case _ => + // Otherwise, don't put a new LocalLimit. + plan } } @@ -322,7 +330,7 @@ object LimitPushDown extends Rule[LogicalPlan] { // pushdown Limit through it. Once we add UNION DISTINCT, however, we will not be able to // pushdown Limit. case LocalLimit(exp, Union(children)) => - LocalLimit(exp, Union(children.map(maybePushLimit(exp, _)))) + LocalLimit(exp, Union(children.map(maybePushLocalLimit(exp, _)))) // Add extra limits below OUTER JOIN. For LEFT OUTER and FULL OUTER JOIN we push limits to the // left and right sides, respectively. For FULL OUTER JOIN, we can only push limits to one side // because we need to ensure that rows from the limited side still have an opportunity to match @@ -334,19 +342,19 @@ object LimitPushDown extends Rule[LogicalPlan] { // - If neither side is limited, limit the side that is estimated to be bigger. case LocalLimit(exp, join @ Join(left, right, joinType, _)) => val newJoin = joinType match { - case RightOuter => join.copy(right = maybePushLimit(exp, right)) - case LeftOuter => join.copy(left = maybePushLimit(exp, left)) + case RightOuter => join.copy(right = maybePushLocalLimit(exp, right)) + case LeftOuter => join.copy(left = maybePushLocalLimit(exp, left)) case FullOuter => (left.maxRows, right.maxRows) match { case (None, None) => if (left.stats.sizeInBytes >= right.stats.sizeInBytes) { - join.copy(left = maybePushLimit(exp, left)) + join.copy(left = maybePushLocalLimit(exp, left)) } else { - join.copy(right = maybePushLimit(exp, right)) + join.copy(right = maybePushLocalLimit(exp, right)) } case (Some(_), Some(_)) => join - case (Some(_), None) => join.copy(left = maybePushLimit(exp, left)) - case (None, Some(_)) => join.copy(right = maybePushLimit(exp, right)) + case (Some(_), None) => join.copy(left = maybePushLocalLimit(exp, left)) + case (None, Some(_)) => join.copy(right = maybePushLocalLimit(exp, right)) } case _ => join @@ -444,6 +452,8 @@ object ColumnPruning extends Rule[LogicalPlan] { // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => a.copy(child = prunedChild(child, a.references)) + case f @ FlatMapGroupsInPandas(_, _, _, child) if (child.outputSet -- f.references).nonEmpty => + f.copy(child = prunedChild(child, f.references)) case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty => e.copy(child = prunedChild(child, e.references)) case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty => @@ -1089,6 +1099,9 @@ object CombineLimits extends Rule[LogicalPlan] { * SELECT * from R, S where R.r = S.s, * the join between R and S is not a cartesian product and therefore should be allowed. * The predicate R.r = S.s is not recognized as a join condition until the ReorderJoin rule. + * + * This rule must be run AFTER the batch "LocalRelation", since a join with empty relation should + * not be a cartesian product. */ object CheckCartesianProducts extends Rule[LogicalPlan] with PredicateHelper { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 273bc6ce27c5d..523b53b39d6b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -169,13 +169,16 @@ object ReorderAssociativeOperator extends Rule[LogicalPlan] { /** * Optimize IN predicates: - * 1. Removes literal repetitions. - * 2. Replaces [[In (value, seq[Literal])]] with optimized version + * 1. Converts the predicate to false when the list is empty and + * the value is not nullable. + * 2. Removes literal repetitions. + * 3. Replaces [[In (value, seq[Literal])]] with optimized version * [[InSet (value, HashSet[Literal])]] which is much faster. */ object OptimizeIn extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsDown { + case In(v, list) if list.isEmpty && !v.nullable => FalseLiteral case expr @ In(v, list) if expr.inSetConvertible => val newList = ExpressionSet(list).toSeq if (newList.size > SQLConf.get.optimizerInSetConversionThreshold) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 85b492e83446e..ce367145bc637 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -699,20 +699,30 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query) } - ctx.sampleType.getType match { - case SqlBaseParser.ROWS => + if (ctx.sampleMethod() == null) { + throw new ParseException("TABLESAMPLE does not accept empty inputs.", ctx) + } + + ctx.sampleMethod() match { + case ctx: SampleByRowsContext => Limit(expression(ctx.expression), query) - case SqlBaseParser.PERCENTLIT => + case ctx: SampleByPercentileContext => val fraction = ctx.percentage.getText.toDouble val sign = if (ctx.negativeSign == null) 1 else -1 sample(sign * fraction / 100.0d) - case SqlBaseParser.BYTELENGTH_LITERAL => - throw new ParseException( - "TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + case ctx: SampleByBytesContext => + val bytesStr = ctx.bytes.getText + if (bytesStr.matches("[0-9]+[bBkKmMgG]")) { + throw new ParseException("TABLESAMPLE(byteLengthLiteral) is not supported", ctx) + } else { + throw new ParseException( + bytesStr + " is not a valid byte length literal, " + + "expected syntax: DIGIT+ ('B' | 'K' | 'M' | 'G')", ctx) + } - case SqlBaseParser.BUCKET if ctx.ON != null => + case ctx: SampleByBucketContext if ctx.ON() != null => if (ctx.identifier != null) { throw new ParseException( "TABLESAMPLE(BUCKET x OUT OF y ON colname) is not supported", ctx) @@ -721,7 +731,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging "TABLESAMPLE(BUCKET x OUT OF y ON function) is not supported", ctx) } - case SqlBaseParser.BUCKET => + case ctx: SampleByBucketContext => sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8d034c21a4960..cc391aae55787 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -205,14 +205,17 @@ object PhysicalAggregation { case logical.Aggregate(groupingExpressions, resultExpressions, child) => // A single aggregate expression might appear multiple times in resultExpressions. // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. + // build a set of semantically distinct aggregate expressions and re-write expressions so + // that they reference the single copy of the aggregate function which actually gets computed. + // Non-deterministic aggregate expressions are not deduplicated. + val equivalentAggregateExpressions = new EquivalentExpressions val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { - case agg: AggregateExpression => agg + // addExpr() always returns false for non-deterministic expressions and do not add them. + case agg: AggregateExpression + if (!equivalentAggregateExpressions.addExpr(agg)) => agg } - }.distinct + } val namedGroupingExpressions = groupingExpressions.map { case ne: NamedExpression => ne -> ne @@ -236,7 +239,8 @@ object PhysicalAggregation { case ae: AggregateExpression => // The final aggregation buffer's attributes will be `finalAggregationAttributes`, // so replace each aggregate expression by its corresponding attribute in the set: - ae.resultAttribute + equivalentAggregateExpressions.getEquivalentExprs(ae).headOption + .getOrElse(ae).asInstanceOf[AggregateExpression].resultAttribute case expression => // Since we're using `namedGroupingAttributes` to extract the grouping key // columns, we need to replace grouping key expressions with their corresponding diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 7addbaaa9afa5..c7952e3ff8280 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -178,7 +178,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT }) } - override def innerChildren: Seq[QueryPlan[_]] = subqueries + override protected def innerChildren: Seq[QueryPlan[_]] = subqueries /** * Returns a plan where a best effort attempt has been made to transform `this` in a way diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 68aae720e026a..14188829db2af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -97,6 +97,11 @@ abstract class LogicalPlan */ def maxRows: Option[Long] = None + /** + * Returns the maximum number of rows this plan may compute on each partition. + */ + def maxRowsPerPartition: Option[Long] = maxRows + /** * Returns true if this expression and all its children have been resolved to a specific schema * and false if it still contains any unresolved placeholders. Implementations of LogicalPlan diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f443cd5a69de3..80243d3d356ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -191,6 +191,9 @@ object Union { } } +/** + * Logical plan for unioning two plans, without a distinct. This is UNION ALL in SQL. + */ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { override def maxRows: Option[Long] = { if (children.exists(_.maxRows.isEmpty)) { @@ -200,6 +203,17 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { } } + /** + * Note the definition has assumption about how union is implemented physically. + */ + override def maxRowsPerPartition: Option[Long] = { + if (children.exists(_.maxRowsPerPartition.isEmpty)) { + None + } else { + Some(children.flatMap(_.maxRowsPerPartition).sum) + } + } + // updating nullability to make all the children consistent override def output: Seq[Attribute] = children.map(_.output).transpose.map(attrs => @@ -669,6 +683,27 @@ case class Pivot( } } +/** + * A constructor for creating a logical limit, which is split into two separate logical nodes: + * a [[LocalLimit]], which is a partition local limit, followed by a [[GlobalLimit]]. + * + * This muds the water for clean logical/physical separation, and is done for better limit pushdown. + * In distributed query processing, a non-terminal global limit is actually an expensive operation + * because it requires coordination (in Spark this is done using a shuffle). + * + * In most cases when we want to push down limit, it is often better to only push some partition + * local limit. Consider the following: + * + * GlobalLimit(Union(A, B)) + * + * It is better to do + * GlobalLimit(Union(LocalLimit(A), LocalLimit(B))) + * + * than + * Union(GlobalLimit(A), GlobalLimit(B)). + * + * So we introduced LocalLimit and GlobalLimit in the logical plan node for limit pushdown. + */ object Limit { def apply(limitExpr: Expression, child: LogicalPlan): UnaryNode = { GlobalLimit(limitExpr, LocalLimit(limitExpr, child)) @@ -682,6 +717,11 @@ object Limit { } } +/** + * A global (coordinated) limit. This operator can emit at most `limitExpr` number in total. + * + * See [[Limit]] for more information. + */ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output override def maxRows: Option[Long] = { @@ -692,9 +732,16 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } } +/** + * A partition-local (non-coordinated) limit. This operator can emit at most `limitExpr` number + * of tuples on each physical partition. + * + * See [[Limit]] for more information. + */ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output - override def maxRows: Option[Long] = { + + override def maxRowsPerPartition: Option[Long] = { limitExpr match { case IntegerLiteral(limit) => Some(limit) case _ => None diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala new file mode 100644 index 0000000000000..254687ec00880 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/pythonLogicalOperators.scala @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression} + +/** + * FlatMap groups using an udf: pandas.Dataframe -> pandas.DataFrame. + * This is used by DataFrame.groupby().apply(). + */ +case class FlatMapGroupsInPandas( + groupingAttributes: Seq[Attribute], + functionExpr: Expression, + output: Seq[Attribute], + child: LogicalPlan) extends UnaryNode { + + /** + * This is needed because output attributes are considered `references` when + * passed through the constructor. + * + * Without this, catalyst will complain that output attributes are missing + * from the input. + */ + override val producedAttributes = AttributeSet(output) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 2ab46dc8330aa..9fac95aed8f12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow trait BroadcastMode { def transform(rows: Array[InternalRow]): Any + def transform(rows: Iterator[InternalRow], sizeHint: Option[Long]): Any + def canonicalized: BroadcastMode } @@ -36,5 +38,9 @@ case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows + override def transform( + rows: Iterator[InternalRow], + sizeHint: Option[Long]): Array[InternalRow] = rows.toArray + override def canonicalized: BroadcastMode = this } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index 51d78dd1233fe..e57c842ce2a36 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -49,7 +49,9 @@ case object AllTuples extends Distribution * can mean such tuples are either co-located in the same partition or they will be contiguous * within a single partition. */ -case class ClusteredDistribution(clustering: Seq[Expression]) extends Distribution { +case class ClusteredDistribution( + clustering: Seq[Expression], + numPartitions: Option[Int] = None) extends Distribution { require( clustering != Nil, "The clustering expressions of a ClusteredDistribution should not be Nil. " + @@ -221,6 +223,7 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = required match { case _: BroadcastDistribution => false + case ClusteredDistribution(_, desiredPartitions) => desiredPartitions.forall(_ == 1) case _ => true } @@ -243,8 +246,9 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true - case ClusteredDistribution(requiredClustering) => - expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, desiredPartitions) => + expressions.forall(x => requiredClustering.exists(_.semanticEquals(x))) && + desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true case _ => false } @@ -289,8 +293,9 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case OrderedDistribution(requiredOrdering) => val minSize = Seq(requiredOrdering.size, ordering.size).min requiredOrdering.take(minSize) == ordering.take(minSize) - case ClusteredDistribution(requiredClustering) => - ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) + case ClusteredDistribution(requiredClustering, desiredPartitions) => + ordering.map(_.child).forall(x => requiredClustering.exists(_.semanticEquals(x))) && + desiredPartitions.forall(_ == numPartitions) // if desiredPartitions = None, returns true case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index af543b04ba780..eb7941cf9e6af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -193,10 +193,10 @@ class QuantileSummaries( // Target rank val rank = math.ceil(quantile * count).toInt - val targetError = math.ceil(relativeError * count) + val targetError = relativeError * count // Minimum rank at current sample var minRank = 0 - var i = 1 + var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) minRank += curSample.g diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d00c672487532..4cfe53b2c115b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -30,6 +30,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver +import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -172,6 +173,13 @@ object SQLConf { .intConf .createWithDefault(4) + val ADVANCED_PARTITION_PREDICATE_PUSHDOWN = + buildConf("spark.sql.hive.advancedPartitionPredicatePushdown.enabled") + .internal() + .doc("When true, advanced partition predicate pushdown into Hive metastore is enabled.") + .booleanConf + .createWithDefault(true) + val ENABLE_FALL_BACK_TO_HDFS_FOR_STATS = buildConf("spark.sql.statistics.fallBackToHdfs") .doc("If the table statistics are not available from table metadata enable fall back to hdfs." + @@ -305,8 +313,9 @@ object SQLConf { val PARQUET_OUTPUT_COMMITTER_CLASS = buildConf("spark.sql.parquet.output.committer.class") .doc("The output committer class used by Parquet. The specified class needs to be a " + - "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + - "of org.apache.parquet.hadoop.ParquetOutputCommitter.") + "subclass of org.apache.hadoop.mapreduce.OutputCommitter. Typically, it's also a subclass " + + "of org.apache.parquet.hadoop.ParquetOutputCommitter. If it is not, then metadata summaries" + + "will never be created, irrespective of the value of parquet.enable.summary-metadata") .internal() .stringConf .createWithDefault("org.apache.parquet.hadoop.ParquetOutputCommitter") @@ -575,15 +584,15 @@ object SQLConf { "disable logging or -1 to apply no limit.") .createWithDefault(1000) - val WHOLESTAGE_MAX_LINES_PER_FUNCTION = buildConf("spark.sql.codegen.maxLinesPerFunction") + val WHOLESTAGE_HUGE_METHOD_LIMIT = buildConf("spark.sql.codegen.hugeMethodLimit") .internal() - .doc("The maximum lines of a single Java function generated by whole-stage codegen. " + - "When the generated function exceeds this threshold, " + + .doc("The maximum bytecode size of a single compiled Java function generated by whole-stage " + + "codegen. When the compiled function exceeds this threshold, " + "the whole-stage codegen is deactivated for this subtree of the current query plan. " + - "The default value 4000 is the max length of byte code JIT supported " + - "for a single function(8000) divided by 2.") + s"The default value is ${CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT} and " + + "this is a limit in the OpenJDK JVM implementation.") .intConf - .createWithDefault(4000) + .createWithDefault(CodeGenerator.DEFAULT_JVM_HUGE_METHOD_LIMIT) val FILES_MAX_PARTITION_BYTES = buildConf("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") @@ -668,7 +677,7 @@ object SQLConf { .createWithDefault(40) val ENABLE_TWOLEVEL_AGG_MAP = - buildConf("spark.sql.codegen.aggregate.map.twolevel.enable") + buildConf("spark.sql.codegen.aggregate.map.twolevel.enabled") .internal() .doc("Enable two-level aggregate hash map. When enabled, records will first be " + "inserted/looked-up at a 1st-level, small, fast map, and then fallback to a " + @@ -907,8 +916,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION = + buildConf("spark.sql.execution.rangeExchange.sampleSizePerPartition") + .internal() + .doc("Number of points to sample per partition in order to determine the range boundaries" + + " for range partitioning, typically used in global sorting (without limit).") + .intConf + .createWithDefault(100) + val ARROW_EXECUTION_ENABLE = - buildConf("spark.sql.execution.arrow.enable") + buildConf("spark.sql.execution.arrow.enabled") .internal() .doc("Make use of Apache Arrow for columnar data transfers. Currently available " + "for use with pyspark.sql.DataFrame.toPandas with the following data types: " + @@ -1050,7 +1067,7 @@ class SQLConf extends Serializable with Logging { def loggingMaxLinesForCodegen: Int = getConf(CODEGEN_LOGGING_MAX_LINES) - def maxLinesPerFunction: Int = getConf(WHOLESTAGE_MAX_LINES_PER_FUNCTION) + def hugeMethodLimit: Int = getConf(WHOLESTAGE_HUGE_METHOD_LIMIT) def tableRelationCacheSize: Int = getConf(StaticSQLConf.FILESOURCE_TABLE_RELATION_CACHE_SIZE) @@ -1082,6 +1099,9 @@ class SQLConf extends Serializable with Logging { def limitScaleUpFactor: Int = getConf(LIMIT_SCALE_UP_FACTOR) + def advancedPartitionPredicatePushdownEnabled: Boolean = + getConf(ADVANCED_PARTITION_PREDICATE_PUSHDOWN) + def fallBackToHdfsForStatsEnabled: Boolean = getConf(ENABLE_FALL_BACK_TO_HDFS_FOR_STATS) def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) @@ -1199,6 +1219,8 @@ class SQLConf extends Serializable with Logging { def supportQuotedRegexColumnName: Boolean = getConf(SUPPORT_QUOTED_REGEX_COLUMN_NAME) + def rangeExchangeSampleSizePerPartition: Int = getConf(RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION) + def arrowEnable: Boolean = getConf(ARROW_EXECUTION_ENABLE) def arrowMaxRecordsPerBatch: Int = getConf(ARROW_EXECUTION_MAX_RECORDS_PER_BATCH) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala index c6c0a605d89ff..c018fc8a332fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/StaticSQLConf.scala @@ -87,4 +87,12 @@ object StaticSQLConf { "implement Function1[SparkSessionExtension, Unit], and must have a no-args constructor.") .stringConf .createOptional + + val QUERY_EXECUTION_LISTENERS = buildStaticConf("spark.sql.queryExecutionListeners") + .doc("List of class names implementing QueryExecutionListener that will be automatically " + + "added to newly created sessions. The classes should have either a no-arg constructor, " + + "or a constructor that expects a SparkConf argument.") + .stringConf + .toSequence + .createOptional } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala index 30745c6a9d42a..d6e0df12218ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DataType.scala @@ -26,6 +26,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.annotation.InterfaceStability import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.Utils /** @@ -80,7 +81,11 @@ abstract class DataType extends AbstractDataType { * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). */ private[spark] def sameType(other: DataType): Boolean = - DataType.equalsIgnoreNullability(this, other) + if (SQLConf.get.caseSensitiveAnalysis) { + DataType.equalsIgnoreNullability(this, other) + } else { + DataType.equalsIgnoreCaseAndNullability(this, other) + } /** * Returns the same data type but set all nullability fields are true diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index 884e113537c93..5d2f8e735e3d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -408,16 +408,25 @@ class AnalysisErrorSuite extends AnalysisTest { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) // Since we manually construct the logical plan at here and Sum only accept // LongType, DoubleType, and DecimalType. We use LongType as the type of a. - val plan = - Aggregate( - Nil, - Alias(sum(AttributeReference("a", LongType)(exprId = ExprId(1))), "b")() :: Nil, - LocalRelation( - AttributeReference("a", LongType)(exprId = ExprId(2)))) + val attrA = AttributeReference("a", LongType)(exprId = ExprId(1)) + val otherA = AttributeReference("a", LongType)(exprId = ExprId(2)) + val attrC = AttributeReference("c", LongType)(exprId = ExprId(3)) + val aliases = Alias(sum(attrA), "b")() :: Alias(sum(attrC), "d")() :: Nil + val plan = Aggregate( + Nil, + aliases, + LocalRelation(otherA)) assert(plan.resolved) - assertAnalysisError(plan, "resolved attribute(s) a#1L missing from a#2L" :: Nil) + val resolved = s"${attrA.toString},${attrC.toString}" + + val errorMsg = s"Resolved attribute(s) $resolved missing from ${otherA.toString} " + + s"in operator !Aggregate [${aliases.mkString(", ")}]. " + + s"Attribute(s) with the same name appear in the operation: a. " + + "Please check if the right attribute(s) are used." + + assertAnalysisError(plan, errorMsg :: Nil) } test("error test for self-join") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala new file mode 100644 index 0000000000000..d670053ba1b5d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CatalogSuite.scala @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.types.StructType + + +class CatalogSuite extends AnalysisTest { + + test("desc table when owner is set to null") { + val table = CatalogTable( + identifier = TableIdentifier("tbl", Some("db1")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + owner = null, + schema = new StructType().add("col1", "int").add("col2", "string"), + provider = Some("parquet")) + table.toLinkedHashMap + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala new file mode 100644 index 0000000000000..8cf41a02320d2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/StreamingJoinHelperSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, AttributeSet} +import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter, LeafNode, LocalRelation} +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, TimestampType} + +class StreamingJoinHelperSuite extends AnalysisTest { + + test("extract watermark from time condition") { + val attributesToFindConstraintFor = Seq( + AttributeReference("leftTime", TimestampType)(), + AttributeReference("leftOther", IntegerType)()) + val metadataWithWatermark = new MetadataBuilder() + .putLong(EventTimeWatermark.delayKey, 1000) + .build() + val attributesWithWatermark = Seq( + AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), + AttributeReference("rightOther", IntegerType)()) + + case class DummyLeafNode() extends LeafNode { + override def output: Seq[Attribute] = + attributesToFindConstraintFor ++ attributesWithWatermark + } + + def watermarkFrom( + conditionStr: String, + rightWatermark: Option[Long] = Some(10000)): Option[Long] = { + val conditionExpr = Some(conditionStr).map { str => + val plan = + Filter( + CatalystSqlParser.parseExpression(str), + DummyLeafNode()) + val optimized = SimpleTestOptimizer.execute(SimpleAnalyzer.execute(plan)) + optimized.asInstanceOf[Filter].condition + } + StreamingJoinHelper.getStateValueWatermark( + AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), + conditionExpr, rightWatermark) + } + + // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, + // then cannot define constraint on leftTime. + assert(watermarkFrom("leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) + assert(watermarkFrom("leftTime < rightTime") === None) + assert(watermarkFrom("leftTime <= rightTime") === None) + assert(watermarkFrom("rightTime > leftTime") === None) + assert(watermarkFrom("rightTime >= leftTime") === None) + assert(watermarkFrom("rightTime < leftTime") === Some(10000)) + assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) + + // Test type conversions + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) + assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) + + // Test with timestamp type + calendar interval on either side of equation + // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. + assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) + assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) + assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) + assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) + assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") + === Some(12000)) + + // Test with casted long type + constants on either side of equation + // Note: long type and constants commute, so more combinations to test. + // -- Constants on the right + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") + === Some(11000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") + === Some(10100)) + assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") + === Some(10200)) + // -- Constants on the left + assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) + assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) + assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") + === Some(7000)) + assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) + assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) + assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") + === Some(12000)) + assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") + === Some(10100)) + // -- Constants on both sides, mixed types + assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") + === Some(13000)) + + // Test multiple conditions, should return minimum watermark + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === + Some(7000)) // first condition wins + assert(watermarkFrom( + "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === + Some(6000)) // second condition wins + + // Test invalid comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes + assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) + assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes + assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed + + // Test static comparisons + assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) + + // Test non-positive results + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 10") === Some(0)) + assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 100") === Some(-90000)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index d62e3b6dfe34f..793e04f66f0f9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -131,14 +131,17 @@ class TypeCoercionSuite extends AnalysisTest { widenFunc: (DataType, DataType) => Option[DataType], t1: DataType, t2: DataType, - expected: Option[DataType]): Unit = { + expected: Option[DataType], + isSymmetric: Boolean = true): Unit = { var found = widenFunc(t1, t2) assert(found == expected, s"Expected $expected as wider common type for $t1 and $t2, found $found") // Test both directions to make sure the widening is symmetric. - found = widenFunc(t2, t1) - assert(found == expected, - s"Expected $expected as wider common type for $t2 and $t1, found $found") + if (isSymmetric) { + found = widenFunc(t2, t1) + assert(found == expected, + s"Expected $expected as wider common type for $t2 and $t1, found $found") + } } test("implicit type cast - ByteType") { @@ -385,6 +388,47 @@ class TypeCoercionSuite extends AnalysisTest { widenTest(NullType, StructType(Seq()), Some(StructType(Seq()))) widenTest(StringType, MapType(IntegerType, StringType, true), None) widenTest(ArrayType(IntegerType), StructType(Seq()), None) + + widenTest( + StructType(Seq(StructField("a", IntegerType))), + StructType(Seq(StructField("b", IntegerType))), + None) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", DoubleType, nullable = false))), + None) + + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", IntegerType, nullable = false))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = false))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = false))), + StructType(Seq(StructField("a", IntegerType, nullable = true))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = true))), + StructType(Seq(StructField("a", IntegerType, nullable = false))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + widenTest( + StructType(Seq(StructField("a", IntegerType, nullable = true))), + StructType(Seq(StructField("a", IntegerType, nullable = true))), + Some(StructType(Seq(StructField("a", IntegerType, nullable = true))))) + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + widenTest( + StructType(Seq(StructField("a", IntegerType))), + StructType(Seq(StructField("A", IntegerType))), + None) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + checkWidenType( + TypeCoercion.findTightestCommonType, + StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType))), + StructType(Seq(StructField("A", IntegerType), StructField("b", IntegerType))), + Some(StructType(Seq(StructField("a", IntegerType), StructField("B", IntegerType)))), + isSymmetric = false) + } } test("wider common type for decimal and array") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala index 11f48a39c1e25..60d1351fda264 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -24,13 +24,14 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, MonotonicallyIncreasingID, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.Count import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{FlatMapGroupsWithState, _} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.{IntegerType, LongType, MetadataBuilder} +import org.apache.spark.unsafe.types.CalendarInterval /** A dummy command for testing unsupported operations. */ case class DummyCommand() extends Command @@ -417,9 +418,57 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testBinaryOperationInStreamingPlan( "left outer join", _.join(_, joinType = LeftOuter), - streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + streamStreamSupported = false, + expectedMsg = "outer join") + + // Left outer joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"left outer join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = LeftOuter, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Left outer joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"left outer join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftOuter, + condition = Some(attribute > rightTimeWithWatermark + 10)) + }, + OutputMode.Append()) + + // Left outer joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"left outer join with stream-stream relations and state value watermark", { + val leftRelation = streamRelation + val rightTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val rightRelation = new TestStreamingRelation(rightTimeWithWatermark) + leftRelation.join( + rightRelation, + joinType = LeftOuter, + condition = Some(attribute < rightTimeWithWatermark + 10)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Left semi joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -427,7 +476,7 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftSemi), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + expectedMsg = "left semi/anti joins") // Left anti joins: stream-* not allowed testBinaryOperationInStreamingPlan( @@ -435,14 +484,63 @@ class UnsupportedOperationsSuite extends SparkFunSuite { _.join(_, joinType = LeftAnti), streamStreamSupported = false, batchStreamSupported = false, - expectedMsg = "left outer/semi/anti joins") + expectedMsg = "left semi/anti joins") // Right outer joins: stream-* not allowed testBinaryOperationInStreamingPlan( "right outer join", _.join(_, joinType = RightOuter), + streamBatchSupported = false, streamStreamSupported = false, - streamBatchSupported = false) + expectedMsg = "outer join") + + // Right outer joins: stream-stream allowed with join on watermark attribute + // Note that the attribute need not be watermarked on both sides. + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on attribute with left watermark", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attributeWithWatermark === attribute)), + OutputMode.Append()) + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on attribute with right watermark", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attribute === attributeWithWatermark)), + OutputMode.Append()) + assertNotSupportedInStreamingPlan( + s"right outer join with stream-stream relations and join on non-watermarked attribute", + streamRelation.join(streamRelation, joinType = RightOuter, + condition = Some(attribute === attribute)), + OutputMode.Append(), + Seq("watermark in the join keys")) + + // Right outer joins: stream-stream allowed with range condition yielding state value watermark + assertSupportedInStreamingPlan( + s"right outer join with stream-stream relations and state value watermark", { + val leftTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val leftRelation = new TestStreamingRelation(leftTimeWithWatermark) + val rightRelation = streamRelation + leftRelation.join( + rightRelation, + joinType = RightOuter, + condition = Some(leftTimeWithWatermark + 10 < attribute)) + }, + OutputMode.Append()) + + // Right outer joins: stream-stream not allowed with insufficient range condition + assertNotSupportedInStreamingPlan( + s"right outer join with stream-stream relations and state value watermark", { + val leftTimeWithWatermark = + AttributeReference("b", IntegerType)().withMetadata(watermarkMetadata) + val leftRelation = new TestStreamingRelation(leftTimeWithWatermark) + val rightRelation = streamRelation + leftRelation.join( + rightRelation, + joinType = RightOuter, + condition = Some(leftTimeWithWatermark + 10 > attribute)) + }, + OutputMode.Append(), + Seq("appropriate range condition")) // Cogroup: only batch-batch is allowed testBinaryOperationInStreamingPlan( @@ -516,6 +614,14 @@ class UnsupportedOperationsSuite extends SparkFunSuite { testOutputMode(Update, shouldSupportAggregation = true, shouldSupportNonAggregation = true) testOutputMode(Complete, shouldSupportAggregation = true, shouldSupportNonAggregation = false) + // Unsupported expressions in streaming plan + assertNotSupportedInStreamingPlan( + "MonotonicallyIncreasingID", + streamRelation.select(MonotonicallyIncreasingID()), + outputMode = Append, + expectedMsgs = Seq("monotonically_increasing_id")) + + /* ======================================================================================= TESTING FUNCTIONS diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala index a1000a0e80799..12eddf557109f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionSetSuite.scala @@ -175,20 +175,14 @@ class ExpressionSetSuite extends SparkFunSuite { aUpper > bUpper || aUpper <= Rand(1L) || aUpper <= 10, aUpper <= Rand(1L) || aUpper <= 10 || aUpper > bUpper) - // Partial reorder case: we don't reorder non-deterministic expressions, - // but we can reorder sub-expressions in deterministic AND/OR expressions. - // There are two predicates: - // (aUpper > bUpper || bUpper > 100) => we can reorder sub-expressions in it. - // (aUpper === Rand(1L)) - setTest(1, + // Keep all the non-deterministic expressions even they are semantically equal. + setTest(2, Rand(1L), Rand(1L)) + + setTest(2, (aUpper > bUpper || bUpper > 100) && aUpper === Rand(1L), (bUpper > 100 || aUpper > bUpper) && aUpper === Rand(1L)) - // There are three predicates: - // (Rand(1L) > aUpper) - // (aUpper <= Rand(1L) && aUpper > bUpper) - // (aUpper > 10 && bUpper > 10) => we can reorder sub-expressions in it. - setTest(1, + setTest(2, Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (aUpper > 10 && bUpper > 10), Rand(1L) > aUpper || (aUpper <= Rand(1L) && aUpper > bUpper) || (bUpper > 10 && aUpper > 10)) @@ -219,4 +213,39 @@ class ExpressionSetSuite extends SparkFunSuite { assert((initialSet ++ setToAddWithSameExpression).size == 2) assert((initialSet ++ setToAddWithOutSameExpression).size == 3) } + + test("add single element to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + + assert((initialSet + (aUpper + 1)).size == 2) + assert((initialSet + Rand(0)).size == 3) + assert((initialSet + (aUpper + 2)).size == 3) + } + + test("remove single element to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + + assert((initialSet - (aUpper + 1)).size == 1) + assert((initialSet - Rand(0)).size == 2) + assert((initialSet - (aUpper + 2)).size == 2) + } + + test("add multiple elements to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToAddWithSameDeterministicExpression = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToAddWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil) + + assert((initialSet ++ setToAddWithSameDeterministicExpression).size == 3) + assert((initialSet ++ setToAddWithOutSameExpression).size == 4) + } + + test("remove multiple elements to set with non-deterministic expressions") { + val initialSet = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToRemoveWithSameDeterministicExpression = ExpressionSet(aUpper + 1 :: Rand(0) :: Nil) + val setToRemoveWithOutSameExpression = ExpressionSet(aUpper + 3 :: aUpper + 4 :: Nil) + + assert((initialSet -- setToRemoveWithSameDeterministicExpression).size == 1) + assert((initialSet -- setToRemoveWithOutSameExpression).size == 2) + } + } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala index d6c38c3608bf8..73f18d4feef3f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproxCountDistinctForIntervalsSuite.scala @@ -32,7 +32,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { val wrongColumnTypes = Seq(BinaryType, BooleanType, StringType, ArrayType(IntegerType), MapType(IntegerType, IntegerType), StructType(Seq(StructField("s", IntegerType)))) wrongColumnTypes.foreach { dataType => - val wrongColumn = new ApproxCountDistinctForIntervals( + val wrongColumn = ApproxCountDistinctForIntervals( AttributeReference("a", dataType)(), endpointsExpression = CreateArray(Seq(1, 10).map(Literal(_)))) assert( @@ -43,7 +43,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { }) } - var wrongEndpoints = new ApproxCountDistinctForIntervals( + var wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = Literal(0.5d)) assert( @@ -52,19 +52,19 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { case _ => false }) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Seq(AttributeReference("b", DoubleType)()))) assert(wrongEndpoints.checkInputDataTypes() == TypeCheckFailure("The endpoints provided must be constant literals")) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array(10L).map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == TypeCheckFailure("The number of endpoints must be >= 2 to construct intervals")) - wrongEndpoints = new ApproxCountDistinctForIntervals( + wrongEndpoints = ApproxCountDistinctForIntervals( AttributeReference("a", DoubleType)(), endpointsExpression = CreateArray(Array("foobar").map(Literal(_)))) assert(wrongEndpoints.checkInputDataTypes() == @@ -75,25 +75,18 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { private def createEstimator[T]( endpoints: Array[T], dt: DataType, - rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, InternalRow) = { + rsd: Double = 0.05): (ApproxCountDistinctForIntervals, InternalRow, Array[Long]) = { val input = new SpecificInternalRow(Seq(dt)) val aggFunc = ApproxCountDistinctForIntervals( BoundReference(0, dt, nullable = true), CreateArray(endpoints.map(Literal(_))), rsd) - val buffer = createBuffer(aggFunc) - (aggFunc, input, buffer) - } - - private def createBuffer(aggFunc: ApproxCountDistinctForIntervals): InternalRow = { - val buffer = new SpecificInternalRow(aggFunc.aggBufferAttributes.map(_.dataType)) - aggFunc.initialize(buffer) - buffer + (aggFunc, input, aggFunc.createAggregationBuffer()) } test("merging ApproxCountDistinctForIntervals instances") { val (aggFunc, input, buffer1a) = createEstimator(Array[Int](0, 10, 2000, 345678, 1000000), IntegerType) - val buffer1b = createBuffer(aggFunc) - val buffer2 = createBuffer(aggFunc) + val buffer1b = aggFunc.createAggregationBuffer() + val buffer2 = aggFunc.createAggregationBuffer() // Add the lower half to `buffer1a`. var i = 0 @@ -123,7 +116,7 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { } // Check if the buffers are equal. - assert(buffer2 == buffer1a, "Buffers should be equal") + assert(buffer2.sameElements(buffer1a), "Buffers should be equal") } test("test findHllppIndex(value) for values in the range") { @@ -152,6 +145,13 @@ class ApproxCountDistinctForIntervalsSuite extends SparkFunSuite { checkHllppIndex(endpoints = Array(1, 3, 5, 7, 7, 9), value = 7, expectedIntervalIndex = 2) } + test("round trip serialization") { + val (aggFunc, _, _) = createEstimator(Array(1, 2), DoubleType) + val longArray = (1L to 100L).toArray + val roundtrip = aggFunc.deserialize(aggFunc.serialize(longArray)) + assert(roundtrip.sameElements(longArray)) + } + test("basic operations: update, merge, eval...") { val endpoints = Array[Double](0, 0.33, 0.6, 0.6, 0.6, 1.0) val data: Seq[Double] = Seq(0, 0.6, 0.3, 1, 0.6, 0.5, 0.6, 0.33) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala new file mode 100644 index 0000000000000..1167d2f3f3891 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/BufferHolderSparkSubmitSuite.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen + +import org.scalatest.{BeforeAndAfterEach, Matchers} +import org.scalatest.concurrent.Timeouts + +import org.apache.spark.{SparkFunSuite, TestUtils} +import org.apache.spark.deploy.SparkSubmitSuite +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.array.ByteArrayMethods +import org.apache.spark.util.ResetSystemProperties + +// A test for growing the buffer holder to nearly 2GB. Due to the heap size limitation of the Spark +// unit tests JVM, the actually test code is running as a submit job. +class BufferHolderSparkSubmitSuite + extends SparkFunSuite + with Matchers + with BeforeAndAfterEach + with ResetSystemProperties + with Timeouts { + + test("SPARK-22222: Buffer holder should be able to allocate memory larger than 1GB") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + + val argsForSparkSubmit = Seq( + "--class", BufferHolderSparkSubmitSuite.getClass.getName.stripSuffix("$"), + "--name", "SPARK-22222", + "--master", "local-cluster[2,1,1024]", + "--driver-memory", "4g", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--conf", "spark.driver.extraJavaOptions=-ea", + unusedJar.toString) + SparkSubmitSuite.runSparkSubmit(argsForSparkSubmit, "../..") + } +} + +object BufferHolderSparkSubmitSuite { + + def main(args: Array[String]): Unit = { + + val ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + + val holder = new BufferHolder(new UnsafeRow(1000)) + + holder.reset() + holder.grow(roundToWord(ARRAY_MAX / 2)) + + holder.reset() + holder.grow(roundToWord(ARRAY_MAX / 2 + 8)) + + holder.reset() + holder.grow(roundToWord(Integer.MAX_VALUE / 2)) + + holder.reset() + holder.grow(roundToWord(Integer.MAX_VALUE)) + } + + private def roundToWord(len: Int): Int = { + ByteArrayMethods.roundNumberOfBytesToNearestWord(len) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 582b3ead5e54a..de0e7c7ee49ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -94,6 +94,21 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("combine redundant deterministic filters") { + val originalQuery = + testRelation + .where(Rand(0) > 0.1 && 'a === 1) + .where(Rand(0) > 0.1 && 'a === 1) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = + testRelation + .where(Rand(0) > 0.1 && 'a === 1 && Rand(0) > 0.1) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("SPARK-16164: Filter pushdown should keep the ordering in the logical plan") { val originalQuery = testRelation diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala index eaad1e32a8aba..d7acd139225cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala @@ -175,4 +175,20 @@ class OptimizeInSuite extends PlanTest { } } } + + test("OptimizedIn test: In empty list gets transformed to FalseLiteral " + + "when value is not nullable") { + val originalQuery = + testRelation + .where(In(Literal("a"), Nil)) + .analyze + + val optimized = Optimize.execute(originalQuery) + val correctAnswer = + testRelation + .where(Literal(false)) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 306e6f2cfbd37..d34a83c42c67e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -110,6 +110,7 @@ class PlanParserSuite extends AnalysisTest { assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) assertEqual("select from tbl", OneRowRelation().select('from.as("tbl"))) + assertEqual("select a from 1k.2m", table("1k", "2m").select('a)) } test("reverse select query") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 10bdfafd6f933..82c5307d54360 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans +import org.scalatest.Suite + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer @@ -29,7 +31,13 @@ import org.apache.spark.sql.internal.SQLConf /** * Provides helper methods for comparing plans. */ -trait PlanTest extends SparkFunSuite with PredicateHelper { +trait PlanTest extends SparkFunSuite with PlanTestBase + +/** + * Provides helper methods for comparing plans, but without the overhead of + * mandating a FunSuite. + */ +trait PlanTestBase extends PredicateHelper { self: Suite => // TODO(gatorsmile): remove this from PlanTest and all the analyzer rules protected def conf = SQLConf.get diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala index df579d5ec1ddf..650813975d75c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/QuantileSummariesSuite.scala @@ -57,8 +57,14 @@ class QuantileSummariesSuite extends SparkFunSuite { private def checkQuantile(quant: Double, data: Seq[Double], summary: QuantileSummaries): Unit = { if (data.nonEmpty) { val approx = summary.query(quant).get - // The rank of the approximation. - val rank = data.count(_ < approx) // has to be <, not <= to be exact + // Get the rank of the approximation. + val rankOfValue = data.count(_ <= approx) + val rankOfPreValue = data.count(_ < approx) + // `rankOfValue` is the last position of the quantile value. If the input repeats the value + // chosen as the quantile, e.g. in (1,2,2,2,2,2,3), the 50% quantile is 2, then it's + // improper to choose the last position as its rank. Instead, we get the rank by averaging + // `rankOfValue` and `rankOfPreValue`. + val rank = math.ceil((rankOfValue + rankOfPreValue) / 2.0) val lower = math.floor((quant - summary.relativeError) * data.size) val upper = math.ceil((quant + summary.relativeError) * data.size) val msg = diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java b/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java new file mode 100644 index 0000000000000..f1785853a94ae --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/columnar/ColumnDictionary.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar; + +import org.apache.spark.sql.execution.vectorized.Dictionary; + +public final class ColumnDictionary implements Dictionary { + private int[] intDictionary; + private long[] longDictionary; + + public ColumnDictionary(int[] dictionary) { + this.intDictionary = dictionary; + } + + public ColumnDictionary(long[] dictionary) { + this.longDictionary = dictionary; + } + + @Override + public int decodeToInt(int id) { + return intDictionary[id]; + } + + @Override + public long decodeToLong(int id) { + return longDictionary[id]; + } + + @Override + public float decodeToFloat(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support float"); + } + + @Override + public double decodeToDouble(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support double"); + } + + @Override + public byte[] decodeToBinary(int id) { + throw new UnsupportedOperationException("Dictionary encoding does not support String"); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index e782756a3e781..bc546c7c425b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -462,6 +462,11 @@ public int numValidRows() { return numRows - numRowsFiltered; } + /** + * Returns the schema that makes up this batch. + */ + public StructType schema() { return schema; } + /** * Returns the max capacity (in number of rows) for this batch. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index e1d36858d4eee..a7522ebf5821a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -85,6 +85,7 @@ public long nullsNativeAddress() { @Override public void close() { + super.close(); Platform.freeMemory(nulls); Platform.freeMemory(data); Platform.freeMemory(lengthData); @@ -227,6 +228,12 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { null, data + 2 * rowId, count * 2); } + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 2, count * 2); + } + @Override public short getShort(int rowId) { if (dictionary == null) { @@ -267,6 +274,12 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { null, data + 4 * rowId, count * 4); } + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 4, count * 4); + } + @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { @@ -333,6 +346,12 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { null, data + 8 * rowId, count * 8); } + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, + null, data + rowId * 8, count * 8); + } + @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { if (!bigEndianPlatform) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 96a452978cb35..166a39e0fabd9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -90,6 +90,16 @@ public long nullsNativeAddress() { @Override public void close() { + super.close(); + nulls = null; + byteData = null; + shortData = null; + intData = null; + longData = null; + floatData = null; + doubleData = null; + arrayLengths = null; + arrayOffsets = null; } // @@ -223,6 +233,12 @@ public void putShorts(int rowId, int count, short[] src, int srcIndex) { System.arraycopy(src, srcIndex, shortData, rowId, count); } + @Override + public void putShorts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, shortData, + Platform.SHORT_ARRAY_OFFSET + rowId * 2, count * 2); + } + @Override public short getShort(int rowId) { if (dictionary == null) { @@ -262,6 +278,12 @@ public void putInts(int rowId, int count, int[] src, int srcIndex) { System.arraycopy(src, srcIndex, intData, rowId, count); } + @Override + public void putInts(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, intData, + Platform.INT_ARRAY_OFFSET + rowId * 4, count * 4); + } + @Override public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; @@ -322,6 +344,12 @@ public void putLongs(int rowId, int count, long[] src, int srcIndex) { System.arraycopy(src, srcIndex, longData, rowId, count); } + @Override + public void putLongs(int rowId, int count, byte[] src, int srcIndex) { + Platform.copyMemory(src, Platform.BYTE_ARRAY_OFFSET + srcIndex, longData, + Platform.LONG_ARRAY_OFFSET + rowId * 8, count * 8); + } + @Override public void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 0bddc351e1bed..d3a14b9d8bd74 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; /** @@ -59,6 +60,24 @@ public void reset() { } } + @Override + public void close() { + if (childColumns != null) { + for (int i = 0; i < childColumns.length; i++) { + childColumns[i].close(); + childColumns[i] = null; + } + childColumns = null; + } + if (dictionaryIds != null) { + dictionaryIds.close(); + dictionaryIds = null; + } + dictionary = null; + resultStruct = null; + resultArray = null; + } + public void reserve(int requiredCapacity) { if (requiredCapacity > capacity) { int newCapacity = (int) Math.min(MAX_CAPACITY, requiredCapacity * 2L); @@ -95,138 +114,156 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { protected abstract void reserveInternal(int capacity); /** - * Sets the value at rowId to null/not null. + * Sets null/not null to the value at rowId. */ public abstract void putNotNull(int rowId); public abstract void putNull(int rowId); /** - * Sets the values from [rowId, rowId + count) to null/not null. + * Sets null/not null to the values at [rowId, rowId + count). */ public abstract void putNulls(int rowId, int count); public abstract void putNotNulls(int rowId, int count); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putBoolean(int rowId, boolean value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putBooleans(int rowId, int count, boolean value); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putByte(int rowId, byte value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putBytes(int rowId, int count, byte value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putBytes(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putShort(int rowId, short value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putShorts(int rowId, int count, short value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putShorts(int rowId, int count, short[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets values from [src[srcIndex], src[srcIndex + count * 2]) to [rowId, rowId + count) + * The data in src must be 2-byte platform native endian shorts. + */ + public abstract void putShorts(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets `value` to the value at rowId. */ public abstract void putInt(int rowId, int value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putInts(int rowId, int count, int value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putInts(int rowId, int count, int[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) + * The data in src must be 4-byte platform native endian ints. + */ + public abstract void putInts(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) * The data in src must be 4-byte little endian ints. */ public abstract void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putLong(int rowId, long value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putLongs(int rowId, int count, long value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putLongs(int rowId, int count, long[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) + * Sets values from [src[srcIndex], src[srcIndex + count * 8]) to [rowId, rowId + count) + * The data in src must be 8-byte platform native endian longs. + */ + public abstract void putLongs(int rowId, int count, byte[] src, int srcIndex); + + /** + * Sets values from [src + srcIndex, src + srcIndex + count * 8) to [rowId, rowId + count) * The data in src must be 8-byte little endian longs. */ public abstract void putLongsLittleEndian(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putFloat(int rowId, float value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putFloats(int rowId, int count, float value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putFloats(int rowId, int count, float[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted floats. + * Sets values from [src[srcIndex], src[srcIndex + count * 4]) to [rowId, rowId + count) + * The data in src must be ieee formatted floats in platform native endian. */ public abstract void putFloats(int rowId, int count, byte[] src, int srcIndex); /** - * Sets the value at rowId to `value`. + * Sets `value` to the value at rowId. */ public abstract void putDouble(int rowId, double value); /** - * Sets values from [rowId, rowId + count) to value. + * Sets value to [rowId, rowId + count). */ public abstract void putDoubles(int rowId, int count, double value); /** - * Sets values from [rowId, rowId + count) to [src + srcIndex, src + srcIndex + count) + * Sets values from [src[srcIndex], src[srcIndex + count]) to [rowId, rowId + count) */ public abstract void putDoubles(int rowId, int count, double[] src, int srcIndex); /** - * Sets values from [rowId, rowId + count) to [src[srcIndex], src[srcIndex + count]) - * The data in src must be ieee formatted doubles. + * Sets values from [src[srcIndex], src[srcIndex + count * 8]) to [rowId, rowId + count) + * The data in src must be ieee formatted doubles in platform native endian. */ public abstract void putDoubles(int rowId, int count, byte[] src, int srcIndex); @@ -236,7 +273,7 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { public abstract void putArray(int rowId, int offset, int length); /** - * Sets the value at rowId to `value`. + * Sets values from [value + offset, value + offset + count) to the values at rowId. */ public abstract int putByteArray(int rowId, byte[] value, int offset, int count); public final int putByteArray(int rowId, byte[] value) { @@ -559,7 +596,7 @@ public final int appendStruct(boolean isNull) { * Upper limit for the maximum capacity for this column. */ @VisibleForTesting - protected int MAX_CAPACITY = Integer.MAX_VALUE - 8; + protected int MAX_CAPACITY = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; /** * Number of nulls in this column. This is an optimization for the reader, to skip NULL checks. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java index ab5254a688d5a..ee489ad0f608f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupport.java @@ -30,9 +30,8 @@ public interface ReadSupport { /** * Creates a {@link DataSourceV2Reader} to scan the data from this data source. * - * @param options the options for this data source reader, which is an immutable case-insensitive - * string-to-string map. - * @return a reader that implements the actual read logic. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. */ DataSourceV2Reader createReader(DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java index c13aeca2ef36f..74e81a2c84d68 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/ReadSupportWithSchema.java @@ -39,9 +39,8 @@ public interface ReadSupportWithSchema { * physical schema of the underlying storage of this data source reader, e.g. * CSV files, JSON files, etc, while this reader may not read data with full * schema, as column pruning or other optimizations may happen. - * @param options the options for this data source reader, which is an immutable case-insensitive - * string-to-string map. - * @return a reader that implements the actual read logic. + * @param options the options for the returned data source reader, which is an immutable + * case-insensitive string-to-string map. */ DataSourceV2Reader createReader(StructType schema, DataSourceV2Options options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java new file mode 100644 index 0000000000000..a8a961598bde3 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/WriteSupport.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import java.util.Optional; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.sources.v2.writer.DataSourceV2Writer; +import org.apache.spark.sql.types.StructType; + +/** + * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * provide data writing ability and save the data to the data source. + */ +@InterfaceStability.Evolving +public interface WriteSupport { + + /** + * Creates an optional {@link DataSourceV2Writer} to save the data to this data source. Data + * sources can return None if there is no writing needed to be done according to the save mode. + * + * @param jobId A unique string for the writing job. It's possible that there are many writing + * jobs running at the same time, and the returned {@link DataSourceV2Writer} should + * use this job id to distinguish itself with writers of other jobs. + * @param schema the schema of the data to be written. + * @param mode the save mode which determines what to do when the data are already in this data + * source, please refer to {@link SaveMode} for more details. + * @param options the options for the returned data source writer, which is an immutable + * case-insensitive string-to-string map. + */ + Optional createWriter( + String jobId, StructType schema, SaveMode mode, DataSourceV2Options options); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java index cfafc1a576793..95e091569b614 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataReader.java @@ -24,6 +24,10 @@ /** * A data reader returned by {@link ReadTask#createReader()} and is responsible for outputting data * for a RDD partition. + * + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data + * source readers, or {@link org.apache.spark.sql.catalyst.expressions.UnsafeRow} for data source + * readers that mix in {@link SupportsScanUnsafeRow}. */ @InterfaceStability.Evolving public interface DataReader extends Closeable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java index fb4d5c0d7ae41..5989a4ac8440b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/DataSourceV2Reader.java @@ -30,7 +30,7 @@ * {@link org.apache.spark.sql.sources.v2.ReadSupportWithSchema#createReader( * StructType, org.apache.spark.sql.sources.v2.DataSourceV2Options)}. * It can mix in various query optimization interfaces to speed up the data scan. The actual scan - * logic should be delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. + * logic is delegated to {@link ReadTask}s that are returned by {@link #createReadTasks()}. * * There are mainly 3 kinds of query optimizations: * 1. Operators push-down. E.g., filter push-down, required columns push-down(aka column diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java index 7885bfcdd49e4..01362df0978cb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ReadTask.java @@ -27,7 +27,8 @@ * is similar to the relationship between {@link Iterable} and {@link java.util.Iterator}. * * Note that, the read task will be serialized and sent to executors, then the data reader will be - * created on executors and do the actual reading. + * created on executors and do the actual reading. So {@link ReadTask} must be serializable and + * {@link DataReader} doesn't need to be. */ @InterfaceStability.Evolving public interface ReadTask extends Serializable { diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java index 19d706238ec8e..d6091774d75aa 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownCatalystFilters.java @@ -40,4 +40,12 @@ public interface SupportsPushDownCatalystFilters { * Pushes down filters, and returns unsupported filters. */ Expression[] pushCatalystFilters(Expression[] filters); + + /** + * Returns the catalyst filters that are pushed in {@link #pushCatalystFilters(Expression[])}. + * It's possible that there is no filters in the query and + * {@link #pushCatalystFilters(Expression[])} is never called, empty array should be returned for + * this case. + */ + Expression[] pushedCatalystFilters(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index d4b509e7080f2..6b0c9d417eeae 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -35,4 +35,11 @@ public interface SupportsPushDownFilters { * Pushes down filters, and returns unsupported filters. */ Filter[] pushFilters(Filter[] filters); + + /** + * Returns the filters that are pushed in {@link #pushFilters(Filter[])}. + * It's possible that there is no filters in the query and {@link #pushFilters(Filter[])} + * is never called, empty array should be returned for this case. + */ + Filter[] pushedFilters(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java new file mode 100644 index 0000000000000..8d8e33633fb0d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataSourceV2Writer.java @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.SaveMode; +import org.apache.spark.sql.sources.v2.DataSourceV2Options; +import org.apache.spark.sql.sources.v2.WriteSupport; +import org.apache.spark.sql.types.StructType; + +/** + * A data source writer that is returned by + * {@link WriteSupport#createWriter(String, StructType, SaveMode, DataSourceV2Options)}. + * It can mix in various writing optimization interfaces to speed up the data saving. The actual + * writing logic is delegated to {@link DataWriter}. + * + * The writing procedure is: + * 1. Create a writer factory by {@link #createWriterFactory()}, serialize and send it to all the + * partitions of the input data(RDD). + * 2. For each partition, create the data writer, and write the data of the partition with this + * writer. If all the data are written successfully, call {@link DataWriter#commit()}. If + * exception happens during the writing, call {@link DataWriter#abort()}. + * 3. If all writers are successfully committed, call {@link #commit(WriterCommitMessage[])}. If + * some writers are aborted, or the job failed with an unknown reason, call + * {@link #abort(WriterCommitMessage[])}. + * + * Spark won't retry failed writing jobs, users should do it manually in their Spark applications if + * they want to retry. + * + * Please refer to the document of commit/abort methods for detailed specifications. + * + * Note that, this interface provides a protocol between Spark and data sources for transactional + * data writing, but the transaction here is Spark-level transaction, which may not be the + * underlying storage transaction. For example, Spark successfully writes data to a Cassandra data + * source, but Cassandra may need some more time to reach consistency at storage level. + */ +@InterfaceStability.Evolving +public interface DataSourceV2Writer { + + /** + * Creates a writer factory which will be serialized and sent to executors. + */ + DataWriterFactory createWriterFactory(); + + /** + * Commits this writing job with a list of commit messages. The commit messages are collected from + * successful data writers and are produced by {@link DataWriter#commit()}. If this method + * fails(throw exception), this writing job is considered to be failed, and + * {@link #abort(WriterCommitMessage[])} will be called. The written data should only be visible + * to data source readers if this method succeeds. + * + * Note that, one partition may have multiple committed data writers because of speculative tasks. + * Spark will pick the first successful one and get its commit message. Implementations should be + * aware of this and handle it correctly, e.g., have a mechanism to make sure only one data writer + * can commit successfully, or have a way to clean up the data of already-committed writers. + */ + void commit(WriterCommitMessage[] messages); + + /** + * Aborts this writing job because some data writers are failed to write the records and aborted, + * or the Spark job fails with some unknown reasons, or {@link #commit(WriterCommitMessage[])} + * fails. If this method fails(throw exception), the underlying data source may have garbage that + * need to be cleaned manually, but these garbage should not be visible to data source readers. + * + * Unless the abort is triggered by the failure of commit, the given messages should have some + * null slots as there maybe only a few data writers that are committed before the abort + * happens, or some data writers were committed but their commit messages haven't reached the + * driver when the abort is triggered. So this is just a "best effort" for data sources to + * clean up the data left by data writers. + */ + void abort(WriterCommitMessage[] messages); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java new file mode 100644 index 0000000000000..14261419af6f6 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A data writer returned by {@link DataWriterFactory#createWriter(int, int)} and is + * responsible for writing data for an input RDD partition. + * + * One Spark task has one exclusive data writer, so there is no thread-safe concern. + * + * {@link #write(Object)} is called for each record in the input RDD partition. If one record fails + * the {@link #write(Object)}, {@link #abort()} is called afterwards and the remaining records will + * not be processed. If all records are successfully written, {@link #commit()} is called. + * + * If this data writer succeeds(all records are successfully written and {@link #commit()} + * succeeds), a {@link WriterCommitMessage} will be sent to the driver side and pass to + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} with commit messages from other data + * writers. If this data writer fails(one record fails to write or {@link #commit()} fails), an + * exception will be sent to the driver side, and Spark will retry this writing task for some times, + * each time {@link DataWriterFactory#createWriter(int, int)} gets a different `attemptNumber`, + * and finally call {@link DataSourceV2Writer#abort(WriterCommitMessage[])} if all retry fail. + * + * Besides the retry mechanism, Spark may launch speculative tasks if the existing writing task + * takes too long to finish. Different from retried tasks, which are launched one by one after the + * previous one fails, speculative tasks are running simultaneously. It's possible that one input + * RDD partition has multiple data writers with different `attemptNumber` running at the same time, + * and data sources should guarantee that these data writers don't conflict and can work together. + * Implementations can coordinate with driver during {@link #commit()} to make sure only one of + * these data writers can commit successfully. Or implementations can allow all of them to commit + * successfully, and have a way to revert committed data writers without the commit message, because + * Spark only accepts the commit message that arrives first and ignore others. + * + * Note that, Currently the type `T` can only be {@link org.apache.spark.sql.Row} for normal data + * source writers, or {@link org.apache.spark.sql.catalyst.InternalRow} for data source writers + * that mix in {@link SupportsWriteInternalRow}. + */ +@InterfaceStability.Evolving +public interface DataWriter { + + /** + * Writes one record. + * + * If this method fails(throw exception), {@link #abort()} will be called and this data writer is + * considered to be failed. + */ + void write(T record); + + /** + * Commits this writer after all records are written successfully, returns a commit message which + * will be send back to driver side and pass to + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * + * The written data should only be visible to data source readers after + * {@link DataSourceV2Writer#commit(WriterCommitMessage[])} succeeds, which means this method + * should still "hide" the written data and ask the {@link DataSourceV2Writer} at driver side to + * do the final commitment via {@link WriterCommitMessage}. + * + * If this method fails(throw exception), {@link #abort()} will be called and this data writer is + * considered to be failed. + */ + WriterCommitMessage commit(); + + /** + * Aborts this writer if it is failed. Implementations should clean up the data for already + * written records. + * + * This method will only be called if there is one record failed to write, or {@link #commit()} + * failed. + * + * If this method fails(throw exception), the underlying data source may have garbage that need + * to be cleaned by {@link DataSourceV2Writer#abort(WriterCommitMessage[])} or manually, but + * these garbage should not be visible to data source readers. + */ + void abort(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java new file mode 100644 index 0000000000000..f812d102bda1a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A factory of {@link DataWriter} returned by {@link DataSourceV2Writer#createWriterFactory()}, + * which is responsible for creating and initializing the actual data writer at executor side. + * + * Note that, the writer factory will be serialized and sent to executors, then the data writer + * will be created on executors and do the actual writing. So {@link DataWriterFactory} must be + * serializable and {@link DataWriter} doesn't need to be. + */ +@InterfaceStability.Evolving +public interface DataWriterFactory extends Serializable { + + /** + * Returns a data writer to do the actual writing work. + * + * @param partitionId A unique id of the RDD partition that the returned writer will process. + * Usually Spark processes many RDD partitions at the same time, + * implementations should use the partition id to distinguish writers for + * different partitions. + * @param attemptNumber Spark may launch multiple tasks with the same task id. For example, a task + * failed, Spark launches a new task wth the same task id but different + * attempt number. Or a task is too slow, Spark launches new tasks wth the + * same task id but different attempt number, which means there are multiple + * tasks with the same task id running at the same time. Implementations can + * use this attempt number to distinguish writers of different task attempts. + */ + DataWriter createWriter(int partitionId, int attemptNumber); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java new file mode 100644 index 0000000000000..a8e95901f3b07 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsWriteInternalRow.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.annotation.InterfaceStability; +import org.apache.spark.sql.Row; +import org.apache.spark.sql.catalyst.InternalRow; + +/** + * A mix-in interface for {@link DataSourceV2Writer}. Data source writers can implement this + * interface to write {@link InternalRow} directly and avoid the row conversion at Spark side. + * This is an experimental and unstable interface, as {@link InternalRow} is not public and may get + * changed in the future Spark versions. + */ + +@InterfaceStability.Evolving +@Experimental +@InterfaceStability.Unstable +public interface SupportsWriteInternalRow extends DataSourceV2Writer { + + @Override + default DataWriterFactory createWriterFactory() { + throw new IllegalStateException( + "createWriterFactory should not be called with SupportsWriteInternalRow."); + } + + DataWriterFactory createInternalRowWriterFactory(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java new file mode 100644 index 0000000000000..082d6b5dc409f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import java.io.Serializable; + +import org.apache.spark.annotation.InterfaceStability; + +/** + * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side + * as the input parameter of {@link DataSourceV2Writer#commit(WriterCommitMessage[])}. + * + * This is an empty interface, data sources should define their own message class and use it in + * their {@link DataWriter#commit()} and {@link DataSourceV2Writer#commit(WriterCommitMessage[])} + * implementations. + */ +@InterfaceStability.Evolving +public interface WriterCommitMessage extends Serializable {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 78b668c04fd5c..17966eecfc051 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -184,7 +184,6 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val cls = DataSource.lookupDataSource(source) if (classOf[DataSourceV2].isAssignableFrom(cls)) { - val dataSource = cls.newInstance() val options = new DataSourceV2Options(extraOptions.asJava) val reader = (cls.newInstance(), userSpecifiedSchema) match { @@ -194,8 +193,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { case (ds: ReadSupport, None) => ds.createReader(options) - case (_: ReadSupportWithSchema, None) => - throw new AnalysisException(s"A schema needs to be specified when using $dataSource.") + case (ds: ReadSupportWithSchema, None) => + throw new AnalysisException(s"A schema needs to be specified when using $ds.") case (ds: ReadSupport, Some(schema)) => val reader = ds.createReader(options) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 07347d2748544..8d95b24c00619 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql -import java.util.{Locale, Properties} +import java.text.SimpleDateFormat +import java.util.{Date, Locale, Properties, UUID} import scala.collection.JavaConverters._ @@ -29,7 +30,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.sources.v2.{DataSourceV2, DataSourceV2Options, WriteSupport} import org.apache.spark.sql.types.StructType /** @@ -231,12 +234,33 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { assertNotBucketed("save") - runCommand(df.sparkSession, "save") { - DataSource( - sparkSession = df.sparkSession, - className = source, - partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + val cls = DataSource.lookupDataSource(source) + if (classOf[DataSourceV2].isAssignableFrom(cls)) { + cls.newInstance() match { + case ds: WriteSupport => + val options = new DataSourceV2Options(extraOptions.asJava) + // Using a timestamp and a random UUID to distinguish different writing jobs. This is good + // enough as there won't be tons of writing jobs created at the same second. + val jobId = new SimpleDateFormat("yyyyMMddHHmmss", Locale.US) + .format(new Date()) + "-" + UUID.randomUUID() + val writer = ds.createWriter(jobId, df.logicalPlan.schema, mode, options) + if (writer.isPresent) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(writer.get(), df.logicalPlan) + } + } + + case _ => throw new AnalysisException(s"$cls does not support data writing.") + } + } else { + // Code path for data source v1. + runCommand(df.sparkSession, "save") { + DataSource( + sparkSession = df.sparkSession, + className = source, + partitionColumns = partitioningColumns.getOrElse(Nil), + options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + } } } @@ -520,8 +544,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `compression` (default is the value specified in `spark.sql.orc.compression.codec`): * compression codec to use when saving to file. This can be one of the known case-insensitive * shorten names(`none`, `snappy`, `zlib`, and `lzo`). This will override - * `orc.compress` and `spark.sql.parquet.compression.codec`. If `orc.compress` is given, - * it overrides `spark.sql.parquet.compression.codec`.
  • + * `orc.compress` and `spark.sql.orc.compression.codec`. If `orc.compress` is given, + * it overrides `spark.sql.orc.compression.codec`. * * * @since 1.5.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ab0c4126bcbdd..b70dfc05330f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -237,7 +237,7 @@ class Dataset[T] private[sql]( */ private[sql] def showString( _numRows: Int, truncate: Int = 20, vertical: Boolean = false): String = { - val numRows = _numRows.max(0) + val numRows = _numRows.max(0).min(Int.MaxValue - 1) val takeResult = toDF().take(numRows + 1) val hasMoreData = takeResult.length > numRows val data = takeResult.take(numRows) @@ -2083,22 +2083,40 @@ class Dataset[T] private[sql]( * @group untypedrel * @since 2.0.0 */ - def withColumn(colName: String, col: Column): DataFrame = { + def withColumn(colName: String, col: Column): DataFrame = withColumns(Seq(colName), Seq(col)) + + /** + * Returns a new Dataset by adding columns or replacing the existing columns that has + * the same names. + */ + private[spark] def withColumns(colNames: Seq[String], cols: Seq[Column]): DataFrame = { + require(colNames.size == cols.size, + s"The size of column names: ${colNames.size} isn't equal to " + + s"the size of columns: ${cols.size}") + SchemaUtils.checkColumnNameDuplication( + colNames, + "in given column names", + sparkSession.sessionState.conf.caseSensitiveAnalysis) + val resolver = sparkSession.sessionState.analyzer.resolver val output = queryExecution.analyzed.output - val shouldReplace = output.exists(f => resolver(f.name, colName)) - if (shouldReplace) { - val columns = output.map { field => - if (resolver(field.name, colName)) { - col.as(colName) - } else { - Column(field) - } + + val columnMap = colNames.zip(cols).toMap + + val replacedAndExistingColumns = output.map { field => + columnMap.find { case (colName, _) => + resolver(field.name, colName) + } match { + case Some((colName: String, col: Column)) => col.as(colName) + case _ => Column(field) } - select(columns : _*) - } else { - select(Column("*"), col.as(colName)) } + + val newColumns = columnMap.filter { case (colName, col) => + !output.exists(f => resolver(f.name, colName)) + }.map { case (colName, col) => col.as(colName) } + + select(replacedAndExistingColumns ++ newColumns : _*) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index cb42e9e4560cf..6bab21dca0cbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -24,7 +24,6 @@ import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CreateStruct} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.expressions.ReduceAggregator import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode} @@ -564,4 +563,25 @@ class KeyValueGroupedDataset[K, V] private[sql]( encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } + + override def toString: String = { + val builder = new StringBuilder + val kFields = kExprEnc.schema.map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + val vFields = vExprEnc.schema.map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append("KeyValueGroupedDataset: [key: [") + builder.append(kFields.take(2).mkString(", ")) + if (kFields.length > 2) { + builder.append(" ... " + (kFields.length - 2) + " more field(s)") + } + builder.append("], value: [") + builder.append(vFields.take(2).mkString(", ")) + if (vFields.length > 2) { + builder.append(" ... " + (vFields.length - 2) + " more field(s)") + } + builder.append("]]").toString() + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 147b549964913..6b45790d5ff6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -27,12 +27,12 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedAlias, UnresolvedAttribute, UnresolvedFunction} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, FlatMapGroupsInR, Pivot} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.usePrettyExpression import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression +import org.apache.spark.sql.execution.python.{PythonUDF, PythonUdfType} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.NumericType -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{NumericType, StructType} /** * A set of methods for aggregations on a `DataFrame`, created by [[Dataset#groupBy groupBy]], @@ -435,6 +435,50 @@ class RelationalGroupedDataset protected[sql]( df.logicalPlan.output, df.logicalPlan)) } + + /** + * Applies a grouped vectorized python user-defined function to each group of data. + * The user-defined function defines a transformation: `pandas.DataFrame` -> `pandas.DataFrame`. + * For each group, all elements in the group are passed as a `pandas.DataFrame` and the results + * for all groups are combined into a new [[DataFrame]]. + * + * This function does not support partial aggregation, and requires shuffling all the data in + * the [[DataFrame]]. + * + * This function uses Apache Arrow as serialization format between Java executors and Python + * workers. + */ + private[sql] def flatMapGroupsInPandas(expr: PythonUDF): DataFrame = { + require(expr.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF, + "Must pass a grouped vectorized python udf") + require(expr.dataType.isInstanceOf[StructType], + "The returnType of the vectorized python udf must be a StructType") + + val groupingNamedExpressions = groupingExprs.map { + case ne: NamedExpression => ne + case other => Alias(other, other.toString)() + } + val groupingAttributes = groupingNamedExpressions.map(_.toAttribute) + val child = df.logicalPlan + val project = Project(groupingNamedExpressions ++ child.output, child) + val output = expr.dataType.asInstanceOf[StructType].toAttributes + val plan = FlatMapGroupsInPandas(groupingAttributes, expr, output, project) + + Dataset.ofRows(df.sparkSession, plan) + } + + override def toString: String = { + val builder = new StringBuilder + builder.append("RelationalGroupedDataset: [grouping expressions: [") + val kFields = groupingExprs.map(_.asInstanceOf[NamedExpression]).map { + case f => s"${f.name}: ${f.dataType.simpleString(2)}" + } + builder.append(kFields.take(2).mkString(", ")) + if (kFields.length > 2) { + builder.append(" ... " + (kFields.length - 2) + " more field(s)") + } + builder.append(s"], value: ${df.toString}, type: $groupType]").toString() + } } private[sql] object RelationalGroupedDataset { @@ -449,7 +493,9 @@ private[sql] object RelationalGroupedDataset { /** * The Grouping Type */ - private[sql] trait GroupType + private[sql] trait GroupType { + override def toString: String = getClass.getSimpleName.stripSuffix("$").stripSuffix("Type") + } /** * To indicate it's the GroupBy diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 1afe83ea3539e..eb01e126bcbef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType @@ -31,8 +30,6 @@ import org.apache.spark.sql.types.DataType */ private[sql] trait ColumnarBatchScan extends CodegenSupport { - val inMemoryTableScan: InMemoryTableScanExec = null - def vectorTypes: Option[Seq[String]] = None override lazy val metrics = Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 4accf54a18232..f404621399cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -119,7 +119,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { * `SparkSQLDriver` for CLI applications. */ def hiveResultString(): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand, _) => + case ExecutedCommandExec(desc: DescribeTableCommand) => // If it is a describe command for a Hive table, we want to have the output format // be similar with Hive. desc.run(sparkSession).map { @@ -130,7 +130,7 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) { .mkString("\t") } // SHOW TABLES in Hive only output table names, while ours output database, table name, isTemp. - case command @ ExecutedCommandExec(s: ShowTablesCommand, _) if !s.isExtended => + case command @ ExecutedCommandExec(s: ShowTablesCommand) if !s.isExtended => command.executeCollect().map(_.getString(1)) case other => val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 00ff4c8ac310b..1c8e4050978dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.ExperimentalMethods import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.execution.datasources.PruneFileSourcePartitions +import org.apache.spark.sql.execution.datasources.v2.PushDownOperatorsToDataSource import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate class SparkOptimizer( @@ -31,7 +32,8 @@ class SparkOptimizer( override def batches: Seq[Batch] = (preOptimizationBatches ++ super.batches :+ Batch("Optimize Metadata Only Query", Once, OptimizeMetadataOnlyQuery(catalog)) :+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+ - Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions)) ++ + Batch("Prune File Source Table Partitions", Once, PruneFileSourcePartitions) :+ + Batch("Push down operators to data source scan", Once, PushDownOperatorsToDataSource)) ++ postHocOptimizationBatches :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b263f100e6068..2ffd948f984bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -223,7 +223,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also * compressed. */ - private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = { + private def getByteArrayRdd(n: Int = -1): RDD[(Long, Array[Byte])] = { execute().mapPartitionsInternal { iter => var count = 0 val buffer = new Array[Byte](4 << 10) // 4K @@ -239,7 +239,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ out.writeInt(-1) out.flush() out.close() - Iterator(bos.toByteArray) + Iterator((count, bos.toByteArray)) } } @@ -274,19 +274,26 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ val byteArrayRdd = getByteArrayRdd() val results = ArrayBuffer[InternalRow]() - byteArrayRdd.collect().foreach { bytes => - decodeUnsafeRows(bytes).foreach(results.+=) + byteArrayRdd.collect().foreach { countAndBytes => + decodeUnsafeRows(countAndBytes._2).foreach(results.+=) } results.toArray } + private[spark] def executeCollectIterator(): (Long, Iterator[InternalRow]) = { + val countsAndBytes = getByteArrayRdd().collect() + val total = countsAndBytes.map(_._1).sum + val rows = countsAndBytes.iterator.flatMap(countAndBytes => decodeUnsafeRows(countAndBytes._2)) + (total, rows) + } + /** * Runs this query returning the result as an iterator of InternalRow. * * @note Triggers multiple jobs (one for each partition). */ def executeToIterator(): Iterator[InternalRow] = { - getByteArrayRdd().toLocalIterator.flatMap(decodeUnsafeRows) + getByteArrayRdd().map(_._2).toLocalIterator.flatMap(decodeUnsafeRows) } /** @@ -307,7 +314,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ return new Array[InternalRow](0) } - val childRDD = getByteArrayRdd(n) + val childRDD = getByteArrayRdd(n).map(_._2) val buf = new ArrayBuffer[InternalRow] val totalParts = childRDD.partitions.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index b143d44eae17b..74048871f8d42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -33,7 +33,7 @@ class SparkPlanner( def numPartitions: Int = conf.numShufflePartitions - def strategies: Seq[Strategy] = + override def strategies: Seq[Strategy] = experimentalMethods.extraStrategies ++ extraPlanningStrategies ++ ( DataSourceV2Strategy :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4da7a73469537..19b858faba6ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.{InMemoryRelation, InMemoryTableScanExec} import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.internal.SQLConf @@ -364,7 +364,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Can we automate these 'pass through' operations? object BasicOperators extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case r: RunnableCommand => ExecutedCommandExec(r, r.children.map(planLater)) :: Nil + case r: RunnableCommand => ExecutedCommandExec(r) :: Nil case MemoryPlan(sink, output) => val encoder = RowEncoder(sink.schema) @@ -392,6 +392,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.FlatMapGroupsInR(f, p, b, is, os, key, value, grouping, data, objAttr, child) => execution.FlatMapGroupsInRExec(f, p, b, is, os, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.FlatMapGroupsInPandas(grouping, func, output, child) => + execution.python.FlatMapGroupsInPandasExec(grouping, func, output, planLater(child)) :: Nil case logical.MapElements(f, _, _, objAttr, child) => execution.MapElementsExec(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, _, _, in, out, child) => @@ -411,7 +413,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { - ShuffleExchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil + ShuffleExchangeExec(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil } else { execution.CoalesceExec(numPartitions, planLater(child)) :: Nil } @@ -446,7 +448,7 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: logical.Range => execution.RangeExec(r) :: Nil case logical.RepartitionByExpression(expressions, child, numPartitions) => - exchange.ShuffleExchange(HashPartitioning( + exchange.ShuffleExchangeExec(HashPartitioning( expressions, numPartitions), planLater(child)) :: Nil case ExternalRDD(outputObjAttr, rdd) => ExternalRDDScanExec(outputObjAttr, rdd) :: Nil case r: LogicalRDD => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index 268ccfa4edfa0..e37d133ff336a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -282,6 +282,18 @@ case class InputAdapter(child: SparkPlan) extends UnaryExecNode with CodegenSupp object WholeStageCodegenExec { val PIPELINE_DURATION_METRIC = "duration" + + private def numOfNestedFields(dataType: DataType): Int = dataType match { + case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum + case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) + case a: ArrayType => numOfNestedFields(a.elementType) + case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) + case _ => 1 + } + + def isTooManyFields(conf: SQLConf, dataType: DataType): Boolean = { + numOfNestedFields(dataType) > conf.wholeStageMaxNumFields + } } /** @@ -380,16 +392,8 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co override def doExecute(): RDD[InternalRow] = { val (ctx, cleanedSource) = doCodeGen() - if (ctx.isTooLongGeneratedFunction) { - logWarning("Found too long generated codes and JIT optimization might not work, " + - "Whole-stage codegen disabled for this plan, " + - "You can change the config spark.sql.codegen.MaxFunctionLength " + - "to adjust the function length limit:\n " - + s"$treeString") - return child.execute() - } // try to compile and fallback if it failed - try { + val (_, maxCodeSize) = try { CodeGenerator.compile(cleanedSource) } catch { case _: Exception if !Utils.isTesting && sqlContext.conf.codegenFallback => @@ -397,6 +401,21 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co logWarning(s"Whole-stage codegen disabled for this plan:\n $treeString") return child.execute() } + + // Check if compiled code has a too large function + if (maxCodeSize > sqlContext.conf.hugeMethodLimit) { + logInfo(s"Found too long generated codes and JIT optimization might not work: " + + s"the bytecode size ($maxCodeSize) is above the limit " + + s"${sqlContext.conf.hugeMethodLimit}, and the whole-stage codegen was disabled " + + s"for this plan. To avoid this, you can raise the limit " + + s"`${SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key}`:\n$treeString") + child match { + // The fallback solution of batch file source scan still uses WholeStageCodegenExec + case f: FileSourceScanExec if f.supportsBatch => // do nothing + case _ => return child.execute() + } + } + val references = ctx.references.toArray val durationMs = longMetric("pipelineTime") @@ -405,7 +424,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co assert(rdds.size <= 2, "Up to two input RDDs can be supported") if (rdds.length == 1) { rdds.head.mapPartitionsWithIndex { (index, iter) => - val clazz = CodeGenerator.compile(cleanedSource) + val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(iter)) new Iterator[InternalRow] { @@ -424,7 +443,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co // a small hack to obtain the correct partition index }.mapPartitionsWithIndex { (index, zippedIter) => val (leftIter, rightIter) = zippedIter.next() - val clazz = CodeGenerator.compile(cleanedSource) + val (clazz, _) = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] buffer.init(index, Array(leftIter, rightIter)) new Iterator[InternalRow] { @@ -483,22 +502,14 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case _ => true } - private def numOfNestedFields(dataType: DataType): Int = dataType match { - case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum - case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) - case a: ArrayType => numOfNestedFields(a.elementType) - case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) - case _ => 1 - } - private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns val hasTooManyOutputFields = - numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields + WholeStageCodegenExec.isTooManyFields(conf, plan.schema) val hasTooManyInputFields = - plan.children.map(p => numOfNestedFields(p.schema)).exists(_ > conf.wholeStageMaxNumFields) + plan.children.exists(p => WholeStageCodegenExec.isTooManyFields(conf, p.schema)) !willFallback && !hasTooManyOutputFields && !hasTooManyInputFields case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f424096b330e3..43e5ff89afee6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -95,11 +95,13 @@ case class HashAggregateExec( val peakMemory = longMetric("peakMemory") val spillSize = longMetric("spillSize") val avgHashProbe = longMetric("avgHashProbe") + val aggTime = longMetric("aggTime") child.execute().mapPartitionsWithIndex { (partIndex, iter) => + val beforeAgg = System.nanoTime() val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { + val res = if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input iterator is empty, // so return an empty iterator. Iterator.empty @@ -128,6 +130,8 @@ case class HashAggregateExec( aggregationIterator } } + aggTime += (System.nanoTime() - beforeAgg) / 1000000 + res } } @@ -539,7 +543,7 @@ case class HashAggregateExec( private def enableTwoLevelHashMap(ctx: CodegenContext) = { if (!checkIfFastHashMapSupported(ctx)) { if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) { - logInfo("spark.sql.codegen.aggregate.map.twolevel.enable is set to true, but" + logInfo("spark.sql.codegen.aggregate.map.twolevel.enabled is set to true, but" + " current version of codegened fast hashmap does not support this aggregate.") } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala index 6316e06a8f34e..66955b8ef723c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ObjectHashAggregateExec.scala @@ -76,7 +76,8 @@ case class ObjectHashAggregateExec( aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows") + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time") ) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -94,13 +95,17 @@ case class ObjectHashAggregateExec( } } + override def outputPartitioning: Partitioning = child.outputPartitioning + protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") { val numOutputRows = longMetric("numOutputRows") + val aggTime = longMetric("aggTime") val fallbackCountThreshold = sqlContext.conf.objectAggSortBasedFallbackThreshold child.execute().mapPartitionsWithIndexInternal { (partIndex, iter) => + val beforeAgg = System.nanoTime() val hasInput = iter.hasNext - if (!hasInput && groupingExpressions.nonEmpty) { + val res = if (!hasInput && groupingExpressions.nonEmpty) { // This is a grouped aggregate and the input kvIterator is empty, // so return an empty kvIterator. Iterator.empty @@ -127,6 +132,8 @@ case class ObjectHashAggregateExec( aggregationIterator } } + aggTime += (System.nanoTime() - beforeAgg) / 1000000 + res } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala index 9316ebcdf105c..3718424931b40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala @@ -50,10 +50,10 @@ class RowBasedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") @@ -63,10 +63,10 @@ class RowBasedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 717758fdf716f..aab8cc50b9526 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -127,7 +127,7 @@ case class SimpleTypedAggregateExpression( nullable: Boolean) extends DeclarativeAggregate with TypedAggregateExpression with NonSQLExpression { - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer @@ -221,7 +221,7 @@ case class ComplexTypedAggregateExpression( inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[Any] with TypedAggregateExpression with NonSQLExpression { - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = inputDeserializer.toSeq diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 13f79275cac41..812d405d5ebfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -55,10 +55,10 @@ class VectorizedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") @@ -68,10 +68,10 @@ class VectorizedHashMapGenerator( val keyName = ctx.addReferenceMinorObj(key.name) key.dataType match { case d: DecimalType => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType( + s""".add($keyName, org.apache.spark.sql.types.DataTypes.createDecimalType( |${d.precision}, ${d.scale}))""".stripMargin case _ => - s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + s""".add($keyName, org.apache.spark.sql.types.DataTypes.${key.dataType})""" } }.mkString("\n").concat(";") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index fec1add18cbf2..72aa4adff4e64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -340,7 +340,7 @@ case class ScalaUDAF( override def dataType: DataType = udaf.dataType - override def deterministic: Boolean = udaf.deterministic + override lazy val deterministic: Boolean = udaf.deterministic override val inputTypes: Seq[DataType] = udaf.inputSchema.map(_.dataType) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 8389e2f3d5be9..d15ece304cac4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration -import org.apache.spark.{InterruptibleIterator, TaskContext} +import org.apache.spark.{InterruptibleIterator, Partition, SparkContext, TaskContext} import org.apache.spark.rdd.{EmptyRDD, PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -554,6 +554,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) /** * Physical plan for unioning two plans, without a distinct. This is UNION ALL in SQL. + * + * If we change how this is implemented physically, we'd need to update + * [[org.apache.spark.sql.catalyst.plans.logical.Union.maxRowsPerPartition]]. */ case class UnionExec(children: Seq[SparkPlan]) extends SparkPlan { override def output: Seq[Attribute] = @@ -587,8 +590,31 @@ case class CoalesceExec(numPartitions: Int, child: SparkPlan) extends UnaryExecN } protected override def doExecute(): RDD[InternalRow] = { - child.execute().coalesce(numPartitions, shuffle = false) + if (numPartitions == 1 && child.execute().getNumPartitions < 1) { + // Make sure we don't output an RDD with 0 partitions, when claiming that we have a + // `SinglePartition`. + new CoalesceExec.EmptyRDDWithPartitions(sparkContext, numPartitions) + } else { + child.execute().coalesce(numPartitions, shuffle = false) + } + } +} + +object CoalesceExec { + /** A simple RDD with no data, but with the given number of partitions. */ + class EmptyRDDWithPartitions( + @transient private val sc: SparkContext, + numPartitions: Int) extends RDD[InternalRow](sc, Nil) { + + override def getPartitions: Array[Partition] = + Array.tabulate(numPartitions)(i => EmptyPartition(i)) + + override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + Iterator.empty + } } + + case class EmptyPartition(index: Int) extends Partition } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 6241b79d9affc..85c36b7da9498 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -24,6 +24,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeMapData, UnsafeRow} import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ /** @@ -122,7 +123,7 @@ private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) with NullableColumnAccessor -private[columnar] object ColumnAccessor { +private[sql] object ColumnAccessor { @tailrec def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { val buf = buffer.order(ByteOrder.nativeOrder) @@ -149,4 +150,22 @@ private[columnar] object ColumnAccessor { throw new Exception(s"not support type: $other") } } + + def decompress(columnAccessor: ColumnAccessor, columnVector: WritableColumnVector, numRows: Int): + Unit = { + if (columnAccessor.isInstanceOf[NativeColumnAccessor[_]]) { + val nativeAccessor = columnAccessor.asInstanceOf[NativeColumnAccessor[_]] + nativeAccessor.decompress(columnVector, numRows) + } else { + throw new RuntimeException("Not support non-primitive type now") + } + } + + def decompress( + array: Array[Byte], columnVector: WritableColumnVector, dataType: DataType, numRows: Int): + Unit = { + val byteBuffer = ByteBuffer.wrap(array) + val columnAccessor = ColumnAccessor(dataType, byteBuffer) + decompress(columnAccessor, columnVector, numRows) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 5cfb003e4f150..e9b150fd86095 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -43,6 +43,12 @@ import org.apache.spark.unsafe.types.UTF8String * WARNING: This only works with HeapByteBuffer */ private[columnar] object ByteBufferHelper { + def getShort(buffer: ByteBuffer): Short = { + val pos = buffer.position() + buffer.position(pos + 2) + Platform.getShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + def getInt(buffer: ByteBuffer): Int = { val pos = buffer.position() buffer.position(pos + 4) @@ -66,6 +72,33 @@ private[columnar] object ByteBufferHelper { buffer.position(pos + 8) Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) } + + def putShort(buffer: ByteBuffer, value: Short): Unit = { + val pos = buffer.position() + buffer.position(pos + 2) + Platform.putShort(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def putInt(buffer: ByteBuffer, value: Int): Unit = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.putInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def putLong(buffer: ByteBuffer, value: Long): Unit = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.putLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos, value) + } + + def copyMemory(src: ByteBuffer, dst: ByteBuffer, len: Int): Unit = { + val srcPos = src.position() + val dstPos = dst.position() + src.position(srcPos + len) + dst.position(dstPos + len) + Platform.copyMemory(src.array(), Platform.BYTE_ARRAY_OFFSET + srcPos, + dst.array(), Platform.BYTE_ARRAY_OFFSET + dstPos, len) + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala index da34643281911..ae600c1ffae8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -227,6 +227,7 @@ object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarItera new CodeAndComment(codeBody, ctx.getPlaceHolderToComments())) logDebug(s"Generated ColumnarIterator:\n${CodeFormatter.format(code)}") - CodeGenerator.compile(code).generate(Array.empty).asInstanceOf[ColumnarIterator] + val (clazz, _) = CodeGenerator.compile(code) + clazz.generate(Array.empty).asInstanceOf[ColumnarIterator] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index bc98d8d9d6d61..a1c62a729900e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -62,7 +62,8 @@ case class InMemoryRelation( @transient var _cachedColumnBuffers: RDD[CachedBatch] = null, val batchStats: LongAccumulator = child.sqlContext.sparkContext.longAccumulator) extends logical.LeafNode with MultiInstanceRelation { - override def innerChildren: Seq[SparkPlan] = Seq(child) + + override protected def innerChildren: Seq[SparkPlan] = Seq(child) override def producedAttributes: AttributeSet = outputSet diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index c7ddec55682e1..2ae3f35eb1da1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -23,21 +23,66 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning} -import org.apache.spark.sql.execution.LeafExecNode -import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.UserDefinedType +import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} +import org.apache.spark.sql.execution.vectorized._ +import org.apache.spark.sql.types._ case class InMemoryTableScanExec( attributes: Seq[Attribute], predicates: Seq[Expression], @transient relation: InMemoryRelation) - extends LeafExecNode { + extends LeafExecNode with ColumnarBatchScan { + + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + + override def vectorTypes: Option[Seq[String]] = + Option(Seq.fill(attributes.length)(classOf[OnHeapColumnVector].getName)) + + /** + * If true, get data from ColumnVector in ColumnarBatch, which are generally faster. + * If false, get data from UnsafeRow build from ColumnVector + */ + override val supportCodegen: Boolean = { + // In the initial implementation, for ease of review + // support only primitive data types and # of fields is less than wholeStageMaxNumFields + relation.schema.fields.forall(f => f.dataType match { + case BooleanType | ByteType | ShortType | IntegerType | LongType | + FloatType | DoubleType => true + case _ => false + }) && !WholeStageCodegenExec.isTooManyFields(conf, relation.schema) + } + + private val columnIndices = + attributes.map(a => relation.output.map(o => o.exprId).indexOf(a.exprId)).toArray + + private val relationSchema = relation.schema.toArray - override def innerChildren: Seq[QueryPlan[_]] = Seq(relation) ++ super.innerChildren + private lazy val columnarBatchSchema = new StructType(columnIndices.map(i => relationSchema(i))) - override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + private def createAndDecompressColumn(cachedColumnarBatch: CachedBatch): ColumnarBatch = { + val rowCount = cachedColumnarBatch.numRows + val columnVectors = OnHeapColumnVector.allocateColumns(rowCount, columnarBatchSchema) + val columnarBatch = new ColumnarBatch( + columnarBatchSchema, columnVectors.asInstanceOf[Array[ColumnVector]], rowCount) + columnarBatch.setNumRows(rowCount) + + for (i <- 0 until attributes.length) { + ColumnAccessor.decompress( + cachedColumnarBatch.buffers(columnIndices(i)), + columnarBatch.column(i).asInstanceOf[WritableColumnVector], + columnarBatchSchema.fields(i).dataType, rowCount) + } + columnarBatch + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + assert(supportCodegen) + val buffers = filteredCachedBatches() + // HACK ALERT: This is actually an RDD[ColumnarBatch]. + // We're taking advantage of Scala's type erasure here to pass these batches along. + Seq(buffers.map(createAndDecompressColumn(_)).asInstanceOf[RDD[InternalRow]]) + } override def output: Seq[Attribute] = attributes @@ -102,7 +147,8 @@ case class InMemoryTableScanExec( case IsNull(a: Attribute) => statsFor(a).nullCount > 0 case IsNotNull(a: Attribute) => statsFor(a).count - statsFor(a).nullCount > 0 - case In(a: AttributeReference, list: Seq[Expression]) if list.forall(_.isInstanceOf[Literal]) => + case In(a: AttributeReference, list: Seq[Expression]) + if list.forall(_.isInstanceOf[Literal]) && list.nonEmpty => list.map(l => statsFor(a).lowerBound <= l.asInstanceOf[Literal] && l.asInstanceOf[Literal] <= statsFor(a).upperBound).reduce(_ || _) } @@ -134,19 +180,11 @@ case class InMemoryTableScanExec( private val inMemoryPartitionPruningEnabled = sqlContext.conf.inMemoryPartitionPruning - protected override def doExecute(): RDD[InternalRow] = { - val numOutputRows = longMetric("numOutputRows") - - if (enableAccumulators) { - readPartitions.setValue(0) - readBatches.setValue(0) - } - + private def filteredCachedBatches(): RDD[CachedBatch] = { // Using these variables here to avoid serialization of entire objects (if referenced directly) // within the map Partitions closure. val schema = relation.partitionStatistics.schema val schemaIndex = schema.zipWithIndex - val relOutput: AttributeSeq = relation.output val buffers = relation.cachedColumnBuffers buffers.mapPartitionsWithIndexInternal { (index, cachedBatchIterator) => @@ -155,35 +193,49 @@ case class InMemoryTableScanExec( schema) partitionFilter.initialize(index) + // Do partition batch pruning if enabled + if (inMemoryPartitionPruningEnabled) { + cachedBatchIterator.filter { cachedBatch => + if (!partitionFilter.eval(cachedBatch.stats)) { + logDebug { + val statsString = schemaIndex.map { case (a, i) => + val value = cachedBatch.stats.get(i, a.dataType) + s"${a.name}: $value" + }.mkString(", ") + s"Skipping partition based on stats $statsString" + } + false + } else { + true + } + } + } else { + cachedBatchIterator + } + } + } + + protected override def doExecute(): RDD[InternalRow] = { + val numOutputRows = longMetric("numOutputRows") + + if (enableAccumulators) { + readPartitions.setValue(0) + readBatches.setValue(0) + } + + // Using these variables here to avoid serialization of entire objects (if referenced directly) + // within the map Partitions closure. + val relOutput: AttributeSeq = relation.output + + filteredCachedBatches().mapPartitionsInternal { cachedBatchIterator => // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = attributes.map { a => relOutput.indexOf(a.exprId) -> a.dataType }.unzip - // Do partition batch pruning if enabled - val cachedBatchesToScan = - if (inMemoryPartitionPruningEnabled) { - cachedBatchIterator.filter { cachedBatch => - if (!partitionFilter.eval(cachedBatch.stats)) { - logDebug { - val statsString = schemaIndex.map { case (a, i) => - val value = cachedBatch.stats.get(i, a.dataType) - s"${a.name}: $value" - }.mkString(", ") - s"Skipping partition based on stats $statsString" - } - false - } else { - true - } - } - } else { - cachedBatchIterator - } - // update SQL metrics - val withMetrics = cachedBatchesToScan.map { batch => + val withMetrics = cachedBatchIterator.map { batch => if (enableAccumulators) { readBatches.add(1) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala index e1d13ad0e94e5..774011f1e3de8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressibleColumnAccessor.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.columnar.{ColumnAccessor, NativeColumnAccessor} +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types.AtomicType private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends ColumnAccessor { @@ -36,4 +37,7 @@ private[columnar] trait CompressibleColumnAccessor[T <: AtomicType] extends Colu override def extractSingle(row: InternalRow, ordinal: Int): Unit = { decoder.next(row, ordinal) } + + def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = + decoder.decompress(columnVector, capacity) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala index 6e4f1c5b80684..f8aeba44257d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/CompressionScheme.scala @@ -21,6 +21,7 @@ import java.nio.{ByteBuffer, ByteOrder} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.columnar.{ColumnType, NativeColumnType} +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types.AtomicType private[columnar] trait Encoder[T <: AtomicType] { @@ -41,6 +42,8 @@ private[columnar] trait Decoder[T <: AtomicType] { def next(row: InternalRow, ordinal: Int): Unit def hasNext: Boolean + + def decompress(columnVector: WritableColumnVector, capacity: Int): Unit } private[columnar] trait CompressionScheme { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala index ee99c90a751d9..bf00ad997c76e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/compression/compressionSchemes.scala @@ -18,12 +18,14 @@ package org.apache.spark.sql.execution.columnar.compression import java.nio.ByteBuffer +import java.nio.ByteOrder import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ @@ -61,6 +63,101 @@ private[columnar] case object PassThrough extends CompressionScheme { } override def hasNext: Boolean = buffer.hasRemaining + + private def putBooleans( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + for (i <- 0 until len) { + columnVector.putBoolean(pos + i, (buffer.get(bufferPos + i) != 0)) + } + } + + private def putBytes( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putBytes(pos, len, buffer.array, bufferPos) + } + + private def putShorts( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putShorts(pos, len, buffer.array, bufferPos) + } + + private def putInts( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putInts(pos, len, buffer.array, bufferPos) + } + + private def putLongs( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putLongs(pos, len, buffer.array, bufferPos) + } + + private def putFloats( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putFloats(pos, len, buffer.array, bufferPos) + } + + private def putDoubles( + columnVector: WritableColumnVector, pos: Int, bufferPos: Int, len: Int): Unit = { + columnVector.putDoubles(pos, len, buffer.array, bufferPos) + } + + private def decompress0( + columnVector: WritableColumnVector, + capacity: Int, + unitSize: Int, + putFunction: (WritableColumnVector, Int, Int, Int) => Unit): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else capacity + var pos = 0 + var seenNulls = 0 + var bufferPos = buffer.position + while (pos < capacity) { + if (pos != nextNullIndex) { + val len = nextNullIndex - pos + assert(len * unitSize < Int.MaxValue) + putFunction(columnVector, pos, bufferPos, len) + bufferPos += len * unitSize + pos += len + } else { + seenNulls += 1 + nextNullIndex = if (seenNulls < nullCount) { + ByteBufferHelper.getInt(nullsBuffer) + } else { + capacity + } + columnVector.putNull(pos) + pos += 1 + } + } + } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + columnType.dataType match { + case _: BooleanType => + val unitSize = 1 + decompress0(columnVector, capacity, unitSize, putBooleans) + case _: ByteType => + val unitSize = 1 + decompress0(columnVector, capacity, unitSize, putBytes) + case _: ShortType => + val unitSize = 2 + decompress0(columnVector, capacity, unitSize, putShorts) + case _: IntegerType => + val unitSize = 4 + decompress0(columnVector, capacity, unitSize, putInts) + case _: LongType => + val unitSize = 8 + decompress0(columnVector, capacity, unitSize, putLongs) + case _: FloatType => + val unitSize = 4 + decompress0(columnVector, capacity, unitSize, putFloats) + case _: DoubleType => + val unitSize = 8 + decompress0(columnVector, capacity, unitSize, putDoubles) + } + } } } @@ -169,6 +266,94 @@ private[columnar] case object RunLengthEncoding extends CompressionScheme { } override def hasNext: Boolean = valueCount < run || buffer.hasRemaining + + private def putBoolean(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putBoolean(pos, value == 1) + } + + private def getByte(buffer: ByteBuffer): Long = { + buffer.get().toLong + } + + private def putByte(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putByte(pos, value.toByte) + } + + private def getShort(buffer: ByteBuffer): Long = { + buffer.getShort().toLong + } + + private def putShort(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putShort(pos, value.toShort) + } + + private def getInt(buffer: ByteBuffer): Long = { + buffer.getInt().toLong + } + + private def putInt(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putInt(pos, value.toInt) + } + + private def getLong(buffer: ByteBuffer): Long = { + buffer.getLong() + } + + private def putLong(columnVector: WritableColumnVector, pos: Int, value: Long): Unit = { + columnVector.putLong(pos, value) + } + + private def decompress0( + columnVector: WritableColumnVector, + capacity: Int, + getFunction: (ByteBuffer) => Long, + putFunction: (WritableColumnVector, Int, Long) => Unit): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + var runLocal = 0 + var valueCountLocal = 0 + var currentValueLocal: Long = 0 + + while (valueCountLocal < runLocal || (pos < capacity)) { + if (pos != nextNullIndex) { + if (valueCountLocal == runLocal) { + currentValueLocal = getFunction(buffer) + runLocal = ByteBufferHelper.getInt(buffer) + valueCountLocal = 1 + } else { + valueCountLocal += 1 + } + putFunction(columnVector, pos, currentValueLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + columnType.dataType match { + case _: BooleanType => + decompress0(columnVector, capacity, getByte, putBoolean) + case _: ByteType => + decompress0(columnVector, capacity, getByte, putByte) + case _: ShortType => + decompress0(columnVector, capacity, getShort, putShort) + case _: IntegerType => + decompress0(columnVector, capacity, getInt, putInt) + case _: LongType => + decompress0(columnVector, capacity, getLong, putLong) + case _ => throw new IllegalStateException("Not supported type in RunLengthEncoding.") + } + } } } @@ -266,11 +451,32 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { } class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - extends compression.Decoder[T] { - - private val dictionary: Array[Any] = { - val elementNum = ByteBufferHelper.getInt(buffer) - Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) + extends compression.Decoder[T] { + val elementNum = ByteBufferHelper.getInt(buffer) + private val dictionary: Array[Any] = new Array[Any](elementNum) + private var intDictionary: Array[Int] = null + private var longDictionary: Array[Long] = null + + columnType.dataType match { + case _: IntegerType => + intDictionary = new Array[Int](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Int] + intDictionary(i) = v + dictionary(i) = v + } + case _: LongType => + longDictionary = new Array[Long](elementNum) + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Long] + longDictionary(i) = v + dictionary(i) = v + } + case _: StringType => + for (i <- 0 until elementNum) { + val v = columnType.extract(buffer).asInstanceOf[Any] + dictionary(i) = v + } } override def next(row: InternalRow, ordinal: Int): Unit = { @@ -278,6 +484,46 @@ private[columnar] case object DictionaryEncoding extends CompressionScheme { } override def hasNext: Boolean = buffer.hasRemaining + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + columnType.dataType match { + case _: IntegerType => + val dictionaryIds = columnVector.reserveDictionaryIds(capacity) + columnVector.setDictionary(new ColumnDictionary(intDictionary)) + while (pos < capacity) { + if (pos != nextNullIndex) { + dictionaryIds.putInt(pos, buffer.getShort()) + } else { + seenNulls += 1 + if (seenNulls < nullCount) nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + columnVector.putNull(pos) + } + pos += 1 + } + case _: LongType => + val dictionaryIds = columnVector.reserveDictionaryIds(capacity) + columnVector.setDictionary(new ColumnDictionary(longDictionary)) + while (pos < capacity) { + if (pos != nextNullIndex) { + dictionaryIds.putInt(pos, buffer.getShort()) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + case _ => throw new IllegalStateException("Not supported type in DictionaryEncoding.") + } + } } } @@ -368,6 +614,38 @@ private[columnar] case object BooleanBitSet extends CompressionScheme { } override def hasNext: Boolean = visited < count + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + val countLocal = count + var currentWordLocal: Long = 0 + var visitedLocal: Int = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (visitedLocal < countLocal) { + if (pos != nextNullIndex) { + val bit = visitedLocal % BITS_PER_LONG + + visitedLocal += 1 + if (bit == 0) { + currentWordLocal = ByteBufferHelper.getLong(buffer) + } + + columnVector.putBoolean(pos, ((currentWordLocal >> bit) & 1) != 0) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } @@ -448,6 +726,32 @@ private[columnar] case object IntDelta extends CompressionScheme { prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer) row.setInt(ordinal, prev) } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + var prevLocal: Int = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind() + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (pos < capacity) { + if (pos != nextNullIndex) { + val delta = buffer.get + prevLocal = if (delta > Byte.MinValue) { prevLocal + delta } else + { ByteBufferHelper.getInt(buffer) } + columnVector.putInt(pos, prevLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } @@ -528,5 +832,31 @@ private[columnar] case object LongDelta extends CompressionScheme { prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer) row.setLong(ordinal, prev) } + + override def decompress(columnVector: WritableColumnVector, capacity: Int): Unit = { + var prevLocal: Long = 0 + val nullsBuffer = buffer.duplicate().order(ByteOrder.nativeOrder()) + nullsBuffer.rewind + val nullCount = ByteBufferHelper.getInt(nullsBuffer) + var nextNullIndex = if (nullCount > 0) ByteBufferHelper.getInt(nullsBuffer) else -1 + var pos = 0 + var seenNulls = 0 + + while (pos < capacity) { + if (pos != nextNullIndex) { + val delta = buffer.get() + prevLocal = if (delta > Byte.MinValue) { prevLocal + delta } else + { ByteBufferHelper.getLong(buffer) } + columnVector.putLong(pos, prevLocal) + } else { + seenNulls += 1 + if (seenNulls < nullCount) { + nextNullIndex = ByteBufferHelper.getInt(nullsBuffer) + } + columnVector.putNull(pos) + } + pos += 1 + } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala index 4e1c5e4846f36..2cf06982e25f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DataWritingCommand.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.hadoop.conf.Configuration import org.apache.spark.SparkContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.datasources.BasicWriteJobStatsTracker import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.util.SerializableConfiguration @@ -30,6 +31,18 @@ import org.apache.spark.util.SerializableConfiguration */ trait DataWritingCommand extends RunnableCommand { + /** + * The input query plan that produces the data to be written. + */ + def query: LogicalPlan + + // We make the input `query` an inner child instead of a child in order to hide it from the + // optimizer. This is because optimizer may not preserve the output schema names' case, and we + // have to keep the original analyzed plan here so that we can pass the corrected schema to the + // writer. The schema of analyzed plan is what user expects(or specifies), so we should respect + // it when writing. + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil + override lazy val metrics: Map[String, SQLMetric] = { val sparkContext = SparkContext.getActive.get Map( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala index 633de4c37af94..9e3519073303c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/InsertIntoDataSourceDirCommand.scala @@ -21,7 +21,6 @@ import org.apache.spark.SparkException import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.datasources._ /** @@ -45,10 +44,9 @@ case class InsertIntoDataSourceDirCommand( query: LogicalPlan, overwrite: Boolean) extends RunnableCommand { - override def children: Seq[LogicalPlan] = Seq(query) + override protected def innerChildren: Seq[LogicalPlan] = query :: Nil - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) + override def run(sparkSession: SparkSession): Seq[Row] = { assert(storage.locationUri.nonEmpty, "Directory path is required") assert(provider.nonEmpty, "Data source is required") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala index 792290bef0163..140f920eaafae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/cache.scala @@ -30,7 +30,7 @@ case class CacheTableCommand( require(plan.isEmpty || tableIdent.database.isEmpty, "Database name is not allowed in CACHE TABLE AS SELECT") - override def innerChildren: Seq[QueryPlan[_]] = plan.toSeq + override protected def innerChildren: Seq[QueryPlan[_]] = plan.toSeq override def run(sparkSession: SparkSession): Seq[Row] = { plan.foreach { logicalPlan => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 7cd4baef89e75..e28b5eb2e2a2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -24,9 +24,9 @@ import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.{logical, QueryPlan} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.catalyst.plans.QueryPlan +import org.apache.spark.sql.catalyst.plans.logical.{Command, LogicalPlan} +import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata} @@ -37,19 +37,13 @@ import org.apache.spark.sql.types._ * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. */ -trait RunnableCommand extends logical.Command { +trait RunnableCommand extends Command { // The map used to record the metrics of running the command. This will be passed to // `ExecutedCommand` during query planning. lazy val metrics: Map[String, SQLMetric] = Map.empty - def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - throw new NotImplementedError - } - - def run(sparkSession: SparkSession): Seq[Row] = { - throw new NotImplementedError - } + def run(sparkSession: SparkSession): Seq[Row] } /** @@ -57,9 +51,8 @@ trait RunnableCommand extends logical.Command { * saves the result to prevent multiple executions. * * @param cmd the `RunnableCommand` this operator will run. - * @param children the children physical plans ran by the `RunnableCommand`. */ -case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) extends SparkPlan { +case class ExecutedCommandExec(cmd: RunnableCommand) extends LeafExecNode { override lazy val metrics: Map[String, SQLMetric] = cmd.metrics @@ -74,19 +67,14 @@ case class ExecutedCommandExec(cmd: RunnableCommand, children: Seq[SparkPlan]) e */ protected[sql] lazy val sideEffectResult: Seq[InternalRow] = { val converter = CatalystTypeConverters.createToCatalystConverter(schema) - val rows = if (children.isEmpty) { - cmd.run(sqlContext.sparkSession) - } else { - cmd.run(sqlContext.sparkSession, children) - } - rows.map(converter(_).asInstanceOf[InternalRow]) + cmd.run(sqlContext.sparkSession).map(converter(_).asInstanceOf[InternalRow]) } - override def innerChildren: Seq[QueryPlan[_]] = cmd.innerChildren + override protected def innerChildren: Seq[QueryPlan[_]] = cmd :: Nil override def output: Seq[Attribute] = cmd.output - override def nodeName: String = cmd.nodeName + override def nodeName: String = "Execute " + cmd.nodeName override def executeCollect(): Array[InternalRow] = sideEffectResult.toArray diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 04b2534ca5eb1..9e3907996995c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -120,7 +120,7 @@ case class CreateDataSourceTableAsSelectCommand( query: LogicalPlan) extends RunnableCommand { - override def innerChildren: Seq[LogicalPlan] = Seq(query) + override protected def innerChildren: Seq[LogicalPlan] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { assert(table.tableType != CatalogTableType.VIEW) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 8d95ca6921cf8..38f91639c0422 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -235,11 +235,10 @@ case class AlterTableAddColumnsCommand( DataSource.lookupDataSource(catalogTable.provider.get).newInstance() match { // For datasource table, this command can only support the following File format. // TextFileFormat only default to one column "value" - // OrcFileFormat can not handle difference between user-specified schema and - // inferred schema yet. TODO, once this issue is resolved , we can add Orc back. // Hive type is already considered as hive serde table, so the logic will not // come in here. case _: JsonFileFormat | _: CSVFileFormat | _: ParquetFileFormat => + case s if s.getClass.getCanonicalName.endsWith("OrcFileFormat") => case s => throw new AnalysisException( s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index ffdfd527fa701..5172f32ec7b9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -98,7 +98,7 @@ case class CreateViewCommand( import ViewHelper._ - override def innerChildren: Seq[QueryPlan[_]] = Seq(child) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(child) if (viewType == PersistedView) { require(originalText.isDefined, "'originalText' must be provided to create permanent view") @@ -267,7 +267,7 @@ case class AlterViewAsCommand( import ViewHelper._ - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(session: SparkSession): Seq[Row] = { // If the plan cannot be analyzed, throw an exception and don't proceed. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala index b8f7d130d569f..11af0aaa7b206 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/BasicWriteStatsTracker.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution.datasources +import java.io.FileNotFoundException + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.SparkContext +import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -44,20 +47,32 @@ case class BasicWriteTaskStats( * @param hadoopConf */ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) - extends WriteTaskStatsTracker { + extends WriteTaskStatsTracker with Logging { private[this] var numPartitions: Int = 0 private[this] var numFiles: Int = 0 + private[this] var submittedFiles: Int = 0 private[this] var numBytes: Long = 0L private[this] var numRows: Long = 0L - private[this] var curFile: String = null - + private[this] var curFile: Option[String] = None - private def getFileSize(filePath: String): Long = { + /** + * Get the size of the file expected to have been written by a worker. + * @param filePath path to the file + * @return the file size or None if the file was not found. + */ + private def getFileSize(filePath: String): Option[Long] = { val path = new Path(filePath) val fs = path.getFileSystem(hadoopConf) - fs.getFileStatus(path).getLen() + try { + Some(fs.getFileStatus(path).getLen()) + } catch { + case e: FileNotFoundException => + // may arise against eventually consistent object stores + logDebug(s"File $path is not yet visible", e) + None + } } @@ -70,12 +85,19 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def newFile(filePath: String): Unit = { - if (numFiles > 0) { - // we assume here that we've finished writing to disk the previous file by now - numBytes += getFileSize(curFile) + statCurrentFile() + curFile = Some(filePath) + submittedFiles += 1 + } + + private def statCurrentFile(): Unit = { + curFile.foreach { path => + getFileSize(path).foreach { len => + numBytes += len + numFiles += 1 + } + curFile = None } - curFile = filePath - numFiles += 1 } override def newRow(row: InternalRow): Unit = { @@ -83,8 +105,11 @@ class BasicWriteTaskStatsTracker(hadoopConf: Configuration) } override def getFinalStats(): WriteTaskStats = { - if (numFiles > 0) { - numBytes += getFileSize(curFile) + statCurrentFile() + if (submittedFiles != numFiles) { + logInfo(s"Expected $submittedFiles files, but only saw $numFiles. " + + "This could be due to the output format not writing empty files, " + + "or files being not immediately visible in the filesystem.") } BasicWriteTaskStats(numPartitions, numFiles, numBytes, numRows) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b9502a95a7c08..b43d282bd434c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -453,6 +453,17 @@ case class DataSource( val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis PartitioningUtils.validatePartitionColumn(data.schema, partitionColumns, caseSensitive) + + // SPARK-17230: Resolve the partition columns so InsertIntoHadoopFsRelationCommand does + // not need to have the query as child, to avoid to analyze an optimized query, + // because InsertIntoHadoopFsRelationCommand will be optimized first. + val partitionAttributes = partitionColumns.map { name => + data.output.find(a => equality(a.name, name)).getOrElse { + throw new AnalysisException( + s"Unable to resolve $name given [${data.output.map(_.name).mkString(", ")}]") + } + } + val fileIndex = catalogTable.map(_.identifier).map { tableIdent => sparkSession.table(tableIdent).queryExecution.analyzed.collect { case LogicalRelation(t: HadoopFsRelation, _, _, _) => t.location @@ -465,7 +476,7 @@ case class DataSource( outputPath = outputPath, staticPartitions = Map.empty, ifPartitionNotExists = false, - partitionColumns = partitionColumns.map(UnresolvedAttribute.quoted), + partitionColumns = partitionAttributes, bucketSpec = bucketSpec, fileFormat = format, options = options, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 514969715091a..1fac01a2c26c6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, _} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} -import org.apache.spark.sql.execution.{SortExec, SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} import org.apache.spark.sql.types.StringType import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -101,7 +101,7 @@ object FileFormatWriter extends Logging { */ def write( sparkSession: SparkSession, - plan: SparkPlan, + queryExecution: QueryExecution, fileFormat: FileFormat, committer: FileCommitProtocol, outputSpec: OutputSpec, @@ -117,7 +117,9 @@ object FileFormatWriter extends Logging { job.setOutputValueClass(classOf[InternalRow]) FileOutputFormat.setOutputPath(job, new Path(outputSpec.outputPath)) - val allColumns = plan.output + // Pick the attributes from analyzed plan, as optimizer may not preserve the output schema + // names' case. + val allColumns = queryExecution.analyzed.output val partitionSet = AttributeSet(partitionColumns) val dataColumns = allColumns.filterNot(partitionSet.contains) @@ -158,7 +160,7 @@ object FileFormatWriter extends Logging { // We should first sort by partition columns, then bucket id, and finally sorting columns. val requiredOrdering = partitionColumns ++ bucketIdExpression ++ sortColumns // the sort order doesn't matter - val actualOrdering = plan.outputOrdering.map(_.child) + val actualOrdering = queryExecution.executedPlan.outputOrdering.map(_.child) val orderingMatched = if (requiredOrdering.length > actualOrdering.length) { false } else { @@ -176,12 +178,17 @@ object FileFormatWriter extends Logging { try { val rdd = if (orderingMatched) { - plan.execute() + queryExecution.toRdd } else { + // SPARK-21165: the `requiredOrdering` is based on the attributes from analyzed plan, and + // the physical plan may have different attribute ids due to optimizer removing some + // aliases. Here we bind the expression ahead to avoid potential attribute ids mismatch. + val orderingExpr = requiredOrdering + .map(SortOrder(_, Ascending)).map(BindReferences.bindReference(_, allColumns)) SortExec( - requiredOrdering.map(SortOrder(_, Ascending)), + orderingExpr, global = false, - child = plan).execute() + child = queryExecution.executedPlan).execute() } val ret = new Array[WriteTaskResult](rdd.partitions.length) sparkSession.sparkContext.runJob( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala index 203d449717512..318ada0ceefc5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InMemoryFileIndex.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs._ import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.spark.SparkContext import org.apache.spark.internal.Logging import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession @@ -187,42 +188,56 @@ object InMemoryFileIndex extends Logging { // in case of large #defaultParallelism. val numParallelism = Math.min(paths.size, parallelPartitionDiscoveryParallelism) - val statusMap = sparkContext - .parallelize(serializedPaths, numParallelism) - .mapPartitions { pathStrings => - val hadoopConf = serializableConfiguration.value - pathStrings.map(new Path(_)).toSeq.map { path => - (path, listLeafFiles(path, hadoopConf, filter, None)) - }.iterator - }.map { case (path, statuses) => - val serializableStatuses = statuses.map { status => - // Turn FileStatus into SerializableFileStatus so we can send it back to the driver - val blockLocations = status match { - case f: LocatedFileStatus => - f.getBlockLocations.map { loc => - SerializableBlockLocation( - loc.getNames, - loc.getHosts, - loc.getOffset, - loc.getLength) - } - - case _ => - Array.empty[SerializableBlockLocation] - } - - SerializableFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime, - blockLocations) + val previousJobDescription = sparkContext.getLocalProperty(SparkContext.SPARK_JOB_DESCRIPTION) + val statusMap = try { + val description = paths.size match { + case 0 => + s"Listing leaf files and directories 0 paths" + case 1 => + s"Listing leaf files and directories for 1 path:
    ${paths(0)}" + case s => + s"Listing leaf files and directories for $s paths:
    ${paths(0)}, ..." } - (path.toString, serializableStatuses) - }.collect() + sparkContext.setJobDescription(description) + sparkContext + .parallelize(serializedPaths, numParallelism) + .mapPartitions { pathStrings => + val hadoopConf = serializableConfiguration.value + pathStrings.map(new Path(_)).toSeq.map { path => + (path, listLeafFiles(path, hadoopConf, filter, None)) + }.iterator + }.map { case (path, statuses) => + val serializableStatuses = statuses.map { status => + // Turn FileStatus into SerializableFileStatus so we can send it back to the driver + val blockLocations = status match { + case f: LocatedFileStatus => + f.getBlockLocations.map { loc => + SerializableBlockLocation( + loc.getNames, + loc.getHosts, + loc.getOffset, + loc.getLength) + } + + case _ => + Array.empty[SerializableBlockLocation] + } + + SerializableFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime, + blockLocations) + } + (path.toString, serializableStatuses) + }.collect() + } finally { + sparkContext.setJobDescription(previousJobDescription) + } // turn SerializableFileStatus back to Status statusMap.map { case (path, serializableStatuses) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index 08b2f4f31170f..a813829d50cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -33,7 +33,7 @@ case class InsertIntoDataSourceCommand( overwrite: Boolean) extends RunnableCommand { - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { val relation = logicalRelation.relation.asInstanceOf[InsertableRelation] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index 64e5a57adc37c..675bee85bf61e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogT import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ import org.apache.spark.sql.util.SchemaUtils @@ -57,11 +56,7 @@ case class InsertIntoHadoopFsRelationCommand( extends DataWritingCommand { import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName - override def children: Seq[LogicalPlan] = query :: Nil - - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) - + override def run(sparkSession: SparkSession): Seq[Row] = { // Most formats don't do well with duplicate columns, so lets not allow that SchemaUtils.checkSchemaColumnNameDuplication( query.schema, @@ -144,7 +139,7 @@ case class InsertIntoHadoopFsRelationCommand( val updatedPartitionPaths = FileFormatWriter.write( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 17a61074d3b5c..3e98cb28453a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -34,17 +34,6 @@ case class LogicalRelation( override val isStreaming: Boolean) extends LeafNode with MultiInstanceRelation { - // Logical Relations are distinct if they have different output for the sake of transformations. - override def equals(other: Any): Boolean = other match { - case l @ LogicalRelation(otherRelation, _, _, isStreaming) => - relation == otherRelation && output == l.output && isStreaming == l.isStreaming - case _ => false - } - - override def hashCode: Int = { - com.google.common.base.Objects.hashCode(relation, output) - } - // Only care about relation when canonicalizing. override lazy val canonicalized: LogicalPlan = copy( output = output.map(QueryPlan.normalizeExprId(_, output)), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala index 5eb6a8471be0d..96c84eab1c894 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SaveIntoDataSourceCommand.scala @@ -38,7 +38,7 @@ case class SaveIntoDataSourceCommand( options: Map[String, String], mode: SaveMode) extends RunnableCommand { - override def innerChildren: Seq[QueryPlan[_]] = Seq(query) + override protected def innerChildren: Seq[QueryPlan[_]] = Seq(query) override def run(sparkSession: SparkSession): Seq[Row] = { dataSource.createRelation( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index 71133666b3249..9debc4ff82748 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -230,7 +230,6 @@ object JdbcUtils extends Logging { case java.sql.Types.TIMESTAMP => TimestampType case java.sql.Types.TIMESTAMP_WITH_TIMEZONE => TimestampType - case -101 => TimestampType // Value for Timestamp with Time Zone in Oracle case java.sql.Types.TINYINT => IntegerType case java.sql.Types.VARBINARY => BinaryType case java.sql.Types.VARCHAR => StringType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index e1e740500205a..c1535babbae1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -86,7 +86,7 @@ class ParquetFileFormat conf.getClass( SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key, classOf[ParquetOutputCommitter], - classOf[ParquetOutputCommitter]) + classOf[OutputCommitter]) if (conf.get(SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key) == null) { logInfo("Using default output committer for Parquet: " + @@ -98,7 +98,7 @@ class ParquetFileFormat conf.setClass( SQLConf.OUTPUT_COMMITTER_CLASS.key, committerClass, - classOf[ParquetOutputCommitter]) + classOf[OutputCommitter]) // We're not really using `ParquetOutputFormat[Row]` for writing data here, because we override // it in `ParquetOutputWriter` to support appending and dynamic partitioning. The reason why @@ -138,6 +138,14 @@ class ParquetFileFormat conf.setBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) } + if (conf.getBoolean(ParquetOutputFormat.ENABLE_JOB_SUMMARY, false) + && !classOf[ParquetOutputCommitter].isAssignableFrom(committerClass)) { + // output summary is requested, but the class is not a Parquet Committer + logWarning(s"Committer $committerClass is not a ParquetOutputCommitter and cannot" + + s" create job summaries. " + + s"Set Parquet option ${ParquetOutputFormat.ENABLE_JOB_SUMMARY} to false.") + } + new OutputWriterFactory { // This OutputWriterFactory instance is deserialized when writing Parquet files on the // executor side without constructing or deserializing ParquetFileFormat. Therefore, we hold diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala new file mode 100644 index 0000000000000..6093df26630cd --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceReaderHolder.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util.Objects + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.sources.v2.reader._ + +/** + * A base class for data source reader holder with customized equals/hashCode methods. + */ +trait DataSourceReaderHolder { + + /** + * The full output of the data source reader, without column pruning. + */ + def fullOutput: Seq[AttributeReference] + + /** + * The held data source reader. + */ + def reader: DataSourceV2Reader + + /** + * The metadata of this data source reader that can be used for equality test. + */ + private def metadata: Seq[Any] = { + val filters: Any = reader match { + case s: SupportsPushDownCatalystFilters => s.pushedCatalystFilters().toSet + case s: SupportsPushDownFilters => s.pushedFilters().toSet + case _ => Nil + } + Seq(fullOutput, reader.getClass, reader.readSchema(), filters) + } + + def canEqual(other: Any): Boolean + + override def equals(other: Any): Boolean = other match { + case other: DataSourceReaderHolder => + canEqual(other) && metadata.length == other.metadata.length && + metadata.zip(other.metadata).forall { case (l, r) => l == r } + case _ => false + } + + override def hashCode(): Int = { + metadata.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } + + lazy val output: Seq[Attribute] = reader.readSchema().map(_.name).map { name => + fullOutput.find(_.name == name).get + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 3c9b598fd07c9..7eb99a645001a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} -import org.apache.spark.sql.sources.v2.reader.{DataSourceV2Reader, SupportsReportStatistics} +import org.apache.spark.sql.sources.v2.reader._ case class DataSourceV2Relation( - output: Seq[AttributeReference], - reader: DataSourceV2Reader) extends LeafNode { + fullOutput: Seq[AttributeReference], + reader: DataSourceV2Reader) extends LeafNode with DataSourceReaderHolder { + + override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2Relation] override def computeStats(): Statistics = reader match { case r: SupportsReportStatistics => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala index 7999c0ceb5749..addc12a3f0901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala @@ -29,20 +29,14 @@ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.types.StructType +/** + * Physical plan node for scanning data from a data source. + */ case class DataSourceV2ScanExec( - fullOutput: Array[AttributeReference], - @transient reader: DataSourceV2Reader, - // TODO: these 3 parameters are only used to determine the equality of the scan node, however, - // the reader also have this information, and ideally we can just rely on the equality of the - // reader. The only concern is, the reader implementation is outside of Spark and we have no - // control. - readSchema: StructType, - @transient filters: ExpressionSet, - hashPartitionKeys: Seq[String]) extends LeafExecNode { - - def output: Seq[Attribute] = readSchema.map(_.name).map { name => - fullOutput.find(_.name == name).get - } + fullOutput: Seq[AttributeReference], + @transient reader: DataSourceV2Reader) extends LeafExecNode with DataSourceReaderHolder { + + override def canEqual(other: Any): Boolean = other.isInstanceOf[DataSourceV2ScanExec] override def references: AttributeSet = AttributeSet.empty @@ -74,7 +68,7 @@ class RowToUnsafeRowReadTask(rowReadTask: ReadTask[Row], schema: StructType) override def preferredLocations: Array[String] = rowReadTask.preferredLocations override def createReader: DataReader[UnsafeRow] = { - new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema)) + new RowToUnsafeDataReader(rowReadTask.createReader, RowEncoder.apply(schema).resolveAndBind()) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index b80f695b2a87f..df5b524485f54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -18,75 +18,16 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} -import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.execution.SparkPlan object DataSourceV2Strategy extends Strategy { - // TODO: write path override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case PhysicalOperation(projects, filters, DataSourceV2Relation(output, reader)) => - val stayUpFilters: Seq[Expression] = reader match { - case r: SupportsPushDownCatalystFilters => - r.pushCatalystFilters(filters.toArray) + case DataSourceV2Relation(output, reader) => + DataSourceV2ScanExec(output, reader) :: Nil - case r: SupportsPushDownFilters => - // A map from original Catalyst expressions to corresponding translated data source - // filters. If a predicate is not in this map, it means it cannot be pushed down. - val translatedMap: Map[Expression, Filter] = filters.flatMap { p => - DataSourceStrategy.translateFilter(p).map(f => p -> f) - }.toMap - - // Catalyst predicate expressions that cannot be converted to data source filters. - val nonConvertiblePredicates = filters.filterNot(translatedMap.contains) - - // Data source filters that cannot be pushed down. An unhandled filter means - // the data source cannot guarantee the rows returned can pass the filter. - // As a result we must return it so Spark can plan an extra filter operator. - val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet - val unhandledPredicates = translatedMap.filter { case (_, f) => - unhandledFilters.contains(f) - }.keys - - nonConvertiblePredicates ++ unhandledPredicates - - case _ => filters - } - - val attrMap = AttributeMap(output.zip(output)) - val projectSet = AttributeSet(projects.flatMap(_.references)) - val filterSet = AttributeSet(stayUpFilters.flatMap(_.references)) - - // Match original case of attributes. - // TODO: nested fields pruning - val requiredColumns = (projectSet ++ filterSet).toSeq.map(attrMap) - reader match { - case r: SupportsPushDownRequiredColumns => - r.pruneColumns(requiredColumns.toStructType) - case _ => - } - - val scan = DataSourceV2ScanExec( - output.toArray, - reader, - reader.readSchema(), - ExpressionSet(filters), - Nil) - - val filterCondition = stayUpFilters.reduceLeftOption(And) - val withFilter = filterCondition.map(FilterExec(_, scan)).getOrElse(scan) - - val withProject = if (projects == withFilter.output) { - withFilter - } else { - ProjectExec(projects, withFilter) - } - - withProject :: Nil + case WriteToDataSourceV2(writer, query) => + WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case _ => Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala new file mode 100644 index 0000000000000..0c1708131ae46 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/PushDownOperatorsToDataSource.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, Expression, NamedExpression, PredicateHelper} +import org.apache.spark.sql.catalyst.optimizer.RemoveRedundantProject +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.datasources.DataSourceStrategy +import org.apache.spark.sql.sources +import org.apache.spark.sql.sources.v2.reader._ + +/** + * Pushes down various operators to the underlying data source for better performance. Operators are + * being pushed down with a specific order. As an example, given a LIMIT has a FILTER child, you + * can't push down LIMIT if FILTER is not completely pushed down. When both are pushed down, the + * data source should execute FILTER before LIMIT. And required columns are calculated at the end, + * because when more operators are pushed down, we may need less columns at Spark side. + */ +object PushDownOperatorsToDataSource extends Rule[LogicalPlan] with PredicateHelper { + override def apply(plan: LogicalPlan): LogicalPlan = { + // Note that, we need to collect the target operator along with PROJECT node, as PROJECT may + // appear in many places for column pruning. + // TODO: Ideally column pruning should be implemented via a plan property that is propagated + // top-down, then we can simplify the logic here and only collect target operators. + val filterPushed = plan transformUp { + case FilterAndProject(fields, condition, r @ DataSourceV2Relation(_, reader)) => + // Non-deterministic expressions are stateful and we must keep the input sequence unchanged + // to avoid changing the result. This means, we can't evaluate the filter conditions that + // are after the first non-deterministic condition ahead. Here we only try to push down + // deterministic conditions that are before the first non-deterministic condition. + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(condition).span(_.deterministic) + + val stayUpFilters: Seq[Expression] = reader match { + case r: SupportsPushDownCatalystFilters => + r.pushCatalystFilters(candidates.toArray) + + case r: SupportsPushDownFilters => + // A map from original Catalyst expressions to corresponding translated data source + // filters. If a predicate is not in this map, it means it cannot be pushed down. + val translatedMap: Map[Expression, sources.Filter] = candidates.flatMap { p => + DataSourceStrategy.translateFilter(p).map(f => p -> f) + }.toMap + + // Catalyst predicate expressions that cannot be converted to data source filters. + val nonConvertiblePredicates = candidates.filterNot(translatedMap.contains) + + // Data source filters that cannot be pushed down. An unhandled filter means + // the data source cannot guarantee the rows returned can pass the filter. + // As a result we must return it so Spark can plan an extra filter operator. + val unhandledFilters = r.pushFilters(translatedMap.values.toArray).toSet + val unhandledPredicates = translatedMap.filter { case (_, f) => + unhandledFilters.contains(f) + }.keys + + nonConvertiblePredicates ++ unhandledPredicates + + case _ => candidates + } + + val filterCondition = (stayUpFilters ++ containingNonDeterministic).reduceLeftOption(And) + val withFilter = filterCondition.map(Filter(_, r)).getOrElse(r) + if (withFilter.output == fields) { + withFilter + } else { + Project(fields, withFilter) + } + } + + // TODO: add more push down rules. + + // TODO: nested fields pruning + def pushDownRequiredColumns(plan: LogicalPlan, requiredByParent: Seq[Attribute]): Unit = { + plan match { + case Project(projectList, child) => + val required = projectList.filter(requiredByParent.contains).flatMap(_.references) + pushDownRequiredColumns(child, required) + + case Filter(condition, child) => + val required = requiredByParent ++ condition.references + pushDownRequiredColumns(child, required) + + case DataSourceV2Relation(fullOutput, reader) => reader match { + case r: SupportsPushDownRequiredColumns => + // Match original case of attributes. + val attrMap = AttributeMap(fullOutput.zip(fullOutput)) + val requiredColumns = requiredByParent.map(attrMap) + r.pruneColumns(requiredColumns.toStructType) + case _ => + } + + // TODO: there may be more operators can be used to calculate required columns, we can add + // more and more in the future. + case _ => plan.children.foreach(child => pushDownRequiredColumns(child, child.output)) + } + } + + pushDownRequiredColumns(filterPushed, filterPushed.output) + // After column pruning, we may have redundant PROJECT nodes in the query plan, remove them. + RemoveRedundantProject(filterPushed) + } + + /** + * Finds a Filter node(with an optional Project child) above data source relation. + */ + object FilterAndProject { + // returns the project list, the filter condition and the data source relation. + def unapply(plan: LogicalPlan) + : Option[(Seq[NamedExpression], Expression, DataSourceV2Relation)] = plan match { + + case Filter(condition, r: DataSourceV2Relation) => Some((r.output, condition, r)) + + case Filter(condition, Project(fields, r: DataSourceV2Relation)) + if fields.forall(_.deterministic) => + val attributeMap = AttributeMap(fields.map(e => e.toAttribute -> e)) + val substituted = condition.transform { + case a: Attribute => attributeMap.getOrElse(a, a) + } + Some((fields, substituted, r)) + + case _ => None + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala new file mode 100644 index 0000000000000..92c1e1f4a3383 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.internal.Logging +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +/** + * The logical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2(writer: DataSourceV2Writer, query: LogicalPlan) extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil +} + +/** + * The physical plan for writing data into data source v2. + */ +case class WriteToDataSourceV2Exec(writer: DataSourceV2Writer, query: SparkPlan) extends SparkPlan { + override def children: Seq[SparkPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + override protected def doExecute(): RDD[InternalRow] = { + val writeTask = writer match { + case w: SupportsWriteInternalRow => w.createInternalRowWriterFactory() + case _ => new RowToInternalRowDataWriterFactory(writer.createWriterFactory(), query.schema) + } + + val rdd = query.execute() + val messages = new Array[WriterCommitMessage](rdd.partitions.length) + + logInfo(s"Start processing data source writer: $writer. " + + s"The input RDD has ${messages.length} partitions.") + + try { + sparkContext.runJob( + rdd, + (context: TaskContext, iter: Iterator[InternalRow]) => + DataWritingSparkTask.run(writeTask, context, iter), + rdd.partitions.indices, + (index, message: WriterCommitMessage) => messages(index) = message + ) + + logInfo(s"Data source writer $writer is committing.") + writer.commit(messages) + logInfo(s"Data source writer $writer committed.") + } catch { + case cause: Throwable => + logError(s"Data source writer $writer is aborting.") + try { + writer.abort(messages) + } catch { + case t: Throwable => + logError(s"Data source writer $writer failed to abort.") + cause.addSuppressed(t) + throw new SparkException("Writing job failed.", cause) + } + logError(s"Data source writer $writer aborted.") + throw new SparkException("Writing job aborted.", cause) + } + + sparkContext.emptyRDD + } +} + +object DataWritingSparkTask extends Logging { + def run( + writeTask: DataWriterFactory[InternalRow], + context: TaskContext, + iter: Iterator[InternalRow]): WriterCommitMessage = { + val dataWriter = writeTask.createWriter(context.partitionId(), context.attemptNumber()) + + // write the data and commit this writer. + Utils.tryWithSafeFinallyAndFailureCallbacks(block = { + iter.foreach(dataWriter.write) + logInfo(s"Writer for partition ${context.partitionId()} is committing.") + val msg = dataWriter.commit() + logInfo(s"Writer for partition ${context.partitionId()} committed.") + msg + })(catchBlock = { + // If there is an error, abort this writer + logError(s"Writer for partition ${context.partitionId()} is aborting.") + dataWriter.abort() + logError(s"Writer for partition ${context.partitionId()} aborted.") + }) + } +} + +class RowToInternalRowDataWriterFactory( + rowWriterFactory: DataWriterFactory[Row], + schema: StructType) extends DataWriterFactory[InternalRow] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + new RowToInternalRowDataWriter( + rowWriterFactory.createWriter(partitionId, attemptNumber), + RowEncoder.apply(schema).resolveAndBind()) + } +} + +class RowToInternalRowDataWriter(rowWriter: DataWriter[Row], encoder: ExpressionEncoder[Row]) + extends DataWriter[InternalRow] { + + override def write(record: InternalRow): Unit = rowWriter.write(encoder.fromRow(record)) + + override def commit(): WriterCommitMessage = rowWriter.commit() + + override def abort(): Unit = rowWriter.abort() +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index 9c859e41f8762..880e18c6808b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning} import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} +import org.apache.spark.sql.execution.joins.HashedRelation import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.ThreadUtils @@ -72,26 +72,39 @@ case class BroadcastExchangeExec( SQLExecution.withExecutionId(sparkContext, executionId) { try { val beforeCollect = System.nanoTime() - // Note that we use .executeCollect() because we don't want to convert data to Scala types - val input: Array[InternalRow] = child.executeCollect() - if (input.length >= 512000000) { + // Use executeCollect/executeCollectIterator to avoid conversion to Scala types + val (numRows, input) = child.executeCollectIterator() + if (numRows >= 512000000) { throw new SparkException( - s"Cannot broadcast the table with more than 512 millions rows: ${input.length} rows") + s"Cannot broadcast the table with more than 512 millions rows: $numRows rows") } + val beforeBuild = System.nanoTime() longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000 - val dataSize = input.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + + // Construct the relation. + val relation = mode.transform(input, Some(numRows)) + + val dataSize = relation match { + case map: HashedRelation => + map.estimatedSize + case arr: Array[InternalRow] => + arr.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum + case _ => + throw new SparkException("[BUG] BroadcastMode.transform returned unexpected type: " + + relation.getClass.getName) + } + longMetric("dataSize") += dataSize if (dataSize >= (8L << 30)) { throw new SparkException( s"Cannot broadcast the table that is larger than 8GB: ${dataSize >> 30} GB") } - // Construct and broadcast the relation. - val relation = mode.transform(input) val beforeBroadcast = System.nanoTime() longMetric("buildTime") += (beforeBroadcast - beforeBuild) / 1000000 + // Broadcast the relation val broadcasted = sparkContext.broadcast(relation) longMetric("broadcastTime") += (System.nanoTime() - beforeBroadcast) / 1000000 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 1da72f2e92329..4e2ca37bc1a59 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.internal.SQLConf * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]] * of input data meets the * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for - * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the - * input partition ordering requirements are met. + * each operator by inserting [[ShuffleExchangeExec]] Operators where required. Also ensure that + * the input partition ordering requirements are met. */ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { private def defaultNumPreShufflePartitions: Int = conf.numShufflePartitions @@ -44,30 +44,33 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { /** * Given a required distribution, returns a partitioning that satisfies that distribution. + * @param requiredDistribution The distribution that is required by the operator + * @param numPartitions Used when the distribution doesn't require a specific number of partitions */ private def createPartitioning( requiredDistribution: Distribution, numPartitions: Int): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition - case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) + case ClusteredDistribution(clustering, desiredPartitions) => + HashPartitioning(clustering, desiredPartitions.getOrElse(numPartitions)) case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions) case dist => sys.error(s"Do not know how to satisfy distribution $dist") } } /** - * Adds [[ExchangeCoordinator]] to [[ShuffleExchange]]s if adaptive query execution is enabled - * and partitioning schemes of these [[ShuffleExchange]]s support [[ExchangeCoordinator]]. + * Adds [[ExchangeCoordinator]] to [[ShuffleExchangeExec]]s if adaptive query execution is enabled + * and partitioning schemes of these [[ShuffleExchangeExec]]s support [[ExchangeCoordinator]]. */ private def withExchangeCoordinator( children: Seq[SparkPlan], requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = { val supportsCoordinator = - if (children.exists(_.isInstanceOf[ShuffleExchange])) { + if (children.exists(_.isInstanceOf[ShuffleExchangeExec])) { // Right now, ExchangeCoordinator only support HashPartitionings. children.forall { - case e @ ShuffleExchange(hash: HashPartitioning, _, _) => true + case e @ ShuffleExchangeExec(hash: HashPartitioning, _, _) => true case child => child.outputPartitioning match { case hash: HashPartitioning => true @@ -94,7 +97,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { targetPostShuffleInputSize, minNumPostShufflePartitions) children.zip(requiredChildDistributions).map { - case (e: ShuffleExchange, _) => + case (e: ShuffleExchangeExec, _) => // This child is an Exchange, we need to add the coordinator. e.copy(coordinator = Some(coordinator)) case (child, distribution) => @@ -138,7 +141,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { val targetPartitioning = createPartitioning(distribution, defaultNumPreShufflePartitions) assert(targetPartitioning.isInstanceOf[HashPartitioning]) - ShuffleExchange(targetPartitioning, child, Some(coordinator)) + ShuffleExchangeExec(targetPartitioning, child, Some(coordinator)) } } else { // If we do not need ExchangeCoordinator, the original children are returned. @@ -162,7 +165,7 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { case (child, BroadcastDistribution(mode)) => BroadcastExchangeExec(mode, child) case (child, distribution) => - ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child) + ShuffleExchangeExec(createPartitioning(distribution, defaultNumPreShufflePartitions), child) } // If the operator has multiple children and specifies child output distributions (e.g. join), @@ -215,8 +218,8 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { child match { // If child is an exchange, we replace it with // a new one having targetPartitioning. - case ShuffleExchange(_, c, _) => ShuffleExchange(targetPartitioning, c) - case _ => ShuffleExchange(targetPartitioning, child) + case ShuffleExchangeExec(_, c, _) => ShuffleExchangeExec(targetPartitioning, c) + case _ => ShuffleExchangeExec(targetPartitioning, child) } } } @@ -246,9 +249,9 @@ case class EnsureRequirements(conf: SQLConf) extends Rule[SparkPlan] { } def apply(plan: SparkPlan): SparkPlan = plan.transformUp { - case operator @ ShuffleExchange(partitioning, child, _) => + case operator @ ShuffleExchangeExec(partitioning, child, _) => child.children match { - case ShuffleExchange(childPartitioning, baseChild, _)::Nil => + case ShuffleExchangeExec(childPartitioning, baseChild, _)::Nil => if (childPartitioning.guarantees(partitioning)) child else operator case _ => operator } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala index 9fc4ffb651ec8..78f11ca8d8c78 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala @@ -35,9 +35,9 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * * A coordinator is constructed with three parameters, `numExchanges`, * `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`. - * - `numExchanges` is used to indicated that how many [[ShuffleExchange]]s that will be registered - * to this coordinator. So, when we start to do any actual work, we have a way to make sure that - * we have got expected number of [[ShuffleExchange]]s. + * - `numExchanges` is used to indicated that how many [[ShuffleExchangeExec]]s that will be + * registered to this coordinator. So, when we start to do any actual work, we have a way to + * make sure that we have got expected number of [[ShuffleExchangeExec]]s. * - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's * input data size. With this parameter, we can estimate the number of post-shuffle partitions. * This parameter is configured through @@ -47,28 +47,28 @@ import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan} * partitions. * * The workflow of this coordinator is described as follows: - * - Before the execution of a [[SparkPlan]], for a [[ShuffleExchange]] operator, + * - Before the execution of a [[SparkPlan]], for a [[ShuffleExchangeExec]] operator, * if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator. * This happens in the `doPrepare` method. - * - Once we start to execute a physical plan, a [[ShuffleExchange]] registered to this + * - Once we start to execute a physical plan, a [[ShuffleExchangeExec]] registered to this * coordinator will call `postShuffleRDD` to get its corresponding post-shuffle * [[ShuffledRowRDD]]. - * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchange]] + * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchangeExec]] * will immediately get its corresponding post-shuffle [[ShuffledRowRDD]]. * - If this coordinator has not made the decision on how to shuffle data, it will ask those - * registered [[ShuffleExchange]]s to submit their pre-shuffle stages. Then, based on the + * registered [[ShuffleExchangeExec]]s to submit their pre-shuffle stages. Then, based on the * size statistics of pre-shuffle partitions, this coordinator will determine the number of * post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices * to a single post-shuffle partition whenever necessary. * - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered - * [[ShuffleExchange]]s. So, when a [[ShuffleExchange]] calls `postShuffleRDD`, this coordinator - * can lookup the corresponding [[RDD]]. + * [[ShuffleExchangeExec]]s. So, when a [[ShuffleExchangeExec]] calls `postShuffleRDD`, this + * coordinator can lookup the corresponding [[RDD]]. * * The strategy used to determine the number of post-shuffle partitions is described as follows. * To determine the number of post-shuffle partitions, we have a target input size for a * post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages - * corresponding to the registered [[ShuffleExchange]]s, we will do a pass of those statistics and - * pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until + * corresponding to the registered [[ShuffleExchangeExec]]s, we will do a pass of those statistics + * and pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until * adding another pre-shuffle partition would cause the size of a post-shuffle partition to be * greater than the target size. * @@ -89,11 +89,11 @@ class ExchangeCoordinator( extends Logging { // The registered Exchange operators. - private[this] val exchanges = ArrayBuffer[ShuffleExchange]() + private[this] val exchanges = ArrayBuffer[ShuffleExchangeExec]() // This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator. - private[this] val postShuffleRDDs: JMap[ShuffleExchange, ShuffledRowRDD] = - new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) + private[this] val postShuffleRDDs: JMap[ShuffleExchangeExec, ShuffledRowRDD] = + new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) // A boolean that indicates if this coordinator has made decision on how to shuffle data. // This variable will only be updated by doEstimationIfNecessary, which is protected by @@ -101,11 +101,11 @@ class ExchangeCoordinator( @volatile private[this] var estimated: Boolean = false /** - * Registers a [[ShuffleExchange]] operator to this coordinator. This method is only allowed to - * be called in the `doPrepare` method of a [[ShuffleExchange]] operator. + * Registers a [[ShuffleExchangeExec]] operator to this coordinator. This method is only allowed + * to be called in the `doPrepare` method of a [[ShuffleExchangeExec]] operator. */ @GuardedBy("this") - def registerExchange(exchange: ShuffleExchange): Unit = synchronized { + def registerExchange(exchange: ShuffleExchangeExec): Unit = synchronized { exchanges += exchange } @@ -200,7 +200,7 @@ class ExchangeCoordinator( // Make sure we have the expected number of registered Exchange operators. assert(exchanges.length == numExchanges) - val newPostShuffleRDDs = new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges) + val newPostShuffleRDDs = new JHashMap[ShuffleExchangeExec, ShuffledRowRDD](numExchanges) // Submit all map stages val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]() @@ -255,7 +255,7 @@ class ExchangeCoordinator( } } - def postShuffleRDD(exchange: ShuffleExchange): ShuffledRowRDD = { + def postShuffleRDD(exchange: ShuffleExchangeExec): ShuffledRowRDD = { doEstimationIfNecessary() if (!postShuffleRDDs.containsKey(exchange)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala similarity index 96% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala index 0d06d83fb2f3c..5a1e217082bc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchangeExec.scala @@ -30,12 +30,13 @@ import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.MutablePair /** * Performs a shuffle that will result in the desired `newPartitioning`. */ -case class ShuffleExchange( +case class ShuffleExchangeExec( var newPartitioning: Partitioning, child: SparkPlan, @transient coordinator: Option[ExchangeCoordinator]) extends Exchange { @@ -84,7 +85,7 @@ case class ShuffleExchange( */ private[exchange] def prepareShuffleDependency() : ShuffleDependency[Int, InternalRow, InternalRow] = { - ShuffleExchange.prepareShuffleDependency( + ShuffleExchangeExec.prepareShuffleDependency( child.execute(), child.output, newPartitioning, serializer) } @@ -129,9 +130,9 @@ case class ShuffleExchange( } } -object ShuffleExchange { - def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange = { - ShuffleExchange(newPartitioning, child, coordinator = Option.empty[ExchangeCoordinator]) +object ShuffleExchangeExec { + def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchangeExec = { + ShuffleExchangeExec(newPartitioning, child, coordinator = Option.empty[ExchangeCoordinator]) } /** @@ -218,7 +219,11 @@ object ShuffleExchange { iter.map(row => mutablePair.update(row.copy(), null)) } implicit val ordering = new LazilyGeneratedOrdering(sortingExpressions, outputAttributes) - new RangePartitioner(numPartitions, rddForSampling, ascending = true) + new RangePartitioner( + numPartitions, + rddForSampling, + ascending = true, + samplePointsPerPartitionHint = SQLConf.get.rangeExchangeSampleSizePerPartition) case SinglePartition => new Partitioner { override def numPartitions: Int = 1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index f8058b2f7813b..b2dcbe5aa9877 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -866,7 +866,18 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalized.key, rows.length) + transform(rows.iterator, Some(rows.length)) + } + + override def transform( + rows: Iterator[InternalRow], + sizeHint: Option[Long]): HashedRelation = { + sizeHint match { + case Some(numRows) => + HashedRelation(rows, canonicalized.key, numRows.toInt) + case None => + HashedRelation(rows, canonicalized.key) + } } override lazy val canonicalized: HashedRelationBroadcastMode = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala index 14de2dc23e3c0..4e02803552e82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoinExec.scala @@ -402,7 +402,7 @@ case class SortMergeJoinExec( } } - private def genComparision(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { + private def genComparison(ctx: CodegenContext, a: Seq[ExprCode], b: Seq[ExprCode]): String = { val comparisons = a.zip(b).zipWithIndex.map { case ((l, r), i) => s""" |if (comp == 0) { @@ -463,7 +463,7 @@ case class SortMergeJoinExec( | continue; | } | if (!$matches.isEmpty()) { - | ${genComparision(ctx, leftKeyVars, matchedKeyVars)} + | ${genComparison(ctx, leftKeyVars, matchedKeyVars)} | if (comp == 0) { | return true; | } @@ -484,7 +484,7 @@ case class SortMergeJoinExec( | } | ${rightKeyVars.map(_.code).mkString("\n")} | } - | ${genComparision(ctx, leftKeyVars, rightKeyVars)} + | ${genComparison(ctx, leftKeyVars, rightKeyVars)} | if (comp > 0) { | $rightRow = null; | } else if (comp < 0) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 1f515e29b4af5..13da4b26a5dcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, LazilyGeneratedOrdering} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.util.Utils /** @@ -40,7 +40,7 @@ case class CollectLimitExec(limit: Int, child: SparkPlan) extends UnaryExecNode protected override def doExecute(): RDD[InternalRow] = { val locallyLimited = child.execute().mapPartitionsInternal(_.take(limit)) val shuffled = new ShuffledRowRDD( - ShuffleExchange.prepareShuffleDependency( + ShuffleExchangeExec.prepareShuffleDependency( locallyLimited, child.output, SinglePartition, serializer)) shuffled.mapPartitionsInternal(_.take(limit)) } @@ -153,7 +153,7 @@ case class TakeOrderedAndProjectExec( } } val shuffled = new ShuffledRowRDD( - ShuffleExchange.prepareShuffleDependency( + ShuffleExchangeExec.prepareShuffleDependency( localTopK, child.output, SinglePartition, serializer)) shuffled.mapPartitions { iter => val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 5a3fcad38888e..d861109436a08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.objects.Invoke -import org.apache.spark.sql.catalyst.plans.logical.{FunctionUtils, LogicalGroupState} +import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, FunctionUtils, LogicalGroupState} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.streaming.GroupStateImpl import org.apache.spark.sql.streaming.GroupStateTimeout @@ -361,8 +361,12 @@ object MapGroupsExec { outputObjAttr: Attribute, timeoutConf: GroupStateTimeout, child: SparkPlan): MapGroupsExec = { + val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false + } val f = (key: Any, values: Iterator[Any]) => { - func(key, values, GroupStateImpl.createForBatch(timeoutConf)) + func(key, values, GroupStateImpl.createForBatch(timeoutConf, watermarkPresent)) } new MapGroupsExec(f, keyDeserializer, valueDeserializer, groupingAttributes, dataAttributes, outputObjAttr, child) @@ -394,7 +398,11 @@ case class FlatMapGroupsInRExec( override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(groupingAttributes.map(SortOrder(_, Ascending))) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index 5e72cd255873a..81896187ecc46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -17,14 +17,44 @@ package org.apache.spark.sql.execution.python +import scala.collection.JavaConverters._ + import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.arrow.{ArrowConverters, ArrowPayload} import org.apache.spark.sql.types.StructType +/** + * Grouped a iterator into batches. + * This is similar to iter.grouped but returns Iterator[T] instead of Seq[T]. + * This is necessary because sometimes we cannot hold reference of input rows + * because the some input rows are mutable and can be reused. + */ +private class BatchIterator[T](iter: Iterator[T], batchSize: Int) + extends Iterator[Iterator[T]] { + + override def hasNext: Boolean = iter.hasNext + + override def next(): Iterator[T] = { + new Iterator[T] { + var count = 0 + + override def hasNext: Boolean = iter.hasNext && count < batchSize + + override def next(): T = { + if (!hasNext) { + Iterator.empty.next() + } else { + count += 1 + iter.next() + } + } + } + } +} + /** * A physical plan that evaluates a [[PythonUDF]], */ @@ -39,25 +69,40 @@ case class ArrowEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[InternalRow] = { - val inputIterator = ArrowConverters.toPayloadIterator( - iter, schema, conf.arrowMaxRecordsPerBatch, context).map(_.asPythonSerializable) - - // Output iterator for results from Python. - val outputIterator = new PythonRunner( - funcs, bufferSize, reuseWorker, PythonEvalType.SQL_PANDAS_UDF, argOffsets) - .compute(inputIterator, context.partitionId(), context) - - val outputRowIterator = ArrowConverters.fromPayloadIterator( - outputIterator.map(new ArrowPayload(_)), context) - - // Verify that the output schema is correct - if (outputRowIterator.hasNext) { - val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex - .map { case (attr, i) => attr.withName(s"_$i") }) - assert(schemaOut.equals(outputRowIterator.schema), - s"Invalid schema from pandas_udf: expected $schemaOut, got ${outputRowIterator.schema}") - } - outputRowIterator + val schemaOut = StructType.fromAttributes(output.drop(child.output.length).zipWithIndex + .map { case (attr, i) => attr.withName(s"_$i") }) + + val batchSize = conf.arrowMaxRecordsPerBatch + // DO NOT use iter.grouped(). See BatchIterator. + val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) + + val columnarBatchIter = new ArrowPythonRunner( + funcs, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_UDF, argOffsets, schema) + .compute(batchIter, context.partitionId(), context) + + new Iterator[InternalRow] { + + private var currentIter = if (columnarBatchIter.hasNext) { + val batch = columnarBatchIter.next() + assert(schemaOut.equals(batch.schema), + s"Invalid schema from pandas_udf: expected $schemaOut, got ${batch.schema}") + batch.rowIterator.asScala + } else { + Iterator.empty + } + + override def hasNext: Boolean = currentIter.hasNext || { + if (columnarBatchIter.hasNext) { + currentIter = columnarBatchIter.next().rowIterator.asScala + hasNext + } else { + false + } + } + + override def next(): InternalRow = currentIter.next() + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala new file mode 100644 index 0000000000000..f6c03c415dc66 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -0,0 +1,180 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.stream.{ArrowStreamReader, ArrowStreamWriter} + +import org.apache.spark._ +import org.apache.spark.api.python._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} +import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +/** + * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. + */ +class ArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]], + schema: StructType) + extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + funcs, bufferSize, reuseWorker, evalType, argOffsets) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Iterator[InternalRow]], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val arrowSchema = ArrowUtils.toArrowSchema(schema) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", 0, Long.MaxValue) + + val root = VectorSchemaRoot.create(arrowSchema, allocator) + val arrowWriter = ArrowWriter.create(root) + + var closed = false + + context.addTaskCompletionListener { _ => + if (!closed) { + root.close() + allocator.close() + } + } + + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + + Utils.tryWithSafeFinally { + while (inputIterator.hasNext) { + val nextBatch = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) + } + + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + } { + writer.end() + root.close() + allocator.close() + closed = true + } + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[ColumnarBatch] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + + private var closed = false + + context.addTaskCompletionListener { _ => + // todo: we need something like `reader.end()`, which release all the resources, but leave + // the input stream open. `reader.close()` will close the socket and we can't reuse worker. + // So here we simply not close the reader, which is problematic. + if (!closed) { + if (root != null) { + root.close() + } + allocator.close() + } + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(schema, vectors, root.getRowCount) + batch.setNumRows(root.getRowCount) + batch + } else { + root.close() + allocator.close() + closed = true + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { + stream.readInt() match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } + } catch handleException + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala index 2978eac50554d..26ee25f633ea4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import net.razorvine.pickle.{Pickler, Unpickler} import org.apache.spark.TaskContext -import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType, PythonRunner} +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan @@ -68,7 +68,7 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi }.grouped(100).map(x => pickle.dumps(x.toArray)) // Output iterator for results from Python. - val outputIterator = new PythonRunner( + val outputIterator = new PythonUDFRunner( funcs, bufferSize, reuseWorker, PythonEvalType.SQL_BATCHED_UDF, argOffsets) .compute(inputIterator, context.partitionId(), context) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index fec456d86dbe2..d6825369f7378 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution -import org.apache.spark.sql.execution.{FilterExec, SparkPlan} +import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} /** @@ -111,6 +110,9 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { } def apply(plan: SparkPlan): SparkPlan = plan transformUp { + // FlatMapGroupsInPandas can be evaluated directly in python worker + // Therefore we don't need to extract the UDFs + case plan: FlatMapGroupsInPandasExec => plan case plan: SparkPlan => extract(plan) } @@ -135,11 +137,15 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { udf.references.subsetOf(child.outputSet) } if (validUdfs.nonEmpty) { + if (validUdfs.exists(_.pythonUdfType == PythonUdfType.PANDAS_GROUPED_UDF)) { + throw new IllegalArgumentException("Can not use grouped vectorized UDFs") + } + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => AttributeReference(s"pythonUDF$i", u.dataType)() } - val evaluation = validUdfs.partition(_.vectorized) match { + val evaluation = validUdfs.partition(_.pythonUdfType == PythonUdfType.PANDAS_UDF) match { case (vectorizedUdfs, plainUdfs) if plainUdfs.isEmpty => ArrowEvalPythonExec(vectorizedUdfs, child.output ++ resultAttrs, child) case (vectorizedUdfs, plainUdfs) if vectorizedUdfs.isEmpty => @@ -169,7 +175,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { val newPlan = extract(rewritten) if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - execution.ProjectExec(plan.output, newPlan) + ProjectExec(plan.output, newPlan) } else { newPlan } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala new file mode 100644 index 0000000000000..5ed88ada428cb --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ + +import org.apache.spark.TaskContext +import org.apache.spark.api.python.{ChainedPythonFunctions, PythonEvalType} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} +import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} +import org.apache.spark.sql.types.StructType + +/** + * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] + * + * Rows in each group are passed to the Python worker as an Arrow record batch. + * The Python worker turns the record batch to a `pandas.DataFrame`, invoke the + * user-defined function, and passes the resulting `pandas.DataFrame` + * as an Arrow record batch. Finally, each record batch is turned to + * Iterator[InternalRow] using ColumnarBatch. + * + * Note on memory usage: + * Both the Python worker and the Java executor need to have enough memory to + * hold the largest group. The memory on the Java side is used to construct the + * record batch (off heap memory). The memory on the Python side is used for + * holding the `pandas.DataFrame`. It's possible to further split one group into + * multiple record batches to reduce the memory footprint on the Java side, this + * is left as future work. + */ +case class FlatMapGroupsInPandasExec( + groupingAttributes: Seq[Attribute], + func: Expression, + output: Seq[Attribute], + child: SparkPlan) + extends UnaryExecNode { + + private val pandasFunction = func.asInstanceOf[PythonUDF].func + + override def outputPartitioning: Partitioning = child.outputPartitioning + + override def producedAttributes: AttributeSet = AttributeSet(output) + + override def requiredChildDistribution: Seq[Distribution] = { + if (groupingAttributes.isEmpty) { + AllTuples :: Nil + } else { + ClusteredDistribution(groupingAttributes) :: Nil + } + } + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + Seq(groupingAttributes.map(SortOrder(_, Ascending))) + + override protected def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute() + + val bufferSize = inputRDD.conf.getInt("spark.buffer.size", 65536) + val reuseWorker = inputRDD.conf.getBoolean("spark.python.worker.reuse", defaultValue = true) + val chainedFunc = Seq(ChainedPythonFunctions(Seq(pandasFunction))) + val argOffsets = Array((0 until (child.output.length - groupingAttributes.length)).toArray) + val schema = StructType(child.schema.drop(groupingAttributes.length)) + + inputRDD.mapPartitionsInternal { iter => + val grouped = if (groupingAttributes.isEmpty) { + Iterator(iter) + } else { + val groupedIter = GroupedIterator(iter, groupingAttributes, child.output) + val dropGrouping = + UnsafeProjection.create(child.output.drop(groupingAttributes.length), child.output) + groupedIter.map { + case (_, groupedRowIter) => groupedRowIter.map(dropGrouping) + } + } + + val context = TaskContext.get() + + val columnarBatchIter = new ArrowPythonRunner( + chainedFunc, bufferSize, reuseWorker, + PythonEvalType.SQL_PANDAS_GROUPED_UDF, argOffsets, schema) + .compute(grouped, context.partitionId(), context) + + columnarBatchIter.flatMap(_.rowIterator.asScala).map(UnsafeProjection.create(output, output)) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 84a6d9e5be59c..9c07c7638de57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -29,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression], - vectorized: Boolean) + pythonUdfType: Int) extends Expression with Unevaluable with NonSQLExpression with UserDefinedExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala new file mode 100644 index 0000000000000..e28def1c4b423 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDFRunner.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import java.io._ +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean + +import org.apache.spark._ +import org.apache.spark.api.python._ + +/** + * A helper class to run Python UDFs in Spark. + */ +class PythonUDFRunner( + funcs: Seq[ChainedPythonFunctions], + bufferSize: Int, + reuseWorker: Boolean, + evalType: Int, + argOffsets: Array[Array[Int]]) + extends BasePythonRunner[Array[Byte], Array[Byte]]( + funcs, bufferSize, reuseWorker, evalType, argOffsets) { + + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + PythonRDD.writeIteratorToStream(inputIterator, dataOut) + dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) + } + } + } + + protected override def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + released: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { + new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) { + + protected override def read(): Array[Byte] = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + stream.readInt() match { + case length if length > 0 => + val obj = new Array[Byte](length) + stream.readFully(obj) + obj + case 0 => Array.empty[Byte] + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null + } + } catch handleException + } + } + } +} + +object PythonUDFRunner { + + def writeUDFs( + dataOut: DataOutputStream, + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]]): Unit = { + dataOut.writeInt(funcs.length) + funcs.zip(argOffsets).foreach { case (chained, offsets) => + dataOut.writeInt(offsets.length) + offsets.foreach { offset => + dataOut.writeInt(offset) + } + dataOut.writeInt(chained.funcs.length) + chained.funcs.foreach { f => + dataOut.writeInt(f.command.length) + dataOut.write(f.command) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala index a30a80acf5c23..b2fe6c300846a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/UserDefinedPythonFunction.scala @@ -22,6 +22,15 @@ import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.types.DataType +private[spark] object PythonUdfType { + // row-at-a-time UDFs + val NORMAL_UDF = 0 + // scalar vectorized UDFs + val PANDAS_UDF = 1 + // grouped vectorized UDFs + val PANDAS_GROUPED_UDF = 2 +} + /** * A user-defined Python function. This is used by the Python API. */ @@ -29,10 +38,10 @@ case class UserDefinedPythonFunction( name: String, func: PythonFunction, dataType: DataType, - vectorized: Boolean) { + pythonUdfType: Int) { def builder(e: Seq[Expression]): PythonUDF = { - PythonUDF(name, func, dataType, e, vectorized) + PythonUDF(name, func, dataType, e, pythonUdfType) } /** Returns a [[Column]] that will evaluate to calling this UDF with the given input. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 72e5ac40bbfed..6bd0696622005 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -121,7 +121,7 @@ class FileStreamSink( FileFormatWriter.write( sparkSession = sparkSession, - plan = data.queryExecution.executedPlan, + queryExecution = data.queryExecution, fileFormat = fileFormat, committer = committer, outputSpec = FileFormatWriter.OutputSpec(path, Map.empty), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala index ab690fd5fbbca..29f38fab3f896 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FlatMapGroupsWithStateExec.scala @@ -23,10 +23,8 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Attribut import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Distribution} import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode} -import org.apache.spark.sql.types.IntegerType import org.apache.spark.util.CompletionIterator /** @@ -62,30 +60,15 @@ case class FlatMapGroupsWithStateExec( import GroupStateImpl._ private val isTimeoutEnabled = timeoutConf != NoTimeout - private val timestampTimeoutAttribute = - AttributeReference("timeoutTimestamp", dataType = IntegerType, nullable = false)() - private val stateAttributes: Seq[Attribute] = { - val encSchemaAttribs = stateEncoder.schema.toAttributes - if (isTimeoutEnabled) encSchemaAttribs :+ timestampTimeoutAttribute else encSchemaAttribs + val stateManager = new FlatMapGroupsWithState_StateManager(stateEncoder, isTimeoutEnabled) + val watermarkPresent = child.output.exists { + case a: Attribute if a.metadata.contains(EventTimeWatermark.delayKey) => true + case _ => false } - // Get the serializer for the state, taking into account whether we need to save timestamps - private val stateSerializer = { - val encoderSerializer = stateEncoder.namedExpressions - if (isTimeoutEnabled) { - encoderSerializer :+ Literal(GroupStateImpl.NO_TIMESTAMP) - } else { - encoderSerializer - } - } - // Get the deserializer for the state. Note that this must be done in the driver, as - // resolving and binding of deserializer expressions to the encoded type can be safely done - // only in the driver. - private val stateDeserializer = stateEncoder.resolveAndBind().deserializer - /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(groupingAttributes) :: Nil + ClusteredDistribution(groupingAttributes, stateInfo.map(_.numPartitions)) :: Nil /** Ordering needed for using GroupingIterator */ override def requiredChildOrdering: Seq[Seq[SortOrder]] = @@ -109,11 +92,11 @@ case class FlatMapGroupsWithStateExec( child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, groupingAttributes.toStructType, - stateAttributes.toStructType, + stateManager.stateSchema, indexOrdinal = None, sqlContext.sessionState, Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => - val updater = new StateStoreUpdater(store) + val processor = new InputProcessor(store) // If timeout is based on event time, then filter late data based on watermark val filteredIter = watermarkPredicateForData match { @@ -128,7 +111,7 @@ case class FlatMapGroupsWithStateExec( // all the data has been processed. This is to ensure that the timeout information of all // the keys with data is updated before they are processed for timeouts. val outputIterator = - updater.updateStateForKeysWithData(filteredIter) ++ updater.updateStateForTimedOutKeys() + processor.processNewData(filteredIter) ++ processor.processTimedOutState() // Return an iterator of all the rows generated by all the keys, such that when fully // consumed, all the state updates will be committed by the state store @@ -143,7 +126,7 @@ case class FlatMapGroupsWithStateExec( } /** Helper class to update the state store */ - class StateStoreUpdater(store: StateStore) { + class InputProcessor(store: StateStore) { // Converters for translating input keys, values, output data between rows and Java objects private val getKeyObj = @@ -152,14 +135,6 @@ case class FlatMapGroupsWithStateExec( ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes) private val getOutputRow = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType) - // Converters for translating state between rows and Java objects - private val getStateObjFromRow = ObjectOperator.deserializeRowToObject( - stateDeserializer, stateAttributes) - private val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) - - // Index of the additional metadata fields in the state row - private val timeoutTimestampIndex = stateAttributes.indexOf(timestampTimeoutAttribute) - // Metrics private val numUpdatedStateRows = longMetric("numUpdatedStateRows") private val numOutputRows = longMetric("numOutputRows") @@ -168,20 +143,19 @@ case class FlatMapGroupsWithStateExec( * For every group, get the key, values and corresponding state and call the function, * and return an iterator of rows */ - def updateStateForKeysWithData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { + def processNewData(dataIter: Iterator[InternalRow]): Iterator[InternalRow] = { val groupedIter = GroupedIterator(dataIter, groupingAttributes, child.output) groupedIter.flatMap { case (keyRow, valueRowIter) => val keyUnsafeRow = keyRow.asInstanceOf[UnsafeRow] callFunctionAndUpdateState( - keyUnsafeRow, + stateManager.getState(store, keyUnsafeRow), valueRowIter, - store.get(keyUnsafeRow), hasTimedOut = false) } } /** Find the groups that have timeout set and are timing out right now, and call the function */ - def updateStateForTimedOutKeys(): Iterator[InternalRow] = { + def processTimedOutState(): Iterator[InternalRow] = { if (isTimeoutEnabled) { val timeoutThreshold = timeoutConf match { case ProcessingTimeTimeout => batchTimestampMs.get @@ -190,12 +164,11 @@ case class FlatMapGroupsWithStateExec( throw new IllegalStateException( s"Cannot filter timed out keys for $timeoutConf") } - val timingOutKeys = store.getRange(None, None).filter { rowPair => - val timeoutTimestamp = getTimeoutTimestamp(rowPair.value) - timeoutTimestamp != NO_TIMESTAMP && timeoutTimestamp < timeoutThreshold + val timingOutKeys = stateManager.getAllState(store).filter { state => + state.timeoutTimestamp != NO_TIMESTAMP && state.timeoutTimestamp < timeoutThreshold } - timingOutKeys.flatMap { rowPair => - callFunctionAndUpdateState(rowPair.key, Iterator.empty, rowPair.value, hasTimedOut = true) + timingOutKeys.flatMap { stateData => + callFunctionAndUpdateState(stateData, Iterator.empty, hasTimedOut = true) } } else Iterator.empty } @@ -205,72 +178,44 @@ case class FlatMapGroupsWithStateExec( * iterator. Note that the store updating is lazy, that is, the store will be updated only * after the returned iterator is fully consumed. * - * @param keyRow Row representing the key, cannot be null + * @param stateData All the data related to the state to be updated * @param valueRowIter Iterator of values as rows, cannot be null, but can be empty - * @param prevStateRow Row representing the previous state, can be null * @param hasTimedOut Whether this function is being called for a key timeout */ private def callFunctionAndUpdateState( - keyRow: UnsafeRow, + stateData: FlatMapGroupsWithState_StateData, valueRowIter: Iterator[InternalRow], - prevStateRow: UnsafeRow, hasTimedOut: Boolean): Iterator[InternalRow] = { - val keyObj = getKeyObj(keyRow) // convert key to objects + val keyObj = getKeyObj(stateData.keyRow) // convert key to objects val valueObjIter = valueRowIter.map(getValueObj.apply) // convert value rows to objects - val stateObj = getStateObj(prevStateRow) - val keyedState = GroupStateImpl.createForStreaming( - Option(stateObj), + val groupState = GroupStateImpl.createForStreaming( + Option(stateData.stateObj), batchTimestampMs.getOrElse(NO_TIMESTAMP), eventTimeWatermark.getOrElse(NO_TIMESTAMP), timeoutConf, - hasTimedOut) + hasTimedOut, + watermarkPresent) // Call function, get the returned objects and convert them to rows - val mappedIterator = func(keyObj, valueObjIter, keyedState).map { obj => + val mappedIterator = func(keyObj, valueObjIter, groupState).map { obj => numOutputRows += 1 getOutputRow(obj) } // When the iterator is consumed, then write changes to state def onIteratorCompletion: Unit = { - - val currentTimeoutTimestamp = keyedState.getTimeoutTimestamp - // If the state has not yet been set but timeout has been set, then - // we have to generate a row to save the timeout. However, attempting serialize - // null using case class encoder throws - - // java.lang.NullPointerException: Null value appeared in non-nullable field: - // If the schema is inferred from a Scala tuple / case class, or a Java bean, please - // try to use scala.Option[_] or other nullable types. - if (!keyedState.exists && currentTimeoutTimestamp != NO_TIMESTAMP) { - throw new IllegalStateException( - "Cannot set timeout when state is not defined, that is, state has not been" + - "initialized or has been removed") - } - - if (keyedState.hasRemoved) { - store.remove(keyRow) + if (groupState.hasRemoved && groupState.getTimeoutTimestamp == NO_TIMESTAMP) { + stateManager.removeState(store, stateData.keyRow) numUpdatedStateRows += 1 - } else { - val previousTimeoutTimestamp = getTimeoutTimestamp(prevStateRow) - val stateRowToWrite = if (keyedState.hasUpdated) { - getStateRow(keyedState.get) - } else { - prevStateRow - } - - val hasTimeoutChanged = currentTimeoutTimestamp != previousTimeoutTimestamp - val shouldWriteState = keyedState.hasUpdated || hasTimeoutChanged + val currentTimeoutTimestamp = groupState.getTimeoutTimestamp + val hasTimeoutChanged = currentTimeoutTimestamp != stateData.timeoutTimestamp + val shouldWriteState = groupState.hasUpdated || groupState.hasRemoved || hasTimeoutChanged if (shouldWriteState) { - if (stateRowToWrite == null) { - // This should never happen because checks in GroupStateImpl should avoid cases - // where empty state would need to be written - throw new IllegalStateException("Attempting to write empty state") - } - setTimeoutTimestamp(stateRowToWrite, currentTimeoutTimestamp) - store.put(keyRow, stateRowToWrite) + val updatedStateObj = if (groupState.exists) groupState.get else null + stateManager.putState(store, stateData.keyRow, updatedStateObj, currentTimeoutTimestamp) numUpdatedStateRows += 1 } } @@ -279,28 +224,5 @@ case class FlatMapGroupsWithStateExec( // Return an iterator of rows such that fully consumed, the updated state value will be saved CompletionIterator[InternalRow, Iterator[InternalRow]](mappedIterator, onIteratorCompletion) } - - /** Returns the state as Java object if defined */ - def getStateObj(stateRow: UnsafeRow): Any = { - if (stateRow != null) getStateObjFromRow(stateRow) else null - } - - /** Returns the row for an updated state */ - def getStateRow(obj: Any): UnsafeRow = { - assert(obj != null) - getStateRowFromObj(obj) - } - - /** Returns the timeout timestamp of a state row is set */ - def getTimeoutTimestamp(stateRow: UnsafeRow): Long = { - if (isTimeoutEnabled && stateRow != null) { - stateRow.getLong(timeoutTimestampIndex) - } else NO_TIMESTAMP - } - - /** Set the timestamp in a state row */ - def setTimeoutTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { - if (isTimeoutEnabled) stateRow.setLong(timeoutTimestampIndex, timeoutTimestamps) - } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala index 4401e86936af9..7f65e3ea9dd5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/GroupStateImpl.scala @@ -43,7 +43,8 @@ private[sql] class GroupStateImpl[S] private( batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, - override val hasTimedOut: Boolean) extends GroupState[S] { + override val hasTimedOut: Boolean, + watermarkPresent: Boolean) extends GroupState[S] { private var value: S = optionalValue.getOrElse(null.asInstanceOf[S]) private var defined: Boolean = optionalValue.isDefined @@ -90,7 +91,7 @@ private[sql] class GroupStateImpl[S] private( if (timeoutConf != ProcessingTimeTimeout) { throw new UnsupportedOperationException( "Cannot set timeout duration without enabling processing time timeout in " + - "map/flatMapGroupsWithState") + "[map|flatMap]GroupsWithState") } if (durationMs <= 0) { throw new IllegalArgumentException("Timeout duration must be positive") @@ -102,10 +103,6 @@ private[sql] class GroupStateImpl[S] private( setTimeoutDuration(parseDuration(duration)) } - @throws[IllegalArgumentException]("if 'timestampMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestampMs: Long): Unit = { checkTimeoutTimestampAllowed() if (timestampMs <= 0) { @@ -119,32 +116,34 @@ private[sql] class GroupStateImpl[S] private( timeoutTimestamp = timestampMs } - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(parseDuration(additionalDuration) + timestampMs) } - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestamp: Date): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(timestamp.getTime) } - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") override def setTimeoutTimestamp(timestamp: Date, additionalDuration: String): Unit = { checkTimeoutTimestampAllowed() setTimeoutTimestamp(timestamp.getTime + parseDuration(additionalDuration)) } + override def getCurrentWatermarkMs(): Long = { + if (!watermarkPresent) { + throw new UnsupportedOperationException( + "Cannot get event time watermark timestamp without setting watermark before " + + "[map|flatMap]GroupsWithState") + } + eventTimeWatermarkMs + } + + override def getCurrentProcessingTimeMs(): Long = { + batchProcessingTimeMs + } + override def toString: String = { s"GroupState(${getOption.map(_.toString).getOrElse("")})" } @@ -187,7 +186,7 @@ private[sql] class GroupStateImpl[S] private( if (timeoutConf != EventTimeTimeout) { throw new UnsupportedOperationException( "Cannot set timeout timestamp without enabling event time timeout in " + - "map/flatMapGroupsWithState") + "[map|flatMapGroupsWithState") } } } @@ -202,17 +201,22 @@ private[sql] object GroupStateImpl { batchProcessingTimeMs: Long, eventTimeWatermarkMs: Long, timeoutConf: GroupStateTimeout, - hasTimedOut: Boolean): GroupStateImpl[S] = { + hasTimedOut: Boolean, + watermarkPresent: Boolean): GroupStateImpl[S] = { new GroupStateImpl[S]( - optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, timeoutConf, hasTimedOut) + optionalValue, batchProcessingTimeMs, eventTimeWatermarkMs, + timeoutConf, hasTimedOut, watermarkPresent) } - def createForBatch(timeoutConf: GroupStateTimeout): GroupStateImpl[Any] = { + def createForBatch( + timeoutConf: GroupStateTimeout, + watermarkPresent: Boolean): GroupStateImpl[Any] = { new GroupStateImpl[Any]( optionalValue = None, - batchProcessingTimeMs = NO_TIMESTAMP, + batchProcessingTimeMs = System.currentTimeMillis, eventTimeWatermarkMs = NO_TIMESTAMP, timeoutConf, - hasTimedOut = false) + hasTimedOut = false, + watermarkPresent) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index 8e0aae39cabb6..a10ed5f2df1b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode /** @@ -61,6 +62,10 @@ class IncrementalExecution( StreamingDeduplicationStrategy :: Nil } + private val numStateStores = offsetSeqMetadata.conf.get(SQLConf.SHUFFLE_PARTITIONS.key) + .map(SQLConf.SHUFFLE_PARTITIONS.valueConverter) + .getOrElse(sparkSession.sessionState.conf.numShufflePartitions) + /** * See [SPARK-18339] * Walk the optimized logical plan and replace CurrentBatchTimestamp @@ -83,7 +88,11 @@ class IncrementalExecution( /** Get the state info of the next stateful operator */ private def nextStatefulOperationStateInfo(): StatefulOperatorStateInfo = { StatefulOperatorStateInfo( - checkpointLocation, runId, statefulOperatorId.getAndIncrement(), currentBatchId) + checkpointLocation, + runId, + statefulOperatorId.getAndIncrement(), + currentBatchId, + numStateStores) } /** Locates save/restore pairs surrounding aggregation. */ @@ -124,40 +133,14 @@ class IncrementalExecution( eventTimeWatermark = Some(offsetSeqMetadata.batchWatermarkMs), stateWatermarkPredicates = StreamingSymmetricHashJoinHelper.getStateWatermarkPredicates( - j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition, + j.left.output, j.right.output, j.leftKeys, j.rightKeys, j.condition.full, Some(offsetSeqMetadata.batchWatermarkMs)) ) } } - override def preparations: Seq[Rule[SparkPlan]] = - Seq(state, EnsureStatefulOpPartitioning) ++ super.preparations + override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations /** No need assert supported, as this check has already been done */ override def assertSupported(): Unit = { } } - -object EnsureStatefulOpPartitioning extends Rule[SparkPlan] { - // Needs to be transformUp to avoid extra shuffles - override def apply(plan: SparkPlan): SparkPlan = plan transformUp { - case so: StatefulOperator => - val numPartitions = plan.sqlContext.sessionState.conf.numShufflePartitions - val distributions = so.requiredChildDistribution - val children = so.children.zip(distributions).map { case (child, reqDistribution) => - val expectedPartitioning = reqDistribution match { - case AllTuples => SinglePartition - case ClusteredDistribution(keys) => HashPartitioning(keys, numPartitions) - case _ => throw new AnalysisException("Unexpected distribution expected for " + - s"Stateful Operator: $so. Expect AllTuples or ClusteredDistribution but got " + - s"$reqDistribution.") - } - if (child.outputPartitioning.guarantees(expectedPartitioning) && - child.execute().getNumPartitions == expectedPartitioning.numPartitions) { - child - } else { - ShuffleExchange(expectedPartitioning, child) - } - } - so.withNewChildren(children) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index ab716052c28ba..6b82c78ea653d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -44,6 +44,14 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: extends LeafNode { override def isStreaming: Boolean = true override def toString: String = sourceName + + // There's no sensible value here. On the execution path, this relation will be + // swapped out with microbatches. But some dataframe operations (in particular explain) do lead + // to this node surviving analysis. So we satisfy the LeafNode contract with the session default + // value. + override def computeStats(): Statistics = Statistics( + sizeInBytes = BigInt(dataSource.sparkSession.sessionState.conf.defaultSizeInBytes) + ) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala index 44f1fa58599d2..c351f658cb955 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala @@ -21,7 +21,7 @@ import java.util.concurrent.TimeUnit.NANOSECONDS import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, BindReferences, Expression, JoinedRow, Literal, NamedExpression, PreciseTimestampConversion, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, GenericInternalRow, JoinedRow, Literal, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ import org.apache.spark.sql.catalyst.plans.physical._ @@ -29,7 +29,6 @@ import org.apache.spark.sql.execution.{BinaryExecNode, SparkPlan} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper._ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.internal.SessionState -import org.apache.spark.sql.types.{LongType, TimestampType} import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} @@ -115,7 +114,8 @@ import org.apache.spark.util.{CompletionIterator, SerializableConfiguration} * @param leftKeys Expression to generate key rows for joining from left input * @param rightKeys Expression to generate key rows for joining from right input * @param joinType Type of join (inner, left outer, etc.) - * @param condition Optional, additional condition to filter output of the equi-join + * @param condition Conditions to filter rows, split by left, right, and joined. See + * [[JoinConditionSplitPredicates]] * @param stateInfo Version information required to read join state (buffered rows) * @param eventTimeWatermark Watermark of input event, same for both sides * @param stateWatermarkPredicates Predicates for removal of state, see @@ -127,7 +127,7 @@ case class StreamingSymmetricHashJoinExec( leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression], + condition: JoinConditionSplitPredicates, stateInfo: Option[StatefulOperatorStateInfo], eventTimeWatermark: Option[Long], stateWatermarkPredicates: JoinStateWatermarkPredicates, @@ -141,12 +141,21 @@ case class StreamingSymmetricHashJoinExec( condition: Option[Expression], left: SparkPlan, right: SparkPlan) = { + this( - leftKeys, rightKeys, joinType, condition, stateInfo = None, eventTimeWatermark = None, + leftKeys, rightKeys, joinType, JoinConditionSplitPredicates(condition, left, right), + stateInfo = None, eventTimeWatermark = None, stateWatermarkPredicates = JoinStateWatermarkPredicates(), left, right) } - require(joinType == Inner, s"${getClass.getSimpleName} should not take $joinType as the JoinType") + private def throwBadJoinTypeException(): Nothing = { + throw new IllegalArgumentException( + s"${getClass.getSimpleName} should not take $joinType as the JoinType") + } + + require( + joinType == Inner || joinType == LeftOuter || joinType == RightOuter, + s"${getClass.getSimpleName} should not take $joinType as the JoinType") require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType)) private val storeConf = new StateStoreConf(sqlContext.conf) @@ -154,14 +163,24 @@ case class StreamingSymmetricHashJoinExec( new SerializableConfiguration(SessionState.newHadoopConf( sparkContext.hadoopConfiguration, sqlContext.conf))) + val nullLeft = new GenericInternalRow(left.output.map(_.withNullability(true)).length) + val nullRight = new GenericInternalRow(right.output.map(_.withNullability(true)).length) + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil - override def output: Seq[Attribute] = left.output ++ right.output + override def output: Seq[Attribute] = joinType match { + case _: InnerLike => left.output ++ right.output + case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) + case RightOuter => left.output.map(_.withNullability(true)) ++ right.output + case _ => throwBadJoinTypeException() + } override def outputPartitioning: Partitioning = joinType match { case _: InnerLike => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftOuter => PartitioningCollection(Seq(left.outputPartitioning)) + case RightOuter => PartitioningCollection(Seq(right.outputPartitioning)) case x => throw new IllegalArgumentException( s"${getClass.getSimpleName} should not take $x as the JoinType") @@ -192,10 +211,15 @@ case class StreamingSymmetricHashJoinExec( val updateStartTimeNs = System.nanoTime val joinedRow = new JoinedRow + + val postJoinFilter = + newPredicate(condition.bothSides.getOrElse(Literal(true)), left.output ++ right.output).eval _ val leftSideJoiner = new OneSideHashJoiner( - LeftSide, left.output, leftKeys, leftInputIter, stateWatermarkPredicates.left) + LeftSide, left.output, leftKeys, leftInputIter, + condition.leftSideOnly, postJoinFilter, stateWatermarkPredicates.left) val rightSideJoiner = new OneSideHashJoiner( - RightSide, right.output, rightKeys, rightInputIter, stateWatermarkPredicates.right) + RightSide, right.output, rightKeys, rightInputIter, + condition.rightSideOnly, postJoinFilter, stateWatermarkPredicates.right) // Join one side input using the other side's buffered/state rows. Here is how it is done. // @@ -207,31 +231,102 @@ case class StreamingSymmetricHashJoinExec( // matching new left input with new right input, since the new left input has become stored // by that point. This tiny asymmetry is necessary to avoid duplication. val leftOutputIter = leftSideJoiner.storeAndJoinWithOtherSide(rightSideJoiner) { - (inputRow: UnsafeRow, matchedRow: UnsafeRow) => - joinedRow.withLeft(inputRow).withRight(matchedRow) + (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(input).withRight(matched) } val rightOutputIter = rightSideJoiner.storeAndJoinWithOtherSide(leftSideJoiner) { - (inputRow: UnsafeRow, matchedRow: UnsafeRow) => - joinedRow.withLeft(matchedRow).withRight(inputRow) + (input: InternalRow, matched: InternalRow) => joinedRow.withLeft(matched).withRight(input) } - // Filter the joined rows based on the given condition. - val outputFilterFunction = - newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output).eval _ - val filteredOutputIter = - (leftOutputIter ++ rightOutputIter).filter(outputFilterFunction).map { row => - numOutputRows += 1 - row - } + // We need to save the time that the inner join output iterator completes, since outer join + // output counts as both update and removal time. + var innerOutputCompletionTimeNs: Long = 0 + def onInnerOutputCompletion = { + innerOutputCompletionTimeNs = System.nanoTime + } + // This is the iterator which produces the inner join rows. For outer joins, this will be + // prepended to a second iterator producing outer join rows; for inner joins, this is the full + // output. + val innerOutputIter = CompletionIterator[InternalRow, Iterator[InternalRow]]( + (leftOutputIter ++ rightOutputIter), onInnerOutputCompletion) + + + val outputIter: Iterator[InternalRow] = joinType match { + case Inner => + innerOutputIter + case LeftOuter => + // We generate the outer join input by: + // * Getting an iterator over the rows that have aged out on the left side. These rows are + // candidates for being null joined. Note that to avoid doing two passes, this iterator + // removes the rows from the state manager as they're processed. + // * Checking whether the current row matches a key in the right side state, and that key + // has any value which satisfies the filter function when joined. If it doesn't, + // we know we can join with null, since there was never (including this batch) a match + // within the watermark period. If it does, there must have been a match at some point, so + // we know we can't join with null. + def matchesWithRightSideState(leftKeyValue: UnsafeRowPair) = { + rightSideJoiner.get(leftKeyValue.key).exists { rightValue => + postJoinFilter(joinedRow.withLeft(leftKeyValue.value).withRight(rightValue)) + } + } + val removedRowIter = leftSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => matchesWithRightSideState(pair)) + .map(pair => joinedRow.withLeft(pair.value).withRight(nullRight)) + + innerOutputIter ++ outerOutputIter + case RightOuter => + // See comments for left outer case. + def matchesWithLeftSideState(rightKeyValue: UnsafeRowPair) = { + leftSideJoiner.get(rightKeyValue.key).exists { leftValue => + postJoinFilter(joinedRow.withLeft(leftValue).withRight(rightKeyValue.value)) + } + } + val removedRowIter = rightSideJoiner.removeOldState() + val outerOutputIter = removedRowIter + .filterNot(pair => matchesWithLeftSideState(pair)) + .map(pair => joinedRow.withLeft(nullLeft).withRight(pair.value)) + + innerOutputIter ++ outerOutputIter + case _ => throwBadJoinTypeException() + } + + val outputProjection = UnsafeProjection.create(left.output ++ right.output, output) + val outputIterWithMetrics = outputIter.map { row => + numOutputRows += 1 + outputProjection(row) + } // Function to remove old state after all the input has been consumed and output generated def onOutputCompletion = { + // All processing time counts as update time. allUpdatesTimeMs += math.max(NANOSECONDS.toMillis(System.nanoTime - updateStartTimeNs), 0) - // Remove old state if needed + // Processing time between inner output completion and here comes from the outer portion of a + // join, and thus counts as removal time as we remove old state from one side while iterating. + if (innerOutputCompletionTimeNs != 0) { + allRemovalsTimeMs += + math.max(NANOSECONDS.toMillis(System.nanoTime - innerOutputCompletionTimeNs), 0) + } + allRemovalsTimeMs += timeTakenMs { - leftSideJoiner.removeOldState() - rightSideJoiner.removeOldState() + // Remove any remaining state rows which aren't needed because they're below the watermark. + // + // For inner joins, we have to remove unnecessary state rows from both sides if possible. + // For outer joins, we have already removed unnecessary state rows from the outer side + // (e.g., left side for left outer join) while generating the outer "null" outputs. Now, we + // have to remove unnecessary state rows from the other side (e.g., right side for the left + // outer join) if possible. In all cases, nothing needs to be outputted, hence the removal + // needs to be done greedily by immediately consuming the returned iterator. + val cleanupIter = joinType match { + case Inner => + leftSideJoiner.removeOldState() ++ rightSideJoiner.removeOldState() + case LeftOuter => rightSideJoiner.removeOldState() + case RightOuter => leftSideJoiner.removeOldState() + case _ => throwBadJoinTypeException() + } + while (cleanupIter.hasNext) { + cleanupIter.next() + } } // Commit all state changes and update state store metrics @@ -251,20 +346,43 @@ case class StreamingSymmetricHashJoinExec( } } - CompletionIterator[InternalRow, Iterator[InternalRow]](filteredOutputIter, onOutputCompletion) + CompletionIterator[InternalRow, Iterator[InternalRow]]( + outputIterWithMetrics, onOutputCompletion) } /** * Internal helper class to consume input rows, generate join output rows using other sides * buffered state rows, and finally clean up this sides buffered state rows + * + * @param joinSide The JoinSide - either left or right. + * @param inputAttributes The input attributes for this side of the join. + * @param joinKeys The join keys. + * @param inputIter The iterator of input rows on this side to be joined. + * @param preJoinFilterExpr A filter over rows on this side. This filter rejects rows that could + * never pass the overall join condition no matter what other side row + * they're joined with. + * @param postJoinFilter A filter over joined rows. This filter completes the application of + * the overall join condition, assuming that preJoinFilter on both sides + * of the join has already been passed. + * Passed as a function rather than expression to avoid creating the + * predicate twice; we also need this filter later on in the parent exec. + * @param stateWatermarkPredicate The state watermark predicate. See + * [[StreamingSymmetricHashJoinExec]] for further description of + * state watermarks. */ private class OneSideHashJoiner( joinSide: JoinSide, inputAttributes: Seq[Attribute], joinKeys: Seq[Expression], inputIter: Iterator[InternalRow], + preJoinFilterExpr: Option[Expression], + postJoinFilter: (InternalRow) => Boolean, stateWatermarkPredicate: Option[JoinStateWatermarkPredicate]) { + // Filter the joined rows based on the given condition. + val preJoinFilter = + newPredicate(preJoinFilterExpr.getOrElse(Literal(true)), inputAttributes).eval _ + private val joinStateManager = new SymmetricHashJoinStateManager( joinSide, inputAttributes, joinKeys, stateInfo, storeConf, hadoopConfBcast.value.value) private[this] val keyGenerator = UnsafeProjection.create(joinKeys, inputAttributes) @@ -296,8 +414,8 @@ case class StreamingSymmetricHashJoinExec( */ def storeAndJoinWithOtherSide( otherSideJoiner: OneSideHashJoiner)( - generateJoinedRow: (UnsafeRow, UnsafeRow) => JoinedRow): Iterator[InternalRow] = { - + generateJoinedRow: (InternalRow, InternalRow) => JoinedRow): + Iterator[InternalRow] = { val watermarkAttribute = inputAttributes.find(_.metadata.contains(delayKey)) val nonLateRows = WatermarkSupport.watermarkExpression(watermarkAttribute, eventTimeWatermark) match { @@ -310,28 +428,60 @@ case class StreamingSymmetricHashJoinExec( nonLateRows.flatMap { row => val thisRow = row.asInstanceOf[UnsafeRow] - val key = keyGenerator(thisRow) - val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow => - generateJoinedRow(thisRow, thatRow) + // If this row fails the pre join filter, that means it can never satisfy the full join + // condition no matter what other side row it's matched with. This allows us to avoid + // adding it to the state, and generate an outer join row immediately (or do nothing in + // the case of inner join). + if (preJoinFilter(thisRow)) { + val key = keyGenerator(thisRow) + val outputIter = otherSideJoiner.joinStateManager.get(key).map { thatRow => + generateJoinedRow(thisRow, thatRow) + }.filter(postJoinFilter) + val shouldAddToState = // add only if both removal predicates do not match + !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) + if (shouldAddToState) { + joinStateManager.append(key, thisRow) + updatedStateRowsCount += 1 + } + outputIter + } else { + joinSide match { + case LeftSide if joinType == LeftOuter => + Iterator(generateJoinedRow(thisRow, nullRight)) + case RightSide if joinType == RightOuter => + Iterator(generateJoinedRow(thisRow, nullLeft)) + case _ => Iterator() + } } - val shouldAddToState = // add only if both removal predicates do not match - !stateKeyWatermarkPredicateFunc(key) && !stateValueWatermarkPredicateFunc(thisRow) - if (shouldAddToState) { - joinStateManager.append(key, thisRow) - updatedStateRowsCount += 1 - } - outputIter } } - /** Remove old buffered state rows using watermarks for state keys and values */ - def removeOldState(): Unit = { + /** + * Get an iterator over the values stored in this joiner's state manager for the given key. + * + * Should not be interleaved with mutations. + */ + def get(key: UnsafeRow): Iterator[UnsafeRow] = { + joinStateManager.get(key) + } + + /** + * Builds an iterator over old state key-value pairs, removing them lazily as they're produced. + * + * @note This iterator must be consumed fully before any other operations are made + * against this joiner's join state manager. For efficiency reasons, the intermediate states of + * the iterator leave the state manager in an undefined state. + * + * We do this to avoid requiring either two passes or full materialization when + * processing the rows for outer join. + */ + def removeOldState(): Iterator[UnsafeRowPair] = { stateWatermarkPredicate match { case Some(JoinStateKeyWatermarkPredicate(expr)) => joinStateManager.removeByKeyCondition(stateKeyWatermarkPredicateFunc) case Some(JoinStateValueWatermarkPredicate(expr)) => joinStateManager.removeByValueCondition(stateValueWatermarkPredicateFunc) - case _ => + case _ => Iterator.empty } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala index e50274a1baba1..167e991ca62f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinHelper.scala @@ -23,8 +23,10 @@ import scala.util.control.NonFatal import org.apache.spark.{Partition, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ZippedPartitionsRDD2} -import org.apache.spark.sql.catalyst.expressions.{Add, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} +import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper +import org.apache.spark.sql.catalyst.expressions.{Add, And, Attribute, AttributeReference, AttributeSet, BoundReference, Cast, CheckOverflow, Expression, ExpressionSet, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Multiply, NamedExpression, PreciseTimestampConversion, PredicateHelper, Subtract, TimeAdd, TimeSub, UnaryMinus} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark._ +import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.WatermarkSupport.watermarkExpression import org.apache.spark.sql.execution.streaming.state.{StateStoreCoordinatorRef, StateStoreProvider, StateStoreProviderId} import org.apache.spark.sql.types._ @@ -34,7 +36,7 @@ import org.apache.spark.unsafe.types.CalendarInterval /** * Helper object for [[StreamingSymmetricHashJoinExec]]. See that object for more details. */ -object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { +object StreamingSymmetricHashJoinHelper extends Logging { sealed trait JoinSide case object LeftSide extends JoinSide { override def toString(): String = "left" } @@ -65,6 +67,73 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { } } + /** + * Wrapper around various useful splits of the join condition. + * left AND right AND joined is equivalent to full. + * + * Note that left and right do not necessarily contain *all* conjuncts which satisfy + * their condition. Any conjuncts after the first nondeterministic one are treated as + * nondeterministic for purposes of the split. + * + * @param leftSideOnly Deterministic conjuncts which reference only the left side of the join. + * @param rightSideOnly Deterministic conjuncts which reference only the right side of the join. + * @param bothSides Conjuncts which are nondeterministic, occur after a nondeterministic conjunct, + * or reference both left and right sides of the join. + * @param full The full join condition. + */ + case class JoinConditionSplitPredicates( + leftSideOnly: Option[Expression], + rightSideOnly: Option[Expression], + bothSides: Option[Expression], + full: Option[Expression]) { + override def toString(): String = { + s"condition = [ leftOnly = ${leftSideOnly.map(_.toString).getOrElse("null")}, " + + s"rightOnly = ${rightSideOnly.map(_.toString).getOrElse("null")}, " + + s"both = ${bothSides.map(_.toString).getOrElse("null")}, " + + s"full = ${full.map(_.toString).getOrElse("null")} ]" + } + } + + object JoinConditionSplitPredicates extends PredicateHelper { + def apply(condition: Option[Expression], left: SparkPlan, right: SparkPlan): + JoinConditionSplitPredicates = { + // Split the condition into 3 parts: + // * Conjuncts that can be evaluated on only the left input. + // * Conjuncts that can be evaluated on only the right input. + // * Conjuncts that require both left and right input. + // + // Note that we treat nondeterministic conjuncts as though they require both left and right + // input. To maintain their semantics, they need to be evaluated exactly once per joined row. + val (leftCondition, rightCondition, joinedCondition) = { + if (condition.isEmpty) { + (None, None, None) + } else { + // Span rather than partition, because nondeterministic expressions don't commute + // across AND. + val (deterministicConjuncts, nonDeterministicConjuncts) = + splitConjunctivePredicates(condition.get).span(_.deterministic) + + val (leftConjuncts, nonLeftConjuncts) = deterministicConjuncts.partition { cond => + cond.references.subsetOf(left.outputSet) + } + + val (rightConjuncts, nonRightConjuncts) = deterministicConjuncts.partition { cond => + cond.references.subsetOf(right.outputSet) + } + + ( + leftConjuncts.reduceOption(And), + rightConjuncts.reduceOption(And), + (nonLeftConjuncts.intersect(nonRightConjuncts) ++ nonDeterministicConjuncts) + .reduceOption(And) + ) + } + } + + JoinConditionSplitPredicates(leftCondition, rightCondition, joinedCondition, condition) + } + } + /** Get the predicates defining the state watermarks for both sides of the join */ def getStateWatermarkPredicates( leftAttributes: Seq[Attribute], @@ -111,7 +180,7 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { expr.map(JoinStateKeyWatermarkPredicate.apply _) } else if (isWatermarkDefinedOnInput) { // case 2 in the StreamingSymmetricHashJoinExec docs - val stateValueWatermark = getStateValueWatermark( + val stateValueWatermark = StreamingJoinHelper.getStateValueWatermark( attributesToFindStateWatermarkFor = AttributeSet(oneSideInputAttributes), attributesWithEventWatermark = AttributeSet(otherSideInputAttributes), condition, @@ -132,242 +201,6 @@ object StreamingSymmetricHashJoinHelper extends PredicateHelper with Logging { JoinStateWatermarkPredicates(leftStateWatermarkPredicate, rightStateWatermarkPredicate) } - /** - * Get state value watermark (see [[StreamingSymmetricHashJoinExec]] for context about it) - * given the join condition and the event time watermark. This is how it works. - * - The condition is split into conjunctive predicates, and we find the predicates of the - * form `leftTime + c1 < rightTime + c2` (or <=, >, >=). - * - We canoncalize the predicate and solve it with the event time watermark value to find the - * value of the state watermark. - * This function is supposed to make best-effort attempt to get the state watermark. If there is - * any error, it will return None. - * - * @param attributesToFindStateWatermarkFor attributes of the side whose state watermark - * is to be calculated - * @param attributesWithEventWatermark attributes of the other side which has a watermark column - * @param joinCondition join condition - * @param eventWatermark watermark defined on the input event data - * @return state value watermark in milliseconds, is possible. - */ - def getStateValueWatermark( - attributesToFindStateWatermarkFor: AttributeSet, - attributesWithEventWatermark: AttributeSet, - joinCondition: Option[Expression], - eventWatermark: Option[Long]): Option[Long] = { - - // If condition or event time watermark is not provided, then cannot calculate state watermark - if (joinCondition.isEmpty || eventWatermark.isEmpty) return None - - // If there is not watermark attribute, then cannot define state watermark - if (!attributesWithEventWatermark.exists(_.metadata.contains(delayKey))) return None - - def getStateWatermarkSafely(l: Expression, r: Expression): Option[Long] = { - try { - getStateWatermarkFromLessThenPredicate( - l, r, attributesToFindStateWatermarkFor, attributesWithEventWatermark, eventWatermark) - } catch { - case NonFatal(e) => - logWarning(s"Error trying to extract state constraint from condition $joinCondition", e) - None - } - } - - val allStateWatermarks = splitConjunctivePredicates(joinCondition.get).flatMap { predicate => - - // The generated the state watermark cleanup expression is inclusive of the state watermark. - // If state watermark is W, all state where timestamp <= W will be cleaned up. - // Now when the canonicalized join condition solves to leftTime >= W, we dont want to clean - // up leftTime <= W. Rather we should clean up leftTime <= W - 1. Hence the -1 below. - val stateWatermark = predicate match { - case LessThan(l, r) => getStateWatermarkSafely(l, r) - case LessThanOrEqual(l, r) => getStateWatermarkSafely(l, r).map(_ - 1) - case GreaterThan(l, r) => getStateWatermarkSafely(r, l) - case GreaterThanOrEqual(l, r) => getStateWatermarkSafely(r, l).map(_ - 1) - case _ => None - } - if (stateWatermark.nonEmpty) { - logInfo(s"Condition $joinCondition generated watermark constraint = ${stateWatermark.get}") - } - stateWatermark - } - allStateWatermarks.reduceOption((x, y) => Math.min(x, y)) - } - - /** - * Extract the state value watermark (milliseconds) from the condition - * `LessThan(leftExpr, rightExpr)` where . For example: if we want to find the constraint for - * leftTime using the watermark on the rightTime. Example: - * - * Input: rightTime-with-watermark + c1 < leftTime + c2 - * Canonical form: rightTime-with-watermark + c1 + (-c2) + (-leftTime) < 0 - * Solving for rightTime: rightTime-with-watermark + c1 + (-c2) < leftTime - * With watermark value: watermark-value + c1 + (-c2) < leftTime - */ - private def getStateWatermarkFromLessThenPredicate( - leftExpr: Expression, - rightExpr: Expression, - attributesToFindStateWatermarkFor: AttributeSet, - attributesWithEventWatermark: AttributeSet, - eventWatermark: Option[Long]): Option[Long] = { - - val attributesInCondition = AttributeSet( - leftExpr.collect { case a: AttributeReference => a } ++ - rightExpr.collect { case a: AttributeReference => a } - ) - if (attributesInCondition.filter { attributesToFindStateWatermarkFor.contains(_) }.size > 1 || - attributesInCondition.filter { attributesWithEventWatermark.contains(_) }.size > 1) { - // If more than attributes present in condition from one side, then it cannot be solved - return None - } - - def containsAttributeToFindStateConstraintFor(e: Expression): Boolean = { - e.collectLeaves().collectFirst { - case a @ AttributeReference(_, TimestampType, _, _) - if attributesToFindStateWatermarkFor.contains(a) => a - }.nonEmpty - } - - // Canonicalization step 1: convert to (rightTime-with-watermark + c1) - (leftTime + c2) < 0 - val allOnLeftExpr = Subtract(leftExpr, rightExpr) - logDebug(s"All on Left:\n${allOnLeftExpr.treeString(true)}\n${allOnLeftExpr.asCode}") - - // Canonicalization step 2: extract commutative terms - // rightTime-with-watermark, c1, -leftTime, -c2 - val terms = ExpressionSet(collectTerms(allOnLeftExpr)) - logDebug("Terms extracted from join condition:\n\t" + terms.mkString("\n\t")) - - - - // Find the term that has leftTime (i.e. the one present in attributesToFindConstraintFor - val constraintTerms = terms.filter(containsAttributeToFindStateConstraintFor) - - // Verify there is only one correct constraint term and of the correct type - if (constraintTerms.size > 1) { - logWarning("Failed to extract state constraint terms: multiple time terms in condition\n\t" + - terms.mkString("\n\t")) - return None - } - if (constraintTerms.isEmpty) { - logDebug("Failed to extract state constraint terms: no time terms in condition\n\t" + - terms.mkString("\n\t")) - return None - } - val constraintTerm = constraintTerms.head - if (constraintTerm.collectFirst { case u: UnaryMinus => u }.isEmpty) { - // Incorrect condition. We want the constraint term in canonical form to be `-leftTime` - // so that resolve for it as `-leftTime + watermark + c < 0` ==> `watermark + c < leftTime`. - // Now, if the original conditions is `rightTime-with-watermark > leftTime` and watermark - // condition is `rightTime-with-watermark > watermarkValue`, then no constraint about - // `leftTime` can be inferred. In this case, after canonicalization and collection of terms, - // the constraintTerm would be `leftTime` and not `-leftTime`. Hence, we return None. - return None - } - - // Replace watermark attribute with watermark value, and generate the resolved expression - // from the other terms. That is, - // rightTime-with-watermark, c1, -c2 => watermark, c1, -c2 => watermark + c1 + (-c2) - logDebug(s"Constraint term from join condition:\t$constraintTerm") - val exprWithWatermarkSubstituted = (terms - constraintTerm).map { term => - term.transform { - case a @ AttributeReference(_, TimestampType, _, metadata) - if attributesWithEventWatermark.contains(a) && metadata.contains(delayKey) => - Multiply(Literal(eventWatermark.get.toDouble), Literal(1000.0)) - } - }.reduceLeft(Add) - - // Calculate the constraint value - logInfo(s"Final expression to evaluate constraint:\t$exprWithWatermarkSubstituted") - val constraintValue = exprWithWatermarkSubstituted.eval().asInstanceOf[java.lang.Double] - Some((Double2double(constraintValue) / 1000.0).toLong) - } - - /** - * Collect all the terms present in an expression after converting it into the form - * a + b + c + d where each term be either an attribute or a literal casted to long, - * optionally wrapped in a unary minus. - */ - private def collectTerms(exprToCollectFrom: Expression): Seq[Expression] = { - var invalid = false - - /** Wrap a term with UnaryMinus if its needs to be negated. */ - def negateIfNeeded(expr: Expression, minus: Boolean): Expression = { - if (minus) UnaryMinus(expr) else expr - } - - /** - * Recursively split the expression into its leaf terms contains attributes or literals. - * Returns terms only of the forms: - * Cast(AttributeReference), UnaryMinus(Cast(AttributeReference)), - * Cast(AttributeReference, Double), UnaryMinus(Cast(AttributeReference, Double)) - * Multiply(Literal), UnaryMinus(Multiply(Literal)) - * Multiply(Cast(Literal)), UnaryMinus(Multiple(Cast(Literal))) - * - * Note: - * - If term needs to be negated for making it a commutative term, - * then it will be wrapped in UnaryMinus(...) - * - Each terms will be representing timestamp value or time interval in microseconds, - * typed as doubles. - */ - def collect(expr: Expression, negate: Boolean): Seq[Expression] = { - expr match { - case Add(left, right) => - collect(left, negate) ++ collect(right, negate) - case Subtract(left, right) => - collect(left, negate) ++ collect(right, !negate) - case TimeAdd(left, right, _) => - collect(left, negate) ++ collect(right, negate) - case TimeSub(left, right, _) => - collect(left, negate) ++ collect(right, !negate) - case UnaryMinus(child) => - collect(child, !negate) - case CheckOverflow(child, _) => - collect(child, negate) - case Cast(child, dataType, _) => - dataType match { - case _: NumericType | _: TimestampType => collect(child, negate) - case _ => - invalid = true - Seq.empty - } - case a: AttributeReference => - val castedRef = if (a.dataType != DoubleType) Cast(a, DoubleType) else a - Seq(negateIfNeeded(castedRef, negate)) - case lit: Literal => - // If literal of type calendar interval, then explicitly convert to millis - // Convert other number like literal to doubles representing millis (by x1000) - val castedLit = lit.dataType match { - case CalendarIntervalType => - val calendarInterval = lit.value.asInstanceOf[CalendarInterval] - if (calendarInterval.months > 0) { - invalid = true - logWarning( - s"Failed to extract state value watermark from condition $exprToCollectFrom " + - s"as imprecise intervals like months and years cannot be used for" + - s"watermark calculation. Use interval in terms of day instead.") - Literal(0.0) - } else { - Literal(calendarInterval.microseconds.toDouble) - } - case DoubleType => - Multiply(lit, Literal(1000000.0)) - case _: NumericType => - Multiply(Cast(lit, DoubleType), Literal(1000000.0)) - case _: TimestampType => - Multiply(PreciseTimestampConversion(lit, TimestampType, LongType), Literal(1000000.0)) - } - Seq(negateIfNeeded(castedLit, negate)) - case a @ _ => - logWarning( - s"Failed to extract state value watermark from condition $exprToCollectFrom due to $a") - invalid = true - Seq.empty - } - } - - val terms = collect(exprToCollectFrom, negate = false) - if (!invalid) terms else Seq.empty - } - /** * A custom RDD that allows partitions to be "zipped" together, while ensuring the tasks' * preferred location is based on which executors have the required join state stores already diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala new file mode 100644 index 0000000000000..d077836da847c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/FlatMapGroupsWithState_StateManager.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.state + +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, BoundReference, CaseWhen, CreateNamedStruct, GetStructField, IsNull, Literal, UnsafeRow} +import org.apache.spark.sql.execution.ObjectOperator +import org.apache.spark.sql.execution.streaming.GroupStateImpl +import org.apache.spark.sql.execution.streaming.GroupStateImpl.NO_TIMESTAMP +import org.apache.spark.sql.types.{IntegerType, LongType, StructType} + + +/** + * Class to serialize/write/read/deserialize state for + * [[org.apache.spark.sql.execution.streaming.FlatMapGroupsWithStateExec]]. + */ +class FlatMapGroupsWithState_StateManager( + stateEncoder: ExpressionEncoder[Any], + shouldStoreTimestamp: Boolean) extends Serializable { + + /** Schema of the state rows saved in the state store */ + val stateSchema = { + val schema = new StructType().add("groupState", stateEncoder.schema, nullable = true) + if (shouldStoreTimestamp) schema.add("timeoutTimestamp", LongType) else schema + } + + /** Get deserialized state and corresponding timeout timestamp for a key */ + def getState(store: StateStore, keyRow: UnsafeRow): FlatMapGroupsWithState_StateData = { + val stateRow = store.get(keyRow) + stateDataForGets.withNew( + keyRow, stateRow, getStateObj(stateRow), getTimestamp(stateRow)) + } + + /** Put state and timeout timestamp for a key */ + def putState(store: StateStore, keyRow: UnsafeRow, state: Any, timestamp: Long): Unit = { + val stateRow = getStateRow(state) + setTimestamp(stateRow, timestamp) + store.put(keyRow, stateRow) + } + + /** Removed all information related to a key */ + def removeState(store: StateStore, keyRow: UnsafeRow): Unit = { + store.remove(keyRow) + } + + /** Get all the keys and corresponding state rows in the state store */ + def getAllState(store: StateStore): Iterator[FlatMapGroupsWithState_StateData] = { + val stateDataForGetAllState = FlatMapGroupsWithState_StateData() + store.getRange(None, None).map { pair => + stateDataForGetAllState.withNew( + pair.key, pair.value, getStateObjFromRow(pair.value), getTimestamp(pair.value)) + } + } + + // Ordinals of the information stored in the state row + private lazy val nestedStateOrdinal = 0 + private lazy val timeoutTimestampOrdinal = 1 + + // Get the serializer for the state, taking into account whether we need to save timestamps + private val stateSerializer = { + val nestedStateExpr = CreateNamedStruct( + stateEncoder.namedExpressions.flatMap(e => Seq(Literal(e.name), e))) + if (shouldStoreTimestamp) { + Seq(nestedStateExpr, Literal(GroupStateImpl.NO_TIMESTAMP)) + } else { + Seq(nestedStateExpr) + } + } + + // Get the deserializer for the state. Note that this must be done in the driver, as + // resolving and binding of deserializer expressions to the encoded type can be safely done + // only in the driver. + private val stateDeserializer = { + val boundRefToNestedState = BoundReference(nestedStateOrdinal, stateEncoder.schema, true) + val deser = stateEncoder.resolveAndBind().deserializer.transformUp { + case BoundReference(ordinal, _, _) => GetStructField(boundRefToNestedState, ordinal) + } + CaseWhen(Seq(IsNull(boundRefToNestedState) -> Literal(null)), elseValue = deser).toCodegen() + } + + // Converters for translating state between rows and Java objects + private lazy val getStateObjFromRow = ObjectOperator.deserializeRowToObject( + stateDeserializer, stateSchema.toAttributes) + private lazy val getStateRowFromObj = ObjectOperator.serializeObjectToRow(stateSerializer) + + // Reusable instance for returning state information + private lazy val stateDataForGets = FlatMapGroupsWithState_StateData() + + /** Returns the state as Java object if defined */ + private def getStateObj(stateRow: UnsafeRow): Any = { + if (stateRow == null) null + else getStateObjFromRow(stateRow) + } + + /** Returns the row for an updated state */ + private def getStateRow(obj: Any): UnsafeRow = { + val row = getStateRowFromObj(obj) + if (obj == null) { + row.setNullAt(nestedStateOrdinal) + } + row + } + + /** Returns the timeout timestamp of a state row is set */ + private def getTimestamp(stateRow: UnsafeRow): Long = { + if (shouldStoreTimestamp && stateRow != null) { + stateRow.getLong(timeoutTimestampOrdinal) + } else NO_TIMESTAMP + } + + /** Set the timestamp in a state row */ + private def setTimestamp(stateRow: UnsafeRow, timeoutTimestamps: Long): Unit = { + if (shouldStoreTimestamp) stateRow.setLong(timeoutTimestampOrdinal, timeoutTimestamps) + } +} + +/** + * Class to capture deserialized state and timestamp return by the state manager. + * This is intended for reuse. + */ +case class FlatMapGroupsWithState_StateData( + var keyRow: UnsafeRow = null, + var stateRow: UnsafeRow = null, + var stateObj: Any = null, + var timeoutTimestamp: Long = -1) { + def withNew( + newKeyRow: UnsafeRow, + newStateRow: UnsafeRow, + newStateObj: Any, + newTimeout: Long): this.type = { + keyRow = newKeyRow + stateRow = newStateRow + stateObj = newStateObj + timeoutTimestamp = newTimeout + this + } +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 37648710dfc2a..6b386308c79fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -76,7 +76,7 @@ class SymmetricHashJoinStateManager( /** Get all the values of a key */ def get(key: UnsafeRow): Iterator[UnsafeRow] = { val numValues = keyToNumValues.get(key) - keyWithIndexToValue.getAll(key, numValues) + keyWithIndexToValue.getAll(key, numValues).map(_.value) } /** Append a new value to the key */ @@ -87,70 +87,163 @@ class SymmetricHashJoinStateManager( } /** - * Remove using a predicate on keys. See class docs for more context and implement details. + * Remove using a predicate on keys. + * + * This produces an iterator over the (key, value) pairs satisfying condition(key), where the + * underlying store is updated as a side-effect of producing next. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. */ - def removeByKeyCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator - - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - if (condition(keyToNumValue.key)) { - keyToNumValues.remove(keyToNumValue.key) - keyWithIndexToValue.removeAllValues(keyToNumValue.key, keyToNumValue.numValue) + def removeByKeyCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { + + private val allKeyToNumValues = keyToNumValues.iterator + + private var currentKeyToNumValue: KeyAndNumValues = null + private var currentValues: Iterator[KeyWithIndexAndValue] = null + + private def currentKey = currentKeyToNumValue.key + + private val reusedPair = new UnsafeRowPair() + + private def getAndRemoveValue() = { + val keyWithIndexAndValue = currentValues.next() + keyWithIndexToValue.remove(currentKey, keyWithIndexAndValue.valueIndex) + reusedPair.withRows(currentKey, keyWithIndexAndValue.value) + } + + override def getNext(): UnsafeRowPair = { + // If there are more values for the current key, remove and return the next one. + if (currentValues != null && currentValues.hasNext) { + return getAndRemoveValue() + } + + // If there weren't any values left, try and find the next key that satisfies the removal + // condition and has values. + while (allKeyToNumValues.hasNext) { + currentKeyToNumValue = allKeyToNumValues.next() + if (removalCondition(currentKey)) { + currentValues = keyWithIndexToValue.getAll( + currentKey, currentKeyToNumValue.numValue) + keyToNumValues.remove(currentKey) + + if (currentValues.hasNext) { + return getAndRemoveValue() + } + } + } + + // We only reach here if there were no satisfying keys left, which means we're done. + finished = true + return null } + + override def close: Unit = {} } } /** - * Remove using a predicate on values. See class docs for more context and implementation details. + * Remove using a predicate on values. + * + * At a high level, this produces an iterator over the (key, value) pairs such that value + * satisfies the predicate, where producing an element removes the value from the state store + * and producing all elements with a given key updates it accordingly. + * + * This implies the iterator must be consumed fully without any other operations on this manager + * or the underlying store being interleaved. */ - def removeByValueCondition(condition: UnsafeRow => Boolean): Unit = { - val allKeyToNumValues = keyToNumValues.iterator + def removeByValueCondition(removalCondition: UnsafeRow => Boolean): Iterator[UnsafeRowPair] = { + new NextIterator[UnsafeRowPair] { - while (allKeyToNumValues.hasNext) { - val keyToNumValue = allKeyToNumValues.next - val key = keyToNumValue.key + // Reuse this object to avoid creation+GC overhead. + private val reusedPair = new UnsafeRowPair() - var numValues: Long = keyToNumValue.numValue - var index: Long = 0L - var valueRemoved: Boolean = false - var valueForIndex: UnsafeRow = null + private val allKeyToNumValues = keyToNumValues.iterator - while (index < numValues) { - if (valueForIndex == null) { - valueForIndex = keyWithIndexToValue.get(key, index) + private var currentKey: UnsafeRow = null + private var numValues: Long = 0L + private var index: Long = 0L + private var valueRemoved: Boolean = false + + // Push the data for the current key to the numValues store, and reset the tracking variables + // to their empty state. + private def updateNumValueForCurrentKey(): Unit = { + if (valueRemoved) { + if (numValues >= 1) { + keyToNumValues.put(currentKey, numValues) + } else { + keyToNumValues.remove(currentKey) + } } - if (condition(valueForIndex)) { - if (numValues > 1) { - val valueAtMaxIndex = keyWithIndexToValue.get(key, numValues - 1) - keyWithIndexToValue.put(key, index, valueAtMaxIndex) - keyWithIndexToValue.remove(key, numValues - 1) - valueForIndex = valueAtMaxIndex + + currentKey = null + numValues = 0 + index = 0 + valueRemoved = false + } + + // Find the next value satisfying the condition, updating `currentKey` and `numValues` if + // needed. Returns null when no value can be found. + private def findNextValueForIndex(): UnsafeRow = { + // Loop across all values for the current key, and then all other keys, until we find a + // value satisfying the removal condition. + def hasMoreValuesForCurrentKey = currentKey != null && index < numValues + def hasMoreKeys = allKeyToNumValues.hasNext + while (hasMoreValuesForCurrentKey || hasMoreKeys) { + if (hasMoreValuesForCurrentKey) { + // First search the values for the current key. + val currentValue = keyWithIndexToValue.get(currentKey, index) + if (removalCondition(currentValue)) { + return currentValue + } else { + index += 1 + } + } else if (hasMoreKeys) { + // If we can't find a value for the current key, cleanup and start looking at the next. + // This will also happen the first time the iterator is called. + updateNumValueForCurrentKey() + + val currentKeyToNumValue = allKeyToNumValues.next() + currentKey = currentKeyToNumValue.key + numValues = currentKeyToNumValue.numValue } else { - keyWithIndexToValue.remove(key, 0) - valueForIndex = null + // Should be unreachable, but in any case means a value couldn't be found. + return null } - numValues -= 1 - valueRemoved = true - } else { - valueForIndex = null - index += 1 } + + // We tried and failed to find the next value. + return null } - if (valueRemoved) { - if (numValues >= 1) { - keyToNumValues.put(key, numValues) + + override def getNext(): UnsafeRowPair = { + val currentValue = findNextValueForIndex() + + // If there's no value, clean up and finish. There aren't any more available. + if (currentValue == null) { + updateNumValueForCurrentKey() + finished = true + return null + } + + // The backing store is arraylike - we as the caller are responsible for filling back in + // any hole. So we swap the last element into the hole and decrement numValues to shorten. + // clean + if (numValues > 1) { + val valueAtMaxIndex = keyWithIndexToValue.get(currentKey, numValues - 1) + keyWithIndexToValue.put(currentKey, index, valueAtMaxIndex) + keyWithIndexToValue.remove(currentKey, numValues - 1) } else { - keyToNumValues.remove(key) + keyWithIndexToValue.remove(currentKey, 0) } + numValues -= 1 + valueRemoved = true + + return reusedPair.withRows(currentKey, currentValue) } - } - } - def iterator(): Iterator[UnsafeRowPair] = { - val pair = new UnsafeRowPair() - keyWithIndexToValue.iterator.map { x => - pair.withRows(x.key, x.value) + override def close: Unit = {} } } @@ -291,7 +384,7 @@ class SymmetricHashJoinStateManager( } /** A wrapper around a [[StateStore]] that stores [(key, index) -> value]. */ - private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValuesType) { + private class KeyWithIndexToValueStore extends StateStoreHandler(KeyWithIndexToValueType) { private val keyWithIndexExprs = keyAttributes :+ Literal(1L) private val keyWithIndexSchema = keySchema.add("index", LongType) private val indexOrdinalInKeyWithIndexRow = keyAttributes.size @@ -309,19 +402,24 @@ class SymmetricHashJoinStateManager( stateStore.get(keyWithIndexRow(key, valueIndex)) } - /** Get all the values for key and all indices. */ - def getAll(key: UnsafeRow, numValues: Long): Iterator[UnsafeRow] = { + /** + * Get all values and indices for the provided key. + * Should not return null. + */ + def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { + val keyWithIndexAndValue = new KeyWithIndexAndValue() var index = 0 - new NextIterator[UnsafeRow] { - override protected def getNext(): UnsafeRow = { + new NextIterator[KeyWithIndexAndValue] { + override protected def getNext(): KeyWithIndexAndValue = { if (index >= numValues) { finished = true null } else { val keyWithIndex = keyWithIndexRow(key, index) val value = stateStore.get(keyWithIndex) + keyWithIndexAndValue.withNew(key, index, value) index += 1 - value + keyWithIndexAndValue } } @@ -373,7 +471,7 @@ class SymmetricHashJoinStateManager( object SymmetricHashJoinStateManager { def allStateStoreNames(joinSides: JoinSide*): Seq[String] = { - val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValuesType) + val allStateStoreTypes: Seq[StateStoreType] = Seq(KeyToNumValuesType, KeyWithIndexToValueType) for (joinSide <- joinSides; stateStoreType <- allStateStoreTypes) yield { getStateStoreName(joinSide, stateStoreType) } @@ -385,8 +483,8 @@ object SymmetricHashJoinStateManager { override def toString(): String = "keyToNumValues" } - private case object KeyWithIndexToValuesType extends StateStoreType { - override def toString(): String = "keyWithIndexToNumValues" + private case object KeyWithIndexToValueType extends StateStoreType { + override def toString(): String = "keyWithIndexToValue" } private def getStateStoreName(joinSide: JoinSide, storeType: StateStoreType): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index fb960fbdde8b3..b9b07a2e688f9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -43,10 +43,11 @@ case class StatefulOperatorStateInfo( checkpointLocation: String, queryRunId: UUID, operatorId: Long, - storeVersion: Long) { + storeVersion: Long, + numPartitions: Int) { override def toString(): String = { s"state info [ checkpoint = $checkpointLocation, runId = $queryRunId, " + - s"opId = $operatorId, ver = $storeVersion]" + s"opId = $operatorId, ver = $storeVersion, numPartitions = $numPartitions]" } } @@ -225,7 +226,7 @@ case class StateStoreRestoreExec( val key = getKey(row) val savedState = store.get(key) numOutputRows += 1 - row +: Option(savedState).toSeq + Option(savedState).toSeq :+ row } } } @@ -239,7 +240,7 @@ case class StateStoreRestoreExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } } @@ -386,7 +387,7 @@ case class StateStoreSaveExec( if (keyExpressions.isEmpty) { AllTuples :: Nil } else { - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil } } } @@ -401,7 +402,7 @@ case class StreamingDeduplicateExec( /** Distribute by grouping attributes */ override def requiredChildDistribution: Seq[Distribution] = - ClusteredDistribution(keyExpressions) :: Nil + ClusteredDistribution(keyExpressions, stateInfo.map(_.numPartitions)) :: Nil override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala index 41929ed474fa7..f9c69864a3361 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/AllExecutionsPage.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.ui import javax.servlet.http.HttpServletRequest import scala.collection.mutable -import scala.xml.Node +import scala.xml.{Node, NodeSeq} import org.apache.commons.lang3.StringEscapeUtils @@ -38,19 +38,19 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L if (listener.getRunningExecutions.nonEmpty) { _content ++= new RunningExecutionTable( - parent, "Running Queries", currentTime, + parent, s"Running Queries (${listener.getRunningExecutions.size})", currentTime, listener.getRunningExecutions.sortBy(_.submissionTime).reverse).toNodeSeq } if (listener.getCompletedExecutions.nonEmpty) { _content ++= new CompletedExecutionTable( - parent, "Completed Queries", currentTime, + parent, s"Completed Queries (${listener.getCompletedExecutions.size})", currentTime, listener.getCompletedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq } if (listener.getFailedExecutions.nonEmpty) { _content ++= new FailedExecutionTable( - parent, "Failed Queries", currentTime, + parent, s"Failed Queries (${listener.getFailedExecutions.size})", currentTime, listener.getFailedExecutions.sortBy(_.submissionTime).reverse).toNodeSeq } _content @@ -61,7 +61,36 @@ private[ui] class AllExecutionsPage(parent: SQLTab) extends WebUIPage("") with L details.parentNode.querySelector('.stage-details').classList.toggle('collapsed') }} - UIUtils.headerSparkPage("SQL", content, parent, Some(5000)) + val summary: NodeSeq = +
    +
      + { + if (listener.getRunningExecutions.nonEmpty) { +
    • + Running Queries: + {listener.getRunningExecutions.size} +
    • + } + } + { + if (listener.getCompletedExecutions.nonEmpty) { +
    • + Completed Queries: + {listener.getCompletedExecutions.size} +
    • + } + } + { + if (listener.getFailedExecutions.nonEmpty) { +
    • + Failed Queries: + {listener.getFailedExecutions.size} +
    • + } + } +
    +
    + UIUtils.headerSparkPage("SQL", summary ++ content, parent, Some(5000)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 4e756084bbdbb..2867b4cd7da5e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -266,7 +266,8 @@ abstract class BaseSessionStateBuilder( * This gets cloned from parent if available, otherwise is a new instance is created. */ protected def listenerManager: ExecutionListenerManager = { - parentState.map(_.listenerManager.clone()).getOrElse(new ExecutionListenerManager) + parentState.map(_.listenerManager.clone()).getOrElse( + new ExecutionListenerManager(session.sparkContext.conf)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 142b005850a49..fdd25330c5e67 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -474,13 +474,20 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def refreshTable(tableName: String): Unit = { val tableIdent = sparkSession.sessionState.sqlParser.parseTableIdentifier(tableName) - // Temp tables: refresh (or invalidate) any metadata/data cached in the plan recursively. - // Non-temp tables: refresh the metadata cache. - sessionCatalog.refreshTable(tableIdent) + val tableMetadata = sessionCatalog.getTempViewOrPermanentTableMetadata(tableIdent) + val table = sparkSession.table(tableIdent) + + if (tableMetadata.tableType == CatalogTableType.VIEW) { + // Temp or persistent views: refresh (or invalidate) any metadata/data cached + // in the plan recursively. + table.queryExecution.analyzed.foreach(_.refresh()) + } else { + // Non-temp tables: refresh the metadata cache. + sessionCatalog.refreshTable(tableIdent) + } // If this table is cached as an InMemoryRelation, drop the original // cached version and make the new version cached lazily. - val table = sparkSession.table(tableIdent) if (isCached(table)) { // Uncache the logicalPlan. sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 3b44c1de93a61..e3f106c41c7ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -23,30 +23,36 @@ import org.apache.spark.sql.types._ private case object OracleDialect extends JdbcDialect { + private[jdbc] val BINARY_FLOAT = 100 + private[jdbc] val BINARY_DOUBLE = 101 + private[jdbc] val TIMESTAMPTZ = -101 override def canHandle(url: String): Boolean = url.startsWith("jdbc:oracle") override def getCatalystType( sqlType: Int, typeName: String, size: Int, md: MetadataBuilder): Option[DataType] = { - if (sqlType == Types.NUMERIC) { - val scale = if (null != md) md.build().getLong("scale") else 0L - size match { - // Handle NUMBER fields that have no precision/scale in special way - // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale - // For more details, please see - // https://github.com/apache/spark/pull/8780#issuecomment-145598968 - // and - // https://github.com/apache/spark/pull/8780#issuecomment-144541760 - case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts - // this to NUMERIC with -127 scale - // Not sure if there is a more robust way to identify the field as a float (or other - // numeric types that do not specify a scale. - case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) - case _ => None - } - } else { - None + sqlType match { + case Types.NUMERIC => + val scale = if (null != md) md.build().getLong("scale") else 0L + size match { + // Handle NUMBER fields that have no precision/scale in special way + // because JDBC ResultSetMetaData converts this to 0 precision and -127 scale + // For more details, please see + // https://github.com/apache/spark/pull/8780#issuecomment-145598968 + // and + // https://github.com/apache/spark/pull/8780#issuecomment-144541760 + case 0 => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + // Handle FLOAT fields in a special way because JDBC ResultSetMetaData converts + // this to NUMERIC with -127 scale + // Not sure if there is a more robust way to identify the field as a float (or other + // numeric types that do not specify a scale. + case _ if scale == -127L => Option(DecimalType(DecimalType.MAX_PRECISION, 10)) + case _ => None + } + case TIMESTAMPTZ => Some(TimestampType) // Value for Timestamp with Time Zone in Oracle + case BINARY_FLOAT => Some(FloatType) // Value for OracleTypes.BINARY_FLOAT + case BINARY_DOUBLE => Some(DoubleType) // Value for OracleTypes.BINARY_DOUBLE + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala index 04a956b70b022..e9510c903acae 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/GroupState.scala @@ -205,11 +205,7 @@ trait GroupState[S] extends LogicalGroupState[S] { /** Get the state value as a scala Option. */ def getOption: Option[S] - /** - * Update the value of the state. Note that `null` is not a valid value, and it throws - * IllegalArgumentException. - */ - @throws[IllegalArgumentException]("when updating with null") + /** Update the value of the state. */ def update(newState: S): Unit /** Remove this state. */ @@ -217,80 +213,114 @@ trait GroupState[S] extends LogicalGroupState[S] { /** * Whether the function has been called because the key has timed out. - * @note This can return true only when timeouts are enabled in `[map/flatmap]GroupsWithStates`. + * @note This can return true only when timeouts are enabled in `[map/flatMap]GroupsWithState`. */ def hasTimedOut: Boolean + /** * Set the timeout duration in ms for this key. * - * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'durationMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutDuration(durationMs: Long): Unit + /** * Set the timeout duration for this key as a string. For example, "1 hour", "2 days", etc. * - * @note ProcessingTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Processing time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ @throws[IllegalArgumentException]("if 'duration' is not a valid duration") - @throws[IllegalStateException]("when state is either not initialized, or already removed") @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutDuration(duration: String): Unit - @throws[IllegalArgumentException]("if 'timestampMs' is not positive") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as milliseconds in epoch time. * This timestamp cannot be older than the current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no effect when used in a batch query. */ + @throws[IllegalArgumentException]( + "if 'timestampMs' is not positive or less than the current watermark in a streaming query") + @throws[UnsupportedOperationException]( + "if processing time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestampMs: Long): Unit - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as milliseconds in epoch time and an additional * duration as a string (e.g. "1 hour", "2 days", etc.). * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[IllegalArgumentException]( + "if 'additionalDuration' is invalid or the final timeout timestamp is less than " + + "the current watermark in a streaming query") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestampMs: Long, additionalDuration: String): Unit - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as a java.sql.Date. * This timestamp cannot be older than the current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestamp: java.sql.Date): Unit - @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") - @throws[IllegalStateException]("when state is either not initialized, or already removed") - @throws[UnsupportedOperationException]( - "if 'timeout' has not been enabled in [map|flatMap]GroupsWithState in a streaming query") + /** * Set the timeout timestamp for this key as a java.sql.Date and an additional * duration as a string (e.g. "1 hour", "2 days", etc.). * The final timestamp (including the additional duration) cannot be older than the * current watermark. * - * @note EventTimeTimeout must be enabled in `[map/flatmap]GroupsWithStates`. + * @note [[GroupStateTimeout Event time timeout]] must be enabled in + * `[map/flatMap]GroupsWithState` for calling this method. + * @note This method has no side effect when used in a batch query. */ + @throws[IllegalArgumentException]("if 'additionalDuration' is invalid") + @throws[UnsupportedOperationException]( + "if event time timeout has not been enabled in [map|flatMap]GroupsWithState") def setTimeoutTimestamp(timestamp: java.sql.Date, additionalDuration: String): Unit + + + /** + * Get the current event time watermark as milliseconds in epoch time. + * + * @note In a streaming query, this can be called only when watermark is set before calling + * `[map/flatMap]GroupsWithState`. In a batch query, this method always returns -1. + */ + @throws[UnsupportedOperationException]( + "if watermark has not been set before in [map|flatMap]GroupsWithState") + def getCurrentWatermarkMs(): Long + + + /** + * Get the current processing time as milliseconds in epoch time. + * @note In a streaming query, this will return a constant value throughout the duration of a + * trigger, even if the trigger is re-executed. + */ + def getCurrentProcessingTimeMs(): Long } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala index f6240d85fba6f..2b46233e1a5df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/QueryExecutionListener.scala @@ -22,9 +22,12 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import scala.collection.mutable.ListBuffer import scala.util.control.NonFatal +import org.apache.spark.SparkConf import org.apache.spark.annotation.{DeveloperApi, Experimental, InterfaceStability} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.StaticSQLConf._ +import org.apache.spark.util.Utils /** * :: Experimental :: @@ -72,7 +75,14 @@ trait QueryExecutionListener { */ @Experimental @InterfaceStability.Evolving -class ExecutionListenerManager private[sql] () extends Logging { +class ExecutionListenerManager private extends Logging { + + private[sql] def this(conf: SparkConf) = { + this() + conf.get(QUERY_EXECUTION_LISTENERS).foreach { classNames => + Utils.loadExtensions(classOf[QueryExecutionListener], classNames, conf).foreach(register) + } + } /** * Registers the specified [[QueryExecutionListener]]. diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 7aacf0346d2fb..da2c13f70c52a 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -54,6 +54,11 @@ public Filter[] pushFilters(Filter[] filters) { return new Filter[0]; } + @Override + public Filter[] pushedFilters() { + return filters; + } + @Override public List> createReadTasks() { List> res = new ArrayList<>(); diff --git a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql index 8aff4cb524199..9721f8c60ebce 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/group-analytics.sql @@ -38,11 +38,11 @@ SELECT course, year, GROUPING(course), GROUPING(year), GROUPING_ID(course, year) GROUP BY CUBE(course, year); SELECT course, year, GROUPING(course) FROM courseSales GROUP BY course, year; SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY course, year; -SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year); +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year; -- GROUPING/GROUPING_ID in having clause SELECT course, year FROM courseSales GROUP BY CUBE(course, year) -HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0; +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, year; SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING(course) > 0; SELECT course, year FROM courseSales GROUP BY course, year HAVING GROUPING_ID(course) > 0; SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0; @@ -54,7 +54,7 @@ SELECT course, year, GROUPING_ID(course, year) FROM courseSales GROUP BY CUBE(co ORDER BY GROUPING(course), GROUPING(year), course, year; SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING(course); SELECT course, year FROM courseSales GROUP BY course, year ORDER BY GROUPING_ID(course); -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id; +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year; -- Aliases in SELECT could be used in ROLLUP/CUBE/GROUPING SETS SELECT a + b AS k1, b AS k2, SUM(a - b) FROM testData GROUP BY CUBE(k1, k2); diff --git a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out index ce7a16a4d0c81..3439a05727f95 100644 --- a/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/group-analytics.sql.out @@ -223,22 +223,29 @@ grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 16 -SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) +SELECT course, year, grouping__id FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year -- !query 16 schema -struct<> +struct -- !query 16 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java 2012 0 +Java 2013 0 +dotNET 2012 0 +dotNET 2013 0 +Java NULL 1 +dotNET NULL 1 +NULL 2012 2 +NULL 2013 2 +NULL NULL 3 -- !query 17 SELECT course, year FROM courseSales GROUP BY CUBE(course, year) -HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 +HAVING GROUPING(year) = 1 AND GROUPING_ID(course, year) > 0 ORDER BY course, year -- !query 17 schema struct -- !query 17 output -Java NULL NULL NULL +Java NULL dotNET NULL @@ -263,10 +270,13 @@ grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 20 SELECT course, year FROM courseSales GROUP BY CUBE(course, year) HAVING grouping__id > 0 -- !query 20 schema -struct<> +struct -- !query 20 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java NULL +NULL 2012 +NULL 2013 +NULL NULL +dotNET NULL -- !query 21 @@ -322,12 +332,19 @@ grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup; -- !query 25 -SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id +SELECT course, year FROM courseSales GROUP BY CUBE(course, year) ORDER BY grouping__id, course, year -- !query 25 schema -struct<> +struct -- !query 25 output -org.apache.spark.sql.AnalysisException -grouping__id is deprecated; use grouping_id() instead; +Java 2012 +Java 2013 +dotNET 2012 +dotNET 2013 +Java NULL +dotNET NULL +NULL 2012 +NULL 2013 +NULL NULL -- !query 26 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out index e4b1a2dbc675c..2586f26f71c35 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out @@ -63,7 +63,7 @@ WHERE t1a IN (SELECT min(t2a) struct<> -- !query 4 output org.apache.spark.sql.AnalysisException -resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]); +Resolved attribute(s) t2b#x missing from min(t2a)#x,t2c#x in operator !Filter t2c#x IN (list#x [t2b#x]).; -- !query 5 diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql new file mode 100755 index 0000000000000..79dd3d516e8c7 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q10.sql @@ -0,0 +1,70 @@ +-- start query 10 in stream 0 using template query10.tpl +with +v1 as ( + select + ws_bill_customer_sk as customer_sk + from web_sales, + date_dim + where ws_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 + union all + select + cs_ship_customer_sk as customer_sk + from catalog_sales, + date_dim + where cs_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 +), +v2 as ( + select + ss_customer_sk as customer_sk + from store_sales, + date_dim + where ss_sold_date_sk = d_date_sk + and d_year = 2002 + and d_moy between 4 and 4+3 +) +select + cd_gender, + cd_marital_status, + cd_education_status, + count(*) cnt1, + cd_purchase_estimate, + count(*) cnt2, + cd_credit_rating, + count(*) cnt3, + cd_dep_count, + count(*) cnt4, + cd_dep_employed_count, + count(*) cnt5, + cd_dep_college_count, + count(*) cnt6 +from customer c +join customer_address ca on (c.c_current_addr_sk = ca.ca_address_sk) +join customer_demographics on (cd_demo_sk = c.c_current_cdemo_sk) +left semi join v1 on (v1.customer_sk = c.c_customer_sk) +left semi join v2 on (v2.customer_sk = c.c_customer_sk) +where + ca_county in ('Walker County','Richland County','Gaines County','Douglas County','Dona Ana County') +group by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +order by + cd_gender, + cd_marital_status, + cd_education_status, + cd_purchase_estimate, + cd_credit_rating, + cd_dep_count, + cd_dep_employed_count, + cd_dep_college_count +limit 100 +-- end query 10 in stream 0 using template query10.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql new file mode 100755 index 0000000000000..1799827762916 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q19.sql @@ -0,0 +1,38 @@ +-- start query 19 in stream 0 using template query19.tpl +select + i_brand_id brand_id, + i_brand brand, + i_manufact_id, + i_manufact, + sum(ss_ext_sales_price) ext_price +from + date_dim, + store_sales, + item, + customer, + customer_address, + store +where + d_date_sk = ss_sold_date_sk + and ss_item_sk = i_item_sk + and i_manager_id = 7 + and d_moy = 11 + and d_year = 1999 + and ss_customer_sk = c_customer_sk + and c_current_addr_sk = ca_address_sk + and substr(ca_zip, 1, 5) <> substr(s_zip, 1, 5) + and ss_store_sk = s_store_sk + and ss_sold_date_sk between 2451484 and 2451513 -- partition key filter +group by + i_brand, + i_brand_id, + i_manufact_id, + i_manufact +order by + ext_price desc, + i_brand, + i_brand_id, + i_manufact_id, + i_manufact +limit 100 +-- end query 19 in stream 0 using template query19.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql new file mode 100755 index 0000000000000..dedbc62a2ab2e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q27.sql @@ -0,0 +1,43 @@ +-- start query 27 in stream 0 using template query27.tpl + with results as + (select i_item_id, + s_state, + ss_quantity agg1, + ss_list_price agg2, + ss_coupon_amt agg3, + ss_sales_price agg4 + --0 as g_state, + --avg(ss_quantity) agg1, + --avg(ss_list_price) agg2, + --avg(ss_coupon_amt) agg3, + --avg(ss_sales_price) agg4 + from store_sales, customer_demographics, date_dim, store, item + where ss_sold_date_sk = d_date_sk and + ss_sold_date_sk between 2451545 and 2451910 and + ss_item_sk = i_item_sk and + ss_store_sk = s_store_sk and + ss_cdemo_sk = cd_demo_sk and + cd_gender = 'F' and + cd_marital_status = 'D' and + cd_education_status = 'Primary' and + d_year = 2000 and + s_state in ('TN','AL', 'SD', 'SD', 'SD', 'SD') + --group by i_item_id, s_state + ) + + select i_item_id, + s_state, g_state, agg1, agg2, agg3, agg4 + from ( + select i_item_id, s_state, 0 as g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, avg(agg4) agg4 from results + group by i_item_id, s_state + union all + select i_item_id, NULL AS s_state, 1 AS g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, + avg(agg4) agg4 from results + group by i_item_id + union all + select NULL AS i_item_id, NULL as s_state, 1 as g_state, avg(agg1) agg1, avg(agg2) agg2, avg(agg3) agg3, + avg(agg4) agg4 from results + ) foo + order by i_item_id, s_state + limit 100 +-- end query 27 in stream 0 using template query27.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql new file mode 100755 index 0000000000000..35b0a20f80a4e --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q3.sql @@ -0,0 +1,228 @@ +-- start query 3 in stream 0 using template query3.tpl +select + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + sum(ss_net_profit) sum_agg +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manufact_id = 436 + and dt.d_moy = 12 + -- partition key filters + and ( +ss_sold_date_sk between 2415355 and 2415385 +or ss_sold_date_sk between 2415720 and 2415750 +or ss_sold_date_sk between 2416085 and 2416115 +or ss_sold_date_sk between 2416450 and 2416480 +or ss_sold_date_sk between 2416816 and 2416846 +or ss_sold_date_sk between 2417181 and 2417211 +or ss_sold_date_sk between 2417546 and 2417576 +or ss_sold_date_sk between 2417911 and 2417941 +or ss_sold_date_sk between 2418277 and 2418307 +or ss_sold_date_sk between 2418642 and 2418672 +or ss_sold_date_sk between 2419007 and 2419037 +or ss_sold_date_sk between 2419372 and 2419402 +or ss_sold_date_sk between 2419738 and 2419768 +or ss_sold_date_sk between 2420103 and 2420133 +or ss_sold_date_sk between 2420468 and 2420498 +or ss_sold_date_sk between 2420833 and 2420863 +or ss_sold_date_sk between 2421199 and 2421229 +or ss_sold_date_sk between 2421564 and 2421594 +or ss_sold_date_sk between 2421929 and 2421959 +or ss_sold_date_sk between 2422294 and 2422324 +or ss_sold_date_sk between 2422660 and 2422690 +or ss_sold_date_sk between 2423025 and 2423055 +or ss_sold_date_sk between 2423390 and 2423420 +or ss_sold_date_sk between 2423755 and 2423785 +or ss_sold_date_sk between 2424121 and 2424151 +or ss_sold_date_sk between 2424486 and 2424516 +or ss_sold_date_sk between 2424851 and 2424881 +or ss_sold_date_sk between 2425216 and 2425246 +or ss_sold_date_sk between 2425582 and 2425612 +or ss_sold_date_sk between 2425947 and 2425977 +or ss_sold_date_sk between 2426312 and 2426342 +or ss_sold_date_sk between 2426677 and 2426707 +or ss_sold_date_sk between 2427043 and 2427073 +or ss_sold_date_sk between 2427408 and 2427438 +or ss_sold_date_sk between 2427773 and 2427803 +or ss_sold_date_sk between 2428138 and 2428168 +or ss_sold_date_sk between 2428504 and 2428534 +or ss_sold_date_sk between 2428869 and 2428899 +or ss_sold_date_sk between 2429234 and 2429264 +or ss_sold_date_sk between 2429599 and 2429629 +or ss_sold_date_sk between 2429965 and 2429995 +or ss_sold_date_sk between 2430330 and 2430360 +or ss_sold_date_sk between 2430695 and 2430725 +or ss_sold_date_sk between 2431060 and 2431090 +or ss_sold_date_sk between 2431426 and 2431456 +or ss_sold_date_sk between 2431791 and 2431821 +or ss_sold_date_sk between 2432156 and 2432186 +or ss_sold_date_sk between 2432521 and 2432551 +or ss_sold_date_sk between 2432887 and 2432917 +or ss_sold_date_sk between 2433252 and 2433282 +or ss_sold_date_sk between 2433617 and 2433647 +or ss_sold_date_sk between 2433982 and 2434012 +or ss_sold_date_sk between 2434348 and 2434378 +or ss_sold_date_sk between 2434713 and 2434743 +or ss_sold_date_sk between 2435078 and 2435108 +or ss_sold_date_sk between 2435443 and 2435473 +or ss_sold_date_sk between 2435809 and 2435839 +or ss_sold_date_sk between 2436174 and 2436204 +or ss_sold_date_sk between 2436539 and 2436569 +or ss_sold_date_sk between 2436904 and 2436934 +or ss_sold_date_sk between 2437270 and 2437300 +or ss_sold_date_sk between 2437635 and 2437665 +or ss_sold_date_sk between 2438000 and 2438030 +or ss_sold_date_sk between 2438365 and 2438395 +or ss_sold_date_sk between 2438731 and 2438761 +or ss_sold_date_sk between 2439096 and 2439126 +or ss_sold_date_sk between 2439461 and 2439491 +or ss_sold_date_sk between 2439826 and 2439856 +or ss_sold_date_sk between 2440192 and 2440222 +or ss_sold_date_sk between 2440557 and 2440587 +or ss_sold_date_sk between 2440922 and 2440952 +or ss_sold_date_sk between 2441287 and 2441317 +or ss_sold_date_sk between 2441653 and 2441683 +or ss_sold_date_sk between 2442018 and 2442048 +or ss_sold_date_sk between 2442383 and 2442413 +or ss_sold_date_sk between 2442748 and 2442778 +or ss_sold_date_sk between 2443114 and 2443144 +or ss_sold_date_sk between 2443479 and 2443509 +or ss_sold_date_sk between 2443844 and 2443874 +or ss_sold_date_sk between 2444209 and 2444239 +or ss_sold_date_sk between 2444575 and 2444605 +or ss_sold_date_sk between 2444940 and 2444970 +or ss_sold_date_sk between 2445305 and 2445335 +or ss_sold_date_sk between 2445670 and 2445700 +or ss_sold_date_sk between 2446036 and 2446066 +or ss_sold_date_sk between 2446401 and 2446431 +or ss_sold_date_sk between 2446766 and 2446796 +or ss_sold_date_sk between 2447131 and 2447161 +or ss_sold_date_sk between 2447497 and 2447527 +or ss_sold_date_sk between 2447862 and 2447892 +or ss_sold_date_sk between 2448227 and 2448257 +or ss_sold_date_sk between 2448592 and 2448622 +or ss_sold_date_sk between 2448958 and 2448988 +or ss_sold_date_sk between 2449323 and 2449353 +or ss_sold_date_sk between 2449688 and 2449718 +or ss_sold_date_sk between 2450053 and 2450083 +or ss_sold_date_sk between 2450419 and 2450449 +or ss_sold_date_sk between 2450784 and 2450814 +or ss_sold_date_sk between 2451149 and 2451179 +or ss_sold_date_sk between 2451514 and 2451544 +or ss_sold_date_sk between 2451880 and 2451910 +or ss_sold_date_sk between 2452245 and 2452275 +or ss_sold_date_sk between 2452610 and 2452640 +or ss_sold_date_sk between 2452975 and 2453005 +or ss_sold_date_sk between 2453341 and 2453371 +or ss_sold_date_sk between 2453706 and 2453736 +or ss_sold_date_sk between 2454071 and 2454101 +or ss_sold_date_sk between 2454436 and 2454466 +or ss_sold_date_sk between 2454802 and 2454832 +or ss_sold_date_sk between 2455167 and 2455197 +or ss_sold_date_sk between 2455532 and 2455562 +or ss_sold_date_sk between 2455897 and 2455927 +or ss_sold_date_sk between 2456263 and 2456293 +or ss_sold_date_sk between 2456628 and 2456658 +or ss_sold_date_sk between 2456993 and 2457023 +or ss_sold_date_sk between 2457358 and 2457388 +or ss_sold_date_sk between 2457724 and 2457754 +or ss_sold_date_sk between 2458089 and 2458119 +or ss_sold_date_sk between 2458454 and 2458484 +or ss_sold_date_sk between 2458819 and 2458849 +or ss_sold_date_sk between 2459185 and 2459215 +or ss_sold_date_sk between 2459550 and 2459580 +or ss_sold_date_sk between 2459915 and 2459945 +or ss_sold_date_sk between 2460280 and 2460310 +or ss_sold_date_sk between 2460646 and 2460676 +or ss_sold_date_sk between 2461011 and 2461041 +or ss_sold_date_sk between 2461376 and 2461406 +or ss_sold_date_sk between 2461741 and 2461771 +or ss_sold_date_sk between 2462107 and 2462137 +or ss_sold_date_sk between 2462472 and 2462502 +or ss_sold_date_sk between 2462837 and 2462867 +or ss_sold_date_sk between 2463202 and 2463232 +or ss_sold_date_sk between 2463568 and 2463598 +or ss_sold_date_sk between 2463933 and 2463963 +or ss_sold_date_sk between 2464298 and 2464328 +or ss_sold_date_sk between 2464663 and 2464693 +or ss_sold_date_sk between 2465029 and 2465059 +or ss_sold_date_sk between 2465394 and 2465424 +or ss_sold_date_sk between 2465759 and 2465789 +or ss_sold_date_sk between 2466124 and 2466154 +or ss_sold_date_sk between 2466490 and 2466520 +or ss_sold_date_sk between 2466855 and 2466885 +or ss_sold_date_sk between 2467220 and 2467250 +or ss_sold_date_sk between 2467585 and 2467615 +or ss_sold_date_sk between 2467951 and 2467981 +or ss_sold_date_sk between 2468316 and 2468346 +or ss_sold_date_sk between 2468681 and 2468711 +or ss_sold_date_sk between 2469046 and 2469076 +or ss_sold_date_sk between 2469412 and 2469442 +or ss_sold_date_sk between 2469777 and 2469807 +or ss_sold_date_sk between 2470142 and 2470172 +or ss_sold_date_sk between 2470507 and 2470537 +or ss_sold_date_sk between 2470873 and 2470903 +or ss_sold_date_sk between 2471238 and 2471268 +or ss_sold_date_sk between 2471603 and 2471633 +or ss_sold_date_sk between 2471968 and 2471998 +or ss_sold_date_sk between 2472334 and 2472364 +or ss_sold_date_sk between 2472699 and 2472729 +or ss_sold_date_sk between 2473064 and 2473094 +or ss_sold_date_sk between 2473429 and 2473459 +or ss_sold_date_sk between 2473795 and 2473825 +or ss_sold_date_sk between 2474160 and 2474190 +or ss_sold_date_sk between 2474525 and 2474555 +or ss_sold_date_sk between 2474890 and 2474920 +or ss_sold_date_sk between 2475256 and 2475286 +or ss_sold_date_sk between 2475621 and 2475651 +or ss_sold_date_sk between 2475986 and 2476016 +or ss_sold_date_sk between 2476351 and 2476381 +or ss_sold_date_sk between 2476717 and 2476747 +or ss_sold_date_sk between 2477082 and 2477112 +or ss_sold_date_sk between 2477447 and 2477477 +or ss_sold_date_sk between 2477812 and 2477842 +or ss_sold_date_sk between 2478178 and 2478208 +or ss_sold_date_sk between 2478543 and 2478573 +or ss_sold_date_sk between 2478908 and 2478938 +or ss_sold_date_sk between 2479273 and 2479303 +or ss_sold_date_sk between 2479639 and 2479669 +or ss_sold_date_sk between 2480004 and 2480034 +or ss_sold_date_sk between 2480369 and 2480399 +or ss_sold_date_sk between 2480734 and 2480764 +or ss_sold_date_sk between 2481100 and 2481130 +or ss_sold_date_sk between 2481465 and 2481495 +or ss_sold_date_sk between 2481830 and 2481860 +or ss_sold_date_sk between 2482195 and 2482225 +or ss_sold_date_sk between 2482561 and 2482591 +or ss_sold_date_sk between 2482926 and 2482956 +or ss_sold_date_sk between 2483291 and 2483321 +or ss_sold_date_sk between 2483656 and 2483686 +or ss_sold_date_sk between 2484022 and 2484052 +or ss_sold_date_sk between 2484387 and 2484417 +or ss_sold_date_sk between 2484752 and 2484782 +or ss_sold_date_sk between 2485117 and 2485147 +or ss_sold_date_sk between 2485483 and 2485513 +or ss_sold_date_sk between 2485848 and 2485878 +or ss_sold_date_sk between 2486213 and 2486243 +or ss_sold_date_sk between 2486578 and 2486608 +or ss_sold_date_sk between 2486944 and 2486974 +or ss_sold_date_sk between 2487309 and 2487339 +or ss_sold_date_sk between 2487674 and 2487704 +or ss_sold_date_sk between 2488039 and 2488069 +) +group by + dt.d_year, + item.i_brand, + item.i_brand_id +order by + dt.d_year, + sum_agg desc, + brand_id +limit 100 +-- end query 3 in stream 0 using template query3.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql new file mode 100755 index 0000000000000..d11696e5e0c34 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q34.sql @@ -0,0 +1,45 @@ +-- start query 34 in stream 0 using template query34.tpl +select + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +from + (select + ss_ticket_number, + ss_customer_sk, + count(*) cnt + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and (date_dim.d_dom between 1 and 3 + or date_dim.d_dom between 25 and 28) + and (household_demographics.hd_buy_potential = '>10000' + or household_demographics.hd_buy_potential = 'Unknown') + and household_demographics.hd_vehicle_count > 0 + and (case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count / household_demographics.hd_vehicle_count else null end) > 1.2 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_county in ('Saginaw County', 'Sumner County', 'Appanoose County', 'Daviess County', 'Fairfield County', 'Raleigh County', 'Ziebach County', 'Williamson County') + and ss_sold_date_sk between 2450816 and 2451910 -- partition key filter + group by + ss_ticket_number, + ss_customer_sk + ) dn, + customer +where + ss_customer_sk = c_customer_sk + and cnt between 15 and 20 +order by + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag desc +-- end query 34 in stream 0 using template query34.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql new file mode 100755 index 0000000000000..b6332a8afbebe --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q42.sql @@ -0,0 +1,28 @@ +-- start query 42 in stream 0 using template query42.tpl +select + dt.d_year, + item.i_category_id, + item.i_category, + sum(ss_ext_sales_price) +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manager_id = 1 + and dt.d_moy = 12 + and dt.d_year = 1998 + and ss_sold_date_sk between 2451149 and 2451179 -- partition key filter +group by + dt.d_year, + item.i_category_id, + item.i_category +order by + sum(ss_ext_sales_price) desc, + dt.d_year, + item.i_category_id, + item.i_category +limit 100 +-- end query 42 in stream 0 using template query42.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql new file mode 100755 index 0000000000000..cc2040b2fdb7c --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q43.sql @@ -0,0 +1,36 @@ +-- start query 43 in stream 0 using template query43.tpl +select + s_store_name, + s_store_id, + sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales, + sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales, + sum(case when (d_day_name = 'Tuesday') then ss_sales_price else null end) tue_sales, + sum(case when (d_day_name = 'Wednesday') then ss_sales_price else null end) wed_sales, + sum(case when (d_day_name = 'Thursday') then ss_sales_price else null end) thu_sales, + sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales, + sum(case when (d_day_name = 'Saturday') then ss_sales_price else null end) sat_sales +from + date_dim, + store_sales, + store +where + d_date_sk = ss_sold_date_sk + and s_store_sk = ss_store_sk + and s_gmt_offset = -5 + and d_year = 1998 + and ss_sold_date_sk between 2450816 and 2451179 -- partition key filter +group by + s_store_name, + s_store_id +order by + s_store_name, + s_store_id, + sun_sales, + mon_sales, + tue_sales, + wed_sales, + thu_sales, + fri_sales, + sat_sales +limit 100 +-- end query 43 in stream 0 using template query43.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql new file mode 100755 index 0000000000000..52b7ba4f4b86b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q46.sql @@ -0,0 +1,80 @@ +-- start query 46 in stream 0 using template query46.tpl +select + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + amt, + profit +from + (select + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + from + store_sales, + date_dim, + store, + household_demographics, + customer_address + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and store_sales.ss_addr_sk = customer_address.ca_address_sk + and (household_demographics.hd_dep_count = 5 + or household_demographics.hd_vehicle_count = 3) + and date_dim.d_dow in (6, 0) + and date_dim.d_year in (1999, 1999 + 1, 1999 + 2) + and store.s_city in ('Midway', 'Concord', 'Spring Hill', 'Brownsville', 'Greenville') + -- partition key filter + and ss_sold_date_sk in (2451181, 2451182, 2451188, 2451189, 2451195, 2451196, 2451202, 2451203, 2451209, 2451210, 2451216, 2451217, + 2451223, 2451224, 2451230, 2451231, 2451237, 2451238, 2451244, 2451245, 2451251, 2451252, 2451258, 2451259, + 2451265, 2451266, 2451272, 2451273, 2451279, 2451280, 2451286, 2451287, 2451293, 2451294, 2451300, 2451301, + 2451307, 2451308, 2451314, 2451315, 2451321, 2451322, 2451328, 2451329, 2451335, 2451336, 2451342, 2451343, + 2451349, 2451350, 2451356, 2451357, 2451363, 2451364, 2451370, 2451371, 2451377, 2451378, 2451384, 2451385, + 2451391, 2451392, 2451398, 2451399, 2451405, 2451406, 2451412, 2451413, 2451419, 2451420, 2451426, 2451427, + 2451433, 2451434, 2451440, 2451441, 2451447, 2451448, 2451454, 2451455, 2451461, 2451462, 2451468, 2451469, + 2451475, 2451476, 2451482, 2451483, 2451489, 2451490, 2451496, 2451497, 2451503, 2451504, 2451510, 2451511, + 2451517, 2451518, 2451524, 2451525, 2451531, 2451532, 2451538, 2451539, 2451545, 2451546, 2451552, 2451553, + 2451559, 2451560, 2451566, 2451567, 2451573, 2451574, 2451580, 2451581, 2451587, 2451588, 2451594, 2451595, + 2451601, 2451602, 2451608, 2451609, 2451615, 2451616, 2451622, 2451623, 2451629, 2451630, 2451636, 2451637, + 2451643, 2451644, 2451650, 2451651, 2451657, 2451658, 2451664, 2451665, 2451671, 2451672, 2451678, 2451679, + 2451685, 2451686, 2451692, 2451693, 2451699, 2451700, 2451706, 2451707, 2451713, 2451714, 2451720, 2451721, + 2451727, 2451728, 2451734, 2451735, 2451741, 2451742, 2451748, 2451749, 2451755, 2451756, 2451762, 2451763, + 2451769, 2451770, 2451776, 2451777, 2451783, 2451784, 2451790, 2451791, 2451797, 2451798, 2451804, 2451805, + 2451811, 2451812, 2451818, 2451819, 2451825, 2451826, 2451832, 2451833, 2451839, 2451840, 2451846, 2451847, + 2451853, 2451854, 2451860, 2451861, 2451867, 2451868, 2451874, 2451875, 2451881, 2451882, 2451888, 2451889, + 2451895, 2451896, 2451902, 2451903, 2451909, 2451910, 2451916, 2451917, 2451923, 2451924, 2451930, 2451931, + 2451937, 2451938, 2451944, 2451945, 2451951, 2451952, 2451958, 2451959, 2451965, 2451966, 2451972, 2451973, + 2451979, 2451980, 2451986, 2451987, 2451993, 2451994, 2452000, 2452001, 2452007, 2452008, 2452014, 2452015, + 2452021, 2452022, 2452028, 2452029, 2452035, 2452036, 2452042, 2452043, 2452049, 2452050, 2452056, 2452057, + 2452063, 2452064, 2452070, 2452071, 2452077, 2452078, 2452084, 2452085, 2452091, 2452092, 2452098, 2452099, + 2452105, 2452106, 2452112, 2452113, 2452119, 2452120, 2452126, 2452127, 2452133, 2452134, 2452140, 2452141, + 2452147, 2452148, 2452154, 2452155, 2452161, 2452162, 2452168, 2452169, 2452175, 2452176, 2452182, 2452183, + 2452189, 2452190, 2452196, 2452197, 2452203, 2452204, 2452210, 2452211, 2452217, 2452218, 2452224, 2452225, + 2452231, 2452232, 2452238, 2452239, 2452245, 2452246, 2452252, 2452253, 2452259, 2452260, 2452266, 2452267, + 2452273, 2452274) + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + ca_city + ) dn, + customer, + customer_address current_addr +where + ss_customer_sk = c_customer_sk + and customer.c_current_addr_sk = current_addr.ca_address_sk + and current_addr.ca_city <> bought_city +order by + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number +limit 100 +-- end query 46 in stream 0 using template query46.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql new file mode 100755 index 0000000000000..a510eefb13e17 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q52.sql @@ -0,0 +1,27 @@ +-- start query 52 in stream 0 using template query52.tpl +select + dt.d_year, + item.i_brand_id brand_id, + item.i_brand brand, + sum(ss_ext_sales_price) ext_price +from + date_dim dt, + store_sales, + item +where + dt.d_date_sk = store_sales.ss_sold_date_sk + and store_sales.ss_item_sk = item.i_item_sk + and item.i_manager_id = 1 + and dt.d_moy = 12 + and dt.d_year = 1998 + and ss_sold_date_sk between 2451149 and 2451179 -- added for partition pruning +group by + dt.d_year, + item.i_brand, + item.i_brand_id +order by + dt.d_year, + ext_price desc, + brand_id +limit 100 +-- end query 52 in stream 0 using template query52.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql new file mode 100755 index 0000000000000..fb7bb75183858 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q53.sql @@ -0,0 +1,37 @@ +-- start query 53 in stream 0 using template query53.tpl +select + * +from + (select + i_manufact_id, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over (partition by i_manufact_id) avg_quarterly_sales + from + item, + store_sales, + date_dim, + store + where + ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_store_sk = s_store_sk + and d_month_seq in (1212, 1212 + 1, 1212 + 2, 1212 + 3, 1212 + 4, 1212 + 5, 1212 + 6, 1212 + 7, 1212 + 8, 1212 + 9, 1212 + 10, 1212 + 11) + and ((i_category in ('Books', 'Children', 'Electronics') + and i_class in ('personal', 'portable', 'reference', 'self-help') + and i_brand in ('scholaramalgamalg #14', 'scholaramalgamalg #7', 'exportiunivamalg #9', 'scholaramalgamalg #9')) + or (i_category in ('Women', 'Music', 'Men') + and i_class in ('accessories', 'classical', 'fragrances', 'pants') + and i_brand in ('amalgimporto #1', 'edu packscholar #1', 'exportiimporto #1', 'importoamalg #1'))) + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + i_manufact_id, + d_qoy + ) tmp1 +where + case when avg_quarterly_sales > 0 then abs (sum_sales - avg_quarterly_sales) / avg_quarterly_sales else null end > 0.1 +order by + avg_quarterly_sales, + sum_sales, + i_manufact_id +limit 100 +-- end query 53 in stream 0 using template query53.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql new file mode 100755 index 0000000000000..47b1f0292d901 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q55.sql @@ -0,0 +1,24 @@ +-- start query 55 in stream 0 using template query55.tpl +select + i_brand_id brand_id, + i_brand brand, + sum(ss_ext_sales_price) ext_price +from + date_dim, + store_sales, + item +where + d_date_sk = ss_sold_date_sk + and ss_item_sk = i_item_sk + and i_manager_id = 48 + and d_moy = 11 + and d_year = 2001 + and ss_sold_date_sk between 2452215 and 2452244 +group by + i_brand, + i_brand_id +order by + ext_price desc, + i_brand_id +limit 100 +-- end query 55 in stream 0 using template query55.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql new file mode 100755 index 0000000000000..3d5c4e9d64419 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q59.sql @@ -0,0 +1,83 @@ +-- start query 59 in stream 0 using template query59.tpl +with + wss as + (select + d_week_seq, + ss_store_sk, + sum(case when (d_day_name = 'Sunday') then ss_sales_price else null end) sun_sales, + sum(case when (d_day_name = 'Monday') then ss_sales_price else null end) mon_sales, + sum(case when (d_day_name = 'Tuesday') then ss_sales_price else null end) tue_sales, + sum(case when (d_day_name = 'Wednesday') then ss_sales_price else null end) wed_sales, + sum(case when (d_day_name = 'Thursday') then ss_sales_price else null end) thu_sales, + sum(case when (d_day_name = 'Friday') then ss_sales_price else null end) fri_sales, + sum(case when (d_day_name = 'Saturday') then ss_sales_price else null end) sat_sales + from + store_sales, + date_dim + where + d_date_sk = ss_sold_date_sk + group by + d_week_seq, + ss_store_sk + ) +select + s_store_name1, + s_store_id1, + d_week_seq1, + sun_sales1 / sun_sales2, + mon_sales1 / mon_sales2, + tue_sales1 / tue_sales1, + wed_sales1 / wed_sales2, + thu_sales1 / thu_sales2, + fri_sales1 / fri_sales2, + sat_sales1 / sat_sales2 +from + (select + s_store_name s_store_name1, + wss.d_week_seq d_week_seq1, + s_store_id s_store_id1, + sun_sales sun_sales1, + mon_sales mon_sales1, + tue_sales tue_sales1, + wed_sales wed_sales1, + thu_sales thu_sales1, + fri_sales fri_sales1, + sat_sales sat_sales1 + from + wss, + store, + date_dim d + where + d.d_week_seq = wss.d_week_seq + and ss_store_sk = s_store_sk + and d_month_seq between 1185 and 1185 + 11 + ) y, + (select + s_store_name s_store_name2, + wss.d_week_seq d_week_seq2, + s_store_id s_store_id2, + sun_sales sun_sales2, + mon_sales mon_sales2, + tue_sales tue_sales2, + wed_sales wed_sales2, + thu_sales thu_sales2, + fri_sales fri_sales2, + sat_sales sat_sales2 + from + wss, + store, + date_dim d + where + d.d_week_seq = wss.d_week_seq + and ss_store_sk = s_store_sk + and d_month_seq between 1185 + 12 and 1185 + 23 + ) x +where + s_store_id1 = s_store_id2 + and d_week_seq1 = d_week_seq2 - 52 +order by + s_store_name1, + s_store_id1, + d_week_seq1 +limit 100 +-- end query 59 in stream 0 using template query59.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql new file mode 100755 index 0000000000000..b71199ab17d0b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q63.sql @@ -0,0 +1,29 @@ +-- start query 63 in stream 0 using template query63.tpl +select * +from (select i_manager_id + ,sum(ss_sales_price) sum_sales + ,avg(sum(ss_sales_price)) over (partition by i_manager_id) avg_monthly_sales + from item + ,store_sales + ,date_dim + ,store + where ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_sold_date_sk between 2452123 and 2452487 + and ss_store_sk = s_store_sk + and d_month_seq in (1219,1219+1,1219+2,1219+3,1219+4,1219+5,1219+6,1219+7,1219+8,1219+9,1219+10,1219+11) + and (( i_category in ('Books','Children','Electronics') + and i_class in ('personal','portable','reference','self-help') + and i_brand in ('scholaramalgamalg #14','scholaramalgamalg #7', + 'exportiunivamalg #9','scholaramalgamalg #9')) + or( i_category in ('Women','Music','Men') + and i_class in ('accessories','classical','fragrances','pants') + and i_brand in ('amalgimporto #1','edu packscholar #1','exportiimporto #1', + 'importoamalg #1'))) +group by i_manager_id, d_moy) tmp1 +where case when avg_monthly_sales > 0 then abs (sum_sales - avg_monthly_sales) / avg_monthly_sales else null end > 0.1 +order by i_manager_id + ,avg_monthly_sales + ,sum_sales +limit 100 +-- end query 63 in stream 0 using template query63.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql new file mode 100755 index 0000000000000..7344feeff6a9f --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q65.sql @@ -0,0 +1,58 @@ +-- start query 65 in stream 0 using template query65.tpl +select + s_store_name, + i_item_desc, + sc.revenue, + i_current_price, + i_wholesale_cost, + i_brand +from + store, + item, + (select + ss_store_sk, + avg(revenue) as ave + from + (select + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) as revenue + from + store_sales, + date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + ss_store_sk, + ss_item_sk + ) sa + group by + ss_store_sk + ) sb, + (select + ss_store_sk, + ss_item_sk, + sum(ss_sales_price) as revenue + from + store_sales, + date_dim + where + ss_sold_date_sk = d_date_sk + and d_month_seq between 1212 and 1212 + 11 + and ss_sold_date_sk between 2451911 and 2452275 -- partition key filter + group by + ss_store_sk, + ss_item_sk + ) sc +where + sb.ss_store_sk = sc.ss_store_sk + and sc.revenue <= 0.1 * sb.ave + and s_store_sk = sc.ss_store_sk + and i_item_sk = sc.ss_item_sk +order by + s_store_name, + i_item_desc +limit 100 +-- end query 65 in stream 0 using template query65.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql new file mode 100755 index 0000000000000..94df4b3f57a90 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q68.sql @@ -0,0 +1,62 @@ +-- start query 68 in stream 0 using template query68.tpl +-- changed to match exact same partitions in original query +select + c_last_name, + c_first_name, + ca_city, + bought_city, + ss_ticket_number, + extended_price, + extended_tax, + list_price +from + (select + ss_ticket_number, + ss_customer_sk, + ca_city bought_city, + sum(ss_ext_sales_price) extended_price, + sum(ss_ext_list_price) list_price, + sum(ss_ext_tax) extended_tax + from + store_sales, + date_dim, + store, + household_demographics, + customer_address + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and store_sales.ss_addr_sk = customer_address.ca_address_sk + and date_dim.d_dom between 1 and 2 + and (household_demographics.hd_dep_count = 5 + or household_demographics.hd_vehicle_count = 3) + and date_dim.d_year in (1999, 1999 + 1, 1999 + 2) + and store.s_city in ('Midway', 'Fairview') + -- partition key filter + and ss_sold_date_sk in (2451180, 2451181, 2451211, 2451212, 2451239, 2451240, 2451270, 2451271, 2451300, 2451301, 2451331, + 2451332, 2451361, 2451362, 2451392, 2451393, 2451423, 2451424, 2451453, 2451454, 2451484, 2451485, + 2451514, 2451515, 2451545, 2451546, 2451576, 2451577, 2451605, 2451606, 2451636, 2451637, 2451666, + 2451667, 2451697, 2451698, 2451727, 2451728, 2451758, 2451759, 2451789, 2451790, 2451819, 2451820, + 2451850, 2451851, 2451880, 2451881, 2451911, 2451912, 2451942, 2451943, 2451970, 2451971, 2452001, + 2452002, 2452031, 2452032, 2452062, 2452063, 2452092, 2452093, 2452123, 2452124, 2452154, 2452155, + 2452184, 2452185, 2452215, 2452216, 2452245, 2452246) + --and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter (3 months) + --and d_date between '1999-01-01' and '1999-03-31' + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + ca_city + ) dn, + customer, + customer_address current_addr +where + ss_customer_sk = c_customer_sk + and customer.c_current_addr_sk = current_addr.ca_address_sk + and current_addr.ca_city <> bought_city +order by + c_last_name, + ss_ticket_number +limit 100 +-- end query 68 in stream 0 using template query68.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql new file mode 100755 index 0000000000000..c61a2d0d2a8fa --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q7.sql @@ -0,0 +1,31 @@ +-- start query 7 in stream 0 using template query7.tpl +select + i_item_id, + avg(ss_quantity) agg1, + avg(ss_list_price) agg2, + avg(ss_coupon_amt) agg3, + avg(ss_sales_price) agg4 +from + store_sales, + customer_demographics, + date_dim, + item, + promotion +where + ss_sold_date_sk = d_date_sk + and ss_item_sk = i_item_sk + and ss_cdemo_sk = cd_demo_sk + and ss_promo_sk = p_promo_sk + and cd_gender = 'F' + and cd_marital_status = 'W' + and cd_education_status = 'Primary' + and (p_channel_email = 'N' + or p_channel_event = 'N') + and d_year = 1998 + and ss_sold_date_sk between 2450815 and 2451179 -- partition key filter +group by + i_item_id +order by + i_item_id +limit 100 +-- end query 7 in stream 0 using template query7.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql new file mode 100755 index 0000000000000..8703910b305a8 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q73.sql @@ -0,0 +1,49 @@ +-- start query 73 in stream 0 using template query73.tpl +select + c_last_name, + c_first_name, + c_salutation, + c_preferred_cust_flag, + ss_ticket_number, + cnt +from + (select + ss_ticket_number, + ss_customer_sk, + count(*) cnt + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and date_dim.d_dom between 1 and 2 + and (household_demographics.hd_buy_potential = '>10000' + or household_demographics.hd_buy_potential = 'Unknown') + and household_demographics.hd_vehicle_count > 0 + and case when household_demographics.hd_vehicle_count > 0 then household_demographics.hd_dep_count / household_demographics.hd_vehicle_count else null end > 1 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_county in ('Fairfield County','Ziebach County','Bronx County','Barrow County') + -- partition key filter + and ss_sold_date_sk in (2450815, 2450816, 2450846, 2450847, 2450874, 2450875, 2450905, 2450906, 2450935, 2450936, 2450966, 2450967, + 2450996, 2450997, 2451027, 2451028, 2451058, 2451059, 2451088, 2451089, 2451119, 2451120, 2451149, + 2451150, 2451180, 2451181, 2451211, 2451212, 2451239, 2451240, 2451270, 2451271, 2451300, 2451301, + 2451331, 2451332, 2451361, 2451362, 2451392, 2451393, 2451423, 2451424, 2451453, 2451454, 2451484, + 2451485, 2451514, 2451515, 2451545, 2451546, 2451576, 2451577, 2451605, 2451606, 2451636, 2451637, + 2451666, 2451667, 2451697, 2451698, 2451727, 2451728, 2451758, 2451759, 2451789, 2451790, 2451819, + 2451820, 2451850, 2451851, 2451880, 2451881) + --and ss_sold_date_sk between 2451180 and 2451269 -- partition key filter (3 months) + group by + ss_ticket_number, + ss_customer_sk + ) dj, + customer +where + ss_customer_sk = c_customer_sk + and cnt between 1 and 5 +order by + cnt desc +-- end query 73 in stream 0 using template query73.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql new file mode 100755 index 0000000000000..4254310ecd10b --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q79.sql @@ -0,0 +1,59 @@ +-- start query 79 in stream 0 using template query79.tpl +select + c_last_name, + c_first_name, + substr(s_city, 1, 30), + ss_ticket_number, + amt, + profit +from + (select + ss_ticket_number, + ss_customer_sk, + store.s_city, + sum(ss_coupon_amt) amt, + sum(ss_net_profit) profit + from + store_sales, + date_dim, + store, + household_demographics + where + store_sales.ss_sold_date_sk = date_dim.d_date_sk + and store_sales.ss_store_sk = store.s_store_sk + and store_sales.ss_hdemo_sk = household_demographics.hd_demo_sk + and (household_demographics.hd_dep_count = 8 + or household_demographics.hd_vehicle_count > 0) + and date_dim.d_dow = 1 + and date_dim.d_year in (1998, 1998 + 1, 1998 + 2) + and store.s_number_employees between 200 and 295 + and ss_sold_date_sk between 2450819 and 2451904 + -- partition key filter + --and ss_sold_date_sk in (2450819, 2450826, 2450833, 2450840, 2450847, 2450854, 2450861, 2450868, 2450875, 2450882, 2450889, + -- 2450896, 2450903, 2450910, 2450917, 2450924, 2450931, 2450938, 2450945, 2450952, 2450959, 2450966, 2450973, 2450980, 2450987, + -- 2450994, 2451001, 2451008, 2451015, 2451022, 2451029, 2451036, 2451043, 2451050, 2451057, 2451064, 2451071, 2451078, 2451085, + -- 2451092, 2451099, 2451106, 2451113, 2451120, 2451127, 2451134, 2451141, 2451148, 2451155, 2451162, 2451169, 2451176, 2451183, + -- 2451190, 2451197, 2451204, 2451211, 2451218, 2451225, 2451232, 2451239, 2451246, 2451253, 2451260, 2451267, 2451274, 2451281, + -- 2451288, 2451295, 2451302, 2451309, 2451316, 2451323, 2451330, 2451337, 2451344, 2451351, 2451358, 2451365, 2451372, 2451379, + -- 2451386, 2451393, 2451400, 2451407, 2451414, 2451421, 2451428, 2451435, 2451442, 2451449, 2451456, 2451463, 2451470, 2451477, + -- 2451484, 2451491, 2451498, 2451505, 2451512, 2451519, 2451526, 2451533, 2451540, 2451547, 2451554, 2451561, 2451568, 2451575, + -- 2451582, 2451589, 2451596, 2451603, 2451610, 2451617, 2451624, 2451631, 2451638, 2451645, 2451652, 2451659, 2451666, 2451673, + -- 2451680, 2451687, 2451694, 2451701, 2451708, 2451715, 2451722, 2451729, 2451736, 2451743, 2451750, 2451757, 2451764, 2451771, + -- 2451778, 2451785, 2451792, 2451799, 2451806, 2451813, 2451820, 2451827, 2451834, 2451841, 2451848, 2451855, 2451862, 2451869, + -- 2451876, 2451883, 2451890, 2451897, 2451904) + group by + ss_ticket_number, + ss_customer_sk, + ss_addr_sk, + store.s_city + ) ms, + customer +where + ss_customer_sk = c_customer_sk +order by + c_last_name, + c_first_name, + substr(s_city, 1, 30), + profit + limit 100 +-- end query 79 in stream 0 using template query79.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql new file mode 100755 index 0000000000000..b1d814af5e57a --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q89.sql @@ -0,0 +1,43 @@ +-- start query 89 in stream 0 using template query89.tpl +select + * +from + (select + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy, + sum(ss_sales_price) sum_sales, + avg(sum(ss_sales_price)) over (partition by i_category, i_brand, s_store_name, s_company_name) avg_monthly_sales + from + item, + store_sales, + date_dim, + store + where + ss_item_sk = i_item_sk + and ss_sold_date_sk = d_date_sk + and ss_store_sk = s_store_sk + and d_year in (2000) + and ((i_category in ('Home', 'Books', 'Electronics') + and i_class in ('wallpaper', 'parenting', 'musical')) + or (i_category in ('Shoes', 'Jewelry', 'Men') + and i_class in ('womens', 'birdal', 'pants'))) + and ss_sold_date_sk between 2451545 and 2451910 -- partition key filter + group by + i_category, + i_class, + i_brand, + s_store_name, + s_company_name, + d_moy + ) tmp1 +where + case when (avg_monthly_sales <> 0) then (abs(sum_sales - avg_monthly_sales) / avg_monthly_sales) else null end > 0.1 +order by + sum_sales - avg_monthly_sales, + s_store_name +limit 100 +-- end query 89 in stream 0 using template query89.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql new file mode 100755 index 0000000000000..f53f2f5f9c5b6 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/q98.sql @@ -0,0 +1,32 @@ +-- start query 98 in stream 0 using template query98.tpl +select + i_item_desc, + i_category, + i_class, + i_current_price, + sum(ss_ext_sales_price) as itemrevenue, + sum(ss_ext_sales_price) * 100 / sum(sum(ss_ext_sales_price)) over (partition by i_class) as revenueratio +from + store_sales, + item, + date_dim +where + ss_item_sk = i_item_sk + and i_category in ('Jewelry', 'Sports', 'Books') + and ss_sold_date_sk = d_date_sk + and ss_sold_date_sk between 2451911 and 2451941 -- partition key filter (1 calendar month) + and d_date between '2001-01-01' and '2001-01-31' +group by + i_item_id, + i_item_desc, + i_category, + i_class, + i_current_price +order by + i_category, + i_class, + i_item_id, + i_item_desc, + revenueratio +--limit 1000; -- added limit +-- end query 98 in stream 0 using template query98.tpl diff --git a/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql b/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql new file mode 100755 index 0000000000000..bf58b4bb3c5a5 --- /dev/null +++ b/sql/core/src/test/resources/tpcds-modifiedQueries/ss_max.sql @@ -0,0 +1,14 @@ +select + count(*) as total, + count(ss_sold_date_sk) as not_null_total, + count(distinct ss_sold_date_sk) as unique_days, + max(ss_sold_date_sk) as max_ss_sold_date_sk, + max(ss_sold_time_sk) as max_ss_sold_time_sk, + max(ss_item_sk) as max_ss_item_sk, + max(ss_customer_sk) as max_ss_customer_sk, + max(ss_cdemo_sk) as max_ss_cdemo_sk, + max(ss_hdemo_sk) as max_ss_hdemo_sk, + max(ss_addr_sk) as max_ss_addr_sk, + max(ss_store_sk) as max_ss_store_sk, + max(ss_promo_sk) as max_ss_promo_sk +from store_sales diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala index 7e61a68025158..938d76c9f0837 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala @@ -24,14 +24,14 @@ import org.apache.spark.SparkConf class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "false", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "false", "configuration parameter changed in test body") } } @@ -39,14 +39,14 @@ class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with Befo class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter { override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") // adding some checking after each test is run, assuring that the configs are not changed // in test code after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true", "configuration parameter changed in test body") } } @@ -57,7 +57,7 @@ class TwoLevelAggregateHashMapWithVectorizedMapSuite override protected def sparkConf: SparkConf = super.sparkConf .set("spark.sql.codegen.fallback", "false") - .set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + .set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") .set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") // adding some checking after each test is run, assuring that the configs are not changed @@ -65,7 +65,7 @@ class TwoLevelAggregateHashMapWithVectorizedMapSuite after { assert(sparkConf.get("spark.sql.codegen.fallback") == "false", "configuration parameter changed in test body") - assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true", + assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enabled") == "true", "configuration parameter changed in test body") assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true", "configuration parameter changed in test body") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala new file mode 100644 index 0000000000000..c7d86bc955d67 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproxCountDistinctForIntervalsQuerySuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.sql.catalyst.expressions.{Alias, CreateArray, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproxCountDistinctForIntervals +import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.test.SharedSQLContext + +class ApproxCountDistinctForIntervalsQuerySuite extends QueryTest with SharedSQLContext { + import testImplicits._ + + // ApproxCountDistinctForIntervals is used in equi-height histogram generation. An equi-height + // histogram usually contains hundreds of buckets. So we need to test + // ApproxCountDistinctForIntervals with large number of endpoints + // (the number of endpoints == the number of buckets + 1). + test("test ApproxCountDistinctForIntervals with large number of endpoints") { + val table = "approx_count_distinct_for_intervals_tbl" + withTable(table) { + (1 to 100000).toDF("col").createOrReplaceTempView(table) + // percentiles of 0, 0.001, 0.002 ... 0.999, 1 + val endpoints = (0 to 1000).map(_ * 100000 / 1000) + + // Since approx_count_distinct_for_intervals is not a public function, here we do + // the computation by constructing logical plan. + val relation = spark.table(table).logicalPlan + val attr = relation.output.find(_.name == "col").get + val aggFunc = ApproxCountDistinctForIntervals(attr, CreateArray(endpoints.map(Literal(_)))) + val aggExpr = aggFunc.toAggregateExpression() + val namedExpr = Alias(aggExpr, aggExpr.toString)() + val ndvsRow = new QueryExecution(spark, Aggregate(Nil, Seq(namedExpr), relation)) + .executedPlan.executeTake(1).head + val ndvArray = ndvsRow.getArray(0).toLongArray() + assert(endpoints.length == ndvArray.length + 1) + + // Each bucket has 100 distinct values. + val expectedNdv = 100 + for (i <- ndvArray.indices) { + val ndv = ndvArray(i) + val error = math.abs((ndv / expectedNdv.toDouble) - 1.0d) + assert(error <= aggFunc.relativeSD * 3.0d, "Error should be within 3 std. errors.") + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala index 1aea33766407f..137c5bea2abb9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ApproximatePercentileQuerySuite.scala @@ -53,6 +53,21 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { } } + test("percentile_approx, the first element satisfies small percentages") { + withTempView(table) { + (1 to 10).toDF("col").createOrReplaceTempView(table) + checkAnswer( + spark.sql( + s""" + |SELECT + | percentile_approx(col, array(0.01, 0.1, 0.11)) + |FROM $table + """.stripMargin), + Row(Seq(1, 1, 2)) + ) + } + } + test("percentile_approx, array of percentile value") { withTempView(table) { (1 to 1000).toDF("col").createOrReplaceTempView(table) @@ -130,7 +145,7 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSQLContext { (1 to 1000).toDF("col").createOrReplaceTempView(table) checkAnswer( spark.sql(s"SELECT percentile_approx(col, array(0.25 + 0.25D), 200 + 800D) FROM $table"), - Row(Seq(500D)) + Row(Seq(499)) ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 3e4f619431599..1e52445f28fc1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan} import org.apache.spark.sql.execution.columnar._ -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} @@ -420,7 +420,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext * Verifies that the plan for `df` contains `expected` number of Exchange operators. */ private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { - assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => e }.size == expected) + assert( + df.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => e }.size == expected) } test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala new file mode 100644 index 0000000000000..cee85ec8af04d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/ConfigBehaviorSuite.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.commons.math3.stat.inference.ChiSquareTest + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + + +class ConfigBehaviorSuite extends QueryTest with SharedSQLContext { + + import testImplicits._ + + test("SPARK-22160 spark.sql.execution.rangeExchange.sampleSizePerPartition") { + // In this test, we run a sort and compute the histogram for partition size post shuffle. + // With a high sample count, the partition size should be more evenly distributed, and has a + // low chi-sq test value. + // Also the whole code path for range partitioning as implemented should be deterministic + // (it uses the partition id as the seed), so this test shouldn't be flaky. + + val numPartitions = 4 + + def computeChiSquareTest(): Double = { + val n = 10000 + // Trigger a sort + val data = spark.range(0, n, 1, 1).sort('id) + .selectExpr("SPARK_PARTITION_ID() pid", "id").as[(Int, Long)].collect() + + // Compute histogram for the number of records per partition post sort + val dist = data.groupBy(_._1).map(_._2.length.toLong).toArray + assert(dist.length == 4) + + new ChiSquareTest().chiSquare( + Array.fill(numPartitions) { n.toDouble / numPartitions }, + dist) + } + + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString) { + // The default chi-sq value should be low + assert(computeChiSquareTest() < 100) + + withSQLConf(SQLConf.RANGE_EXCHANGE_SAMPLE_SIZE_PER_PARTITION.key -> "1") { + // If we only sample one point, the range boundaries will be pretty bad and the + // chi-sq value would be very high. + assert(computeChiSquareTest() > 300) + } + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 8549eac58ee95..06848e4d2b297 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -21,6 +21,7 @@ import scala.util.Random import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -636,4 +637,33 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { spark.sql("SELECT 3 AS c, 4 AS d, SUM(b) FROM testData2 GROUP BY c, d"), Seq(Row(3, 4, 9))) } + + test("SPARK-22223: ObjectHashAggregate should not introduce unnecessary shuffle") { + withSQLConf(SQLConf.USE_OBJECT_HASH_AGG.key -> "true") { + val df = Seq(("1", "2", 1), ("1", "2", 2), ("2", "3", 3), ("2", "3", 4)).toDF("a", "b", "c") + .repartition(col("a")) + + val objHashAggDF = df + .withColumn("d", expr("(a, b, c)")) + .groupBy("a", "b").agg(collect_list("d").as("e")) + .withColumn("f", expr("(b, e)")) + .groupBy("a").agg(collect_list("f").as("g")) + val aggPlan = objHashAggDF.queryExecution.executedPlan + + val sortAggPlans = aggPlan.collect { + case sortAgg: SortAggregateExec => sortAgg + } + assert(sortAggPlans.isEmpty) + + val objHashAggPlans = aggPlan.collect { + case objHashAgg: ObjectHashAggregateExec => objHashAgg + } + assert(objHashAggPlans.nonEmpty) + + val exchangePlans = aggPlan.collect { + case shuffle: ShuffleExchangeExec => shuffle + } + assert(exchangePlans.length == 1) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 247c30e2ee65b..46b21c3b64a2e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -141,7 +141,7 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { test("approximate quantile") { val n = 1000 - val df = Seq.tabulate(n)(i => (i, 2.0 * i)).toDF("singles", "doubles") + val df = Seq.tabulate(n + 1)(i => (i, 2.0 * i)).toDF("singles", "doubles") val q1 = 0.5 val q2 = 0.8 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 6178661cf7b2b..473c355cf3c7f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.plans.logical.{Filter, OneRowRelation, Union} import org.apache.spark.sql.execution.{FilterExec, QueryExecution} import org.apache.spark.sql.execution.aggregate.HashAggregateExec -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext} @@ -368,6 +368,8 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { checkAnswer( testData.select('key).coalesce(1).select('key), testData.select('key).collect().toSeq) + + assert(spark.emptyDataFrame.coalesce(1).rdd.partitions.size === 1) } test("convert $\"attribute name\" into unresolved attribute") { @@ -641,6 +643,49 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(df.schema.map(_.name) === Seq("key", "value", "newCol")) } + test("withColumns") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCol2"), + Seq(col("key") + 1, col("key") + 2)) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCol2")) + + val err = intercept[IllegalArgumentException] { + testData.toDF().withColumns(Seq("newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert( + err.getMessage.contains("The size of column names: 1 isn't equal to the size of columns: 2")) + + val err2 = intercept[AnalysisException] { + testData.toDF().withColumns(Seq("newCol1", "newCOL1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert(err2.getMessage.contains("Found duplicate column(s)")) + } + + test("withColumns: case sensitive") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val df = testData.toDF().withColumns(Seq("newCol1", "newCOL1"), + Seq(col("key") + 1, col("key") + 2)) + checkAnswer( + df, + testData.collect().map { case Row(key: Int, value: String) => + Row(key, value, key + 1, key + 2) + }.toSeq) + assert(df.schema.map(_.name) === Seq("key", "value", "newCol1", "newCOL1")) + + val err = intercept[AnalysisException] { + testData.toDF().withColumns(Seq("newCol1", "newCol1"), + Seq(col("key") + 1, col("key") + 2)) + } + assert(err.getMessage.contains("Found duplicate column(s)")) + } + } + test("replace column using withColumn") { val df2 = sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) @@ -649,6 +694,15 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row(2) :: Row(3) :: Row(4) :: Nil) } + test("replace column using withColumns") { + val df2 = sparkContext.parallelize(Array((1, 2), (2, 3), (3, 4))).toDF("x", "y") + val df3 = df2.withColumns(Seq("x", "newCol1", "newCol2"), + Seq(df2("x") + 1, df2("y"), df2("y") + 1)) + checkAnswer( + df3.select("x", "newCol1", "newCol2"), + Row(2, 2, 3) :: Row(3, 3, 4) :: Row(4, 4, 5) :: Nil) + } + test("drop column using drop") { val df = testData.drop("key") checkAnswer( @@ -803,7 +857,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { Row("mean", null, "33.0", "178.0"), Row("stddev", null, "19.148542155126762", "11.547005383792516"), Row("min", "Alice", "16", "164"), - Row("25%", null, "24", "176"), + Row("25%", null, "16", "164"), Row("50%", null, "24", "176"), Row("75%", null, "32", "180"), Row("max", "David", "60", "192")) @@ -993,6 +1047,18 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { assert(testData.select($"*").showString(0) === expectedAnswer) } + test("showString(Int.MaxValue)") { + val df = Seq((1, 2), (3, 4)).toDF("a", "b") + val expectedAnswer = """+---+---+ + || a| b| + |+---+---+ + || 1| 2| + || 3| 4| + |+---+---+ + |""".stripMargin + assert(df.showString(Int.MaxValue) === expectedAnswer) + } + test("showString(0), vertical = true") { val expectedAnswer = "(0 rows)\n" assert(testData.select($"*").showString(0, vertical = true) === expectedAnswer) @@ -1529,7 +1595,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { fail("Should not have back to back Aggregates") } atFirstAgg = true - case e: ShuffleExchange => atFirstAgg = false + case e: ShuffleExchangeExec => atFirstAgg = false case _ => } } @@ -1710,19 +1776,19 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val plan = join.queryExecution.executedPlan checkAnswer(join, df) assert( - join.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + join.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size === 1) assert( join.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 1) val broadcasted = broadcast(join) val join2 = join.join(broadcasted, "id").join(broadcasted, "id") checkAnswer(join2, df) assert( - join2.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1) + join2.queryExecution.executedPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) assert( join2.queryExecution.executedPlan .collect { case e: BroadcastExchangeExec => true }.size === 1) assert( - join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size === 4) + join2.queryExecution.executedPlan.collect { case e: ReusedExchangeExec => true }.size == 4) } } @@ -2039,4 +2105,13 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { testData2.select(lit(7), 'a, 'b).orderBy(lit(1), lit(2), lit(3)), Seq(Row(7, 1, 1), Row(7, 1, 2), Row(7, 2, 1), Row(7, 2, 2), Row(7, 3, 1), Row(7, 3, 2))) } + + test("SPARK-22271: mean overflows and returns null for some decimal variables") { + val d = 0.034567890 + val df = Seq(d, d, d, d, d, d, d, d, d, d).toDF("DecimalCol") + val result = df.select('DecimalCol cast DecimalType(38, 33)) + .select(col("DecimalCol")).describe() + val mean = result.select("DecimalCol").where($"summary" === "mean") + assert(mean.collect().toSet === Set(Row("0.0345678900000000000000000000000000000"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala index fe6ba83b4cbfb..0881212a64de8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala @@ -73,4 +73,40 @@ class DataFrameTungstenSuite extends QueryTest with SharedSQLContext { val df = spark.createDataFrame(data, schema) assert(df.select("b").first() === Row(outerStruct)) } + + test("primitive data type accesses in persist data") { + val data = Seq(true, 1.toByte, 3.toShort, 7, 15.toLong, + 31.25.toFloat, 63.75, null) + val dataTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, + FloatType, DoubleType, IntegerType) + val schemas = dataTypes.zipWithIndex.map { case (dataType, index) => + StructField(s"col$index", dataType, true) + } + val rdd = sparkContext.makeRDD(Seq(Row.fromSeq(data))) + val df = spark.createDataFrame(rdd, StructType(schemas)) + val row = df.persist.take(1).apply(0) + checkAnswer(df, row) + } + + test("access cache multiple times") { + val df0 = sparkContext.parallelize(Seq(1, 2, 3), 1).toDF("x").cache + df0.count + val df1 = df0.filter("x > 1") + checkAnswer(df1, Seq(Row(2), Row(3))) + val df2 = df0.filter("x > 2") + checkAnswer(df2, Row(3)) + + val df10 = sparkContext.parallelize(Seq(3, 4, 5, 6), 1).toDF("x").cache + for (_ <- 0 to 2) { + val df11 = df10.filter("x > 5") + checkAnswer(df11, Row(6)) + } + } + + test("access only some column of the all of columns") { + val df = spark.range(1, 10).map(i => (i, (i + 1).toDouble)).toDF("l", "d") + df.cache + df.count + assert(df.filter("d < 3").count == 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 5015f3709f131..1537ce3313c09 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.encoders.{OuterScopes, RowEncoder} import org.apache.spark.sql.catalyst.plans.{LeftAnti, LeftSemi} import org.apache.spark.sql.catalyst.util.sideBySide import org.apache.spark.sql.execution.{LogicalRDD, RDDScanExec} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.streaming.MemoryStream import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -1206,7 +1206,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext { val agg = cp.groupBy('id % 2).agg(count('id)) agg.queryExecution.executedPlan.collectFirst { - case ShuffleExchange(_, _: RDDScanExec, _) => + case ShuffleExchangeExec(_, _: RDDScanExec, _) => case BroadcastExchangeExec(_, _: RDDScanExec) => }.foreach { _ => fail( @@ -1341,8 +1341,69 @@ class DatasetSuite extends QueryTest with SharedSQLContext { Seq(1).toDS().map(_ => ("", TestForTypeAlias.seqOfTupleTypeAlias)), ("", Seq((1, 1), (2, 2)))) } + + test("Check RelationalGroupedDataset toString: Single data") { + val kvDataset = (1 to 3).toDF("id").groupBy("id") + val expected = "RelationalGroupedDataset: [" + + "grouping expressions: [id: int], value: [id: int], type: GroupBy]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check RelationalGroupedDataset toString: over length schema ") { + val kvDataset = (1 to 3).map( x => (x, x.toString, x.toLong)) + .toDF("id", "val1", "val2").groupBy("id") + val expected = "RelationalGroupedDataset:" + + " [grouping expressions: [id: int]," + + " value: [id: int, val1: string ... 1 more field]," + + " type: GroupBy]" + val actual = kvDataset.toString + assert(expected === actual) + } + + + test("Check KeyValueGroupedDataset toString: Single data") { + val kvDataset = (1 to 3).toDF("id").as[SingleData].groupByKey(identity) + val expected = "KeyValueGroupedDataset: [key: [id: int], value: [id: int]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: Unnamed KV-pair") { + val kvDataset = (1 to 3).map(x => (x, x.toString)) + .toDF("id", "val1").as[DoubleData].groupByKey(x => (x.id, x.val1)) + val expected = "KeyValueGroupedDataset:" + + " [key: [_1: int, _2: string]," + + " value: [id: int, val1: string]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: Named KV-pair") { + val kvDataset = (1 to 3).map( x => (x, x.toString)) + .toDF("id", "val1").as[DoubleData].groupByKey(x => DoubleData(x.id, x.val1)) + val expected = "KeyValueGroupedDataset:" + + " [key: [id: int, val1: string]," + + " value: [id: int, val1: string]]" + val actual = kvDataset.toString + assert(expected === actual) + } + + test("Check KeyValueGroupedDataset toString: over length schema ") { + val kvDataset = (1 to 3).map( x => (x, x.toString, x.toLong)) + .toDF("id", "val1", "val2").as[TripleData].groupByKey(identity) + val expected = "KeyValueGroupedDataset:" + + " [key: [id: int, val1: string ... 1 more field(s)]," + + " value: [id: int, val1: string ... 1 more field(s)]]" + val actual = kvDataset.toString + assert(expected === actual) + } } +case class SingleData(id: Int) +case class DoubleData(id: Int, val1: String) +case class TripleData(id: Int, val1: String, val2: Long) + case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map[Long, String]) case class WithMap(id: String, map_test: scala.collection.Map[Long, String]) case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 9d50e8be60891..226cc3028b135 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -200,6 +200,14 @@ class JoinSuite extends QueryTest with SharedSQLContext { Nil) } + test("SPARK-22141: Propagate empty relation before checking Cartesian products") { + Seq("inner", "left", "right", "left_outer", "right_outer", "full_outer").foreach { joinType => + val x = testData2.where($"a" === 2 && !($"a" === 2)).as("x") + val y = testData2.where($"a" === 1 && !($"a" === 1)).as("y") + checkAnswer(x.join(y, Seq.empty, joinType), Nil) + } + } + test("big inner join, 4 matches per row") { val bigData = testData.union(testData).union(testData).union(testData) val bigDataX = bigData.as("x") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f9808834df4a5..fcaca3d75b74f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,23 +17,13 @@ package org.apache.spark.sql -import java.util.{ArrayDeque, Locale, TimeZone} +import java.util.{Locale, TimeZone} import scala.collection.JavaConverters._ -import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.datasources.LogicalRelation -import org.apache.spark.sql.execution.streaming.MemoryPlan -import org.apache.spark.sql.types.{Metadata, ObjectType} abstract class QueryTest extends PlanTest { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 93a7777b70b46..caf332d050d7b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -27,6 +27,7 @@ import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -2646,6 +2647,44 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { } } + test("SPARK-21247: Allow case-insensitive type equality in Set operation") { + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") + sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") + + withTable("t", "S") { + sql("CREATE TABLE t(c struct) USING parquet") + sql("CREATE TABLE S(C struct) USING parquet") + Seq(("c", "C"), ("C", "c"), ("c.f", "C.F"), ("C.F", "c.f")).foreach { + case (left, right) => + checkAnswer(sql(s"SELECT * FROM t, S WHERE t.$left = S.$right"), Seq.empty) + } + } + } + + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + val m1 = intercept[AnalysisException] { + sql("SELECT struct(1 a) UNION ALL (SELECT struct(2 A))") + }.message + assert(m1.contains("Union can only be performed on tables with the compatible column types")) + + val m2 = intercept[AnalysisException] { + sql("SELECT struct(1 a) EXCEPT (SELECT struct(2 A))") + }.message + assert(m2.contains("Except can only be performed on tables with the compatible column types")) + + withTable("t", "S") { + sql("CREATE TABLE t(c struct) USING parquet") + sql("CREATE TABLE S(C struct) USING parquet") + checkAnswer(sql("SELECT * FROM t, S WHERE t.c.f = S.C.F"), Seq.empty) + val m = intercept[AnalysisException] { + sql("SELECT * FROM t, S WHERE c = C") + }.message + assert(m.contains("cannot resolve '(t.`c` = S.`C`)' due to data type mismatch")) + } + } + } + test("SPARK-21335: support un-aliased subquery") { withTempView("v") { Seq(1 -> "a").toDF("i", "j").createOrReplaceTempView("v") @@ -2677,4 +2716,29 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(df, Row(1, 1, 1)) } } + + test("SRARK-22266: the same aggregate function was calculated multiple times") { + val query = "SELECT a, max(b+1), max(b+1) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg : HashAggregateExec => agg.aggregateExpressions + case agg : SortAggregateExec => agg.aggregateExpressions + } + assert (aggregateExpressions.isDefined) + assert (aggregateExpressions.get.size == 1) + checkAnswer(df, Row(1, 3, 4) :: Row(2, 3, 4) :: Row(3, 3, 4) :: Nil) + } + + test("Non-deterministic aggregate functions should not be deduplicated") { + val query = "SELECT a, first_value(b), first_value(b) + 1 FROM testData2 GROUP BY a" + val df = sql(query) + val physical = df.queryExecution.sparkPlan + val aggregateExpressions = physical.collectFirst { + case agg : HashAggregateExec => agg.aggregateExpressions + case agg : SortAggregateExec => agg.aggregateExpressions + } + assert (aggregateExpressions.isDefined) + assert (aggregateExpressions.get.size == 2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala new file mode 100644 index 0000000000000..e47d4b0ee25d4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/TPCDSQuerySuite.scala @@ -0,0 +1,372 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.catalyst.util.resourceToString +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.util.Utils + +/** + * This test suite ensures all the TPC-DS queries can be successfully analyzed and optimized + * without hitting the max iteration threshold. + */ +class TPCDSQuerySuite extends QueryTest with SharedSQLContext with BeforeAndAfterAll { + + // When Utils.isTesting is true, the RuleExecutor will issue an exception when hitting + // the max iteration of analyzer/optimizer batches. + assert(Utils.isTesting, "spark.testing is not set to true") + + /** + * Drop all the tables + */ + protected override def afterAll(): Unit = { + try { + spark.sessionState.catalog.reset() + } finally { + super.afterAll() + } + } + + override def beforeAll() { + super.beforeAll() + sql( + """ + |CREATE TABLE `catalog_page` ( + |`cp_catalog_page_sk` INT, `cp_catalog_page_id` STRING, `cp_start_date_sk` INT, + |`cp_end_date_sk` INT, `cp_department` STRING, `cp_catalog_number` INT, + |`cp_catalog_page_number` INT, `cp_description` STRING, `cp_type` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `catalog_returns` ( + |`cr_returned_date_sk` INT, `cr_returned_time_sk` INT, `cr_item_sk` INT, + |`cr_refunded_customer_sk` INT, `cr_refunded_cdemo_sk` INT, `cr_refunded_hdemo_sk` INT, + |`cr_refunded_addr_sk` INT, `cr_returning_customer_sk` INT, `cr_returning_cdemo_sk` INT, + |`cr_returning_hdemo_sk` INT, `cr_returning_addr_sk` INT, `cr_call_center_sk` INT, + |`cr_catalog_page_sk` INT, `cr_ship_mode_sk` INT, `cr_warehouse_sk` INT, `cr_reason_sk` INT, + |`cr_order_number` INT, `cr_return_quantity` INT, `cr_return_amount` DECIMAL(7,2), + |`cr_return_tax` DECIMAL(7,2), `cr_return_amt_inc_tax` DECIMAL(7,2), `cr_fee` DECIMAL(7,2), + |`cr_return_ship_cost` DECIMAL(7,2), `cr_refunded_cash` DECIMAL(7,2), + |`cr_reversed_charge` DECIMAL(7,2), `cr_store_credit` DECIMAL(7,2), + |`cr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer` ( + |`c_customer_sk` INT, `c_customer_id` STRING, `c_current_cdemo_sk` INT, + |`c_current_hdemo_sk` INT, `c_current_addr_sk` INT, `c_first_shipto_date_sk` INT, + |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING, + |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT, + |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING, + |`c_email_address` STRING, `c_last_review_date` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer_address` ( + |`ca_address_sk` INT, `ca_address_id` STRING, `ca_street_number` STRING, + |`ca_street_name` STRING, `ca_street_type` STRING, `ca_suite_number` STRING, + |`ca_city` STRING, `ca_county` STRING, `ca_state` STRING, `ca_zip` STRING, + |`ca_country` STRING, `ca_gmt_offset` DECIMAL(5,2), `ca_location_type` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `customer_demographics` ( + |`cd_demo_sk` INT, `cd_gender` STRING, `cd_marital_status` STRING, + |`cd_education_status` STRING, `cd_purchase_estimate` INT, `cd_credit_rating` STRING, + |`cd_dep_count` INT, `cd_dep_employed_count` INT, `cd_dep_college_count` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `date_dim` ( + |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING, + |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT, + |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT, + |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING, + |`d_weekend` STRING, `d_following_holiday` STRING, `d_first_dom` INT, `d_last_dom` INT, + |`d_same_day_ly` INT, `d_same_day_lq` INT, `d_current_day` STRING, `d_current_week` STRING, + |`d_current_month` STRING, `d_current_quarter` STRING, `d_current_year` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `household_demographics` ( + |`hd_demo_sk` INT, `hd_income_band_sk` INT, `hd_buy_potential` STRING, `hd_dep_count` INT, + |`hd_vehicle_count` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `inventory` (`inv_date_sk` INT, `inv_item_sk` INT, `inv_warehouse_sk` INT, + |`inv_quantity_on_hand` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING, + |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), + |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT, + |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT, + |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING, + |`i_units` STRING, `i_container` STRING, `i_manager_id` INT, `i_product_name` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `promotion` ( + |`p_promo_sk` INT, `p_promo_id` STRING, `p_start_date_sk` INT, `p_end_date_sk` INT, + |`p_item_sk` INT, `p_cost` DECIMAL(15,2), `p_response_target` INT, `p_promo_name` STRING, + |`p_channel_dmail` STRING, `p_channel_email` STRING, `p_channel_catalog` STRING, + |`p_channel_tv` STRING, `p_channel_radio` STRING, `p_channel_press` STRING, + |`p_channel_event` STRING, `p_channel_demo` STRING, `p_channel_details` STRING, + |`p_purpose` STRING, `p_discount_active` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store` ( + |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING, + |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING, + |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING, + |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING, + |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING, + |`s_company_id` INT, `s_company_name` STRING, `s_street_number` STRING, + |`s_street_name` STRING, `s_street_type` STRING, `s_suite_number` STRING, `s_city` STRING, + |`s_county` STRING, `s_state` STRING, `s_zip` STRING, `s_country` STRING, + |`s_gmt_offset` DECIMAL(5,2), `s_tax_precentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store_returns` ( + |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT, + |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT, + |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT, + |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), + |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2), + |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2), + |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2), + |`sr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `catalog_sales` ( + |`cs_sold_date_sk` INT, `cs_sold_time_sk` INT, `cs_ship_date_sk` INT, + |`cs_bill_customer_sk` INT, `cs_bill_cdemo_sk` INT, `cs_bill_hdemo_sk` INT, + |`cs_bill_addr_sk` INT, `cs_ship_customer_sk` INT, `cs_ship_cdemo_sk` INT, + |`cs_ship_hdemo_sk` INT, `cs_ship_addr_sk` INT, `cs_call_center_sk` INT, + |`cs_catalog_page_sk` INT, `cs_ship_mode_sk` INT, `cs_warehouse_sk` INT, + |`cs_item_sk` INT, `cs_promo_sk` INT, `cs_order_number` INT, `cs_quantity` INT, + |`cs_wholesale_cost` DECIMAL(7,2), `cs_list_price` DECIMAL(7,2), + |`cs_sales_price` DECIMAL(7,2), `cs_ext_discount_amt` DECIMAL(7,2), + |`cs_ext_sales_price` DECIMAL(7,2), `cs_ext_wholesale_cost` DECIMAL(7,2), + |`cs_ext_list_price` DECIMAL(7,2), `cs_ext_tax` DECIMAL(7,2), `cs_coupon_amt` DECIMAL(7,2), + |`cs_ext_ship_cost` DECIMAL(7,2), `cs_net_paid` DECIMAL(7,2), + |`cs_net_paid_inc_tax` DECIMAL(7,2), `cs_net_paid_inc_ship` DECIMAL(7,2), + |`cs_net_paid_inc_ship_tax` DECIMAL(7,2), `cs_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_sales` ( + |`ws_sold_date_sk` INT, `ws_sold_time_sk` INT, `ws_ship_date_sk` INT, `ws_item_sk` INT, + |`ws_bill_customer_sk` INT, `ws_bill_cdemo_sk` INT, `ws_bill_hdemo_sk` INT, + |`ws_bill_addr_sk` INT, `ws_ship_customer_sk` INT, `ws_ship_cdemo_sk` INT, + |`ws_ship_hdemo_sk` INT, `ws_ship_addr_sk` INT, `ws_web_page_sk` INT, `ws_web_site_sk` INT, + |`ws_ship_mode_sk` INT, `ws_warehouse_sk` INT, `ws_promo_sk` INT, `ws_order_number` INT, + |`ws_quantity` INT, `ws_wholesale_cost` DECIMAL(7,2), `ws_list_price` DECIMAL(7,2), + |`ws_sales_price` DECIMAL(7,2), `ws_ext_discount_amt` DECIMAL(7,2), + |`ws_ext_sales_price` DECIMAL(7,2), `ws_ext_wholesale_cost` DECIMAL(7,2), + |`ws_ext_list_price` DECIMAL(7,2), `ws_ext_tax` DECIMAL(7,2), + |`ws_coupon_amt` DECIMAL(7,2), `ws_ext_ship_cost` DECIMAL(7,2), `ws_net_paid` DECIMAL(7,2), + |`ws_net_paid_inc_tax` DECIMAL(7,2), `ws_net_paid_inc_ship` DECIMAL(7,2), + |`ws_net_paid_inc_ship_tax` DECIMAL(7,2), `ws_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `store_sales` ( + |`ss_sold_date_sk` INT, `ss_sold_time_sk` INT, `ss_item_sk` INT, `ss_customer_sk` INT, + |`ss_cdemo_sk` INT, `ss_hdemo_sk` INT, `ss_addr_sk` INT, `ss_store_sk` INT, + |`ss_promo_sk` INT, `ss_ticket_number` INT, `ss_quantity` INT, + |`ss_wholesale_cost` DECIMAL(7,2), `ss_list_price` DECIMAL(7,2), + |`ss_sales_price` DECIMAL(7,2), `ss_ext_discount_amt` DECIMAL(7,2), + |`ss_ext_sales_price` DECIMAL(7,2), `ss_ext_wholesale_cost` DECIMAL(7,2), + |`ss_ext_list_price` DECIMAL(7,2), `ss_ext_tax` DECIMAL(7,2), + |`ss_coupon_amt` DECIMAL(7,2), `ss_net_paid` DECIMAL(7,2), + |`ss_net_paid_inc_tax` DECIMAL(7,2), `ss_net_profit` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_returns` ( + |`wr_returned_date_sk` BIGINT, `wr_returned_time_sk` BIGINT, `wr_item_sk` BIGINT, + |`wr_refunded_customer_sk` BIGINT, `wr_refunded_cdemo_sk` BIGINT, + |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT, + |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT, + |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT, + |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT, + |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2), + |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2), + |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2), + |`wr_reversed_charge` DECIMAL(7,2), `wr_account_credit` DECIMAL(7,2), + |`wr_net_loss` DECIMAL(7,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_site` ( + |`web_site_sk` INT, `web_site_id` STRING, `web_rec_start_date` DATE, + |`web_rec_end_date` DATE, `web_name` STRING, `web_open_date_sk` INT, + |`web_close_date_sk` INT, `web_class` STRING, `web_manager` STRING, `web_mkt_id` INT, + |`web_mkt_class` STRING, `web_mkt_desc` STRING, `web_market_manager` STRING, + |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING, + |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING, + |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING, + |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `reason` ( + |`r_reason_sk` INT, `r_reason_id` STRING, `r_reason_desc` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `call_center` ( + |`cc_call_center_sk` INT, `cc_call_center_id` STRING, `cc_rec_start_date` DATE, + |`cc_rec_end_date` DATE, `cc_closed_date_sk` INT, `cc_open_date_sk` INT, `cc_name` STRING, + |`cc_class` STRING, `cc_employees` INT, `cc_sq_ft` INT, `cc_hours` STRING, + |`cc_manager` STRING, `cc_mkt_id` INT, `cc_mkt_class` STRING, `cc_mkt_desc` STRING, + |`cc_market_manager` STRING, `cc_division` INT, `cc_division_name` STRING, `cc_company` INT, + |`cc_company_name` STRING, `cc_street_number` STRING, `cc_street_name` STRING, + |`cc_street_type` STRING, `cc_suite_number` STRING, `cc_city` STRING, `cc_county` STRING, + |`cc_state` STRING, `cc_zip` STRING, `cc_country` STRING, `cc_gmt_offset` DECIMAL(5,2), + |`cc_tax_percentage` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `warehouse` ( + |`w_warehouse_sk` INT, `w_warehouse_id` STRING, `w_warehouse_name` STRING, + |`w_warehouse_sq_ft` INT, `w_street_number` STRING, `w_street_name` STRING, + |`w_street_type` STRING, `w_suite_number` STRING, `w_city` STRING, `w_county` STRING, + |`w_state` STRING, `w_zip` STRING, `w_country` STRING, `w_gmt_offset` DECIMAL(5,2)) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `ship_mode` ( + |`sm_ship_mode_sk` INT, `sm_ship_mode_id` STRING, `sm_type` STRING, `sm_code` STRING, + |`sm_carrier` STRING, `sm_contract` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `income_band` ( + |`ib_income_band_sk` INT, `ib_lower_bound` INT, `ib_upper_bound` INT) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `time_dim` ( + |`t_time_sk` INT, `t_time_id` STRING, `t_time` INT, `t_hour` INT, `t_minute` INT, + |`t_second` INT, `t_am_pm` STRING, `t_shift` STRING, `t_sub_shift` STRING, + |`t_meal_time` STRING) + |USING parquet + """.stripMargin) + + sql( + """ + |CREATE TABLE `web_page` (`wp_web_page_sk` INT, `wp_web_page_id` STRING, + |`wp_rec_start_date` DATE, `wp_rec_end_date` DATE, `wp_creation_date_sk` INT, + |`wp_access_date_sk` INT, `wp_autogen_flag` STRING, `wp_customer_sk` INT, + |`wp_url` STRING, `wp_type` STRING, `wp_char_count` INT, `wp_link_count` INT, + |`wp_image_count` INT, `wp_max_ad_count` INT) + |USING parquet + """.stripMargin) + } + + val tpcdsQueries = Seq( + "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", + "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", + "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", "q30", + "q31", "q32", "q33", "q34", "q35", "q36", "q37", "q38", "q39a", "q39b", "q40", + "q41", "q42", "q43", "q44", "q45", "q46", "q47", "q48", "q49", "q50", + "q51", "q52", "q53", "q54", "q55", "q56", "q57", "q58", "q59", "q60", + "q61", "q62", "q63", "q64", "q65", "q66", "q67", "q68", "q69", "q70", + "q71", "q72", "q73", "q74", "q75", "q76", "q77", "q78", "q79", "q80", + "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", + "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") + + tpcdsQueries.foreach { name => + val queryString = resourceToString(s"tpcds/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(name) { + withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") { + // Just check the plans can be properly generated + sql(queryString).queryExecution.executedPlan + } + } + } + + // These queries are from https://github.com/cloudera/impala-tpcds-kit/tree/master/queries + val modifiedTPCDSQueries = Seq( + "q3", "q7", "q10", "q19", "q27", "q34", "q42", "q43", "q46", "q52", "q53", "q55", "q59", + "q63", "q65", "q68", "q73", "q79", "q89", "q98", "ss_max") + + modifiedTPCDSQueries.foreach { name => + val queryString = resourceToString(s"tpcds-modifiedQueries/$name.sql", + classLoader = Thread.currentThread().getContextClassLoader) + test(s"modified-$name") { + // Just check the plans can be properly generated + sql(queryString).queryExecution.executedPlan + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala index b76f168220d84..c5fb17345222a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TypedImperativeAggregateSuite.scala @@ -268,7 +268,7 @@ object TypedImperativeAggregateSuite { } } - override def deterministic: Boolean = true + override lazy val deterministic: Boolean = true override def children: Seq[Expression] = Seq(child) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index f1b5e3be5b63f..737eeb0af586e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -21,7 +21,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{MapOutputStatistics, SparkConf, SparkFunSuite} import org.apache.spark.sql._ -import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchangeExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -300,13 +300,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = agg.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 1) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -314,7 +314,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 3) case o => @@ -351,13 +351,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 2) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -365,7 +365,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { case None => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 2) case o => @@ -407,13 +407,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 4) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => @@ -459,13 +459,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { // Then, let's look at the number of post-shuffle partitions estimated // by the ExchangeCoordinator. val exchanges = join.queryExecution.executedPlan.collect { - case e: ShuffleExchange => e + case e: ShuffleExchangeExec => e } assert(exchanges.length === 3) minNumPostShufflePartitions match { case Some(numPartitions) => exchanges.foreach { - case e: ShuffleExchange => + case e: ShuffleExchangeExec => assert(e.coordinator.isDefined) assert(e.outputPartitioning.numPartitions === 5) case o => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 59eaf4d1c29b7..aac8d56ba6201 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions.{Alias, Literal} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition} -import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{BroadcastExchangeExec, ReusedExchangeExec, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.test.SharedSQLContext @@ -31,7 +31,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { val input = (1 to 1000).map(Tuple1.apply) checkAnswer( input.toDF(), - plan => ShuffleExchange(SinglePartition, plan), + plan => ShuffleExchangeExec(SinglePartition, plan), input.map(Row.fromTuple) ) } @@ -81,12 +81,12 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(plan sameResult plan) val part1 = HashPartitioning(output, 1) - val exchange1 = ShuffleExchange(part1, plan) - val exchange2 = ShuffleExchange(part1, plan) + val exchange1 = ShuffleExchangeExec(part1, plan) + val exchange2 = ShuffleExchangeExec(part1, plan) val part2 = HashPartitioning(output, 2) - val exchange3 = ShuffleExchange(part2, plan) + val exchange3 = ShuffleExchangeExec(part2, plan) val part3 = HashPartitioning(output ++ output, 2) - val exchange4 = ShuffleExchange(part3, plan) + val exchange4 = ShuffleExchangeExec(part3, plan) val exchange5 = ReusedExchangeExec(output, exchange4) assert(exchange1 sameResult exchange1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 63e17c7f372b0..c25c90d0c70e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftOuter, import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.columnar.InMemoryRelation -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchange} +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchangeExec, ReuseExchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -214,7 +214,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (small.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: ShuffleExchange => exchange + case exchange: ShuffleExchangeExec => exchange }.length assert(numExchanges === 5) } @@ -229,7 +229,7 @@ class PlannerSuite extends SharedSQLContext { | JOIN tiny ON (normal.key = tiny.key) """.stripMargin ).queryExecution.executedPlan.collect { - case exchange: ShuffleExchange => exchange + case exchange: ShuffleExchangeExec => exchange }.length assert(numExchanges === 5) } @@ -300,7 +300,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -338,7 +338,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.isEmpty) { fail(s"Exchange should have been added:\n$outputPlan") } } @@ -358,7 +358,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.nonEmpty) { fail(s"Exchange should not have been added:\n$outputPlan") } } @@ -381,7 +381,7 @@ class PlannerSuite extends SharedSQLContext { ) val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.nonEmpty) { fail(s"No Exchanges should have been added:\n$outputPlan") } } @@ -391,7 +391,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchange(finalPartitioning, + val inputPlan = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -400,7 +400,7 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 2) { fail(s"Topmost Exchange should have been eliminated:\n$outputPlan") } } @@ -411,7 +411,7 @@ class PlannerSuite extends SharedSQLContext { val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val inputPlan = ShuffleExchange(finalPartitioning, + val inputPlan = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -420,17 +420,34 @@ class PlannerSuite extends SharedSQLContext { val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) assertDistributionRequirementsAreSatisfied(outputPlan) - if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size == 1) { fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan") } } + test("EnsureRequirements should respect ClusteredDistribution's num partitioning") { + val distribution = ClusteredDistribution(Literal(1) :: Nil, Some(13)) + // Number of partitions differ + val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 13) + val childPartitioning = HashPartitioning(Literal(1) :: Nil, 5) + assert(!childPartitioning.satisfies(distribution)) + val inputPlan = DummySparkPlan( + children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, + requiredChildDistribution = Seq(distribution), + requiredChildOrdering = Seq(Seq.empty)) + + val outputPlan = EnsureRequirements(spark.sessionState.conf).apply(inputPlan) + val shuffle = outputPlan.collect { case e: ShuffleExchangeExec => e } + assert(shuffle.size === 1) + assert(shuffle.head.newPartitioning === finalPartitioning) + } + test("Reuse exchanges") { val distribution = ClusteredDistribution(Literal(1) :: Nil) val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5) val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5) assert(!childPartitioning.satisfies(distribution)) - val shuffle = ShuffleExchange(finalPartitioning, + val shuffle = ShuffleExchangeExec(finalPartitioning, DummySparkPlan( children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil, requiredChildDistribution = Seq(distribution), @@ -449,7 +466,7 @@ class PlannerSuite extends SharedSQLContext { if (outputPlan.collect { case e: ReusedExchangeExec => true }.size != 1) { fail(s"Should re-use the shuffle:\n$outputPlan") } - if (outputPlan.collect { case e: ShuffleExchange => true }.size != 1) { + if (outputPlan.collect { case e: ShuffleExchangeExec => true }.size != 1) { fail(s"Should have only one shuffle:\n$outputPlan") } @@ -459,14 +476,14 @@ class PlannerSuite extends SharedSQLContext { Literal(1) :: Nil, Inner, None, - ShuffleExchange(finalPartitioning, inputPlan), - ShuffleExchange(finalPartitioning, inputPlan)) + ShuffleExchangeExec(finalPartitioning, inputPlan), + ShuffleExchangeExec(finalPartitioning, inputPlan)) val outputPlan2 = ReuseExchange(spark.sessionState.conf).apply(inputPlan2) if (outputPlan2.collect { case e: ReusedExchangeExec => true }.size != 2) { fail(s"Should re-use the two shuffles:\n$outputPlan2") } - if (outputPlan2.collect { case e: ShuffleExchange => true }.size != 2) { + if (outputPlan2.collect { case e: ShuffleExchangeExec => true }.size != 2) { fail(s"Should have only two shuffles:\n$outputPlan") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index beeee6a97c8dd..bc05dca578c47 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.{Column, Dataset, Row} -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.{Add, Literal, Stack} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.{QueryTest, Row, SaveMode} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeAndComment, CodeGenerator} import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.joins.BroadcastHashJoinExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.expressions.scalalang.typed @@ -30,7 +29,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} -class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { +class WholeStageCodegenSuite extends QueryTest with SharedSQLContext { test("range/filter should be combined") { val df = spark.range(10).filter("id = 1").selectExpr("id + 1") @@ -119,6 +118,37 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) } + test("cache for primitive type should be in WholeStageCodegen with InMemoryTableScanExec") { + import testImplicits._ + + val dsInt = spark.range(3).cache + dsInt.count + val dsIntFilter = dsInt.filter(_ > 0) + val planInt = dsIntFilter.queryExecution.executedPlan + assert(planInt.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec] && + p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .asInstanceOf[InMemoryTableScanExec].supportCodegen).isDefined + ) + assert(dsIntFilter.collect() === Array(1, 2)) + + // cache for string type is not supported for InMemoryTableScanExec + val dsString = spark.range(3).map(_.toString).cache + dsString.count + val dsStringFilter = dsString.filter(_ == "1") + val planString = dsStringFilter.queryExecution.executedPlan + assert(planString.find(p => + p.isInstanceOf[WholeStageCodegenExec] && + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec] && + !p.asInstanceOf[WholeStageCodegenExec].child.asInstanceOf[FilterExec].child + .isInstanceOf[InMemoryTableScanExec]).isDefined + ) + assert(dsStringFilter.collect() === Array("1")) + } + test("SPARK-19512 codegen for comparing structs is incorrect") { // this would raise CompileException before the fix spark.range(10) @@ -151,7 +181,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { } } - def genGroupByCodeGenContext(caseNum: Int): CodegenContext = { + def genGroupByCode(caseNum: Int): CodeAndComment = { val caseExp = (1 to caseNum).map { i => s"case when id > $i and id <= ${i + 1} then 1 else 0 end as v$i" }.toList @@ -176,34 +206,34 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { }) assert(wholeStageCodeGenExec.isDefined) - wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._1 + wholeStageCodeGenExec.get.asInstanceOf[WholeStageCodegenExec].doCodeGen()._2 } - test("SPARK-21603 check there is a too long generated function") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { - val ctx = genGroupByCodeGenContext(30) - assert(ctx.isTooLongGeneratedFunction === true) - } + test("SPARK-21871 check if we can get large code size when compiling too long functions") { + val codeWithShortFunctions = genGroupByCode(3) + val (_, maxCodeSize1) = CodeGenerator.compile(codeWithShortFunctions) + assert(maxCodeSize1 < SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) + val codeWithLongFunctions = genGroupByCode(20) + val (_, maxCodeSize2) = CodeGenerator.compile(codeWithLongFunctions) + assert(maxCodeSize2 > SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.defaultValue.get) } - test("SPARK-21603 check there is not a too long generated function") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "1500") { - val ctx = genGroupByCodeGenContext(1) - assert(ctx.isTooLongGeneratedFunction === false) - } - } - - test("SPARK-21603 check there is not a too long generated function when threshold is Int.Max") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> Int.MaxValue.toString) { - val ctx = genGroupByCodeGenContext(30) - assert(ctx.isTooLongGeneratedFunction === false) - } - } - - test("SPARK-21603 check there is a too long generated function when threshold is 0") { - withSQLConf(SQLConf.WHOLESTAGE_MAX_LINES_PER_FUNCTION.key -> "0") { - val ctx = genGroupByCodeGenContext(1) - assert(ctx.isTooLongGeneratedFunction === true) + test("bytecode of batch file scan exceeds the limit of WHOLESTAGE_HUGE_METHOD_LIMIT") { + import testImplicits._ + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(10).select(Seq.tabulate(201) {i => ('id + i).as(s"c$i")} : _*) + df.write.mode(SaveMode.Overwrite).parquet(path) + + withSQLConf(SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> "202", + SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key -> "2000") { + // wide table batch scan causes the byte code of codegen exceeds the limit of + // WHOLESTAGE_HUGE_METHOD_LIMIT + val df2 = spark.read.parquet(path) + val fileScan2 = df2.queryExecution.sparkPlan.find(_.isInstanceOf[FileSourceScanExec]).get + assert(fileScan2.asInstanceOf[FileSourceScanExec].supportsBatch) + checkAnswer(df2, df) + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala index 691fa9ac5e1e7..a834b7cd2c69f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala @@ -24,6 +24,7 @@ import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.joins.LongToUnsafeRowMap import org.apache.spark.sql.execution.vectorized.AggregateHashMap +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.hash.Murmur3_x86_32 @@ -106,14 +107,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -148,14 +149,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true) - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -188,14 +189,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -227,14 +228,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -276,14 +277,14 @@ class AggregateBenchmark extends BenchmarkBase { benchmark.addCase(s"codegen = T hashmap = F") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "false") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "false") f() } benchmark.addCase(s"codegen = T hashmap = T") { iter => sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true") + sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enabled", "true") sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true") f() } @@ -301,10 +302,10 @@ class AggregateBenchmark extends BenchmarkBase { */ } - ignore("max function length of wholestagecodegen") { + ignore("max function bytecode size of wholestagecodegen") { val N = 20 << 15 - val benchmark = new Benchmark("max function length of wholestagecodegen", N) + val benchmark = new Benchmark("max function bytecode size", N) def f(): Unit = sparkSession.range(N) .selectExpr( "id", @@ -333,33 +334,34 @@ class AggregateBenchmark extends BenchmarkBase { .sum() .collect() - benchmark.addCase(s"codegen = F") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "false") + benchmark.addCase("codegen = F") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "false") f() } - benchmark.addCase(s"codegen = T maxLinesPerFunction = 10000") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "10000") + benchmark.addCase("codegen = T hugeMethodLimit = 10000") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key, "10000") f() } - benchmark.addCase(s"codegen = T maxLinesPerFunction = 1500") { iter => - sparkSession.conf.set("spark.sql.codegen.wholeStage", "true") - sparkSession.conf.set("spark.sql.codegen.maxLinesPerFunction", "1500") + benchmark.addCase("codegen = T hugeMethodLimit = 1500") { iter => + sparkSession.conf.set(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key, "true") + sparkSession.conf.set(SQLConf.WHOLESTAGE_HUGE_METHOD_LIMIT.key, "1500") f() } benchmark.run() /* - Java HotSpot(TM) 64-Bit Server VM 1.8.0_111-b14 on Windows 7 6.1 - Intel64 Family 6 Model 58 Stepping 9, GenuineIntel - max function length of wholestagecodegen: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ---------------------------------------------------------------------------------------------- - codegen = F 462 / 533 1.4 704.4 1.0X - codegen = T maxLinesPerFunction = 10000 3444 / 3447 0.2 5255.3 0.1X - codegen = T maxLinesPerFunction = 1500 447 / 478 1.5 682.1 1.0X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_31-b13 on Mac OS X 10.10.2 + Intel(R) Core(TM) i7-4578U CPU @ 3.00GHz + + max function bytecode size: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------------ + codegen = F 709 / 803 0.9 1082.1 1.0X + codegen = T hugeMethodLimit = 10000 3485 / 3548 0.2 5317.7 0.2X + codegen = T hugeMethodLimit = 1500 636 / 701 1.0 969.9 1.1X */ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala index 99c6df7389205..69247d7f4e9aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala @@ -20,11 +20,10 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.SparkConf import org.apache.spark.internal.Logging import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation -import org.apache.spark.sql.catalyst.expressions.SubqueryExpression -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.catalog.HiveTableRelation +import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.util.Benchmark /** @@ -66,24 +65,15 @@ object TPCDSQueryBenchmark extends Logging { classLoader = Thread.currentThread().getContextClassLoader) // This is an indirect hack to estimate the size of each query's input by traversing the - // logical plan and adding up the sizes of all tables that appear in the plan. Note that this - // currently doesn't take WITH subqueries into account which might lead to fairly inaccurate - // per-row processing time for those cases. + // logical plan and adding up the sizes of all tables that appear in the plan. val queryRelations = scala.collection.mutable.HashSet[String]() - spark.sql(queryString).queryExecution.logical.map { - case UnresolvedRelation(t: TableIdentifier) => - queryRelations.add(t.table) - case lp: LogicalPlan => - lp.expressions.foreach { _ foreach { - case subquery: SubqueryExpression => - subquery.plan.foreach { - case UnresolvedRelation(t: TableIdentifier) => - queryRelations.add(t.table) - case _ => - } - case _ => - } - } + spark.sql(queryString).queryExecution.analyzed.foreach { + case SubqueryAlias(alias, _: LogicalRelation) => + queryRelations.add(alias) + case LogicalRelation(_, _, Some(catalogTable), _) => + queryRelations.add(catalogTable.identifier.table) + case HiveTableRelation(tableMeta, _, _) => + queryRelations.add(tableMeta.identifier.table) case _ => } val numRows = queryRelations.map(tableSizes.getOrElse(_, 0L)).sum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 8d411eb191cd9..e662e294228db 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -21,8 +21,9 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.{DataFrame, QueryTest, Row} -import org.apache.spark.sql.catalyst.expressions.AttributeSet +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, In} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.execution.{FilterExec, LocalTableScanExec, WholeStageCodegenExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext @@ -429,4 +430,53 @@ class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext { checkAnswer(agg_without_cache, agg_with_cache) } } + + test("SPARK-22249: IN should work also with cached DataFrame") { + val df = spark.range(10).cache() + // with an empty list + assert(df.filter($"id".isin()).count() == 0) + // with a non-empty list + assert(df.filter($"id".isin(2)).count() == 1) + assert(df.filter($"id".isin(2, 3)).count() == 2) + df.unpersist() + val dfNulls = spark.range(10).selectExpr("null as id").cache() + // with null as value for the attribute + assert(dfNulls.filter($"id".isin()).count() == 0) + assert(dfNulls.filter($"id".isin(2, 3)).count() == 0) + dfNulls.unpersist() + } + + test("SPARK-22249: buildFilter should not throw exception when In contains an empty list") { + val attribute = AttributeReference("a", IntegerType)() + val testRelation = InMemoryRelation(false, 1, MEMORY_ONLY, + LocalTableScanExec(Seq(attribute), Nil), None) + val tableScanExec = InMemoryTableScanExec(Seq(attribute), + Seq(In(attribute, Nil)), testRelation) + assert(tableScanExec.partitionFilters.isEmpty) + } + + test("SPARK-22348: table cache should do partition batch pruning") { + Seq("true", "false").foreach { enabled => + withSQLConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> enabled) { + val df1 = Seq((1, 1), (1, 1), (2, 2)).toDF("x", "y") + df1.unpersist() + df1.cache() + + // Push predicate to the cached table. + val df2 = df1.where("y = 3") + + val planBeforeFilter = df2.queryExecution.executedPlan.collect { + case f: FilterExec => f.child + } + assert(planBeforeFilter.head.isInstanceOf[InMemoryTableScanExec]) + + val execPlan = if (enabled == "true") { + WholeStageCodegenExec(planBeforeFilter.head) + } else { + planBeforeFilter.head + } + assert(execPlan.executeCollectPublic().length == 0) + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala index d01bf911e3a77..2d71a42628dfb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/BooleanBitSetSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar.{BOOLEAN, NoopColumnStats} import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.types.BooleanType class BooleanBitSetSuite extends SparkFunSuite { import BooleanBitSet._ @@ -85,6 +87,36 @@ class BooleanBitSetSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(count: Int) { + val builder = TestCompressibleColumnBuilder(new NoopColumnStats, BOOLEAN, BooleanBitSet) + val rows = Seq.fill[InternalRow](count)(makeRandomRow(BOOLEAN)) + val values = rows.map(_.getBoolean(0)) + + rows.foreach(builder.appendFrom(_, 0)) + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(BooleanBitSet.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = BooleanBitSet.decoder(buffer, BOOLEAN) + val columnVector = new OnHeapColumnVector(values.length, BooleanType) + decoder.decompress(columnVector, values.length) + + if (values.nonEmpty) { + values.zipWithIndex.foreach { case (b: Boolean, index: Int) => + assertResult(b, s"Wrong ${index}-th decoded boolean value") { + columnVector.getBoolean(index) + } + } + } + } + test(s"$BooleanBitSet: empty") { skeleton(0) } @@ -104,4 +136,24 @@ class BooleanBitSetSuite extends SparkFunSuite { test(s"$BooleanBitSet: multiple words and 1 more bit") { skeleton(BITS_PER_LONG * 2 + 1) } + + test(s"$BooleanBitSet: empty for decompression()") { + skeletonForDecompress(0) + } + + test(s"$BooleanBitSet: less than 1 word for decompression()") { + skeletonForDecompress(BITS_PER_LONG - 1) + } + + test(s"$BooleanBitSet: exactly 1 word for decompression()") { + skeletonForDecompress(BITS_PER_LONG) + } + + test(s"$BooleanBitSet: multiple whole words for decompression()") { + skeletonForDecompress(BITS_PER_LONG * 2) + } + + test(s"$BooleanBitSet: multiple words and 1 more bit for decompression()") { + skeletonForDecompress(BITS_PER_LONG * 2 + 1) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala index 67139b13d7882..28950b74cf1c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/DictionaryEncodingSuite.scala @@ -23,16 +23,19 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.AtomicType class DictionaryEncodingSuite extends SparkFunSuite { + val nullValue = -1 testDictionaryEncoding(new IntColumnStats, INT) testDictionaryEncoding(new LongColumnStats, LONG) - testDictionaryEncoding(new StringColumnStats, STRING) + testDictionaryEncoding(new StringColumnStats, STRING, false) def testDictionaryEncoding[T <: AtomicType]( columnStats: ColumnStats, - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + testDecompress: Boolean = true) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -113,6 +116,58 @@ class DictionaryEncodingSuite extends SparkFunSuite { } } + def skeletonForDecompress(uniqueValueCount: Int, inputSeq: Seq[Int]) { + if (!testDecompress) return + val builder = TestCompressibleColumnBuilder(columnStats, columnType, DictionaryEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, uniqueValueCount) + val dictValues = stableDistinct(inputSeq) + + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + inputSeq.foreach { i => + if (i == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + builder.appendFrom(rows(i), 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(DictionaryEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = DictionaryEncoding.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) + decoder.decompress(columnVector, inputSeq.length) + + if (inputSeq.nonEmpty) { + inputSeq.zipWithIndex.foreach { case (i: Any, index: Int) => + if (i == nullValue) { + assertResult(true, s"Wrong null ${index}-th position") { + columnVector.isNullAt(index) + } + } else { + columnType match { + case INT => + assertResult(values(i), s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case LONG => + assertResult(values(i), s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => fail("Unsupported type") + } + } + } + } + } + test(s"$DictionaryEncoding with $typeName: empty") { skeleton(0, Seq.empty) } @@ -124,5 +179,18 @@ class DictionaryEncodingSuite extends SparkFunSuite { test(s"$DictionaryEncoding with $typeName: dictionary overflow") { skeleton(DictionaryEncoding.MAX_DICT_SIZE + 1, 0 to DictionaryEncoding.MAX_DICT_SIZE) } + + test(s"$DictionaryEncoding with $typeName: empty for decompress()") { + skeletonForDecompress(0, Seq.empty) + } + + test(s"$DictionaryEncoding with $typeName: simple case for decompress()") { + skeletonForDecompress(2, Seq(0, nullValue, 0, nullValue)) + } + + test(s"$DictionaryEncoding with $typeName: dictionary overflow for decompress()") { + skeletonForDecompress(DictionaryEncoding.MAX_DICT_SIZE + 2, + Seq(nullValue) ++ (0 to DictionaryEncoding.MAX_DICT_SIZE - 1) ++ Seq(nullValue)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala index 411d31fa0e29b..0d9f1fb0c02c9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/IntegralDeltaSuite.scala @@ -21,9 +21,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.IntegralType class IntegralDeltaSuite extends SparkFunSuite { + val nullValue = -1 testIntegralDelta(new IntColumnStats, INT, IntDelta) testIntegralDelta(new LongColumnStats, LONG, LongDelta) @@ -109,6 +111,53 @@ class IntegralDeltaSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(input: Seq[I#InternalType]) { + val builder = TestCompressibleColumnBuilder(columnStats, columnType, scheme) + val row = new GenericInternalRow(1) + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + input.map { value => + if (value == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(scheme.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = scheme.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) + decoder.decompress(columnVector, input.length) + + if (input.nonEmpty) { + input.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (expected: Int, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case (expected: Long, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => + fail("Unsupported type") + } + } + } + test(s"$scheme: empty column") { skeleton(Seq.empty) } @@ -127,5 +176,28 @@ class IntegralDeltaSuite extends SparkFunSuite { val input = Array.fill[Any](10000)(makeRandomValue(columnType)) skeleton(input.map(_.asInstanceOf[I#InternalType])) } + + + test(s"$scheme: empty column for decompress()") { + skeletonForDecompress(Seq.empty) + } + + test(s"$scheme: simple case for decompress()") { + val input = columnType match { + case INT => Seq(2: Int, 1: Int, 2: Int, 130: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, 130: Long) + } + + skeletonForDecompress(input.map(_.asInstanceOf[I#InternalType])) + } + + test(s"$scheme: simple case with null for decompress()") { + val input = columnType match { + case INT => Seq(2: Int, 1: Int, 2: Int, nullValue: Int, 5: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, nullValue: Long, 5: Long) + } + + skeletonForDecompress(input.map(_.asInstanceOf[I#InternalType])) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala new file mode 100644 index 0000000000000..b6f0b5e6277b4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/PassThroughEncodingSuite.scala @@ -0,0 +1,189 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar.compression + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.execution.columnar._ +import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector +import org.apache.spark.sql.types.AtomicType + +class PassThroughSuite extends SparkFunSuite { + val nullValue = -1 + testPassThrough(new ByteColumnStats, BYTE) + testPassThrough(new ShortColumnStats, SHORT) + testPassThrough(new IntColumnStats, INT) + testPassThrough(new LongColumnStats, LONG) + testPassThrough(new FloatColumnStats, FLOAT) + testPassThrough(new DoubleColumnStats, DOUBLE) + + def testPassThrough[T <: AtomicType]( + columnStats: ColumnStats, + columnType: NativeColumnType[T]) { + + val typeName = columnType.getClass.getSimpleName.stripSuffix("$") + + def skeleton(input: Seq[T#InternalType]) { + // ------------- + // Tests encoder + // ------------- + + val builder = TestCompressibleColumnBuilder(columnStats, columnType, PassThrough) + + input.map { value => + val row = new GenericInternalRow(1) + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + + val buffer = builder.build() + // Column type ID + null count + null positions + val headerSize = CompressionScheme.columnHeaderSize(buffer) + + // Compression scheme ID + compressed contents + val compressedSize = 4 + input.size * columnType.defaultSize + + // 4 extra bytes for compression scheme type ID + assertResult(headerSize + compressedSize, "Wrong buffer capacity")(buffer.capacity) + + buffer.position(headerSize) + assertResult(PassThrough.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + if (input.nonEmpty) { + input.foreach { value => + assertResult(value, "Wrong value")(columnType.extract(buffer)) + } + } + + // ------------- + // Tests decoder + // ------------- + + // Rewinds, skips column header and 4 more bytes for compression scheme ID + buffer.rewind().position(headerSize + 4) + + val decoder = PassThrough.decoder(buffer, columnType) + val mutableRow = new GenericInternalRow(1) + + if (input.nonEmpty) { + input.foreach{ + assert(decoder.hasNext) + assertResult(_, "Wrong decoded value") { + decoder.next(mutableRow, 0) + columnType.getField(mutableRow, 0) + } + } + } + assert(!decoder.hasNext) + } + + def skeletonForDecompress(input: Seq[T#InternalType]) { + val builder = TestCompressibleColumnBuilder(columnStats, columnType, PassThrough) + val row = new GenericInternalRow(1) + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + input.map { value => + if (value == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + columnType.setField(row, 0, value) + builder.appendFrom(row, 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(PassThrough.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = PassThrough.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(input.length, columnType.dataType) + decoder.decompress(columnVector, input.length) + + if (input.nonEmpty) { + input.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (expected: Byte, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded byte value") { + columnVector.getByte(index) + } + case (expected: Short, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded short value") { + columnVector.getShort(index) + } + case (expected: Int, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case (expected: Long, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case (expected: Float, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded float value") { + columnVector.getFloat(index) + } + case (expected: Double, index: Int) => + assertResult(expected, s"Wrong ${index}-th decoded double value") { + columnVector.getDouble(index) + } + case _ => fail("Unsupported type") + } + } + } + + test(s"$PassThrough with $typeName: empty column") { + skeleton(Seq.empty) + } + + test(s"$PassThrough with $typeName: long random series") { + val input = Array.fill[Any](10000)(makeRandomValue(columnType)) + skeleton(input.map(_.asInstanceOf[T#InternalType])) + } + + test(s"$PassThrough with $typeName: empty column for decompress()") { + skeletonForDecompress(Seq.empty) + } + + test(s"$PassThrough with $typeName: long random series for decompress()") { + val input = Array.fill[Any](10000)(makeRandomValue(columnType)) + skeletonForDecompress(input.map(_.asInstanceOf[T#InternalType])) + } + + test(s"$PassThrough with $typeName: simple case with null for decompress()") { + val input = columnType match { + case BYTE => Seq(2: Byte, 1: Byte, 2: Byte, nullValue.toByte: Byte, 5: Byte) + case SHORT => Seq(2: Short, 1: Short, 2: Short, nullValue.toShort: Short, 5: Short) + case INT => Seq(2: Int, 1: Int, 2: Int, nullValue: Int, 5: Int) + case LONG => Seq(2: Long, 1: Long, 2: Long, nullValue: Long, 5: Long) + case FLOAT => Seq(2: Float, 1: Float, 2: Float, nullValue: Float, 5: Float) + case DOUBLE => Seq(2: Double, 1: Double, 2: Double, nullValue: Double, 5: Double) + } + + skeletonForDecompress(input.map(_.asInstanceOf[T#InternalType])) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala index dffa9b364ebfe..eb1cdd9bbceff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/RunLengthEncodingSuite.scala @@ -21,19 +21,22 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.columnar.ColumnarTestUtils._ +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.types.AtomicType class RunLengthEncodingSuite extends SparkFunSuite { + val nullValue = -1 testRunLengthEncoding(new NoopColumnStats, BOOLEAN) testRunLengthEncoding(new ByteColumnStats, BYTE) testRunLengthEncoding(new ShortColumnStats, SHORT) testRunLengthEncoding(new IntColumnStats, INT) testRunLengthEncoding(new LongColumnStats, LONG) - testRunLengthEncoding(new StringColumnStats, STRING) + testRunLengthEncoding(new StringColumnStats, STRING, false) def testRunLengthEncoding[T <: AtomicType]( columnStats: ColumnStats, - columnType: NativeColumnType[T]) { + columnType: NativeColumnType[T], + testDecompress: Boolean = true) { val typeName = columnType.getClass.getSimpleName.stripSuffix("$") @@ -95,6 +98,72 @@ class RunLengthEncodingSuite extends SparkFunSuite { assert(!decoder.hasNext) } + def skeletonForDecompress(uniqueValueCount: Int, inputRuns: Seq[(Int, Int)]) { + if (!testDecompress) return + val builder = TestCompressibleColumnBuilder(columnStats, columnType, RunLengthEncoding) + val (values, rows) = makeUniqueValuesAndSingleValueRows(columnType, uniqueValueCount) + val inputSeq = inputRuns.flatMap { case (index, run) => + Seq.fill(run)(index) + } + + val nullRow = new GenericInternalRow(1) + nullRow.setNullAt(0) + inputSeq.foreach { i => + if (i == nullValue) { + builder.appendFrom(nullRow, 0) + } else { + builder.appendFrom(rows(i), 0) + } + } + val buffer = builder.build() + + // ---------------- + // Tests decompress + // ---------------- + // Rewinds, skips column header and 4 more bytes for compression scheme ID + val headerSize = CompressionScheme.columnHeaderSize(buffer) + buffer.position(headerSize) + assertResult(RunLengthEncoding.typeId, "Wrong compression scheme ID")(buffer.getInt()) + + val decoder = RunLengthEncoding.decoder(buffer, columnType) + val columnVector = new OnHeapColumnVector(inputSeq.length, columnType.dataType) + decoder.decompress(columnVector, inputSeq.length) + + if (inputSeq.nonEmpty) { + inputSeq.zipWithIndex.foreach { + case (expected: Any, index: Int) if expected == nullValue => + assertResult(true, s"Wrong null ${index}th-position") { + columnVector.isNullAt(index) + } + case (i: Int, index: Int) => + columnType match { + case BOOLEAN => + assertResult(values(i), s"Wrong ${index}-th decoded boolean value") { + columnVector.getBoolean(index) + } + case BYTE => + assertResult(values(i), s"Wrong ${index}-th decoded byte value") { + columnVector.getByte(index) + } + case SHORT => + assertResult(values(i), s"Wrong ${index}-th decoded short value") { + columnVector.getShort(index) + } + case INT => + assertResult(values(i), s"Wrong ${index}-th decoded int value") { + columnVector.getInt(index) + } + case LONG => + assertResult(values(i), s"Wrong ${index}-th decoded long value") { + columnVector.getLong(index) + } + case _ => fail("Unsupported type") + } + case _ => fail("Unsupported type") + } + } + } + test(s"$RunLengthEncoding with $typeName: empty column") { skeleton(0, Seq.empty) } @@ -110,5 +179,21 @@ class RunLengthEncodingSuite extends SparkFunSuite { test(s"$RunLengthEncoding with $typeName: single long run") { skeleton(1, Seq(0 -> 1000)) } + + test(s"$RunLengthEncoding with $typeName: empty column for decompress()") { + skeletonForDecompress(0, Seq.empty) + } + + test(s"$RunLengthEncoding with $typeName: simple case for decompress()") { + skeletonForDecompress(2, Seq(0 -> 2, 1 -> 2)) + } + + test(s"$RunLengthEncoding with $typeName: single long run for decompress()") { + skeletonForDecompress(1, Seq(0 -> 1000)) + } + + test(s"$RunLengthEncoding with $typeName: single case with null for decompress()") { + skeletonForDecompress(2, Seq(0 -> 2, nullValue -> 2)) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala index 5e078f251375a..310cb0be5f5a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/compression/TestCompressibleColumnBuilder.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.columnar.compression import org.apache.spark.sql.execution.columnar._ -import org.apache.spark.sql.types.AtomicType +import org.apache.spark.sql.types.{AtomicType, DataType} class TestCompressibleColumnBuilder[T <: AtomicType]( override val columnStats: ColumnStats, @@ -42,3 +42,10 @@ object TestCompressibleColumnBuilder { builder } } + +object ColumnBuilderHelper { + def apply( + dataType: DataType, batchSize: Int, name: String, useCompression: Boolean): ColumnBuilder = { + ColumnBuilder(dataType, batchSize, name, useCompression) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index fa5172ca8a3e7..eb7c33590b602 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -525,6 +525,25 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { assert(e.message.contains("you can only specify one of them.")) } + test("create table - byte length literal table name") { + val sql = "CREATE TABLE 1m.2g(a INT) USING parquet" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("2g", Some("1m")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType), + provider = Some("parquet")) + + parser.parsePlan(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + test("insert overwrite directory") { val v1 = "INSERT OVERWRITE DIRECTORY '/tmp/file' USING parquet SELECT 1 as a" parser.parsePlan(v1) match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index d19cfeef7d19f..21a2c62929146 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -795,7 +795,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { checkAnswer(spark.table("teachers"), df) } - test("rename temporary table - destination table with database name") { + test("rename temporary view - destination table with database name") { withTempView("tab1") { sql( """ @@ -812,7 +812,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE tab1 RENAME TO default.tab2") } assert(e.getMessage.contains( - "RENAME TEMPORARY TABLE from '`tab1`' to '`default`.`tab2`': " + + "RENAME TEMPORARY VIEW from '`tab1`' to '`default`.`tab2`': " + "cannot specify database name 'default' in the destination table")) val catalog = spark.sessionState.catalog @@ -820,7 +820,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("rename temporary table") { + test("rename temporary view") { withTempView("tab1", "tab2") { spark.range(10).createOrReplaceTempView("tab1") sql("ALTER TABLE tab1 RENAME TO tab2") @@ -832,7 +832,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } - test("rename temporary table - destination table already exists") { + test("rename temporary view - destination table already exists") { withTempView("tab1", "tab2") { sql( """ @@ -860,7 +860,7 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { sql("ALTER TABLE tab1 RENAME TO tab2") } assert(e.getMessage.contains( - "RENAME TEMPORARY TABLE from '`tab1`' to '`tab2`': destination table already exists")) + "RENAME TEMPORARY VIEW from '`tab1`' to '`tab2`': destination table already exists")) val catalog = spark.sessionState.catalog assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"), TableIdentifier("tab2"))) @@ -2202,56 +2202,64 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { } } + protected def testAddColumn(provider: String): Unit = { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int) USING $provider") + sql("INSERT INTO t1 VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 is null"), + Seq(Row(1, null)) + ) + + sql("INSERT INTO t1 VALUES (3, 2)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 2"), + Seq(Row(3, 2)) + ) + } + } + + protected def testAddColumnPartitioned(provider: String): Unit = { + withTable("t1") { + sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") + sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") + sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") + checkAnswer( + spark.table("t1"), + Seq(Row(1, null, 2)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 is null"), + Seq(Row(1, null, 2)) + ) + sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") + checkAnswer( + sql("SELECT * FROM t1 WHERE c3 = 3"), + Seq(Row(2, 3, 1)) + ) + checkAnswer( + sql("SELECT * FROM t1 WHERE c2 = 1"), + Seq(Row(2, 3, 1)) + ) + } + } + val supportedNativeFileFormatsForAlterTableAddColumns = Seq("parquet", "json", "csv") supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => test(s"alter datasource table add columns - $provider") { - withTable("t1") { - sql(s"CREATE TABLE t1 (c1 int) USING $provider") - sql("INSERT INTO t1 VALUES (1)") - sql("ALTER TABLE t1 ADD COLUMNS (c2 int)") - checkAnswer( - spark.table("t1"), - Seq(Row(1, null)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 is null"), - Seq(Row(1, null)) - ) - - sql("INSERT INTO t1 VALUES (3, 2)") - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 = 2"), - Seq(Row(3, 2)) - ) - } + testAddColumn(provider) } } supportedNativeFileFormatsForAlterTableAddColumns.foreach { provider => test(s"alter datasource table add columns - partitioned - $provider") { - withTable("t1") { - sql(s"CREATE TABLE t1 (c1 int, c2 int) USING $provider PARTITIONED BY (c2)") - sql("INSERT INTO t1 PARTITION(c2 = 2) VALUES (1)") - sql("ALTER TABLE t1 ADD COLUMNS (c3 int)") - checkAnswer( - spark.table("t1"), - Seq(Row(1, null, 2)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c3 is null"), - Seq(Row(1, null, 2)) - ) - sql("INSERT INTO t1 PARTITION(c2 =1) VALUES (2, 3)") - checkAnswer( - sql("SELECT * FROM t1 WHERE c3 = 3"), - Seq(Row(2, 3, 1)) - ) - checkAnswer( - sql("SELECT * FROM t1 WHERE c2 = 1"), - Seq(Row(2, 3, 1)) - ) - } + testAddColumnPartitioned(provider) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala new file mode 100644 index 0000000000000..bf3c8ede9a980 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/BasicWriteTaskStatsTrackerSuite.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.nio.charset.Charset + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkFunSuite +import org.apache.spark.util.Utils + +/** + * Test how BasicWriteTaskStatsTracker handles files. + * + * Two different datasets are written (alongside 0), one of + * length 10, one of 3. They were chosen to be distinct enough + * that it is straightforward to determine which file lengths were added + * from the sum of all files added. Lengths like "10" and "5" would + * be less informative. + */ +class BasicWriteTaskStatsTrackerSuite extends SparkFunSuite { + + private val tempDir = Utils.createTempDir() + private val tempDirPath = new Path(tempDir.toURI) + private val conf = new Configuration() + private val localfs = tempDirPath.getFileSystem(conf) + private val data1 = "0123456789".getBytes(Charset.forName("US-ASCII")) + private val data2 = "012".getBytes(Charset.forName("US-ASCII")) + private val len1 = data1.length + private val len2 = data2.length + + /** + * In teardown delete the temp dir. + */ + protected override def afterAll(): Unit = { + Utils.deleteRecursively(tempDir) + } + + /** + * Assert that the stats match that expected. + * @param tracker tracker to check + * @param files number of files expected + * @param bytes total number of bytes expected + */ + private def assertStats( + tracker: BasicWriteTaskStatsTracker, + files: Int, + bytes: Int): Unit = { + val stats = finalStatus(tracker) + assert(files === stats.numFiles, "Wrong number of files") + assert(bytes === stats.numBytes, "Wrong byte count of file size") + } + + private def finalStatus(tracker: BasicWriteTaskStatsTracker): BasicWriteTaskStats = { + tracker.getFinalStats().asInstanceOf[BasicWriteTaskStats] + } + + test("No files in run") { + val tracker = new BasicWriteTaskStatsTracker(conf) + assertStats(tracker, 0, 0) + } + + test("Missing File") { + val missing = new Path(tempDirPath, "missing") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(missing.toString) + assertStats(tracker, 0, 0) + } + + test("Empty filename is forwarded") { + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile("") + intercept[IllegalArgumentException] { + finalStatus(tracker) + } + } + + test("Null filename is only picked up in final status") { + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(null) + intercept[IllegalArgumentException] { + finalStatus(tracker) + } + } + + test("0 byte file") { + val file = new Path(tempDirPath, "file0") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + touch(file) + assertStats(tracker, 1, 0) + } + + test("File with data") { + val file = new Path(tempDirPath, "file-with-data") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + write1(file) + assertStats(tracker, 1, len1) + } + + test("Open file") { + val file = new Path(tempDirPath, "file-open") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file.toString) + val stream = localfs.create(file, true) + try { + assertStats(tracker, 1, 0) + stream.write(data1) + stream.flush() + assert(1 === finalStatus(tracker).numFiles, "Wrong number of files") + } finally { + stream.close() + } + } + + test("Two files") { + val file1 = new Path(tempDirPath, "f-2-1") + val file2 = new Path(tempDirPath, "f-2-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file1.toString) + write1(file1) + tracker.newFile(file2.toString) + write2(file2) + assertStats(tracker, 2, len1 + len2) + } + + test("Three files, last one empty") { + val file1 = new Path(tempDirPath, "f-3-1") + val file2 = new Path(tempDirPath, "f-3-2") + val file3 = new Path(tempDirPath, "f-3-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + tracker.newFile(file1.toString) + write1(file1) + tracker.newFile(file2.toString) + write2(file2) + tracker.newFile(file3.toString) + touch(file3) + assertStats(tracker, 3, len1 + len2) + } + + test("Three files, one not found") { + val file1 = new Path(tempDirPath, "f-4-1") + val file2 = new Path(tempDirPath, "f-4-2") + val file3 = new Path(tempDirPath, "f-3-2") + val tracker = new BasicWriteTaskStatsTracker(conf) + // file 1 + tracker.newFile(file1.toString) + write1(file1) + + // file 2 is noted, but not created + tracker.newFile(file2.toString) + + // file 3 is noted & then created + tracker.newFile(file3.toString) + write2(file3) + + // the expected size is file1 + file3; only two files are reported + // as found + assertStats(tracker, 2, len1 + len2) + } + + /** + * Write a 0-byte file. + * @param file file path + */ + private def touch(file: Path): Unit = { + localfs.create(file, true).close() + } + + /** + * Write a byte array. + * @param file path to file + * @param data data + * @return bytes written + */ + private def write(file: Path, data: Array[Byte]): Integer = { + val stream = localfs.create(file, true) + try { + stream.write(data) + } finally { + stream.close() + } + data.length + } + + /** + * Write a data1 array. + * @param file file + */ + private def write1(file: Path): Unit = { + write(file, data1) + } + + /** + * Write a data2 array. + * + * @param file file + */ + private def write2(file: Path): Unit = { + write(file, data2) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala index a0c1ea63d3827..13f0e0bca86c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileFormatWriterSuite.scala @@ -17,10 +17,11 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.sql.QueryTest +import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.test.SharedSQLContext class FileFormatWriterSuite extends QueryTest with SharedSQLContext { + import testImplicits._ test("empty file should be skipped while write to file") { withTempPath { path => @@ -30,4 +31,17 @@ class FileFormatWriterSuite extends QueryTest with SharedSQLContext { assert(partFiles.length === 2) } } + + test("SPARK-22252: FileFormatWriter should respect the input query schema") { + withTable("t1", "t2", "t3", "t4") { + spark.range(1).select('id as 'col1, 'id as 'col2).write.saveAsTable("t1") + spark.sql("select COL1, COL2 from t1").write.saveAsTable("t2") + checkAnswer(spark.table("t2"), Row(0, 0)) + + // Test picking part of the columns when writing. + spark.range(1).select('id, 'id as 'col1, 'id as 'col2).write.saveAsTable("t3") + spark.sql("select COL1, COL2 from t3").write.saveAsTable("t4") + checkAnswer(spark.table("t4"), Row(0, 0)) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala new file mode 100644 index 0000000000000..caa4f6d70c6a9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCommitterSuite.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.FileNotFoundException + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, Path} +import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.output.FileOutputCommitter +import org.apache.parquet.hadoop.{ParquetOutputCommitter, ParquetOutputFormat} + +import org.apache.spark.{LocalSparkContext, SparkFunSuite} +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +/** + * Test logic related to choice of output committers. + */ +class ParquetCommitterSuite extends SparkFunSuite with SQLTestUtils + with LocalSparkContext { + + private val PARQUET_COMMITTER = classOf[ParquetOutputCommitter].getCanonicalName + + protected var spark: SparkSession = _ + + /** + * Create a new [[SparkSession]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + spark = SparkSession.builder() + .master("local-cluster[2,1,1024]") + .appName("testing") + .getOrCreate() + } + + override def afterAll(): Unit = { + try { + if (spark != null) { + spark.stop() + spark = null + } + } finally { + super.afterAll() + } + } + + test("alternative output committer, merge schema") { + writeDataFrame(MarkingFileOutput.COMMITTER, summary = true, check = true) + } + + test("alternative output committer, no merge schema") { + writeDataFrame(MarkingFileOutput.COMMITTER, summary = false, check = true) + } + + test("Parquet output committer, merge schema") { + writeDataFrame(PARQUET_COMMITTER, summary = true, check = false) + } + + test("Parquet output committer, no merge schema") { + writeDataFrame(PARQUET_COMMITTER, summary = false, check = false) + } + + /** + * Write a trivial dataframe as Parquet, using the given committer + * and job summary option. + * @param committer committer to use + * @param summary create a job summary + * @param check look for a marker file + * @return if a marker file was sought, it's file status. + */ + private def writeDataFrame( + committer: String, + summary: Boolean, + check: Boolean): Option[FileStatus] = { + var result: Option[FileStatus] = None + withSQLConf( + SQLConf.PARQUET_OUTPUT_COMMITTER_CLASS.key -> committer, + ParquetOutputFormat.ENABLE_JOB_SUMMARY -> summary.toString) { + withTempPath { dest => + val df = spark.createDataFrame(Seq((1, "4"), (2, "2"))) + val destPath = new Path(dest.toURI) + df.write.format("parquet").save(destPath.toString) + if (check) { + result = Some(MarkingFileOutput.checkMarker( + destPath, + spark.sparkContext.hadoopConfiguration)) + } + } + } + result + } +} + +/** + * A file output committer which explicitly touches a file "marker"; this + * is how tests can verify that this committer was used. + * @param outputPath output path + * @param context task context + */ +private class MarkingFileOutputCommitter( + outputPath: Path, + context: TaskAttemptContext) extends FileOutputCommitter(outputPath, context) { + + override def commitJob(context: JobContext): Unit = { + super.commitJob(context) + MarkingFileOutput.touch(outputPath, context.getConfiguration) + } +} + +private object MarkingFileOutput { + + val COMMITTER = classOf[MarkingFileOutputCommitter].getCanonicalName + + /** + * Touch the marker. + * @param outputPath destination directory + * @param conf configuration to create the FS with + */ + def touch(outputPath: Path, conf: Configuration): Unit = { + outputPath.getFileSystem(conf).create(new Path(outputPath, "marker")).close() + } + + /** + * Get the file status of the marker + * + * @param outputPath destination directory + * @param conf configuration to create the FS with + * @return the status of the marker + * @throws FileNotFoundException if the marker is absent + */ + def checkMarker(outputPath: Path, conf: Configuration): FileStatus = { + outputPath.getFileSystem(conf).getFileStatus(new Path(outputPath, "marker")) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 0dc612ef735fa..58a194b8af62b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -227,8 +227,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLMetricsTestUtils with Shared val df = df1.join(broadcast(df2), "key") testSparkPlanMetrics(df, 2, Map( 1L -> (("BroadcastHashJoin", Map( - "number of output rows" -> 2L, - "avg hash probe (min, med, max)" -> "\n(1, 1, 1)")))) + "number of output rows" -> 2L)))) ) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala index 153e6e1f88c70..95b21fc9f16ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -109,4 +109,4 @@ class MyDummyPythonUDF extends UserDefinedPythonFunction( name = "dummyUDF", func = new DummyUDF, dataType = BooleanType, - vectorized = false) + pythonUdfType = PythonUdfType.NORMAL_UDF) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index defb9ed63a881..65b39f0fbd73d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -214,7 +214,7 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn path: String, queryRunId: UUID = UUID.randomUUID, version: Int = 0): StatefulOperatorStateInfo = { - StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version) + StatefulOperatorStateInfo(path, queryRunId, operatorId = 0, version, numPartitions = 5) } private val increment = (store: StateStore, iter: Iterator[String]) => { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index ffa4c3c22a194..c0216a2ef3e61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -137,14 +137,16 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter BoundReference( 1, inputValueAttribWithWatermark.dataType, inputValueAttribWithWatermark.nullable), Literal(threshold)) - manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + val iter = manager.removeByKeyCondition(GeneratePredicate.generate(expr).eval _) + while (iter.hasNext) iter.next() } /** Remove values where `time <= threshold` */ def removeByValue(watermark: Long)(implicit manager: SymmetricHashJoinStateManager): Unit = { val expr = LessThanOrEqual(inputValueAttribWithWatermark, Literal(watermark)) - manager.removeByValueCondition( + val iter = manager.removeByValueCondition( GeneratePredicate.generate(expr, inputValueAttribs).eval _) + while (iter.hasNext) iter.next() } def numRows(implicit manager: SymmetricHashJoinStateManager): Long = { @@ -158,7 +160,7 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter withTempDir { file => val storeConf = new StateStoreConf() - val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0) + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) val manager = new SymmetricHashJoinStateManager( LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration) try { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index f7b06c97f9db6..c5c8ae3a17c6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -20,24 +20,39 @@ package org.apache.spark.sql.execution.vectorized import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.execution.columnar.ColumnAccessor +import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { + private def withVector( + vector: WritableColumnVector)( + block: WritableColumnVector => Unit): Unit = { + try block(vector) finally vector.close() + } - var testVector: WritableColumnVector = _ - - private def allocate(capacity: Int, dt: DataType): WritableColumnVector = { - new OnHeapColumnVector(capacity, dt) + private def withVectors( + size: Int, + dt: DataType)( + block: WritableColumnVector => Unit): Unit = { + withVector(new OnHeapColumnVector(size, dt))(block) + withVector(new OffHeapColumnVector(size, dt))(block) } - override def afterEach(): Unit = { - testVector.close() + private def testVectors( + name: String, + size: Int, + dt: DataType)( + block: WritableColumnVector => Unit): Unit = { + test(name) { + withVectors(size, dt)(block) + } } - test("boolean") { - testVector = allocate(10, BooleanType) + testVectors("boolean", 10, BooleanType) { testVector => (0 until 10).foreach { i => testVector.appendBoolean(i % 2 == 0) } @@ -49,8 +64,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("byte") { - testVector = allocate(10, ByteType) + testVectors("byte", 10, ByteType) { testVector => (0 until 10).foreach { i => testVector.appendByte(i.toByte) } @@ -58,12 +72,11 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val array = new ColumnVector.Array(testVector) (0 until 10).foreach { i => - assert(array.get(i, ByteType) === (i.toByte)) + assert(array.get(i, ByteType) === i.toByte) } } - test("short") { - testVector = allocate(10, ShortType) + testVectors("short", 10, ShortType) { testVector => (0 until 10).foreach { i => testVector.appendShort(i.toShort) } @@ -71,12 +84,11 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val array = new ColumnVector.Array(testVector) (0 until 10).foreach { i => - assert(array.get(i, ShortType) === (i.toShort)) + assert(array.get(i, ShortType) === i.toShort) } } - test("int") { - testVector = allocate(10, IntegerType) + testVectors("int", 10, IntegerType) { testVector => (0 until 10).foreach { i => testVector.appendInt(i) } @@ -88,8 +100,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("long") { - testVector = allocate(10, LongType) + testVectors("long", 10, LongType) { testVector => (0 until 10).foreach { i => testVector.appendLong(i) } @@ -101,8 +112,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("float") { - testVector = allocate(10, FloatType) + testVectors("float", 10, FloatType) { testVector => (0 until 10).foreach { i => testVector.appendFloat(i.toFloat) } @@ -114,8 +124,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("double") { - testVector = allocate(10, DoubleType) + testVectors("double", 10, DoubleType) { testVector => (0 until 10).foreach { i => testVector.appendDouble(i.toDouble) } @@ -127,8 +136,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("string") { - testVector = allocate(10, StringType) + testVectors("string", 10, StringType) { testVector => (0 until 10).map { i => val utf8 = s"str$i".getBytes("utf8") testVector.appendByteArray(utf8, 0, utf8.length) @@ -141,8 +149,7 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("binary") { - testVector = allocate(10, BinaryType) + testVectors("binary", 10, BinaryType) { testVector => (0 until 10).map { i => val utf8 = s"str$i".getBytes("utf8") testVector.appendByteArray(utf8, 0, utf8.length) @@ -156,9 +163,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("array") { - val arrayType = ArrayType(IntegerType, true) - testVector = allocate(10, arrayType) + val arrayType: ArrayType = ArrayType(IntegerType, containsNull = true) + testVectors("array", 10, arrayType) { testVector => val data = testVector.arrayData() var i = 0 @@ -181,9 +187,8 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(3, 4, 5)) } - test("struct") { - val schema = new StructType().add("int", IntegerType).add("double", DoubleType) - testVector = allocate(10, schema) + val structType: StructType = new StructType().add("int", IntegerType).add("double", DoubleType) + testVectors("struct", 10, structType) { testVector => val c1 = testVector.getChildColumn(0) val c2 = testVector.getChildColumn(1) c1.putInt(0, 123) @@ -193,35 +198,203 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val array = new ColumnVector.Array(testVector) - assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) - assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) - assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) - assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) + assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 123) + assert(array.get(0, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 3.45) + assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(0, IntegerType) === 456) + assert(array.get(1, structType).asInstanceOf[ColumnarBatch.Row].get(1, DoubleType) === 5.67) } test("[SPARK-22092] off-heap column vector reallocation corrupts array data") { - val arrayType = ArrayType(IntegerType, true) - testVector = new OffHeapColumnVector(8, arrayType) + withVector(new OffHeapColumnVector(8, arrayType)) { testVector => + val data = testVector.arrayData() + (0 until 8).foreach(i => data.putInt(i, i)) + (0 until 8).foreach(i => testVector.putArray(i, i, 1)) + + // Increase vector's capacity and reallocate the data to new bigger buffers. + testVector.reserve(16) + + // Check that none of the values got lost/overwritten. + val array = new ColumnVector.Array(testVector) + (0 until 8).foreach { i => + assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) + } + } + } - val data = testVector.arrayData() - (0 until 8).foreach(i => data.putInt(i, i)) - (0 until 8).foreach(i => testVector.putArray(i, i, 1)) + test("[SPARK-22092] off-heap column vector reallocation corrupts struct nullability") { + withVector(new OffHeapColumnVector(8, structType)) { testVector => + (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i)) + testVector.reserve(16) + (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) + } + } - // Increase vector's capacity and reallocate the data to new bigger buffers. - testVector.reserve(16) + test("CachedBatch boolean Apis") { + val dataType = BooleanType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) - // Check that none of the values got lost/overwritten. - val array = new ColumnVector.Array(testVector) - (0 until 8).foreach { i => - assert(array.get(i, arrayType).asInstanceOf[ArrayData].toIntArray() === Array(i)) + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setBoolean(0, i % 2 == 0) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getBoolean(i) == (i % 2 == 0)) + } } } - test("[SPARK-22092] off-heap column vector reallocation corrupts struct nullability") { - val structType = new StructType().add("int", IntegerType).add("double", DoubleType) - testVector = new OffHeapColumnVector(8, structType) - (0 until 8).foreach(i => if (i % 2 == 0) testVector.putNull(i) else testVector.putNotNull(i)) - testVector.reserve(16) - (0 until 8).foreach(i => assert(testVector.isNullAt(i) == (i % 2 == 0))) + test("CachedBatch byte Apis") { + val dataType = ByteType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setByte(0, i.toByte) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getByte(i) == i) + } + } + } + + test("CachedBatch short Apis") { + val dataType = ShortType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setShort(0, i.toShort) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getShort(i) == i) + } + } + } + + test("CachedBatch int Apis") { + val dataType = IntegerType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setInt(0, i) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getInt(i) == i) + } + } + } + + test("CachedBatch long Apis") { + val dataType = LongType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setLong(0, i.toLong) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getLong(i) == i.toLong) + } + } + } + + test("CachedBatch float Apis") { + val dataType = FloatType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setFloat(0, i.toFloat) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getFloat(i) == i.toFloat) + } + } + } + + test("CachedBatch double Apis") { + val dataType = DoubleType + val columnBuilder = ColumnBuilderHelper(dataType, 1024, "col", true) + val row = new SpecificInternalRow(Array(dataType)) + + row.setNullAt(0) + columnBuilder.appendFrom(row, 0) + for (i <- 1 until 16) { + row.setDouble(0, i.toDouble) + columnBuilder.appendFrom(row, 0) + } + + withVectors(16, dataType) { testVector => + val columnAccessor = ColumnAccessor(dataType, columnBuilder.build) + ColumnAccessor.decompress(columnAccessor, testVector, 16) + + assert(testVector.isNullAt(0) == true) + for (i <- 1 until 16) { + assert(testVector.isNullAt(i) == false) + assert(testVector.getDouble(i) == i.toDouble) + } + } } } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index ebf76613343ba..0b179aa97c479 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -38,7 +38,7 @@ import org.apache.spark.unsafe.types.CalendarInterval class ColumnarBatchSuite extends SparkFunSuite { - def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { + private def allocate(capacity: Int, dt: DataType, memMode: MemoryMode): WritableColumnVector = { if (memMode == MemoryMode.OFF_HEAP) { new OffHeapColumnVector(capacity, dt) } else { @@ -46,23 +46,36 @@ class ColumnarBatchSuite extends SparkFunSuite { } } - test("Null Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val reference = mutable.ArrayBuffer.empty[Boolean] + private def testVector( + name: String, + size: Int, + dt: DataType)( + block: (WritableColumnVector, MemoryMode) => Unit): Unit = { + test(name) { + Seq(MemoryMode.ON_HEAP, MemoryMode.OFF_HEAP).foreach { mode => + val vector = allocate(size, dt, mode) + try block(vector, mode) finally { + vector.close() + } + } + } + } - val column = allocate(1024, IntegerType, memMode) + testVector("Null APIs", 1024, IntegerType) { + (column, memMode) => + val reference = mutable.ArrayBuffer.empty[Boolean] var idx = 0 - assert(column.anyNullsSet() == false) + assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNotNull() reference += false - assert(column.anyNullsSet() == false) + assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNotNulls(3) (1 to 3).foreach(_ => reference += false) - assert(column.anyNullsSet() == false) + assert(!column.anyNullsSet()) assert(column.numNulls() == 0) column.appendNull() @@ -113,16 +126,12 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == (Platform.getByte(null, addr + v._2) == 1), "index=" + v._2) } } - column.close - }} } - test("Byte Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Byte APIs", 1024, ByteType) { + (column, memMode) => val reference = mutable.ArrayBuffer.empty[Byte] - val column = allocate(1024, ByteType, memMode) - var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toByte).toArray column.appendBytes(2, values, 0) reference += 10.toByte @@ -170,17 +179,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getByte(null, addr + v._2)) } } - }} } - test("Short Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Short APIs", 1024, ShortType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Short] - val column = allocate(1024, ShortType, memMode) - var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).map(_.toShort).toArray column.appendShorts(2, values, 0) reference += 10.toShort @@ -248,19 +254,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getShort(null, addr + 2 * v._2)) } } - - column.close - }} } - test("Int Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Int APIs", 1024, IntegerType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Int] - val column = allocate(1024, IntegerType, memMode) - var values = (10 :: 20 :: 30 :: 40 :: 50 :: Nil).toArray column.appendInts(2, values, 0) reference += 10 @@ -334,18 +335,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getInt(null, addr + 4 * v._2)) } } - column.close - }} } - test("Long Apis") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Long APIs", 1024, LongType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Long] - val column = allocate(1024, LongType, memMode) - var values = (10L :: 20L :: 30L :: 40L :: 50L :: Nil).toArray column.appendLongs(2, values, 0) reference += 10L @@ -416,23 +413,20 @@ class ColumnarBatchSuite extends SparkFunSuite { reference.zipWithIndex.foreach { v => assert(v._1 == column.getLong(v._2), "idx=" + v._2 + - " Seed = " + seed + " MemMode=" + memMode) + " Seed = " + seed + " MemMode=" + memMode) if (memMode == MemoryMode.OFF_HEAP) { val addr = column.valuesNativeAddress() assert(v._1 == Platform.getLong(null, addr + 8 * v._2)) } } - }} } - test("Float APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Float APIs", 1024, FloatType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Float] - val column = allocate(1024, FloatType, memMode) - var values = (.1f :: .2f :: .3f :: .4f :: .5f :: Nil).toArray column.appendFloats(2, values, 0) reference += .1f @@ -512,18 +506,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getFloat(null, addr + 4 * v._2)) } } - column.close - }} } - test("Double APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("Double APIs", 1024, DoubleType) { + (column, memMode) => val seed = System.currentTimeMillis() val random = new Random(seed) val reference = mutable.ArrayBuffer.empty[Double] - val column = allocate(1024, DoubleType, memMode) - var values = (.1 :: .2 :: .3 :: .4 :: .5 :: Nil).toArray column.appendDoubles(2, values, 0) reference += .1 @@ -603,15 +593,12 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(v._1 == Platform.getDouble(null, addr + 8 * v._2)) } } - column.close - }} } - test("String APIs") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { + testVector("String APIs", 6, StringType) { + (column, memMode) => val reference = mutable.ArrayBuffer.empty[String] - val column = allocate(6, BinaryType, memMode) assert(column.arrayData().elementsAppended == 0) val str = "string" @@ -663,15 +650,13 @@ class ColumnarBatchSuite extends SparkFunSuite { column.reset() assert(column.arrayData().elementsAppended == 0) - }} } - test("Int Array") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val column = allocate(10, new ArrayType(IntegerType, true), memMode) + testVector("Int Array", 10, new ArrayType(IntegerType, true)) { + (column, _) => // Fill the underlying data with all the arrays back to back. - val data = column.arrayData(); + val data = column.arrayData() var i = 0 while (i < 6) { data.putInt(i, i) @@ -709,7 +694,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(3).getInt(2) == 5) // Add a longer array which requires resizing - column.reset + column.reset() val array = Array(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) assert(data.capacity == 10) data.reserve(array.length) @@ -718,63 +703,67 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(0, 0, array.length) assert(ColumnVectorUtils.toPrimitiveJavaArray(column.getArray(0)).asInstanceOf[Array[Int]] === array) - }} } test("toArray for primitive types") { - // (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - (MemoryMode.ON_HEAP :: Nil).foreach { memMode => { + (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => val len = 4 val columnBool = allocate(len, new ArrayType(BooleanType, false), memMode) val boolArray = Array(false, true, false, true) - boolArray.zipWithIndex.map { case (v, i) => columnBool.arrayData.putBoolean(i, v) } + boolArray.zipWithIndex.foreach { case (v, i) => columnBool.arrayData.putBoolean(i, v) } columnBool.putArray(0, 0, len) assert(columnBool.getArray(0).toBooleanArray === boolArray) + columnBool.close() val columnByte = allocate(len, new ArrayType(ByteType, false), memMode) val byteArray = Array[Byte](0, 1, 2, 3) - byteArray.zipWithIndex.map { case (v, i) => columnByte.arrayData.putByte(i, v) } + byteArray.zipWithIndex.foreach { case (v, i) => columnByte.arrayData.putByte(i, v) } columnByte.putArray(0, 0, len) assert(columnByte.getArray(0).toByteArray === byteArray) + columnByte.close() val columnShort = allocate(len, new ArrayType(ShortType, false), memMode) val shortArray = Array[Short](0, 1, 2, 3) - shortArray.zipWithIndex.map { case (v, i) => columnShort.arrayData.putShort(i, v) } + shortArray.zipWithIndex.foreach { case (v, i) => columnShort.arrayData.putShort(i, v) } columnShort.putArray(0, 0, len) assert(columnShort.getArray(0).toShortArray === shortArray) + columnShort.close() val columnInt = allocate(len, new ArrayType(IntegerType, false), memMode) val intArray = Array(0, 1, 2, 3) - intArray.zipWithIndex.map { case (v, i) => columnInt.arrayData.putInt(i, v) } + intArray.zipWithIndex.foreach { case (v, i) => columnInt.arrayData.putInt(i, v) } columnInt.putArray(0, 0, len) assert(columnInt.getArray(0).toIntArray === intArray) + columnInt.close() val columnLong = allocate(len, new ArrayType(LongType, false), memMode) val longArray = Array[Long](0, 1, 2, 3) - longArray.zipWithIndex.map { case (v, i) => columnLong.arrayData.putLong(i, v) } + longArray.zipWithIndex.foreach { case (v, i) => columnLong.arrayData.putLong(i, v) } columnLong.putArray(0, 0, len) assert(columnLong.getArray(0).toLongArray === longArray) + columnLong.close() val columnFloat = allocate(len, new ArrayType(FloatType, false), memMode) val floatArray = Array(0.0F, 1.1F, 2.2F, 3.3F) - floatArray.zipWithIndex.map { case (v, i) => columnFloat.arrayData.putFloat(i, v) } + floatArray.zipWithIndex.foreach { case (v, i) => columnFloat.arrayData.putFloat(i, v) } columnFloat.putArray(0, 0, len) assert(columnFloat.getArray(0).toFloatArray === floatArray) + columnFloat.close() val columnDouble = allocate(len, new ArrayType(DoubleType, false), memMode) val doubleArray = Array(0.0, 1.1, 2.2, 3.3) - doubleArray.zipWithIndex.map { case (v, i) => columnDouble.arrayData.putDouble(i, v) } + doubleArray.zipWithIndex.foreach { case (v, i) => columnDouble.arrayData.putDouble(i, v) } columnDouble.putArray(0, 0, len) assert(columnDouble.getArray(0).toDoubleArray === doubleArray) - }} + columnDouble.close() + } } - test("Struct Column") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => { - val schema = new StructType().add("int", IntegerType).add("double", DoubleType) - val column = allocate(1024, schema, memMode) - + testVector( + "Struct Column", + 10, + new StructType().add("int", IntegerType).add("double", DoubleType)) { (column, _) => val c1 = column.getChildColumn(0) val c2 = column.getChildColumn(1) assert(c1.dataType() == IntegerType) @@ -797,13 +786,10 @@ class ColumnarBatchSuite extends SparkFunSuite { val s2 = column.getStruct(1) assert(s2.getInt(0) == 456) assert(s2.getDouble(1) == 5.67) - }} } - test("Nest Array in Array.") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val column = allocate(10, new ArrayType(new ArrayType(IntegerType, true), true), - memMode) + testVector("Nest Array in Array", 10, new ArrayType(new ArrayType(IntegerType, true), true)) { + (column, _) => val childColumn = column.arrayData() val data = column.arrayData().arrayData() (0 until 6).foreach { @@ -829,13 +815,14 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getArray(2).getArray(1).getInt(1) === 4) assert(column.getArray(2).getArray(1).getInt(2) === 5) assert(column.isNullAt(3)) - } } - test("Nest Struct in Array.") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val schema = new StructType().add("int", IntegerType).add("long", LongType) - val column = allocate(10, new ArrayType(schema, true), memMode) + private val structType: StructType = new StructType().add("i", IntegerType).add("l", LongType) + + testVector( + "Nest Struct in Array", + 10, + new ArrayType(structType, true)) { (column, _) => val data = column.arrayData() val c0 = data.getChildColumn(0) val c1 = data.getChildColumn(1) @@ -850,22 +837,21 @@ class ColumnarBatchSuite extends SparkFunSuite { column.putArray(1, 1, 3) column.putArray(2, 4, 2) - assert(column.getArray(0).getStruct(0, 2).toSeq(schema) === Seq(0, 0)) - assert(column.getArray(0).getStruct(1, 2).toSeq(schema) === Seq(1, 10)) - assert(column.getArray(1).getStruct(0, 2).toSeq(schema) === Seq(1, 10)) - assert(column.getArray(1).getStruct(1, 2).toSeq(schema) === Seq(2, 20)) - assert(column.getArray(1).getStruct(2, 2).toSeq(schema) === Seq(3, 30)) - assert(column.getArray(2).getStruct(0, 2).toSeq(schema) === Seq(4, 40)) - assert(column.getArray(2).getStruct(1, 2).toSeq(schema) === Seq(5, 50)) - } + assert(column.getArray(0).getStruct(0, 2).toSeq(structType) === Seq(0, 0)) + assert(column.getArray(0).getStruct(1, 2).toSeq(structType) === Seq(1, 10)) + assert(column.getArray(1).getStruct(0, 2).toSeq(structType) === Seq(1, 10)) + assert(column.getArray(1).getStruct(1, 2).toSeq(structType) === Seq(2, 20)) + assert(column.getArray(1).getStruct(2, 2).toSeq(structType) === Seq(3, 30)) + assert(column.getArray(2).getStruct(0, 2).toSeq(structType) === Seq(4, 40)) + assert(column.getArray(2).getStruct(1, 2).toSeq(structType) === Seq(5, 50)) } - test("Nest Array in Struct.") { - (MemoryMode.ON_HEAP :: MemoryMode.OFF_HEAP :: Nil).foreach { memMode => - val schema = new StructType() - .add("int", IntegerType) - .add("array", new ArrayType(IntegerType, true)) - val column = allocate(10, schema, memMode) + testVector( + "Nest Array in Struct", + 10, + new StructType() + .add("int", IntegerType) + .add("array", new ArrayType(IntegerType, true))) { (column, _) => val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -886,18 +872,15 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getStruct(1).getArray(1).toIntArray() === Array(2)) assert(column.getStruct(2).getInt(0) === 2) assert(column.getStruct(2).getArray(1).toIntArray() === Array(3, 4, 5)) - } } - test("Nest Struct in Struct.") { - (MemoryMode.ON_HEAP :: Nil).foreach { memMode => - val subSchema = new StructType() - .add("int", IntegerType) - .add("int", IntegerType) - val schema = new StructType() - .add("int", IntegerType) - .add("struct", subSchema) - val column = allocate(10, schema, memMode) + private val subSchema: StructType = new StructType() + .add("int", IntegerType) + .add("int", IntegerType) + testVector( + "Nest Struct in Struct", + 10, + new StructType().add("int", IntegerType).add("struct", subSchema)) { (column, _) => val c0 = column.getChildColumn(0) val c1 = column.getChildColumn(1) c0.putInt(0, 0) @@ -919,7 +902,6 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(column.getStruct(1).getStruct(1, 2).toSeq(subSchema) === Seq(8, 80)) assert(column.getStruct(2).getInt(0) === 2) assert(column.getStruct(2).getStruct(1, 2).toSeq(subSchema) === Seq(9, 90)) - } } test("ColumnarBatch basic") { @@ -1040,7 +1022,7 @@ class ColumnarBatchSuite extends SparkFunSuite { val it4 = batch.rowIterator() rowEquals(it4.next(), Row(null, 2.2, 2, "abc")) - batch.close + batch.close() }} } @@ -1138,7 +1120,7 @@ class ColumnarBatchSuite extends SparkFunSuite { } batch.close() } - }} + }} /** * This test generates a random schema data, serializes it to column batches and verifies the diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index 34205e0b2bf08..167b3e0190026 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -815,6 +815,12 @@ class JDBCSuite extends SparkFunSuite Some(DecimalType(DecimalType.MAX_PRECISION, 10))) assert(oracleDialect.getCatalystType(java.sql.Types.NUMERIC, "numeric", 0, null) == Some(DecimalType(DecimalType.MAX_PRECISION, 10))) + assert(oracleDialect.getCatalystType(OracleDialect.BINARY_FLOAT, "BINARY_FLOAT", 0, null) == + Some(FloatType)) + assert(oracleDialect.getCatalystType(OracleDialect.BINARY_DOUBLE, "BINARY_DOUBLE", 0, null) == + Some(DoubleType)) + assert(oracleDialect.getCatalystType(OracleDialect.TIMESTAMPTZ, "TIMESTAMP", 0, null) == + Some(TimestampType)) } test("table exists query by jdbc dialect") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala index eb9e6458fc61c..ab18905e2ddb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.execution.{DataSourceScanExec, SortExec} import org.apache.spark.sql.execution.datasources.DataSourceStrategy -import org.apache.spark.sql.execution.exchange.ShuffleExchange +import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.SortMergeJoinExec import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -302,10 +302,10 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { // check existence of shuffle assert( - joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft, + joinOperator.left.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined == shuffleLeft, s"expected shuffle in plan to be $shuffleLeft but found\n${joinOperator.left}") assert( - joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight, + joinOperator.right.find(_.isInstanceOf[ShuffleExchangeExec]).isDefined == shuffleRight, s"expected shuffle in plan to be $shuffleRight but found\n${joinOperator.right}") // check existence of sort @@ -506,7 +506,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) } } @@ -520,7 +520,7 @@ abstract class BucketedReadSuite extends QueryTest with SQLTestUtils { agged.sort("i", "j"), df1.groupBy("i", "j").agg(max("k")).sort("i", "j")) - assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty) + assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchangeExec]).isEmpty) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 9ce93d7ae926c..092702a1d5173 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -21,6 +21,7 @@ import java.util.{ArrayList, List => JList} import test.org.apache.spark.sql.sources.v2._ +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.sources.{Filter, GreaterThan} @@ -80,6 +81,74 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } } + + test("simple writable data source") { + // TODO: java implementation. + Seq(classOf[SimpleWritableDataSource]).foreach { cls => + withTempPath { file => + val path = file.getCanonicalPath + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + spark.range(10).select('id, -'id).write.format(cls.getName) + .option("path", path).save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).select('id, -'id)) + + // test with different save modes + spark.range(10).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("append").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(10).union(spark.range(10)).select('id, -'id)) + + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("overwrite").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("ignore").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + + val e = intercept[Exception] { + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).mode("error").save() + } + assert(e.getMessage.contains("data already exists")) + + // test transaction + val failingUdf = org.apache.spark.sql.functions.udf { + var count = 0 + (id: Long) => { + if (count > 5) { + throw new RuntimeException("testing error") + } + count += 1 + id + } + } + // this input data will fail to read middle way. + val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i) + val e2 = intercept[SparkException] { + input.write.format(cls.getName).option("path", path).mode("overwrite").save() + } + assert(e2.getMessage.contains("Writing job aborted")) + // make sure we don't have partial data. + assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) + + // test internal row writer + spark.range(5).select('id, -'id).write.format(cls.getName) + .option("path", path).option("internal", "true").mode("overwrite").save() + checkAnswer( + spark.read.format(cls.getName).option("path", path).load(), + spark.range(5).select('id, -'id)) + } + } + } } class SimpleDataSourceV2 extends DataSourceV2 with ReadSupport { @@ -129,6 +198,8 @@ class AdvancedDataSourceV2 extends DataSourceV2 with ReadSupport { Array.empty } + override def pushedFilters(): Array[Filter] = filters + override def readSchema(): StructType = { requiredSchema } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala new file mode 100644 index 0000000000000..6fb60f4d848d7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.io.{BufferedReader, InputStreamReader, IOException} +import java.text.SimpleDateFormat +import java.util.{Collections, Date, List => JList, Locale, Optional, UUID} + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileSystem, FSDataInputStream, Path} + +import org.apache.spark.SparkContext +import org.apache.spark.sql.{Row, SaveMode} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.reader.{DataReader, DataSourceV2Reader, ReadTask} +import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.util.SerializableConfiguration + +/** + * A HDFS based transactional writable data source. + * Each task writes data to `target/_temporary/jobId/$jobId-$partitionId-$attemptNumber`. + * Each job moves files from `target/_temporary/jobId/` to `target`. + */ +class SimpleWritableDataSource extends DataSourceV2 with ReadSupport with WriteSupport { + + private val schema = new StructType().add("i", "long").add("j", "long") + + class Reader(path: String, conf: Configuration) extends DataSourceV2Reader { + override def readSchema(): StructType = schema + + override def createReadTasks(): JList[ReadTask[Row]] = { + val dataPath = new Path(path) + val fs = dataPath.getFileSystem(conf) + if (fs.exists(dataPath)) { + fs.listStatus(dataPath).filterNot { status => + val name = status.getPath.getName + name.startsWith("_") || name.startsWith(".") + }.map { f => + val serializableConf = new SerializableConfiguration(conf) + new SimpleCSVReadTask(f.getPath.toUri.toString, serializableConf): ReadTask[Row] + }.toList.asJava + } else { + Collections.emptyList() + } + } + } + + class Writer(jobId: String, path: String, conf: Configuration) extends DataSourceV2Writer { + override def createWriterFactory(): DataWriterFactory[Row] = { + new SimpleCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + val finalPath = new Path(path) + val jobPath = new Path(new Path(finalPath, "_temporary"), jobId) + val fs = jobPath.getFileSystem(conf) + try { + for (file <- fs.listStatus(jobPath).map(_.getPath)) { + val dest = new Path(finalPath, file.getName) + if(!fs.rename(file, dest)) { + throw new IOException(s"failed to rename($file, $dest)") + } + } + } finally { + fs.delete(jobPath, true) + } + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val fs = jobPath.getFileSystem(conf) + fs.delete(jobPath, true) + } + } + + class InternalRowWriter(jobId: String, path: String, conf: Configuration) + extends Writer(jobId, path, conf) with SupportsWriteInternalRow { + + override def createWriterFactory(): DataWriterFactory[Row] = { + throw new IllegalArgumentException("not expected!") + } + + override def createInternalRowWriterFactory(): DataWriterFactory[InternalRow] = { + new InternalRowCSVDataWriterFactory(path, jobId, new SerializableConfiguration(conf)) + } + } + + override def createReader(options: DataSourceV2Options): DataSourceV2Reader = { + val path = new Path(options.get("path").get()) + val conf = SparkContext.getActive.get.hadoopConfiguration + new Reader(path.toUri.toString, conf) + } + + override def createWriter( + jobId: String, + schema: StructType, + mode: SaveMode, + options: DataSourceV2Options): Optional[DataSourceV2Writer] = { + assert(DataType.equalsStructurally(schema.asNullable, this.schema.asNullable)) + assert(!SparkContext.getActive.get.conf.getBoolean("spark.speculation", false)) + + val path = new Path(options.get("path").get()) + val internal = options.get("internal").isPresent + val conf = SparkContext.getActive.get.hadoopConfiguration + val fs = path.getFileSystem(conf) + + if (mode == SaveMode.ErrorIfExists) { + if (fs.exists(path)) { + throw new RuntimeException("data already exists.") + } + } + if (mode == SaveMode.Ignore) { + if (fs.exists(path)) { + return Optional.empty() + } + } + if (mode == SaveMode.Overwrite) { + fs.delete(path, true) + } + + Optional.of(createWriter(jobId, path, conf, internal)) + } + + private def createWriter( + jobId: String, path: Path, conf: Configuration, internal: Boolean): DataSourceV2Writer = { + val pathStr = path.toUri.toString + if (internal) { + new InternalRowWriter(jobId, pathStr, conf) + } else { + new Writer(jobId, pathStr, conf) + } + } +} + +class SimpleCSVReadTask(path: String, conf: SerializableConfiguration) + extends ReadTask[Row] with DataReader[Row] { + + @transient private var lines: Iterator[String] = _ + @transient private var currentLine: String = _ + @transient private var inputStream: FSDataInputStream = _ + + override def createReader(): DataReader[Row] = { + val filePath = new Path(path) + val fs = filePath.getFileSystem(conf.value) + inputStream = fs.open(filePath) + lines = new BufferedReader(new InputStreamReader(inputStream)) + .lines().iterator().asScala + this + } + + override def next(): Boolean = { + if (lines.hasNext) { + currentLine = lines.next() + true + } else { + false + } + } + + override def get(): Row = Row(currentLine.split(",").map(_.trim.toLong): _*) + + override def close(): Unit = { + inputStream.close() + } +} + +class SimpleCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory[Row] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[Row] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val fs = filePath.getFileSystem(conf.value) + new SimpleCSVDataWriter(fs, filePath) + } +} + +class SimpleCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[Row] { + + private val out = fs.create(file) + + override def write(record: Row): Unit = { + out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") + } + + override def commit(): WriterCommitMessage = { + out.close() + null + } + + override def abort(): Unit = { + try { + out.close() + } finally { + fs.delete(file, false) + } + } +} + +class InternalRowCSVDataWriterFactory(path: String, jobId: String, conf: SerializableConfiguration) + extends DataWriterFactory[InternalRow] { + + override def createWriter(partitionId: Int, attemptNumber: Int): DataWriter[InternalRow] = { + val jobPath = new Path(new Path(path, "_temporary"), jobId) + val filePath = new Path(jobPath, s"$jobId-$partitionId-$attemptNumber") + val fs = filePath.getFileSystem(conf.value) + new InternalRowCSVDataWriter(fs, filePath) + } +} + +class InternalRowCSVDataWriter(fs: FileSystem, file: Path) extends DataWriter[InternalRow] { + + private val out = fs.create(file) + + override def write(record: InternalRow): Unit = { + out.writeBytes(s"${record.getLong(0)},${record.getLong(1)}\n") + } + + override def commit(): WriterCommitMessage = { + out.close() + null + } + + override def abort(): Unit = { + try { + out.close() + } finally { + fs.delete(file, false) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index e858b7d9998a8..caf2bab8a5859 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -19,12 +19,15 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfterAll +import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, HashPartitioning, SinglePartition} import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ -import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.{MemoryStream, StreamingDeduplicateExec} import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.functions._ -class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class DeduplicateSuite extends StateStoreMetricsTest + with BeforeAndAfterAll + with StatefulOperatorTest { import testImplicits._ @@ -41,6 +44,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { AddData(inputData, "a"), CheckLastBatch("a"), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("value"))), AddData(inputData, "a"), CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), @@ -58,6 +63,8 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { AddData(inputData, "a" -> 1), CheckLastBatch("a" -> 1), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StreamingDeduplicateExec](sq, Seq("_1"))), AddData(inputData, "a" -> 2), // Dropped CheckLastBatch(), assertNumStateRows(total = 1, updated = 0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala deleted file mode 100644 index 044bb03480aa4..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EnsureStatefulOpPartitioningSuite.scala +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.streaming - -import java.util.UUID - -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest, UnaryExecNode} -import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchange} -import org.apache.spark.sql.execution.streaming.{IncrementalExecution, OffsetSeqMetadata, StatefulOperator, StatefulOperatorStateInfo} -import org.apache.spark.sql.test.SharedSQLContext - -class EnsureStatefulOpPartitioningSuite extends SparkPlanTest with SharedSQLContext { - - import testImplicits._ - - private var baseDf: DataFrame = null - - override def beforeAll(): Unit = { - super.beforeAll() - baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char") - } - - test("ClusteredDistribution generates Exchange with HashPartitioning") { - testEnsureStatefulOpPartitioning( - baseDf.queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys), - expectedPartitioning = - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), - expectShuffle = true) - } - - test("ClusteredDistribution with coalesce(1) generates Exchange with HashPartitioning") { - testEnsureStatefulOpPartitioning( - baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = keys => ClusteredDistribution(keys), - expectedPartitioning = - keys => HashPartitioning(keys, spark.sessionState.conf.numShufflePartitions), - expectShuffle = true) - } - - test("AllTuples generates Exchange with SinglePartition") { - testEnsureStatefulOpPartitioning( - baseDf.queryExecution.sparkPlan, - requiredDistribution = _ => AllTuples, - expectedPartitioning = _ => SinglePartition, - expectShuffle = true) - } - - test("AllTuples with coalesce(1) doesn't need Exchange") { - testEnsureStatefulOpPartitioning( - baseDf.coalesce(1).queryExecution.sparkPlan, - requiredDistribution = _ => AllTuples, - expectedPartitioning = _ => SinglePartition, - expectShuffle = false) - } - - /** - * For `StatefulOperator` with the given `requiredChildDistribution`, and child SparkPlan - * `inputPlan`, ensures that the incremental planner adds exchanges, if required, in order to - * ensure the expected partitioning. - */ - private def testEnsureStatefulOpPartitioning( - inputPlan: SparkPlan, - requiredDistribution: Seq[Attribute] => Distribution, - expectedPartitioning: Seq[Attribute] => Partitioning, - expectShuffle: Boolean): Unit = { - val operator = TestStatefulOperator(inputPlan, requiredDistribution(inputPlan.output.take(1))) - val executed = executePlan(operator, OutputMode.Complete()) - if (expectShuffle) { - val exchange = executed.children.find(_.isInstanceOf[Exchange]) - if (exchange.isEmpty) { - fail(s"Was expecting an exchange but didn't get one in:\n$executed") - } - assert(exchange.get === - ShuffleExchange(expectedPartitioning(inputPlan.output.take(1)), inputPlan), - s"Exchange didn't have expected properties:\n${exchange.get}") - } else { - assert(!executed.children.exists(_.isInstanceOf[Exchange]), - s"Unexpected exchange found in:\n$executed") - } - } - - /** Executes a SparkPlan using the IncrementalPlanner used for Structured Streaming. */ - private def executePlan( - p: SparkPlan, - outputMode: OutputMode = OutputMode.Append()): SparkPlan = { - val execution = new IncrementalExecution( - spark, - null, - OutputMode.Complete(), - "chk", - UUID.randomUUID(), - 0L, - OffsetSeqMetadata()) { - override lazy val sparkPlan: SparkPlan = p transform { - case plan: SparkPlan => - val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap - plan transformExpressions { - case UnresolvedAttribute(Seq(u)) => - inputMap.getOrElse(u, - sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) - } - } - } - execution.executedPlan - } -} - -/** Used to emulate a `StatefulOperator` with the given requiredDistribution. */ -case class TestStatefulOperator( - child: SparkPlan, - requiredDist: Distribution) extends UnaryExecNode with StatefulOperator { - override def output: Seq[Attribute] = child.output - override def doExecute(): RDD[InternalRow] = child.execute() - override def requiredChildDistribution: Seq[Distribution] = requiredDist :: Nil - override def stateInfo: Option[StatefulOperatorStateInfo] = None -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala index 9d74a5c701ef1..b906393a379ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FlatMapGroupsWithStateSuite.scala @@ -21,6 +21,7 @@ import java.sql.Date import java.util.concurrent.ConcurrentHashMap import org.scalatest.BeforeAndAfterAll +import org.scalatest.exceptions.TestFailedException import org.apache.spark.SparkException import org.apache.spark.api.java.function.FlatMapGroupsWithStateFunction @@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.streaming.{FlatMapGroupsWithStateExec, GroupStateImpl, MemoryStream} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreId, StateStoreMetrics, UnsafeRowPair} -import org.apache.spark.sql.streaming.FlatMapGroupsWithStateSuite.MemoryStateStore import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.types.{DataType, IntegerType} @@ -41,11 +41,14 @@ case class RunningCount(count: Long) case class Result(key: Long, count: Int) -class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { +class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest + with BeforeAndAfterAll + with StatefulOperatorTest { import testImplicits._ import GroupStateImpl._ import GroupStateTimeout._ + import FlatMapGroupsWithStateSuite._ override def afterAll(): Unit = { super.afterAll() @@ -75,13 +78,15 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // === Tests for state in streaming queries === // Updating empty state - state = GroupStateImpl.createForStreaming(None, 1, 1, NoTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + None, 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false) testState(None) state.update("") testState(Some(""), shouldBeUpdated = true) // Updating exiting state - state = GroupStateImpl.createForStreaming(Some("2"), 1, 1, NoTimeout, hasTimedOut = false) + state = GroupStateImpl.createForStreaming( + Some("2"), 1, 1, NoTimeout, hasTimedOut = false, watermarkPresent = false) testState(Some("2")) state.update("3") testState(Some("3"), shouldBeUpdated = true) @@ -102,8 +107,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("GroupState - setTimeout - with NoTimeout") { for (initValue <- Seq(None, Some(5))) { val states = Seq( - GroupStateImpl.createForStreaming(initValue, 1000, 1000, NoTimeout, hasTimedOut = false), - GroupStateImpl.createForBatch(NoTimeout) + GroupStateImpl.createForStreaming( + initValue, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false), + GroupStateImpl.createForBatch(NoTimeout, watermarkPresent = false) ) for (state <- states) { // for streaming queries @@ -120,7 +126,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("GroupState - setTimeout - with ProcessingTimeTimeout") { // for streaming queries var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( - None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + None, 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(500) assert(state.getTimeoutTimestamp === 1500) // can be set without initializing state @@ -141,7 +147,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) // for batch queries - state = GroupStateImpl.createForBatch(ProcessingTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + state = GroupStateImpl.createForBatch( + ProcessingTimeTimeout, watermarkPresent = false).asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) state.setTimeoutDuration(500) testTimeoutTimestampNotAllowed[UnsupportedOperationException](state) @@ -158,7 +165,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("GroupState - setTimeout - with EventTimeTimeout") { var state: GroupStateImpl[Int] = GroupStateImpl.createForStreaming( - None, 1000, 1000, EventTimeTimeout, false) + None, 1000, 1000, EventTimeTimeout, false, watermarkPresent = true) assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) @@ -180,7 +187,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testTimeoutDurationNotAllowed[UnsupportedOperationException](state) // for batch queries - state = GroupStateImpl.createForBatch(EventTimeTimeout).asInstanceOf[GroupStateImpl[Int]] + state = GroupStateImpl.createForBatch(EventTimeTimeout, watermarkPresent = false) + .asInstanceOf[GroupStateImpl[Int]] assert(state.getTimeoutTimestamp === NO_TIMESTAMP) testTimeoutDurationNotAllowed[UnsupportedOperationException](state) state.setTimeoutTimestamp(5000) @@ -207,7 +215,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } state = GroupStateImpl.createForStreaming( - Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false) + Some(5), 1000, 1000, ProcessingTimeTimeout, hasTimedOut = false, watermarkPresent = false) testIllegalTimeout { state.setTimeoutDuration(-1000) } @@ -225,7 +233,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } state = GroupStateImpl.createForStreaming( - Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false) + Some(5), 1000, 1000, EventTimeTimeout, hasTimedOut = false, watermarkPresent = false) testIllegalTimeout { state.setTimeoutTimestamp(-10000) } @@ -257,29 +265,92 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // for streaming queries for (initState <- Seq(None, Some(5))) { val state1 = GroupStateImpl.createForStreaming( - initState, 1000, 1000, timeoutConf, hasTimedOut = false) + initState, 1000, 1000, timeoutConf, hasTimedOut = false, watermarkPresent = false) assert(state1.hasTimedOut === false) val state2 = GroupStateImpl.createForStreaming( - initState, 1000, 1000, timeoutConf, hasTimedOut = true) + initState, 1000, 1000, timeoutConf, hasTimedOut = true, watermarkPresent = false) assert(state2.hasTimedOut === true) } // for batch queries - assert(GroupStateImpl.createForBatch(timeoutConf).hasTimedOut === false) + assert( + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent = false).hasTimedOut === false) + } + } + + test("GroupState - getCurrentWatermarkMs") { + def streamingState(timeoutConf: GroupStateTimeout, watermark: Option[Long]): GroupState[Int] = { + GroupStateImpl.createForStreaming( + None, 1000, watermark.getOrElse(-1), timeoutConf, + hasTimedOut = false, watermark.nonEmpty) + } + + def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) + } + + def assertWrongTimeoutError(test: => Unit): Unit = { + val e = intercept[UnsupportedOperationException] { test } + assert(e.getMessage.contains( + "Cannot get event time watermark timestamp without setting watermark")) + } + + for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { + // Tests for getCurrentWatermarkMs in streaming queries + assertWrongTimeoutError { streamingState(timeoutConf, None).getCurrentWatermarkMs() } + assert(streamingState(timeoutConf, Some(1000)).getCurrentWatermarkMs() === 1000) + assert(streamingState(timeoutConf, Some(2000)).getCurrentWatermarkMs() === 2000) + + // Tests for getCurrentWatermarkMs in batch queries + assertWrongTimeoutError { + batchState(timeoutConf, watermarkPresent = false).getCurrentWatermarkMs() + } + assert(batchState(timeoutConf, watermarkPresent = true).getCurrentWatermarkMs() === -1) } } + test("GroupState - getCurrentProcessingTimeMs") { + def streamingState( + timeoutConf: GroupStateTimeout, + procTime: Long, + watermarkPresent: Boolean): GroupState[Int] = { + GroupStateImpl.createForStreaming( + None, procTime, -1, timeoutConf, hasTimedOut = false, watermarkPresent = false) + } + + def batchState(timeoutConf: GroupStateTimeout, watermarkPresent: Boolean): GroupState[Any] = { + GroupStateImpl.createForBatch(timeoutConf, watermarkPresent) + } + + for (timeoutConf <- Seq(NoTimeout, EventTimeTimeout, ProcessingTimeTimeout)) { + for (watermarkPresent <- Seq(false, true)) { + // Tests for getCurrentProcessingTimeMs in streaming queries + assert(streamingState(timeoutConf, NO_TIMESTAMP, watermarkPresent) + .getCurrentProcessingTimeMs() === -1) + assert(streamingState(timeoutConf, 1000, watermarkPresent) + .getCurrentProcessingTimeMs() === 1000) + assert(streamingState(timeoutConf, 2000, watermarkPresent) + .getCurrentProcessingTimeMs() === 2000) + + // Tests for getCurrentProcessingTimeMs in batch queries + val currentTime = System.currentTimeMillis() + assert(batchState(timeoutConf, watermarkPresent).getCurrentProcessingTimeMs >= currentTime) + } + } + } + + test("GroupState - primitive type") { var intState = GroupStateImpl.createForStreaming[Int]( - None, 1000, 1000, NoTimeout, hasTimedOut = false) + None, 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false) intercept[NoSuchElementException] { intState.get } assert(intState.getOption === None) intState = GroupStateImpl.createForStreaming[Int]( - Some(10), 1000, 1000, NoTimeout, hasTimedOut = false) + Some(10), 1000, 1000, NoTimeout, hasTimedOut = false, watermarkPresent = false) assert(intState.get == 10) intState.update(0) assert(intState.get == 0) @@ -289,20 +360,24 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } } - // Values used for testing StateStoreUpdater + // Values used for testing InputProcessor val currentBatchTimestamp = 1000 val currentBatchWatermark = 1000 val beforeTimeoutThreshold = 999 val afterTimeoutThreshold = 1001 - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout = NoTimeout + // Tests for InputProcessor.processNewData() when timeout = NoTimeout for (priorState <- Seq(None, Some(0))) { val priorStateStr = if (priorState.nonEmpty) "prior state set" else "no prior state" val testName = s"NoTimeout - $priorStateStr - " testStateUpdateWithData( testName + "no update", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = GroupStateTimeout.NoTimeout, priorState = priorState, expectedState = priorState) // should not change @@ -322,7 +397,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = None) // should be removed } - // Tests for StateStoreUpdater.updateStateForKeysWithData() when timeout != NoTimeout + // Tests for InputProcessor.processTimedOutState() when timeout != NoTimeout for (priorState <- Seq(None, Some(0))) { for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { var testName = "" @@ -340,7 +415,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testStateUpdateWithData( s"$timeoutConf - $testName - no update", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = timeoutConf, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -365,6 +444,18 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = None) // state should be removed } + // Tests with ProcessingTimeTimeout + if (priorState == None) { + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated without initializing state", + stateUpdates = state => { state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + } + testStateUpdateWithData( s"ProcessingTimeTimeout - $testName - state and timeout duration updated", stateUpdates = @@ -375,10 +466,36 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState = Some(5), // state should change expectedTimeoutTimestamp = currentBatchTimestamp + 5000) // timestamp should change + testStateUpdateWithData( + s"ProcessingTimeTimeout - $testName - timeout updated after state removed", + stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, + timeoutConf = ProcessingTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = currentBatchTimestamp + 5000) + + // Tests with EventTimeTimeout + + if (priorState == None) { + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout without init state not allowed", + stateUpdates = state => { + state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = None, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } + testStateUpdateWithData( s"EventTimeTimeout - $testName - state and timeout timestamp updated", stateUpdates = - (state: GroupState[Int]) => { state.update(5); state.setTimeoutTimestamp(5000) }, + (state: GroupState[Int]) => { + state.update(5); state.setTimeoutTimestamp(5000) + }, timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, @@ -397,50 +514,23 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf timeoutConf = EventTimeTimeout, priorState = priorState, priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedState = Some(5), // state should change - expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - } - } + expectedState = Some(5), // state should change + expectedTimeoutTimestamp = NO_TIMESTAMP) // timestamp should not update - // Currently disallowed cases for StateStoreUpdater.updateStateForKeysWithData(), - // Try to remove these cases in the future - for (priorTimeoutTimestamp <- Seq(NO_TIMESTAMP, 1000)) { - val testName = - if (priorTimeoutTimestamp != NO_TIMESTAMP) "prior timeout set" else "no prior timeout" - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"ProcessingTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutDuration(5000) }, - timeoutConf = ProcessingTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout without init state not allowed", - stateUpdates = state => { state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = None, - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) - - testStateUpdateWithData( - s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", - stateUpdates = state => { state.remove(); state.setTimeoutTimestamp(10000) }, - timeoutConf = EventTimeTimeout, - priorState = Some(5), - priorTimeoutTimestamp = priorTimeoutTimestamp, - expectedException = classOf[IllegalStateException]) + testStateUpdateWithData( + s"EventTimeTimeout - $testName - setting timeout with state removal not allowed", + stateUpdates = state => { + state.remove(); state.setTimeoutTimestamp(10000) + }, + timeoutConf = EventTimeTimeout, + priorState = priorState, + priorTimeoutTimestamp = priorTimeoutTimestamp, + expectedState = None, + expectedTimeoutTimestamp = 10000) + } } - // Tests for StateStoreUpdater.updateStateForTimedOutKeys() + // Tests for InputProcessor.processTimedOutState() val preTimeoutState = Some(5) for (timeoutConf <- Seq(ProcessingTimeTimeout, EventTimeTimeout)) { testStateUpdateWithTimeout( @@ -453,7 +543,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf testStateUpdateWithTimeout( s"$timeoutConf - should timeout - no update/remove", - stateUpdates = state => { /* do nothing */ }, + stateUpdates = state => { + assert(state.getCurrentProcessingTimeMs() === currentBatchTimestamp) + intercept[Exception] { state.getCurrentWatermarkMs() } // watermark not specified + /* no updates */ + }, timeoutConf = timeoutConf, priorTimeoutTimestamp = beforeTimeoutThreshold, expectedState = preTimeoutState, // state should not change @@ -512,6 +606,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // Function to maintain running count up to 2, and then remove the count // Returns the data and the count if state is defined, otherwise does not return anything val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -533,6 +629,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf AddData(inputData, "a"), CheckLastBatch(("a", "1")), assertNumStateRows(total = 1, updated = 1), + AssertOnQuery(sq => checkChildOutputHashPartitioning[FlatMapGroupsWithStateExec]( + sq, Seq("value"))), AddData(inputData, "a", "b"), CheckLastBatch(("a", "2"), ("b", "1")), assertNumStateRows(total = 2, updated = 2), @@ -632,6 +730,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("flatMapGroupsWithState - batch") { // Function that returns running count only if its even, otherwise does not return val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() > 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.exists) throw new IllegalArgumentException("state.exists should be false") Iterator((key, values.size)) } @@ -645,6 +746,9 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.hasTimedOut) { state.remove() Iterator((key, "-1")) @@ -698,10 +802,10 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf test("flatMapGroupsWithState - streaming with event time timeout + watermark") { // Function to maintain the max event time // Returns the max event time in the state, or -1 if the state was removed by timeout - val stateFunc = ( - key: String, - values: Iterator[(String, Long)], - state: GroupState[Long]) => { + val stateFunc = (key: String, values: Iterator[(String, Long)], state: GroupState[Long]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCanGetWatermark { state.getCurrentWatermarkMs() >= -1 } + val timeoutDelay = 5 if (key != "a") { Iterator.empty @@ -745,6 +849,8 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // Function to maintain running count up to 2, and then remove the count // Returns the data and the count (-1 if count reached beyond 2 and state was just removed) val stateFunc = (key: String, values: Iterator[String], state: GroupState[RunningCount]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() >= 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } val count = state.getOption.map(_.count).getOrElse(0L) + values.size if (count == 3) { @@ -787,7 +893,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf // - no initial state // - timeouts operations work, does not throw any error [SPARK-20792] // - works with primitive state type + // - can get processing time val stateFunc = (key: String, values: Iterator[String], state: GroupState[Int]) => { + assertCanGetProcessingTime { state.getCurrentProcessingTimeMs() > 0 } + assertCannotGetWatermark { state.getCurrentWatermarkMs() } + if (state.exists) throw new IllegalArgumentException("state.exists should be false") state.setTimeoutTimestamp(0, "1 hour") state.update(10) @@ -924,7 +1034,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf if (priorState.isEmpty && priorTimeoutTimestamp != NO_TIMESTAMP) { return // there can be no prior timestamp, when there is no prior state } - test(s"StateStoreUpdater - updates with data - $testName") { + test(s"InputProcessor - process new data - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === false, "hasTimedOut not false") assert(values.nonEmpty, "Some value is expected") @@ -946,7 +1056,7 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf expectedState: Option[Int], expectedTimeoutTimestamp: Long = NO_TIMESTAMP): Unit = { - test(s"StateStoreUpdater - updates for timeout - $testName") { + test(s"InputProcessor - process timed out state - $testName") { val mapGroupsFunc = (key: Int, values: Iterator[Int], state: GroupState[Int]) => { assert(state.hasTimedOut === true, "hasTimedOut not true") assert(values.isEmpty, "values not empty") @@ -973,21 +1083,20 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf val store = newStateStore() val mapGroupsSparkPlan = newFlatMapGroupsWithStateExec( mapGroupsFunc, timeoutConf, currentBatchTimestamp) - val updater = new mapGroupsSparkPlan.StateStoreUpdater(store) + val inputProcessor = new mapGroupsSparkPlan.InputProcessor(store) + val stateManager = mapGroupsSparkPlan.stateManager val key = intToRow(0) // Prepare store with prior state configs - if (priorState.nonEmpty) { - val row = updater.getStateRow(priorState.get) - updater.setTimeoutTimestamp(row, priorTimeoutTimestamp) - store.put(key.copy(), row.copy()) + if (priorState.nonEmpty || priorTimeoutTimestamp != NO_TIMESTAMP) { + stateManager.putState(store, key, priorState.orNull, priorTimeoutTimestamp) } // Call updating function to update state store def callFunction() = { val returnedIter = if (testTimeoutUpdates) { - updater.updateStateForTimedOutKeys() + inputProcessor.processTimedOutState() } else { - updater.updateStateForKeysWithData(Iterator(key)) + inputProcessor.processNewData(Iterator(key)) } returnedIter.size // consume the iterator to force state updates } @@ -998,15 +1107,11 @@ class FlatMapGroupsWithStateSuite extends StateStoreMetricsTest with BeforeAndAf } else { // Call function to update and verify updated state in store callFunction() - val updatedStateRow = store.get(key) - assert( - Option(updater.getStateObj(updatedStateRow)).map(_.toString.toInt) === expectedState, + val updatedState = stateManager.getState(store, key) + assert(Option(updatedState.stateObj).map(_.toString.toInt) === expectedState, "final state not as expected") - if (updatedStateRow != null) { - assert( - updater.getTimeoutTimestamp(updatedStateRow) === expectedTimeoutTimestamp, - "final timeout timestamp not as expected") - } + assert(updatedState.timeoutTimestamp === expectedTimeoutTimestamp, + "final timeout timestamp not as expected") } } @@ -1080,4 +1185,24 @@ object FlatMapGroupsWithStateSuite { override def metrics: StateStoreMetrics = new StateStoreMetrics(map.size, 0, Map.empty) override def hasCommitted: Boolean = true } + + def assertCanGetProcessingTime(predicate: => Boolean): Unit = { + if (!predicate) throw new TestFailedException("Could not get processing time", 20) + } + + def assertCanGetWatermark(predicate: => Boolean): Unit = { + if (!predicate) throw new TestFailedException("Could not get processing time", 20) + } + + def assertCannotGetWatermark(func: => Unit): Unit = { + try { + func + } catch { + case u: UnsupportedOperationException => + return + case _: Throwable => + throw new TestFailedException("Unexpected exception when trying to get watermark", 20) + } + throw new TestFailedException("Could get watermark when not expected", 20) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala new file mode 100644 index 0000000000000..45142278993bb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StatefulOperatorTest.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.streaming._ + +trait StatefulOperatorTest { + /** + * Check that the output partitioning of a child operator of a Stateful operator satisfies the + * distribution that we expect for our Stateful operator. + */ + protected def checkChildOutputHashPartitioning[T <: StatefulOperator]( + sq: StreamingQuery, + colNames: Seq[String]): Boolean = { + val attr = sq.asInstanceOf[StreamExecution].lastExecution.analyzed.output + val partitions = sq.sparkSession.sessionState.conf.numShufflePartitions + val groupingAttr = attr.filter(a => colNames.contains(a.name)) + checkChildOutputPartitioning(sq, HashPartitioning(groupingAttr, partitions)) + } + + /** + * Check that the output partitioning of a child operator of a Stateful operator satisfies the + * distribution that we expect for our Stateful operator. + */ + protected def checkChildOutputPartitioning[T <: StatefulOperator]( + sq: StreamingQuery, + expectedPartitioning: Partitioning): Boolean = { + val operator = sq.asInstanceOf[StreamExecution].lastExecution + .executedPlan.collect { case p: T => p } + operator.head.children.forall( + _.outputPartitioning.numPartitions == expectedPartitioning.numPartitions) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 9c901062d570a..3d687d2214e90 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -76,20 +76,65 @@ class StreamSuite extends StreamTest { CheckAnswer(Row(1, 1, "one"), Row(2, 2, "two"), Row(4, 4, "four"))) } + test("StreamingRelation.computeStats") { + val streamingRelation = spark.readStream.format("rate").load().logicalPlan collect { + case s: StreamingRelation => s + } + assert(streamingRelation.nonEmpty, "cannot find StreamingRelation") + assert( + streamingRelation.head.computeStats.sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) + } - test("explain join") { - // Make a table and ensure it will be broadcast. - val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + test("StreamingExecutionRelation.computeStats") { + val streamingExecutionRelation = MemoryStream[Int].toDF.logicalPlan collect { + case s: StreamingExecutionRelation => s + } + assert(streamingExecutionRelation.nonEmpty, "cannot find StreamingExecutionRelation") + assert(streamingExecutionRelation.head.computeStats.sizeInBytes + == spark.sessionState.conf.defaultSizeInBytes) + } - // Join the input stream with a table. - val inputData = MemoryStream[Int] - val joined = inputData.toDF().join(smallTable, smallTable("number") === $"value") + test("explain join with a normal source") { + // This test triggers CostBasedJoinReorder to call `computeStats` + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable2 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable3 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val df = spark.readStream.format("rate").load() + val joined = df.join(smallTable, smallTable("number") === $"value") + .join(smallTable2, smallTable2("number") === $"value") + .join(smallTable3, smallTable3("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain(true) + } + assert(outputStream.toString.contains("StreamingRelation")) + } + } - val outputStream = new java.io.ByteArrayOutputStream() - Console.withOut(outputStream) { - joined.explain() + test("explain join with MemoryStream") { + // This test triggers CostBasedJoinReorder to call `computeStats` + // Because MemoryStream doesn't use DataSource code path, we need a separate test. + withSQLConf(SQLConf.CBO_ENABLED.key -> "true", SQLConf.JOIN_REORDER_ENABLED.key -> "true") { + val smallTable = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable2 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + val smallTable3 = Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word") + + // Join the input stream with a table. + val df = MemoryStream[Int].toDF + val joined = df.join(smallTable, smallTable("number") === $"value") + .join(smallTable2, smallTable2("number") === $"value") + .join(smallTable3, smallTable3("number") === $"value") + + val outputStream = new java.io.ByteArrayOutputStream() + Console.withOut(outputStream) { + joined.explain(true) + } + assert(outputStream.toString.contains("StreamingRelation")) } - assert(outputStream.toString.contains("StreamingRelation")) } test("SPARK-20432: union one stream with itself") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 995cea3b37d4f..1b4d8556f6ae5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -44,7 +44,7 @@ object FailureSingleton { } class StreamingAggregationSuite extends StateStoreMetricsTest - with BeforeAndAfterAll with Assertions { + with BeforeAndAfterAll with Assertions with StatefulOperatorTest { override def afterAll(): Unit = { super.afterAll() @@ -281,6 +281,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest AddData(inputData, 0L, 5L, 5L, 10L), AdvanceManualClock(10 * 1000), CheckLastBatch((0L, 1), (5L, 2), (10L, 1)), + AssertOnQuery(sq => + checkChildOutputHashPartitioning[StateStoreRestoreExec](sq, Seq("value"))), // advance clock to 20 seconds, should retain keys >= 10 AddData(inputData, 15L, 15L, 20L), @@ -455,8 +457,8 @@ class StreamingAggregationSuite extends StateStoreMetricsTest }, AddBlockData(inputSource), // create an empty trigger CheckLastBatch(1), - AssertOnQuery("Verify addition of exchange operator") { se => - checkAggregationChain(se, expectShuffling = true, 1) + AssertOnQuery("Verify that no exchange is required") { se => + checkAggregationChain(se, expectShuffling = false, 1) }, AddBlockData(inputSource, Seq(2, 3)), CheckLastBatch(3), @@ -520,6 +522,22 @@ class StreamingAggregationSuite extends StateStoreMetricsTest } } + test("SPARK-22230: last should change with new batches") { + val input = MemoryStream[Int] + + val aggregated = input.toDF().agg(last('value)) + testStream(aggregated, OutputMode.Complete())( + AddData(input, 1, 2, 3), + CheckLastBatch(3), + AddData(input, 4, 5, 6), + CheckLastBatch(6), + AddData(input), + CheckLastBatch(6), + AddData(input, 0), + CheckLastBatch(0) + ) + } + /** Add blocks of data to the `BlockRDDBackedSource`. */ case class AddBlockData(source: BlockRDDBackedSource, data: Seq[Int]*) extends AddData { override def addData(query: Option[StreamExecution]): (Source, Offset) = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala index 533e1165fd59c..54eb863dacc83 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingJoinSuite.scala @@ -24,8 +24,9 @@ import scala.util.Random import org.scalatest.BeforeAndAfter import org.apache.spark.scheduler.ExecutorCacheTaskLocation -import org.apache.spark.sql.{AnalysisException, SparkSession} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SparkSession} +import org.apache.spark.sql.catalyst.analysis.StreamingJoinHelper +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Literal} import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, Filter} import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.streaming.{MemoryStream, StatefulOperatorStateInfo, StreamingSymmetricHashJoinHelper} @@ -35,7 +36,7 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.Utils -class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { +class StreamingInnerJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { before { SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' @@ -322,111 +323,6 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo assert(e.toString.contains("Stream stream joins without equality predicate is not supported")) } - testQuietly("extract watermark from time condition") { - val attributesToFindConstraintFor = Seq( - AttributeReference("leftTime", TimestampType)(), - AttributeReference("leftOther", IntegerType)()) - val metadataWithWatermark = new MetadataBuilder() - .putLong(EventTimeWatermark.delayKey, 1000) - .build() - val attributesWithWatermark = Seq( - AttributeReference("rightTime", TimestampType, metadata = metadataWithWatermark)(), - AttributeReference("rightOther", IntegerType)()) - - def watermarkFrom( - conditionStr: String, - rightWatermark: Option[Long] = Some(10000)): Option[Long] = { - val conditionExpr = Some(conditionStr).map { str => - val plan = - Filter( - spark.sessionState.sqlParser.parseExpression(str), - LogicalRDD( - attributesToFindConstraintFor ++ attributesWithWatermark, - spark.sparkContext.emptyRDD)(spark)) - plan.queryExecution.optimizedPlan.asInstanceOf[Filter].condition - } - StreamingSymmetricHashJoinHelper.getStateValueWatermark( - AttributeSet(attributesToFindConstraintFor), AttributeSet(attributesWithWatermark), - conditionExpr, rightWatermark) - } - - // Test comparison directionality. E.g. if leftTime < rightTime and rightTime > watermark, - // then cannot define constraint on leftTime. - assert(watermarkFrom("leftTime > rightTime") === Some(10000)) - assert(watermarkFrom("leftTime >= rightTime") === Some(9999)) - assert(watermarkFrom("leftTime < rightTime") === None) - assert(watermarkFrom("leftTime <= rightTime") === None) - assert(watermarkFrom("rightTime > leftTime") === None) - assert(watermarkFrom("rightTime >= leftTime") === None) - assert(watermarkFrom("rightTime < leftTime") === Some(10000)) - assert(watermarkFrom("rightTime <= leftTime") === Some(9999)) - - // Test type conversions - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) < CAST(rightTime AS LONG)") === None) - assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS DOUBLE)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS DOUBLE)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS FLOAT)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS DOUBLE) > CAST(rightTime AS FLOAT)") === Some(10000)) - assert(watermarkFrom("CAST(leftTime AS STRING) > CAST(rightTime AS STRING)") === None) - - // Test with timestamp type + calendar interval on either side of equation - // Note: timestamptype and calendar interval don't commute, so less valid combinations to test. - assert(watermarkFrom("leftTime > rightTime + interval 1 second") === Some(11000)) - assert(watermarkFrom("leftTime + interval 2 seconds > rightTime ") === Some(8000)) - assert(watermarkFrom("leftTime > rightTime - interval 3 second") === Some(7000)) - assert(watermarkFrom("rightTime < leftTime - interval 3 second") === Some(13000)) - assert(watermarkFrom("rightTime - interval 1 second < leftTime - interval 3 second") - === Some(12000)) - - // Test with casted long type + constants on either side of equation - // Note: long type and constants commute, so more combinations to test. - // -- Constants on the right - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) + 1") === Some(11000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST(rightTime AS LONG) - 1") === Some(9000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > CAST((rightTime + interval 1 second) AS LONG)") - === Some(11000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > 2 + CAST(rightTime AS LONG)") === Some(12000)) - assert(watermarkFrom("CAST(leftTime AS LONG) > -0.5 + CAST(rightTime AS LONG)") === Some(9500)) - assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) > 2") === Some(12000)) - assert(watermarkFrom("-CAST(rightTime AS DOUBLE) + CAST(leftTime AS LONG) > 0.1") - === Some(10100)) - assert(watermarkFrom("0 > CAST(rightTime AS LONG) - CAST(leftTime AS LONG) + 0.2") - === Some(10200)) - // -- Constants on the left - assert(watermarkFrom("CAST(leftTime AS LONG) + 2 > CAST(rightTime AS LONG)") === Some(8000)) - assert(watermarkFrom("1 + CAST(leftTime AS LONG) > CAST(rightTime AS LONG)") === Some(9000)) - assert(watermarkFrom("CAST((leftTime + interval 3 second) AS LONG) > CAST(rightTime AS LONG)") - === Some(7000)) - assert(watermarkFrom("CAST(leftTime AS LONG) - 2 > CAST(rightTime AS LONG)") === Some(12000)) - assert(watermarkFrom("CAST(leftTime AS LONG) + 0.5 > CAST(rightTime AS LONG)") === Some(9500)) - assert(watermarkFrom("CAST(leftTime AS LONG) - CAST(rightTime AS LONG) - 2 > 0") - === Some(12000)) - assert(watermarkFrom("-CAST(rightTime AS LONG) + CAST(leftTime AS LONG) - 0.1 > 0") - === Some(10100)) - // -- Constants on both sides, mixed types - assert(watermarkFrom("CAST(leftTime AS LONG) - 2.0 > CAST(rightTime AS LONG) + 1") - === Some(13000)) - - // Test multiple conditions, should return minimum watermark - assert(watermarkFrom( - "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 2 seconds") === - Some(7000)) // first condition wins - assert(watermarkFrom( - "leftTime > rightTime - interval 3 second AND rightTime < leftTime + interval 4 seconds") === - Some(6000)) // second condition wins - - // Test invalid comparisons - assert(watermarkFrom("cast(leftTime AS LONG) > leftOther") === None) // non-time attributes - assert(watermarkFrom("leftOther > rightOther") === None) // non-time attributes - assert(watermarkFrom("leftOther > rightOther AND leftTime > rightTime") === Some(10000)) - assert(watermarkFrom("cast(rightTime AS DOUBLE) < rightOther") === None) // non-time attributes - assert(watermarkFrom("leftTime > rightTime + interval 1 month") === None) // month not allowed - - // Test static comparisons - assert(watermarkFrom("cast(leftTime AS LONG) > 10") === Some(10000)) - } - test("locality preferences of StateStoreAwareZippedRDD") { import StreamingSymmetricHashJoinHelper._ @@ -434,7 +330,7 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo val queryId = UUID.randomUUID val opId = 0 val path = Utils.createDirectory(tempDir.getAbsolutePath, Random.nextString(10)).toString - val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L) + val stateInfo = StatefulOperatorStateInfo(path, queryId, opId, 0L, 5) implicit val sqlContext = spark.sqlContext val coordinatorRef = sqlContext.streams.stateStoreCoordinator @@ -469,4 +365,332 @@ class StreamingJoinSuite extends StreamTest with StateStoreMetricsTest with Befo } } } + + test("join between three streams") { + val input1 = MemoryStream[Int] + val input2 = MemoryStream[Int] + val input3 = MemoryStream[Int] + + val df1 = input1.toDF.select('value as "leftKey", ('value * 2) as "leftValue") + val df2 = input2.toDF.select('value as "middleKey", ('value * 3) as "middleValue") + val df3 = input3.toDF.select('value as "rightKey", ('value * 5) as "rightValue") + + val joined = df1.join(df2, expr("leftKey = middleKey")).join(df3, expr("rightKey = middleKey")) + + testStream(joined)( + AddData(input1, 1, 5), + AddData(input2, 1, 5, 10), + AddData(input3, 5, 10), + CheckLastBatch((5, 10, 5, 15, 5, 25))) + } +} + +class StreamingOuterJoinSuite extends StreamTest with StateStoreMetricsTest with BeforeAndAfter { + + import testImplicits._ + import org.apache.spark.sql.functions._ + + before { + SparkSession.setActiveSession(spark) // set this before force initializing 'joinExec' + spark.streams.stateStoreCoordinator // initialize the lazy coordinator + } + + after { + StateStore.stop() + } + + private def setupStream(prefix: String, multiplier: Int): (MemoryStream[Int], DataFrame) = { + val input = MemoryStream[Int] + val df = input.toDF + .select( + 'value as "key", + 'value.cast("timestamp") as s"${prefix}Time", + ('value * multiplier) as s"${prefix}Value") + .withWatermark(s"${prefix}Time", "10 seconds") + + return (input, df) + } + + private def setupWindowedJoin(joinType: String): + (MemoryStream[Int], MemoryStream[Int], DataFrame) = { + val (input1, df1) = setupStream("left", 2) + val (input2, df2) = setupStream("right", 3) + val windowed1 = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val windowed2 = df2.select('key, window('rightTime, "10 second"), 'rightValue) + val joined = windowed1.join(windowed2, Seq("key", "window"), joinType) + .select('key, $"window.end".cast("long"), 'leftValue, 'rightValue) + + (input1, input2, joined) + } + + test("left outer early state exclusion on left") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'leftValue > 4, + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 3, 4, 5), + // The left rows with leftValue <= 4 should generate their outer join row now and + // not get added to the state. + CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 4, updated = 4), + // We shouldn't get more outer join rows when the watermark advances. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch((20, 30, 40, "60")) + ) + } + + test("left outer early state exclusion on right") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'rightValue.cast("int") > 7, + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 3, 4, 5), + AddData(rightInput, 1, 2, 3), + // The right rows with value <= 7 should never be added to the state. + CheckLastBatch(Row(3, 10, 6, "9")), + assertNumStateRows(total = 4, updated = 4), + // When the watermark advances, we get the outer join rows just as we would if they + // were added but didn't match the full join condition. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, 8, null), Row(5, 10, 10, null)) + ) + } + + test("right outer early state exclusion on left") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'leftValue > 4, + "right_outer") + .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 3, 4, 5), + // The left rows with value <= 4 should never be added to the state. + CheckLastBatch(Row(3, 10, 6, "9")), + assertNumStateRows(total = 4, updated = 4), + // When the watermark advances, we get the outer join rows just as we would if they + // were added but didn't match the full join condition. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch(Row(20, 30, 40, "60"), Row(4, 10, null, "12"), Row(5, 10, null, "15")) + ) + } + + test("right outer early state exclusion on right") { + val (leftInput, df1) = setupStream("left", 2) + val (rightInput, df2) = setupStream("right", 3) + // Use different schemas to ensure the null row is being generated from the correct side. + val left = df1.select('key, window('leftTime, "10 second"), 'leftValue) + val right = df2.select('key, window('rightTime, "10 second"), 'rightValue.cast("string")) + + val joined = left.join( + right, + left("key") === right("key") + && left("window") === right("window") + && 'rightValue.cast("int") > 7, + "right_outer") + .select(right("key"), right("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + AddData(leftInput, 3, 4, 5), + AddData(rightInput, 1, 2, 3), + // The right rows with rightValue <= 7 should generate their outer join row now and + // not get added to the state. + CheckLastBatch(Row(3, 10, 6, "9"), Row(1, 10, null, "3"), Row(2, 10, null, "6")), + assertNumStateRows(total = 4, updated = 4), + // We shouldn't get more outer join rows when the watermark advances. + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + AddData(rightInput, 20), + CheckLastBatch((20, 30, 40, "60")) + ) + } + + test("windowed left outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("left_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 2), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(1, 10, 2, null), Row(2, 10, 4, null)), + assertNumStateRows(total = 3, updated = 1) + ) + } + + test("windowed right outer join") { + val (leftInput, rightInput, joined) = setupWindowedJoin("right_outer") + + testStream(joined)( + // Test inner part of the join. + AddData(leftInput, 1, 2, 3, 4, 5), + AddData(rightInput, 3, 4, 5, 6, 7), + CheckLastBatch((3, 10, 6, 9), (4, 10, 8, 12), (5, 10, 10, 15)), + // Old state doesn't get dropped until the batch *after* it gets introduced, so the + // nulls won't show up until the next batch after the watermark advances. + AddData(leftInput, 21), + AddData(rightInput, 22), + CheckLastBatch(), + assertNumStateRows(total = 12, updated = 2), + AddData(leftInput, 22), + CheckLastBatch(Row(22, 30, 44, 66), Row(6, 10, null, 18), Row(7, 10, null, 21)), + assertNumStateRows(total = 3, updated = 1) + ) + } + + Seq( + ("left_outer", Row(3, null, 5, null)), + ("right_outer", Row(null, 2, null, 5)) + ).foreach { case (joinType: String, outerResult) => + test(s"${joinType.replaceAllLiterally("_", " ")} with watermark range condition") { + import org.apache.spark.sql.functions._ + + val leftInput = MemoryStream[(Int, Int)] + val rightInput = MemoryStream[(Int, Int)] + + val df1 = leftInput.toDF.toDF("leftKey", "time") + .select('leftKey, 'time.cast("timestamp") as "leftTime", ('leftKey * 2) as "leftValue") + .withWatermark("leftTime", "10 seconds") + + val df2 = rightInput.toDF.toDF("rightKey", "time") + .select('rightKey, 'time.cast("timestamp") as "rightTime", ('rightKey * 3) as "rightValue") + .withWatermark("rightTime", "10 seconds") + + val joined = + df1.join( + df2, + expr("leftKey = rightKey AND " + + "leftTime BETWEEN rightTime - interval 5 seconds AND rightTime + interval 5 seconds"), + joinType) + .select('leftKey, 'rightKey, 'leftTime.cast("int"), 'rightTime.cast("int")) + testStream(joined)( + AddData(leftInput, (1, 5), (3, 5)), + CheckAnswer(), + AddData(rightInput, (1, 10), (2, 5)), + CheckLastBatch((1, 1, 5, 10)), + AddData(rightInput, (1, 11)), + CheckLastBatch(), // no match as left time is too low + assertNumStateRows(total = 5, updated = 1), + + // Increase event time watermark to 20s by adding data with time = 30s on both inputs + AddData(leftInput, (1, 7), (1, 30)), + CheckLastBatch((1, 1, 7, 10), (1, 1, 7, 11)), + assertNumStateRows(total = 7, updated = 2), + AddData(rightInput, (0, 30)), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 1), + AddData(rightInput, (0, 30)), + CheckLastBatch(outerResult), + assertNumStateRows(total = 3, updated = 1) + ) + } + } + + // When the join condition isn't true, the outer null rows must be generated, even if the join + // keys themselves have a match. + test("left outer join with non-key condition violated") { + val (leftInput, simpleLeftDf) = setupStream("left", 2) + val (rightInput, simpleRightDf) = setupStream("right", 3) + + val left = simpleLeftDf.select('key, window('leftTime, "10 second"), 'leftValue) + val right = simpleRightDf.select('key, window('rightTime, "10 second"), 'rightValue) + + val joined = left.join( + right, + left("key") === right("key") && left("window") === right("window") && + 'leftValue > 10 && ('rightValue < 300 || 'rightValue > 1000), + "left_outer") + .select(left("key"), left("window.end").cast("long"), 'leftValue, 'rightValue) + + testStream(joined)( + // leftValue <= 10 should generate outer join rows even though it matches right keys + AddData(leftInput, 1, 2, 3), + AddData(rightInput, 1, 2, 3), + CheckLastBatch(Row(1, 10, 2, null), Row(2, 10, 4, null), Row(3, 10, 6, null)), + AddData(leftInput, 20), + AddData(rightInput, 21), + CheckLastBatch(), + assertNumStateRows(total = 5, updated = 2), + AddData(rightInput, 20), + CheckLastBatch( + Row(20, 30, 40, 60)), + assertNumStateRows(total = 3, updated = 1), + // leftValue and rightValue both satisfying condition should not generate outer join rows + AddData(leftInput, 40, 41), + AddData(rightInput, 40, 41), + CheckLastBatch((40, 50, 80, 120), (41, 50, 82, 123)), + AddData(leftInput, 70), + AddData(rightInput, 71), + CheckLastBatch(), + assertNumStateRows(total = 6, updated = 2), + AddData(rightInput, 70), + CheckLastBatch((70, 80, 140, 210)), + assertNumStateRows(total = 3, updated = 1), + // rightValue between 300 and 1000 should generate outer join rows even though it matches left + AddData(leftInput, 101, 102, 103), + AddData(rightInput, 101, 102, 103), + CheckLastBatch(), + AddData(leftInput, 1000), + AddData(rightInput, 1001), + CheckLastBatch(), + assertNumStateRows(total = 8, updated = 2), + AddData(rightInput, 1000), + CheckLastBatch( + Row(1000, 1010, 2000, 3000), + Row(101, 110, 202, null), + Row(102, 110, 204, null), + Row(103, 110, 206, null)), + assertNumStateRows(total = 3, updated = 1) + ) + } } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index ab35079dca23f..cc693909270f8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -652,6 +652,19 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi } } + test("SPARK-22238: don't check for RDD partitions during streaming aggregation preparation") { + val stream = MemoryStream[(Int, Int)] + val baseDf = Seq((1, "A"), (2, "b")).toDF("num", "char").where("char = 'A'") + val otherDf = stream.toDF().toDF("num", "numSq") + .join(broadcast(baseDf), "num") + .groupBy('char) + .agg(sum('numSq)) + + testStream(otherDf, OutputMode.Complete())( + AddData(stream, (1, 1), (2, 4)), + CheckLastBatch(("A", 1))) + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) @@ -731,7 +744,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi assert(returnedValue === expectedReturnValue, "Returned value does not match expected") } } - AwaitTerminationTester.test(expectedBehavior, awaitTermFunc) + AwaitTerminationTester.test(expectedBehavior, () => awaitTermFunc()) true // If the control reached here, then everything worked as expected } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala new file mode 100644 index 0000000000000..2a854e37bf0df --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingSymmetricHashJoinHelperSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.sql.Column +import org.apache.spark.sql.catalyst.analysis.SimpleAnalyzer +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.execution.{LeafExecNode, LocalTableScanExec, SparkPlan} +import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec +import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.JoinConditionSplitPredicates +import org.apache.spark.sql.types._ + +class StreamingSymmetricHashJoinHelperSuite extends StreamTest { + import org.apache.spark.sql.functions._ + + val leftAttributeA = AttributeReference("a", IntegerType)() + val leftAttributeB = AttributeReference("b", IntegerType)() + val rightAttributeC = AttributeReference("c", IntegerType)() + val rightAttributeD = AttributeReference("d", IntegerType)() + val leftColA = new Column(leftAttributeA) + val leftColB = new Column(leftAttributeB) + val rightColC = new Column(rightAttributeC) + val rightColD = new Column(rightAttributeD) + + val left = new LocalTableScanExec(Seq(leftAttributeA, leftAttributeB), Seq()) + val right = new LocalTableScanExec(Seq(rightAttributeC, rightAttributeD), Seq()) + + test("empty") { + val split = JoinConditionSplitPredicates(None, left, right) + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.isEmpty) + assert(split.full.isEmpty) + } + + test("only literals") { + // Literal-only conjuncts end up on the left side because that's the first bucket they fit in. + // There's no semantic reason they couldn't be in any bucket. + val predicate = (lit(1) < lit(5) && lit(6) < lit(7) && lit(0) === lit(-1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains(predicate)) + assert(split.rightSideOnly.contains(predicate)) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("only left") { + val predicate = (leftColA > lit(1) && leftColB > lit(5) && leftColA < leftColB).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains(predicate)) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("only right") { + val predicate = (rightColC > lit(1) && rightColD > lit(5) && rightColD < rightColC).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.contains(predicate)) + assert(split.bothSides.isEmpty) + assert(split.full.contains(predicate)) + } + + test("mixed conjuncts") { + val predicate = + (leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC).expr)) + assert(split.full.contains(predicate)) + } + + test("conjuncts after nondeterministic") { + // All conjuncts after a nondeterministic conjunct shouldn't be split because they don't + // commute across it. + val predicate = + (rand() > lit(0) + && leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.isEmpty) + assert(split.rightSideOnly.isEmpty) + assert(split.bothSides.contains(predicate)) + assert(split.full.contains(predicate)) + } + + + test("conjuncts before nondeterministic") { + val randCol = rand() + val predicate = + (leftColA > leftColB + && rightColC > rightColD + && leftColA === rightColC + && lit(1) === lit(1) + && randCol > lit(0)).expr + val split = JoinConditionSplitPredicates(Some(predicate), left, right) + + assert(split.leftSideOnly.contains((leftColA > leftColB && lit(1) === lit(1)).expr)) + assert(split.rightSideOnly.contains((rightColC > rightColD && lit(1) === lit(1)).expr)) + assert(split.bothSides.contains((leftColA === rightColC && randCol > lit(0)).expr)) + assert(split.full.contains(predicate)) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index 569bac156b531..a5d7e6257a6df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -21,10 +21,14 @@ import java.io.File import java.util.Locale import java.util.concurrent.ConcurrentLinkedQueue +import scala.collection.JavaConverters._ + import org.scalatest.BeforeAndAfter +import org.apache.spark.SparkContext import org.apache.spark.internal.io.FileCommitProtocol.TaskCommitMessage import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol +import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.internal.SQLConf @@ -775,4 +779,31 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } } } + + test("use Spark jobs to list files") { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "1") { + withTempDir { dir => + val jobDescriptions = new ConcurrentLinkedQueue[String]() + val jobListener = new SparkListener { + override def onJobStart(jobStart: SparkListenerJobStart): Unit = { + jobDescriptions.add(jobStart.properties.getProperty(SparkContext.SPARK_JOB_DESCRIPTION)) + } + } + sparkContext.addSparkListener(jobListener) + try { + spark.range(0, 3).map(i => (i, i)) + .write.partitionBy("_1").mode("overwrite").parquet(dir.getCanonicalPath) + // normal file paths + checkDatasetUnorderly( + spark.read.parquet(dir.getCanonicalPath).as[(Long, Long)], + 0L -> 0L, 1L -> 1L, 2L -> 2L) + sparkContext.listenerBus.waitUntilEmpty(10000) + assert(jobDescriptions.asScala.toList.exists( + _.contains("Listing leaf files and directories for 3 paths"))) + } finally { + sparkContext.removeSparkListener(jobListener) + } + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala new file mode 100644 index 0000000000000..6179585a0d39a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFlatSpecSuite.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FlatSpec + +/** + * The purpose of this suite is to make sure that generic FlatSpec-based scala + * tests work with a shared spark session + */ +class GenericFlatSpecSuite extends FlatSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" should "have the specified number of elements" in { + assert(8 === ds.count) + } + it should "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + it should "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it should "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala new file mode 100644 index 0000000000000..15139ee8b3047 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericFunSpecSuite.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.FunSpec + +/** + * The purpose of this suite is to make sure that generic FunSpec-based scala + * tests work with a shared spark session + */ +class GenericFunSpecSuite extends FunSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + describe("Simple Dataset") { + it("should have the specified number of elements") { + assert(8 === ds.count) + } + it("should have the specified number of unique elements") { + assert(8 === ds.distinct.count) + } + it("should have the specified number of elements in each column") { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + it("should have the correct number of distinct elements in each column") { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala new file mode 100644 index 0000000000000..b6548bf95fec8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/GenericWordSpecSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import org.scalatest.WordSpec + +/** + * The purpose of this suite is to make sure that generic WordSpec-based scala + * tests work with a shared spark session + */ +class GenericWordSpecSuite extends WordSpec with SharedSparkSession { + import testImplicits._ + initializeSession() + val ds = Seq((1, 1), (2, 1), (3, 2), (4, 2), (5, 3), (6, 3), (7, 4), (8, 4)).toDS + + "A Simple Dataset" when { + "looked at as complete rows" should { + "have the specified number of elements" in { + assert(8 === ds.count) + } + "have the specified number of unique elements" in { + assert(8 === ds.distinct.count) + } + } + "refined to specific columns" should { + "have the specified number of elements in each column" in { + assert(8 === ds.select("_1").count) + assert(8 === ds.select("_2").count) + } + "have the correct number of distinct elements in each column" in { + assert(8 === ds.select("_1").distinct.count) + assert(4 === ds.select("_2").distinct.count) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index a14a1441a4313..b4248b74f50ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,7 +27,7 @@ import scala.language.implicitConversions import scala.util.control.NonFatal import org.apache.hadoop.fs.Path -import org.scalatest.BeforeAndAfterAll +import org.scalatest.{BeforeAndAfterAll, Suite} import org.scalatest.concurrent.Eventually import org.apache.spark.SparkFunSuite @@ -36,14 +36,17 @@ import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchTableException import org.apache.spark.sql.catalyst.catalog.SessionCatalog.DEFAULT_DATABASE import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.PlanTestBase import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.FilterExec import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.util.{UninterruptibleThread, Utils} +import org.apache.spark.util.UninterruptibleThread +import org.apache.spark.util.Utils /** - * Helper trait that should be extended by all SQL test suites. + * Helper trait that should be extended by all SQL test suites within the Spark + * code base. * * This allows subclasses to plugin a custom `SQLContext`. It comes with test data * prepared in advance as well as all implicit conversions used extensively by dataframes. @@ -52,17 +55,99 @@ import org.apache.spark.util.{UninterruptibleThread, Utils} * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. */ -private[sql] trait SQLTestUtils - extends SparkFunSuite with Eventually +private[sql] trait SQLTestUtils extends SparkFunSuite with SQLTestUtilsBase with PlanTest { + // Whether to materialize all test data before the first test is run + private var loadTestDataBeforeTests = false + + protected override def beforeAll(): Unit = { + super.beforeAll() + if (loadTestDataBeforeTests) { + loadTestData() + } + } + + /** + * Materialize the test data immediately after the `SQLContext` is set up. + * This is necessary if the data is accessed by name but not through direct reference. + */ + protected def setupTestData(): Unit = { + loadTestDataBeforeTests = true + } + + /** + * Disable stdout and stderr when running the test. To not output the logs to the console, + * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of + * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if + * we change System.out and System.err. + */ + protected def testQuietly(name: String)(f: => Unit): Unit = { + test(name) { + quietly { + f + } + } + } + + /** + * Run a test on a separate `UninterruptibleThread`. + */ + protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) + (body: => Unit): Unit = { + val timeoutMillis = 10000 + @transient var ex: Throwable = null + + def runOnThread(): Unit = { + val thread = new UninterruptibleThread(s"Testing thread for test $name") { + override def run(): Unit = { + try { + body + } catch { + case NonFatal(e) => + ex = e + } + } + } + thread.setDaemon(true) + thread.start() + thread.join(timeoutMillis) + if (thread.isAlive) { + thread.interrupt() + // If this interrupt does not work, then this thread is most likely running something that + // is not interruptible. There is not much point to wait for the thread to termniate, and + // we rather let the JVM terminate the thread on exit. + fail( + s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + + s" $timeoutMillis ms") + } else if (ex != null) { + throw ex + } + } + + if (quietly) { + testQuietly(name) { runOnThread() } + } else { + test(name) { runOnThread() } + } + } +} + +/** + * Helper trait that can be extended by all external SQL test suites. + * + * This allows subclasses to plugin a custom `SQLContext`. + * To use implicit methods, import `testImplicits._` instead of through the `SQLContext`. + * + * Subclasses should *not* create `SQLContext`s in the test suite constructor, which is + * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM. + */ +private[sql] trait SQLTestUtilsBase + extends Eventually with BeforeAndAfterAll with SQLTestData - with PlanTest { self => + with PlanTestBase { self: Suite => protected def sparkContext = spark.sparkContext - // Whether to materialize all test data before the first test is run - private var loadTestDataBeforeTests = false - // Shorthand for running a query using our SQLContext protected lazy val sql = spark.sql _ @@ -77,21 +162,6 @@ private[sql] trait SQLTestUtils protected override def _sqlContext: SQLContext = self.spark.sqlContext } - /** - * Materialize the test data immediately after the `SQLContext` is set up. - * This is necessary if the data is accessed by name but not through direct reference. - */ - protected def setupTestData(): Unit = { - loadTestDataBeforeTests = true - } - - protected override def beforeAll(): Unit = { - super.beforeAll() - if (loadTestDataBeforeTests) { - loadTestData() - } - } - protected override def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = { SparkSession.setActiveSession(spark) super.withSQLConf(pairs: _*)(f) @@ -297,61 +367,6 @@ private[sql] trait SQLTestUtils Dataset.ofRows(spark, plan) } - /** - * Disable stdout and stderr when running the test. To not output the logs to the console, - * ConsoleAppender's `follow` should be set to `true` so that it will honors reassignments of - * System.out or System.err. Otherwise, ConsoleAppender will still output to the console even if - * we change System.out and System.err. - */ - protected def testQuietly(name: String)(f: => Unit): Unit = { - test(name) { - quietly { - f - } - } - } - - /** - * Run a test on a separate `UninterruptibleThread`. - */ - protected def testWithUninterruptibleThread(name: String, quietly: Boolean = false) - (body: => Unit): Unit = { - val timeoutMillis = 10000 - @transient var ex: Throwable = null - - def runOnThread(): Unit = { - val thread = new UninterruptibleThread(s"Testing thread for test $name") { - override def run(): Unit = { - try { - body - } catch { - case NonFatal(e) => - ex = e - } - } - } - thread.setDaemon(true) - thread.start() - thread.join(timeoutMillis) - if (thread.isAlive) { - thread.interrupt() - // If this interrupt does not work, then this thread is most likely running something that - // is not interruptible. There is not much point to wait for the thread to termniate, and - // we rather let the JVM terminate the thread on exit. - fail( - s"Test '$name' running on o.a.s.util.UninterruptibleThread timed out after" + - s" $timeoutMillis ms") - } else if (ex != null) { - throw ex - } - } - - if (quietly) { - testQuietly(name) { runOnThread() } - } else { - test(name) { runOnThread() } - } - } /** * This method is used to make the given path qualified, when a path diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala index cd8d0708d8a32..4d578e21f5494 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala @@ -17,86 +17,4 @@ package org.apache.spark.sql.test -import scala.concurrent.duration._ - -import org.scalatest.BeforeAndAfterEach -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.sql.{SparkSession, SQLContext} -import org.apache.spark.sql.internal.SQLConf - -/** - * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. - */ -trait SharedSQLContext extends SQLTestUtils with BeforeAndAfterEach with Eventually { - - protected def sparkConf = { - new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set("spark.unsafe.exceptionOnMemoryLeak", "true") - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected implicit def spark: SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - new TestSparkSession(sparkConf) - } - - /** - * Initialize the [[TestSparkSession]]. - */ - protected override def beforeAll(): Unit = { - SparkSession.sqlListener.set(null) - if (_spark == null) { - _spark = createSparkSession - } - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - super.afterAll() - if (_spark != null) { - _spark.sessionState.catalog.reset() - _spark.stop() - _spark = null - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } -} +trait SharedSQLContext extends SQLTestUtils with SharedSparkSession diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala new file mode 100644 index 0000000000000..e0568a3c5c99f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.test + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf} +import org.apache.spark.sql.{SparkSession, SQLContext} +import org.apache.spark.sql.internal.SQLConf + +/** + * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. + */ +trait SharedSparkSession + extends SQLTestUtilsBase + with BeforeAndAfterEach + with Eventually { self: Suite => + + protected def sparkConf = { + new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set("spark.unsafe.exceptionOnMemoryLeak", "true") + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: TestSparkSession = null + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + */ + protected implicit def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: TestSparkSession = { + new TestSparkSession(sparkConf) + } + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + SparkSession.sqlListener.set(null) + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + super.afterAll() + if (_spark != null) { + _spark.sessionState.catalog.reset() + _spark.stop() + _spark = null + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala new file mode 100644 index 0000000000000..4205e23ae240a --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/ExecutionListenerManagerSuite.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.spark._ +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.internal.StaticSQLConf._ + +class ExecutionListenerManagerSuite extends SparkFunSuite { + + import CountingQueryExecutionListener._ + + test("register query execution listeners using configuration") { + val conf = new SparkConf(false) + .set(QUERY_EXECUTION_LISTENERS, Seq(classOf[CountingQueryExecutionListener].getName())) + + val mgr = new ExecutionListenerManager(conf) + assert(INSTANCE_COUNT.get() === 1) + mgr.onSuccess(null, null, 42L) + assert(CALLBACK_COUNT.get() === 1) + + val clone = mgr.clone() + assert(INSTANCE_COUNT.get() === 1) + + clone.onSuccess(null, null, 42L) + assert(CALLBACK_COUNT.get() === 2) + } + +} + +private class CountingQueryExecutionListener extends QueryExecutionListener { + + import CountingQueryExecutionListener._ + + INSTANCE_COUNT.incrementAndGet() + + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + CALLBACK_COUNT.incrementAndGet() + } + + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { + CALLBACK_COUNT.incrementAndGet() + } + +} + +private object CountingQueryExecutionListener { + + val CALLBACK_COUNT = new AtomicInteger() + val INSTANCE_COUNT = new AtomicInteger() + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index b256ffc27b199..1f11adbd4f62e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -94,8 +94,15 @@ private[sql] class HiveSessionCatalog( } } catch { case NonFatal(e) => - val analysisException = - new AnalysisException(s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e") + val noHandlerMsg = s"No handler for UDF/UDAF/UDTF '${clazz.getCanonicalName}': $e" + val errorMsg = + if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + s"$noHandlerMsg\nPlease make sure your function overrides " + + "`public StructObjectInspector initialize(ObjectInspector[] args)`." + } else { + noHandlerMsg + } + val analysisException = new AnalysisException(errorMsg) analysisException.setStackTrace(e.getStackTrace) throw analysisException } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 805b3171cdaab..3592b8f4846d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -189,12 +189,12 @@ case class RelationConversions( private def convert(relation: HiveTableRelation): LogicalRelation = { val serde = relation.tableMeta.storage.serde.getOrElse("").toLowerCase(Locale.ROOT) if (serde.contains("parquet")) { - val options = Map(ParquetOptions.MERGE_SCHEMA -> + val options = relation.tableMeta.storage.properties + (ParquetOptions.MERGE_SCHEMA -> conf.getConf(HiveUtils.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING).toString) sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[ParquetFileFormat], "parquet") } else { - val options = Map[String, String]() + val options = relation.tableMeta.storage.properties sessionCatalog.metastoreCatalog .convertToLogicalRelation(relation, options, classOf[OrcFileFormat], "orc") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c4e48c9360db7..16c95c53b4201 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -111,12 +111,6 @@ private[hive] class HiveClientImpl( if (clientLoader.isolationOn) { // Switch to the initClassLoader. Thread.currentThread().setContextClassLoader(initClassLoader) - // Set up kerberos credentials for UserGroupInformation.loginUser within current class loader - if (sparkConf.contains("spark.yarn.principal") && sparkConf.contains("spark.yarn.keytab")) { - val principal = sparkConf.get("spark.yarn.principal") - val keytab = sparkConf.get("spark.yarn.keytab") - SparkHadoopUtil.get.loginUserFromKeytab(principal, keytab) - } try { newState() } finally { @@ -461,7 +455,7 @@ private[hive] class HiveClientImpl( // in table properties. This means, if we have bucket spec in both hive metastore and // table properties, we will trust the one in table properties. bucketSpec = bucketSpec, - owner = h.getOwner, + owner = Option(h.getOwner).getOrElse(""), createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, storage = CatalogStorageFormat( @@ -638,12 +632,13 @@ private[hive] class HiveClientImpl( table: CatalogTable, spec: Option[TablePartitionSpec]): Seq[CatalogTablePartition] = withHiveState { val hiveTable = toHiveTable(table, Some(userName)) - val parts = spec match { - case None => shim.getAllPartitions(client, hiveTable).map(fromHivePartition) + val partSpec = spec match { + case None => CatalogTypes.emptyTablePartitionSpec case Some(s) => assert(s.values.forall(_.nonEmpty), s"partition spec '$s' is invalid") - client.getPartitions(hiveTable, s.asJava).asScala.map(fromHivePartition) + s } + val parts = client.getPartitions(hiveTable, partSpec.asJava).asScala.map(fromHivePartition) HiveCatalogMetrics.incrementFetchedPartitions(parts.length) parts } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index cde20da186acd..5c1ff2b76fdaa 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -585,6 +585,35 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { * Unsupported predicates are skipped. */ def convertFilters(table: Table, filters: Seq[Expression]): String = { + if (SQLConf.get.advancedPartitionPredicatePushdownEnabled) { + convertComplexFilters(table, filters) + } else { + convertBasicFilters(table, filters) + } + } + + private def convertBasicFilters(table: Table, filters: Seq[Expression]): String = { + // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. + lazy val varcharKeys = table.getPartitionKeys.asScala + .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || + col.getType.startsWith(serdeConstants.CHAR_TYPE_NAME)) + .map(col => col.getName).toSet + + filters.collect { + case op @ BinaryComparison(a: Attribute, Literal(v, _: IntegralType)) => + s"${a.name} ${op.symbol} $v" + case op @ BinaryComparison(Literal(v, _: IntegralType), a: Attribute) => + s"$v ${op.symbol} ${a.name}" + case op @ BinaryComparison(a: Attribute, Literal(v, _: StringType)) + if !varcharKeys.contains(a.name) => + s"""${a.name} ${op.symbol} ${quoteStringLiteral(v.toString)}""" + case op @ BinaryComparison(Literal(v, _: StringType), a: Attribute) + if !varcharKeys.contains(a.name) => + s"""${quoteStringLiteral(v.toString)} ${op.symbol} ${a.name}""" + }.mkString(" and ") + } + + private def convertComplexFilters(table: Table, filters: Seq[Expression]): String = { // hive varchar is treated as catalyst string, but hive varchar can't be pushed down. lazy val varcharKeys = table.getPartitionKeys.asScala .filter(col => col.getType.startsWith(serdeConstants.VARCHAR_TYPE_NAME) || diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 48d0b4a63e54a..4f8dab9cd6172 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -162,21 +162,19 @@ case class HiveTableScanExec( // exposed for tests @transient lazy val rawPartitions = { - val prunedPartitions = if (sparkSession.sessionState.conf.metastorePartitionPruning) { - // Retrieve the original attributes based on expression ID so that capitalization matches. - val normalizedFilters = partitionPruningPred.map(_.transform { - case a: AttributeReference => originalAttributes(a) - }) - sparkSession.sharedState.externalCatalog.listPartitionsByFilter( - relation.tableMeta.database, - relation.tableMeta.identifier.table, - normalizedFilters, - sparkSession.sessionState.conf.sessionLocalTimeZone) - } else { - sparkSession.sharedState.externalCatalog.listPartitions( - relation.tableMeta.database, - relation.tableMeta.identifier.table) - } + val prunedPartitions = + if (sparkSession.sessionState.conf.metastorePartitionPruning && + partitionPruningPred.size > 0) { + // Retrieve the original attributes based on expression ID so that capitalization matches. + val normalizedFilters = partitionPruningPred.map(_.transform { + case a: AttributeReference => originalAttributes(a) + }) + sparkSession.sessionState.catalog.listPartitionsByFilter( + relation.tableMeta.identifier, + normalizedFilters) + } else { + sparkSession.sessionState.catalog.listPartitions(relation.tableMeta.identifier) + } prunedPartitions.map(HiveClientImpl.toHivePartition(_, hiveQlTable)) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala index 918c8be00d69d..1c6f8dd77fc2c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveDirCommand.scala @@ -27,11 +27,10 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.hadoop.mapred._ import org.apache.spark.SparkException -import org.apache.spark.sql.{Row, SparkSession} +import org.apache.spark.sql.{Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.hive.client.HiveClientImpl /** @@ -57,10 +56,7 @@ case class InsertIntoHiveDirCommand( query: LogicalPlan, overwrite: Boolean) extends SaveAsHiveFile { - override def children: Seq[LogicalPlan] = query :: Nil - - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) + override def run(sparkSession: SparkSession): Seq[Row] = { assert(storage.locationUri.nonEmpty) val hiveTable = HiveClientImpl.toHiveTable(CatalogTable( @@ -102,7 +98,7 @@ case class InsertIntoHiveDirCommand( try { saveAsHiveFile( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpPath.toString) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index e5b59ed7a1a6b..56e10bc457a00 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -17,20 +17,16 @@ package org.apache.spark.sql.hive.execution -import scala.util.control.NonFatal - import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.ql.ErrorMsg import org.apache.hadoop.hive.ql.plan.TableDesc import org.apache.spark.SparkException -import org.apache.spark.sql.{AnalysisException, Row, SparkSession} +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command.CommandUtils -import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.HiveShim.{ShimFileSinkDesc => FileSinkDesc} import org.apache.spark.sql.hive.client.HiveClientImpl @@ -72,16 +68,12 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifPartitionNotExists: Boolean) extends SaveAsHiveFile { - override def children: Seq[LogicalPlan] = query :: Nil - /** * Inserts all the rows in the table into Hive. Row objects are properly serialized with the * `org.apache.hadoop.hive.serde2.SerDe` and the * `org.apache.hadoop.mapred.OutputFormat` provided by the table definition. */ - override def run(sparkSession: SparkSession, children: Seq[SparkPlan]): Seq[Row] = { - assert(children.length == 1) - + override def run(sparkSession: SparkSession): Seq[Row] = { val externalCatalog = sparkSession.sharedState.externalCatalog val hadoopConf = sparkSession.sessionState.newHadoopConf() @@ -170,7 +162,7 @@ case class InsertIntoHiveTable( saveAsHiveFile( sparkSession = sparkSession, - plan = children.head, + queryExecution = Dataset.ofRows(sparkSession, query).queryExecution, hadoopConf = hadoopConf, fileSinkConf = fileSinkConf, outputLocation = tmpLocation.toString, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 2d74ef040ef5a..63657590e5e79 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -33,7 +33,7 @@ import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.DataWritingCommand import org.apache.spark.sql.execution.datasources.FileFormatWriter import org.apache.spark.sql.hive.HiveExternalCatalog @@ -47,7 +47,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { protected def saveAsHiveFile( sparkSession: SparkSession, - plan: SparkPlan, + queryExecution: QueryExecution, hadoopConf: Configuration, fileSinkConf: FileSinkDesc, outputLocation: String, @@ -75,7 +75,7 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { FileFormatWriter.write( sparkSession = sparkSession, - plan = plan, + queryExecution = queryExecution, fileFormat = new HiveFileFormat(fileSinkConf), committer = committer, outputSpec = FileFormatWriter.OutputSpec(outputLocation, customPartitionLocations), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index e9bdcf00b9346..68af99ea272a8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -48,7 +48,7 @@ private[hive] case class HiveSimpleUDF( with Logging with UserDefinedExpression { - override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def nullable: Boolean = true @@ -131,7 +131,7 @@ private[hive] case class HiveGenericUDF( override def nullable: Boolean = true - override def deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) + override lazy val deterministic: Boolean = isUDFDeterministic && children.forall(_.deterministic) override def foldable: Boolean = isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 4d92a67044373..d26ec15410d95 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -32,6 +32,7 @@ import org.apache.hadoop.io.{NullWritable, Writable} import org.apache.hadoop.mapred.{JobConf, OutputFormat => MapRedOutputFormat, RecordWriter, Reporter} import org.apache.hadoop.mapreduce._ import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat, FileSplit} +import org.apache.orc.OrcConf.COMPRESS import org.apache.spark.TaskContext import org.apache.spark.sql.SparkSession @@ -58,7 +59,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable options: Map[String, String], files: Seq[FileStatus]): Option[StructType] = { OrcFileOperator.readSchema( - files.map(_.getPath.toUri.toString), + files.map(_.getPath.toString), Some(sparkSession.sessionState.newHadoopConf()) ) } @@ -72,7 +73,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val configuration = job.getConfiguration - configuration.set(OrcRelation.ORC_COMPRESSION, orcOptions.compressionCodec) + configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodec) configuration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) @@ -93,8 +94,8 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable override def getFileExtension(context: TaskAttemptContext): String = { val compressionExtension: String = { - val name = context.getConfiguration.get(OrcRelation.ORC_COMPRESSION) - OrcRelation.extensionsForCompressionCodecNames.getOrElse(name, "") + val name = context.getConfiguration.get(COMPRESS.getAttribute) + OrcFileFormat.extensionsForCompressionCodecNames.getOrElse(name, "") } compressionExtension + ".orc" @@ -120,7 +121,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable if (sparkSession.sessionState.conf.orcFilterPushDown) { // Sets pushed predicates OrcFilters.createFilter(requiredSchema, filters.toArray).foreach { f => - hadoopConf.set(OrcRelation.SARG_PUSHDOWN, f.toKryo) + hadoopConf.set(OrcFileFormat.SARG_PUSHDOWN, f.toKryo) hadoopConf.setBoolean(ConfVars.HIVEOPTINDEXFILTER.varname, true) } } @@ -134,12 +135,11 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable // SPARK-8501: Empty ORC files always have an empty schema stored in their footer. In this // case, `OrcFileOperator.readSchema` returns `None`, and we can't read the underlying file // using the given physical schema. Instead, we simply return an empty iterator. - val maybePhysicalSchema = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)) - if (maybePhysicalSchema.isEmpty) { + val isEmptyFile = OrcFileOperator.readSchema(Seq(file.filePath), Some(conf)).isEmpty + if (isEmptyFile) { Iterator.empty } else { - val physicalSchema = maybePhysicalSchema.get - OrcRelation.setRequiredColumns(conf, physicalSchema, requiredSchema) + OrcFileFormat.setRequiredColumns(conf, dataSchema, requiredSchema) val orcRecordReader = { val job = Job.getInstance(conf) @@ -161,8 +161,9 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable Option(TaskContext.get()).foreach(_.addTaskCompletionListener(_ => recordsIterator.close())) // Unwraps `OrcStruct`s to `UnsafeRow`s - OrcRelation.unwrapOrcStructs( + OrcFileFormat.unwrapOrcStructs( conf, + dataSchema, requiredSchema, Some(orcRecordReader.getObjectInspector.asInstanceOf[StructObjectInspector]), recordsIterator) @@ -255,10 +256,7 @@ private[orc] class OrcOutputWriter( } } -private[orc] object OrcRelation extends HiveInspectors { - // The references of Hive's classes will be minimized. - val ORC_COMPRESSION = "orc.compress" - +private[orc] object OrcFileFormat extends HiveInspectors { // This constant duplicates `OrcInputFormat.SARG_PUSHDOWN`, which is unfortunately not public. private[orc] val SARG_PUSHDOWN = "sarg.pushdown" @@ -272,25 +270,32 @@ private[orc] object OrcRelation extends HiveInspectors { def unwrapOrcStructs( conf: Configuration, dataSchema: StructType, + requiredSchema: StructType, maybeStructOI: Option[StructObjectInspector], iterator: Iterator[Writable]): Iterator[InternalRow] = { val deserializer = new OrcSerde - val mutableRow = new SpecificInternalRow(dataSchema.map(_.dataType)) - val unsafeProjection = UnsafeProjection.create(dataSchema) + val mutableRow = new SpecificInternalRow(requiredSchema.map(_.dataType)) + val unsafeProjection = UnsafeProjection.create(requiredSchema) def unwrap(oi: StructObjectInspector): Iterator[InternalRow] = { - val (fieldRefs, fieldOrdinals) = dataSchema.zipWithIndex.map { - case (field, ordinal) => oi.getStructFieldRef(field.name) -> ordinal + val (fieldRefs, fieldOrdinals) = requiredSchema.zipWithIndex.map { + case (field, ordinal) => + var ref = oi.getStructFieldRef(field.name) + if (ref == null) { + ref = oi.getStructFieldRef("_col" + dataSchema.fieldIndex(field.name)) + } + ref -> ordinal }.unzip - val unwrappers = fieldRefs.map(unwrapperFor) + val unwrappers = fieldRefs.map(r => if (r == null) null else unwrapperFor(r)) iterator.map { value => val raw = deserializer.deserialize(value) var i = 0 val length = fieldRefs.length while (i < length) { - val fieldValue = oi.getStructFieldData(raw, fieldRefs(i)) + val fieldRef = fieldRefs(i) + val fieldValue = if (fieldRef == null) null else oi.getStructFieldData(raw, fieldRef) if (fieldValue == null) { mutableRow.setNullAt(fieldOrdinals(i)) } else { @@ -306,8 +311,8 @@ private[orc] object OrcRelation extends HiveInspectors { } def setRequiredColumns( - conf: Configuration, physicalSchema: StructType, requestedSchema: StructType): Unit = { - val ids = requestedSchema.map(a => physicalSchema.fieldIndex(a.name): Integer) + conf: Configuration, dataSchema: StructType, requestedSchema: StructType): Unit = { + val ids = requestedSchema.map(a => dataSchema.fieldIndex(a.name): Integer) val (sortedIDs, sortedNames) = ids.zip(requestedSchema.fieldNames).sorted.unzip HiveShim.appendReadColumns(conf, sortedIDs, sortedNames) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala index 7f94c8c579026..6ce90c07b4921 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcOptions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.hive.orc import java.util.Locale +import org.apache.orc.OrcConf.COMPRESS + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -40,9 +42,9 @@ private[orc] class OrcOptions( * Acceptable values are defined in [[shortOrcCompressionCodecNames]]. */ val compressionCodec: String = { - // `compression`, `orc.compress`, and `spark.sql.orc.compression.codec` are - // in order of precedence from highest to lowest. - val orcCompressionConf = parameters.get(OrcRelation.ORC_COMPRESSION) + // `compression`, `orc.compress`(i.e., OrcConf.COMPRESS), and `spark.sql.orc.compression.codec` + // are in order of precedence from highest to lowest. + val orcCompressionConf = parameters.get(COMPRESS.getAttribute) val codecName = parameters .get("compression") .orElse(orcCompressionConf) diff --git a/sql/hive/src/test/resources/SPARK-21101-1.0.jar b/sql/hive/src/test/resources/SPARK-21101-1.0.jar new file mode 100644 index 0000000000000..768b2334db5c3 Binary files /dev/null and b/sql/hive/src/test/resources/SPARK-21101-1.0.jar differ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index 305f5b533d592..5f8c9d5799662 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -53,7 +53,9 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def downloadSpark(version: String): Unit = { import scala.sys.process._ - val url = s"https://d3kbcqa49mib13.cloudfront.net/spark-$version-bin-hadoop2.7.tgz" + val preferredMirror = + Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim + val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" Seq("wget", url, "-q", "-P", sparkTestingDir.getCanonicalPath).! @@ -142,7 +144,7 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { object PROCESS_TABLES extends QueryTest with SQLTestUtils { // Tests the latest version of every release line. - val testingVersions = Seq("2.0.2", "2.1.1", "2.2.0") + val testingVersions = Seq("2.0.2", "2.1.2", "2.2.0") protected var spark: SparkSession = _ diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala index 0c28a1b609bb8..e71aba72c31fe 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetadataCacheSuite.scala @@ -31,14 +31,22 @@ import org.apache.spark.sql.test.SQLTestUtils class HiveMetadataCacheSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-16337 temporary view refresh") { - withTempView("view_refresh") { + checkRefreshView(isTemp = true) + } + + test("view refresh") { + checkRefreshView(isTemp = false) + } + + private def checkRefreshView(isTemp: Boolean) { + withView("view_refresh") { withTable("view_table") { // Create a Parquet directory spark.range(start = 0, end = 100, step = 1, numPartitions = 3) .write.saveAsTable("view_table") - // Read the table in - spark.table("view_table").filter("id > -1").createOrReplaceTempView("view_refresh") + val temp = if (isTemp) "TEMPORARY" else "" + spark.sql(s"CREATE $temp VIEW view_refresh AS SELECT * FROM view_table WHERE id > -1") assert(sql("select count(*) from view_refresh").first().getLong(0) == 100) // Delete a file using the Hadoop file system interface since the path returned by diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala index aa5cae33f5cd9..ab91727049ff5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertSuite.scala @@ -728,4 +728,26 @@ class InsertSuite extends QueryTest with TestHiveSingleton with BeforeAndAfter assert(e.contains("mismatched input 'ROW'")) } } + + test("SPARK-21165: FileFormatWriter should only rely on attributes from analyzed plan") { + withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { + withTable("tab1", "tab2") { + Seq(("a", "b", 3)).toDF("word", "first", "length").write.saveAsTable("tab1") + + spark.sql( + """ + |CREATE TABLE tab2 (word string, length int) + |PARTITIONED BY (first string) + """.stripMargin) + + spark.sql( + """ + |INSERT INTO TABLE tab2 PARTITION(first) + |SELECT word, length, cast(first as string) as first FROM tab1 + """.stripMargin) + + checkAnswer(spark.table("tab2"), Row("a", 3, "b")) + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 29b0e6c8533ef..f5d41c91270a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -993,7 +993,6 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv spark.sql("""drop database if exists testdb8156 CASCADE""") } - test("skip hive metadata on table creation") { withTempDir { tempPath => val schema = StructType((1 to 5).map(i => StructField(s"c_$i", StringType))) @@ -1345,6 +1344,17 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv } } + Seq("orc", "parquet", "csv", "json", "text").foreach { format => + test(s"SPARK-22146: read files containing special characters using $format") { + val nameWithSpecialChars = s"sp&cial%chars" + withTempDir { dir => + val tmpFile = s"$dir/$nameWithSpecialChars" + spark.createDataset(Seq("a", "b")).write.format(format).save(tmpFile) + spark.read.format(format).load(tmpFile) + } + } + } + private def withDebugMode(f: => Unit): Unit = { val previousValue = sparkSession.sparkContext.conf.get(DEBUG_MODE) try { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 9ff9ecf7f3677..b9a5ad7657134 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -937,26 +937,20 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } test("test statistics of LogicalRelation converted from Hive serde tables") { - val parquetTable = "parquetTable" - val orcTable = "orcTable" - withTable(parquetTable, orcTable) { - sql(s"CREATE TABLE $parquetTable (key STRING, value STRING) STORED AS PARQUET") - sql(s"CREATE TABLE $orcTable (key STRING, value STRING) STORED AS ORC") - sql(s"INSERT INTO TABLE $parquetTable SELECT * FROM src") - sql(s"INSERT INTO TABLE $orcTable SELECT * FROM src") - - // the default value for `spark.sql.hive.convertMetastoreParquet` is true, here we just set it - // for robustness - withSQLConf(HiveUtils.CONVERT_METASTORE_PARQUET.key -> "true") { - checkTableStats(parquetTable, hasSizeInBytes = false, expectedRowCounts = None) - sql(s"ANALYZE TABLE $parquetTable COMPUTE STATISTICS") - checkTableStats(parquetTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) - } - withSQLConf(HiveUtils.CONVERT_METASTORE_ORC.key -> "true") { - // We still can get tableSize from Hive before Analyze - checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = None) - sql(s"ANALYZE TABLE $orcTable COMPUTE STATISTICS") - checkTableStats(orcTable, hasSizeInBytes = true, expectedRowCounts = Some(500)) + Seq("orc", "parquet").foreach { format => + Seq(true, false).foreach { isConverted => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> s"$isConverted", + HiveUtils.CONVERT_METASTORE_PARQUET.key -> s"$isConverted") { + withTable(format) { + sql(s"CREATE TABLE $format (key STRING, value STRING) STORED AS $format") + sql(s"INSERT INTO TABLE $format SELECT * FROM src") + + checkTableStats(format, hasSizeInBytes = !isConverted, expectedRowCounts = None) + sql(s"ANALYZE TABLE $format COMPUTE STATISTICS") + checkTableStats(format, hasSizeInBytes = true, expectedRowCounts = Some(500)) + } + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala index 031c1a5ec0ec3..19765695fbcb4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/FiltersSuite.scala @@ -26,13 +26,15 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ /** * A set of tests for the filter conversion logic used when pushing partition pruning into the * metastore */ -class FiltersSuite extends SparkFunSuite with Logging { +class FiltersSuite extends SparkFunSuite with Logging with PlanTest { private val shim = new Shim_v0_13 private val testTable = new org.apache.hadoop.hive.ql.metadata.Table("default", "test") @@ -72,10 +74,28 @@ class FiltersSuite extends SparkFunSuite with Logging { private def filterTest(name: String, filters: Seq[Expression], result: String) = { test(name) { - val converted = shim.convertFilters(testTable, filters) - if (converted != result) { - fail( - s"Expected filters ${filters.mkString(",")} to convert to '$result' but got '$converted'") + withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> "true") { + val converted = shim.convertFilters(testTable, filters) + if (converted != result) { + fail(s"Expected ${filters.mkString(",")} to convert to '$result' but got '$converted'") + } + } + } + } + + test("turn on/off ADVANCED_PARTITION_PREDICATE_PUSHDOWN") { + import org.apache.spark.sql.catalyst.dsl.expressions._ + Seq(true, false).foreach { enabled => + withSQLConf(SQLConf.ADVANCED_PARTITION_PREDICATE_PUSHDOWN.key -> enabled.toString) { + val filters = + (Literal(1) === a("intcol", IntegerType) || + Literal(2) === a("intcol", IntegerType)) :: Nil + val converted = shim.convertFilters(testTable, filters) + if (enabled) { + assert(converted == "(1 = intcol or 2 = intcol)") + } else { + assert(converted.isEmpty) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index f245a79f805a2..ae675149df5e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -1015,7 +1015,7 @@ class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySu override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = { Seq("true", "false").foreach { enableTwoLevelMaps => - withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enable" -> + withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enabled" -> enableTwoLevelMaps) { (1 to 3).foreach { fallbackStartsAt => withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" -> diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 668da5fb47323..d3465a641a1a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -23,6 +23,8 @@ import java.net.URI import scala.language.existentials import org.apache.hadoop.fs.Path +import org.apache.parquet.format.converter.ParquetMetadataConverter.NO_FILTER +import org.apache.parquet.hadoop.ParquetFileReader import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException @@ -32,6 +34,7 @@ import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAl import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.execution.command.{DDLSuite, DDLUtils} import org.apache.spark.sql.hive.HiveExternalCatalog +import org.apache.spark.sql.hive.HiveUtils.{CONVERT_METASTORE_ORC, CONVERT_METASTORE_PARQUET} import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} @@ -163,6 +166,14 @@ class HiveCatalogedDDLSuite extends DDLSuite with TestHiveSingleton with BeforeA test("drop table") { testDropTable(isDatasourceTable = false) } + + test("alter datasource table add columns - orc") { + testAddColumn("orc") + } + + test("alter datasource table add columns - partitioned - orc") { + testAddColumnPartitioned("orc") + } } class HiveDDLSuite @@ -1455,12 +1466,8 @@ class HiveDDLSuite sql("INSERT INTO t SELECT 1") checkAnswer(spark.table("t"), Row(1)) // Check if this is compressed as ZLIB. - val maybeOrcFile = path.listFiles().find(!_.getName.endsWith(".crc")) - assert(maybeOrcFile.isDefined) - val orcFilePath = maybeOrcFile.get.toPath.toString - val expectedCompressionKind = - OrcFileOperator.getFileReader(orcFilePath).get.getCompression - assert("ZLIB" === expectedCompressionKind.name()) + val maybeOrcFile = path.listFiles().find(_.getName.startsWith("part")) + assertCompression(maybeOrcFile, "orc", "ZLIB") sql("CREATE TABLE t2 USING HIVE AS SELECT 1 AS c1, 'a' AS c2") val table2 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t2")) @@ -2009,4 +2016,47 @@ class HiveDDLSuite } } } + + private def assertCompression(maybeFile: Option[File], format: String, compression: String) = { + assert(maybeFile.isDefined) + + val actualCompression = format match { + case "orc" => + OrcFileOperator.getFileReader(maybeFile.get.toPath.toString).get.getCompression.name + + case "parquet" => + val footer = ParquetFileReader.readFooter( + sparkContext.hadoopConfiguration, new Path(maybeFile.get.getPath), NO_FILTER) + footer.getBlocks.get(0).getColumns.get(0).getCodec.toString + } + + assert(compression === actualCompression) + } + + Seq(("orc", "ZLIB"), ("parquet", "GZIP")).foreach { case (fileFormat, compression) => + test(s"SPARK-22158 convertMetastore should not ignore table property - $fileFormat") { + withSQLConf(CONVERT_METASTORE_ORC.key -> "true", CONVERT_METASTORE_PARQUET.key -> "true") { + withTable("t") { + withTempPath { path => + sql( + s""" + |CREATE TABLE t(id int) USING hive + |OPTIONS(fileFormat '$fileFormat', compression '$compression') + |LOCATION '${path.toURI}' + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(DDLUtils.isHiveTable(table)) + assert(table.storage.serde.get.contains(fileFormat)) + assert(table.storage.properties.get("compression") == Some(compression)) + assert(spark.table("t").collect().isEmpty) + + sql("INSERT INTO t SELECT 1") + checkAnswer(spark.table("t"), Row(1)) + val maybeFile = path.listFiles().find(_.getName.startsWith("part")) + assertCompression(maybeFile, fileFormat, compression) + } + } + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 09c59000b3e3f..c11e37a516646 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.hive.HiveUtils +import org.apache.spark.sql.hive.{HiveExternalCatalog, HiveUtils} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils @@ -1376,113 +1376,125 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("SPARK-8976 Wrong Result for Rollup #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH ROLLUP"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s"SELECT count(*) AS cnt, key % 5, $gid FROM src GROUP BY key%5 WITH ROLLUP"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } } test("SPARK-8976 Wrong Result for Rollup #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM src GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM src GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for Rollup #3") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, 0, 5, 0), - (1, 0, 15, 0), - (1, 0, 25, 0), - (1, 0, 60, 0), - (1, 0, 75, 0), - (1, 0, 80, 0), - (1, 0, 100, 0), - (1, 0, 140, 0), - (1, 0, 145, 0), - (1, 0, 150, 0) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH ROLLUP ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, 0, 5, 0), + (1, 0, 15, 0), + (1, 0, 25, 0), + (1, 0, 60, 0), + (1, 0, 75, 0), + (1, 0, 80, 0), + (1, 0, 100, 0), + (1, 0, 140, 0), + (1, 0, 145, 0), + (1, 0, 150, 0) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for CUBE #1") { - checkAnswer(sql( - "SELECT count(*) AS cnt, key % 5, grouping_id() FROM src GROUP BY key%5 WITH CUBE"), - Seq( - (113, 3, 0), - (91, 0, 0), - (500, null, 1), - (84, 1, 0), - (105, 2, 0), - (107, 4, 0) - ).map(i => Row(i._1, i._2, i._3))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s"SELECT count(*) AS cnt, key % 5, $gid FROM src GROUP BY key%5 WITH CUBE"), + Seq( + (113, 3, 0), + (91, 0, 0), + (500, null, 1), + (84, 1, 0), + (105, 2, 0), + (107, 4, 0) + ).map(i => Row(i._1, i._2, i._3))) + } } test("SPARK-8976 Wrong Result for CUBE #2") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |WITH CUBE ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } test("SPARK-8976 Wrong Result for GroupingSet") { - checkAnswer(sql( - """ - |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, grouping_id() AS k3 - |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 - |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 - """.stripMargin), - Seq( - (1, null, -3, 2), - (1, null, -1, 2), - (1, null, 3, 2), - (1, null, 4, 2), - (1, null, 5, 2), - (1, null, 6, 2), - (1, null, 12, 2), - (1, null, 14, 2), - (1, null, 15, 2), - (1, null, 22, 2) - ).map(i => Row(i._1, i._2, i._3, i._4))) + Seq("grouping_id()", "grouping__id").foreach { gid => + checkAnswer(sql( + s""" + |SELECT count(*) AS cnt, key % 5 AS k1, key-5 AS k2, $gid AS k3 + |FROM (SELECT key, key%2, key - 5 FROM src) t GROUP BY key%5, key-5 + |GROUPING SETS (key%5, key-5) ORDER BY cnt, k1, k2, k3 LIMIT 10 + """.stripMargin), + Seq( + (1, null, -3, 2), + (1, null, -1, 2), + (1, null, 3, 2), + (1, null, 4, 2), + (1, null, 5, 2), + (1, null, 6, 2), + (1, null, 12, 2), + (1, null, 14, 2), + (1, null, 15, 2), + (1, null, 22, 2) + ).map(i => Row(i._1, i._2, i._3, i._4))) + } } ignore("SPARK-10562: partition by column with mixed case name") { @@ -1998,6 +2010,32 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } + test("SPARK-21101 UDTF should override initialize(ObjectInspector[] args)") { + withUserDefinedFunction("udtf_stack1" -> true, "udtf_stack2" -> true) { + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_stack1 + |AS 'org.apache.spark.sql.hive.execution.UDTFStack' + |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' + """.stripMargin) + val cnt = + sql("SELECT udtf_stack1(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')").count() + assert(cnt === 2) + + sql( + s""" + |CREATE TEMPORARY FUNCTION udtf_stack2 + |AS 'org.apache.spark.sql.hive.execution.UDTFStack2' + |USING JAR '${hiveContext.getHiveFile("SPARK-21101-1.0.jar").toURI}' + """.stripMargin) + val e = intercept[org.apache.spark.sql.AnalysisException] { + sql("SELECT udtf_stack2(2, 'A', 10, date '2015-01-01', 'B', 20, date '2016-01-01')") + } + assert( + e.getMessage.contains("public StructObjectInspector initialize(ObjectInspector[] args)")) + } + } + test("SPARK-21721: Clear FileSystem deleterOnExit cache if path is successfully removed") { val table = "test21721" withTable(table) { @@ -2019,8 +2057,8 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { test("SPARK-21912 ORC/Parquet table should not create invalid column names") { Seq(" ", ",", ";", "{", "}", "(", ")", "\n", "\t", "=").foreach { name => - withTable("t21912") { - Seq("ORC", "PARQUET").foreach { source => + Seq("ORC", "PARQUET").foreach { source => + withTable("t21912") { val m = intercept[AnalysisException] { sql(s"CREATE TABLE t21912(`col$name` INT) USING $source") }.getMessage @@ -2037,17 +2075,82 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { }.getMessage assert(m3.contains(s"contains invalid character(s)")) } - } - // TODO: After SPARK-21929, we need to check ORC, too. - Seq("PARQUET").foreach { source => sql(s"CREATE TABLE t21912(`col` INT) USING $source") - val m = intercept[AnalysisException] { + val m4 = intercept[AnalysisException] { sql(s"ALTER TABLE t21912 ADD COLUMNS(`col$name` INT)") }.getMessage - assert(m.contains(s"contains invalid character(s)")) + assert(m4.contains(s"contains invalid character(s)")) + } + } + } + } + + Seq("orc", "parquet").foreach { format => + test(s"SPARK-18355 Read data from a hive table with a new column - $format") { + val client = spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client + + Seq("true", "false").foreach { value => + withSQLConf( + HiveUtils.CONVERT_METASTORE_ORC.key -> value, + HiveUtils.CONVERT_METASTORE_PARQUET.key -> value) { + withTempDatabase { db => + client.runSqlHive( + s""" + |CREATE TABLE $db.t( + | click_id string, + | search_id string, + | uid bigint) + |PARTITIONED BY ( + | ts string, + | hour string) + |STORED AS $format + """.stripMargin) + + client.runSqlHive( + s""" + |INSERT INTO TABLE $db.t + |PARTITION (ts = '98765', hour = '01') + |VALUES (12, 2, 12345) + """.stripMargin + ) + + checkAnswer( + sql(s"SELECT click_id, search_id, uid, ts, hour FROM $db.t"), + Row("12", "2", 12345, "98765", "01")) + + client.runSqlHive(s"ALTER TABLE $db.t ADD COLUMNS (dummy string)") + + checkAnswer( + sql(s"SELECT click_id, search_id FROM $db.t"), + Row("12", "2")) + + checkAnswer( + sql(s"SELECT search_id, click_id FROM $db.t"), + Row("2", "12")) + + checkAnswer( + sql(s"SELECT search_id FROM $db.t"), + Row("2")) + + checkAnswer( + sql(s"SELECT dummy, click_id FROM $db.t"), + Row(null, "12")) + + checkAnswer( + sql(s"SELECT click_id, search_id, uid, dummy, ts, hour FROM $db.t"), + Row("12", "2", 12345, null, "98765", "01")) + } } } } } + + Seq("orc", "parquet", "csv", "json", "text").foreach { format => + test(s"Writing empty datasets should not fail - $format") { + withTempDir { dir => + Seq("str").toDS.limit(0).write.format(format).save(dir.getCanonicalPath + "/tmp") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala index 60ccd996d6d58..1fa9091f967a3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcQuerySuite.scala @@ -22,6 +22,7 @@ import java.sql.Timestamp import org.apache.hadoop.conf.Configuration import org.apache.hadoop.hive.ql.io.orc.{OrcStruct, SparkOrcNewRecordReader} +import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql._ @@ -176,11 +177,11 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { } } - test("SPARK-16610: Respect orc.compress option when compression is unset") { - // Respect `orc.compress`. + test("SPARK-16610: Respect orc.compress (i.e., OrcConf.COMPRESS) when compression is unset") { + // Respect `orc.compress` (i.e., OrcConf.COMPRESS). withTempPath { file => spark.range(0, 10).write - .option("orc.compress", "ZLIB") + .option(COMPRESS.getAttribute, "ZLIB") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression @@ -191,7 +192,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { withTempPath { file => spark.range(0, 10).write .option("compression", "ZLIB") - .option("orc.compress", "SNAPPY") + .option(COMPRESS.getAttribute, "SNAPPY") .orc(file.getCanonicalPath) val expectedCompressionKind = OrcFileOperator.getFileReader(file.getCanonicalPath).get.getCompression @@ -598,7 +599,7 @@ class OrcQuerySuite extends QueryTest with BeforeAndAfterAll with OrcTest { val requestedSchema = StructType(Nil) val conf = new Configuration() val physicalSchema = OrcFileOperator.readSchema(Seq(path), Some(conf)).get - OrcRelation.setRequiredColumns(conf, physicalSchema, requestedSchema) + OrcFileFormat.setRequiredColumns(conf, physicalSchema, requestedSchema) val maybeOrcReader = OrcFileOperator.getFileReader(path, Some(conf)) assert(maybeOrcReader.isDefined) val orcRecordReader = new SparkOrcNewRecordReader( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala index 781de6631f324..ef9e67c743837 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcSourceSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.hive.orc import java.io.File +import java.util.Locale +import org.apache.orc.OrcConf.COMPRESS import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.{QueryTest, Row} @@ -150,7 +152,8 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { val conf = sqlContext.sessionState.conf - assert(new OrcOptions(Map("Orc.Compress" -> "NONE"), conf).compressionCodec == "NONE") + val option = new OrcOptions(Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> "NONE"), conf) + assert(option.compressionCodec == "NONE") } test("SPARK-19459/SPARK-18220: read char/varchar column written by Hive") { @@ -205,8 +208,8 @@ abstract class OrcSuite extends QueryTest with TestHiveSingleton with BeforeAndA // `compression` -> `orc.compression` -> `spark.sql.orc.compression.codec` withSQLConf(SQLConf.ORC_COMPRESSION.key -> "uncompressed") { assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "NONE") - val map1 = Map("orc.compress" -> "zlib") - val map2 = Map("orc.compress" -> "zlib", "compression" -> "lzo") + val map1 = Map(COMPRESS.getAttribute -> "zlib") + val map2 = Map(COMPRESS.getAttribute -> "zlib", "compression" -> "lzo") assert(new OrcOptions(map1, conf).compressionCodec == "ZLIB") assert(new OrcOptions(map2, conf).compressionCodec == "LZO") } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala index d6e15cfdd2723..ab7c8558321c8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/FileBasedWriteAheadLog.scala @@ -139,7 +139,7 @@ private[streaming] class FileBasedWriteAheadLog( def readFile(file: String): Iterator[ByteBuffer] = { logDebug(s"Creating log reader with $file") val reader = new FileBasedWriteAheadLogReader(file, hadoopConf) - CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, reader.close _) + CompletionIterator[ByteBuffer, Iterator[ByteBuffer]](reader, () => reader.close()) } if (!closeFileAfterWrite) { logFilesToRead.iterator.map(readFile).flatten.asJava