diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index f48c61c1d59c5..667fff7192b59 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -101,6 +101,7 @@ exportMethods("arrange", "withColumn", "withColumnRenamed", "write.df", + "write.jdbc", "write.json", "write.parquet", "write.text") @@ -125,6 +126,7 @@ exportMethods("%in%", "between", "bin", "bitwiseNOT", + "bround", "cast", "cbrt", "ceil", @@ -284,6 +286,7 @@ export("as.DataFrame", "loadDF", "parquetFile", "read.df", + "read.jdbc", "read.json", "read.parquet", "read.text", @@ -292,7 +295,8 @@ export("as.DataFrame", "tableToDF", "tableNames", "tables", - "uncacheTable") + "uncacheTable", + "print.summary.GeneralizedLinearRegressionModel") export("structField", "structField.jobj", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index a64a013b654ef..95e2eb2be037f 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2296,12 +2296,8 @@ setMethod("fillna", #' } setMethod("as.data.frame", signature(x = "DataFrame"), - function(x, ...) { - # Check if additional parameters have been passed - if (length(list(...)) > 0) { - stop(paste("Unused argument(s): ", paste(list(...), collapse = ", "))) - } - collect(x) + function(x, row.names = NULL, optional = FALSE, ...) { + as.data.frame(collect(x), row.names, optional, ...) }) #' The specified DataFrame is attached to the R search path. This means that @@ -2363,7 +2359,7 @@ setMethod("with", #' @examples \dontrun{ #' # Create a DataFrame from the Iris dataset #' irisDF <- createDataFrame(sqlContext, iris) -#' +#' #' # Show the structure of the DataFrame #' str(irisDF) #' } @@ -2468,3 +2464,40 @@ setMethod("drop", function(x) { base::drop(x) }) + +#' Saves the content of the DataFrame to an external database table via JDBC +#' +#' Additional JDBC database connection properties can be set (...) +#' +#' Also, mode is used to specify the behavior of the save operation when +#' data already exists in the data source. There are four modes: \cr +#' append: Contents of this DataFrame are expected to be appended to existing data. \cr +#' overwrite: Existing data is expected to be overwritten by the contents of this DataFrame. \cr +#' error: An exception is expected to be thrown. \cr +#' ignore: The save operation is expected to not save the contents of the DataFrame +#' and to not change the existing data. \cr +#' +#' @param x A SparkSQL DataFrame +#' @param url JDBC database url of the form `jdbc:subprotocol:subname` +#' @param tableName The name of the table in the external database +#' @param mode One of 'append', 'overwrite', 'error', 'ignore' save mode (it is 'error' by default) +#' @family DataFrame functions +#' @rdname write.jdbc +#' @name write.jdbc +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' jdbcUrl <- "jdbc:mysql://localhost:3306/databasename" +#' write.jdbc(df, jdbcUrl, "table", user = "username", password = "password") +#' } +setMethod("write.jdbc", + signature(x = "DataFrame", url = "character", tableName = "character"), + function(x, url, tableName, mode = "error", ...){ + jmode <- convertToJSaveMode(mode) + jprops <- varargsToJProperties(...) + write <- callJMethod(x@sdf, "write") + write <- callJMethod(write, "mode", jmode) + invisible(callJMethod(write, "jdbc", url, tableName, jprops)) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 16a2578678cd3..b726c1e1b9f2c 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -583,3 +583,61 @@ createExternalTable <- function(sqlContext, tableName, path = NULL, source = NUL sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) dataFrame(sdf) } + +#' Create a DataFrame representing the database table accessible via JDBC URL +#' +#' Additional JDBC database connection properties can be set (...) +#' +#' Only one of partitionColumn or predicates should be set. Partitions of the table will be +#' retrieved in parallel based on the `numPartitions` or by the predicates. +#' +#' Don't create too many partitions in parallel on a large cluster; otherwise Spark might crash +#' your external database systems. +#' +#' @param sqlContext SQLContext to use +#' @param url JDBC database url of the form `jdbc:subprotocol:subname` +#' @param tableName the name of the table in the external database +#' @param partitionColumn the name of a column of integral type that will be used for partitioning +#' @param lowerBound the minimum value of `partitionColumn` used to decide partition stride +#' @param upperBound the maximum value of `partitionColumn` used to decide partition stride +#' @param numPartitions the number of partitions, This, along with `lowerBound` (inclusive), +#' `upperBound` (exclusive), form partition strides for generated WHERE +#' clause expressions used to split the column `partitionColumn` evenly. +#' This defaults to SparkContext.defaultParallelism when unset. +#' @param predicates a list of conditions in the where clause; each one defines one partition +#' @return DataFrame +#' @rdname read.jdbc +#' @name read.jdbc +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' jdbcUrl <- "jdbc:mysql://localhost:3306/databasename" +#' df <- read.jdbc(sqlContext, jdbcUrl, "table", predicates = list("field<=123"), user = "username") +#' df2 <- read.jdbc(sqlContext, jdbcUrl, "table2", partitionColumn = "index", lowerBound = 0, +#' upperBound = 10000, user = "username", password = "password") +#' } + +read.jdbc <- function(sqlContext, url, tableName, + partitionColumn = NULL, lowerBound = NULL, upperBound = NULL, + numPartitions = 0L, predicates = list(), ...) { + jprops <- varargsToJProperties(...) + + read <- callJMethod(sqlContext, "read") + if (!is.null(partitionColumn)) { + if (is.null(numPartitions) || numPartitions == 0) { + sc <- callJMethod(sqlContext, "sparkContext") + numPartitions <- callJMethod(sc, "defaultParallelism") + } else { + numPartitions <- numToInt(numPartitions) + } + sdf <- callJMethod(read, "jdbc", url, tableName, as.character(partitionColumn), + numToInt(lowerBound), numToInt(upperBound), numPartitions, jprops) + } else if (length(predicates) > 0) { + sdf <- callJMethod(read, "jdbc", url, tableName, as.list(as.character(predicates)), jprops) + } else { + sdf <- callJMethod(read, "jdbc", url, tableName, jprops) + } + dataFrame(sdf) +} diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index db877b2d63d30..54234b0455eab 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -994,7 +994,7 @@ setMethod("rint", #' round #' -#' Returns the value of the column `e` rounded to 0 decimal places. +#' Returns the value of the column `e` rounded to 0 decimal places using HALF_UP rounding mode. #' #' @rdname round #' @name round @@ -1008,6 +1008,26 @@ setMethod("round", column(jc) }) +#' bround +#' +#' Returns the value of the column `e` rounded to `scale` decimal places using HALF_EVEN rounding +#' mode if `scale` >= 0 or at integral part when `scale` < 0. +#' Also known as Gaussian rounding or bankers' rounding that rounds to the nearest even number. +#' bround(2.5, 0) = 2, bround(3.5, 0) = 4. +#' +#' @rdname bround +#' @name bround +#' @family math_funcs +#' @export +#' @examples \dontrun{bround(df$c, 0)} +setMethod("bround", + signature(x = "Column"), + function(x, scale = 0) { + jc <- callJStatic("org.apache.spark.sql.functions", "bround", x@jc, as.integer(scale)) + column(jc) + }) + + #' rtrim #' #' Trim the spaces from right end for the specified string value. diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ecdeea5ec4912..6b67258d77e6c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -397,7 +397,10 @@ setGeneric("arrange", function(x, col, ...) { standardGeneric("arrange") }) #' @rdname as.data.frame #' @export -setGeneric("as.data.frame") +setGeneric("as.data.frame", + function(x, row.names = NULL, optional = FALSE, ...) { + standardGeneric("as.data.frame") + }) #' @rdname attach #' @export @@ -577,6 +580,12 @@ setGeneric("saveDF", function(df, path, source = NULL, mode = "error", ...) { standardGeneric("saveDF") }) +#' @rdname write.jdbc +#' @export +setGeneric("write.jdbc", function(x, url, tableName, mode = "error", ...) { + standardGeneric("write.jdbc") +}) + #' @rdname write.json #' @export setGeneric("write.json", function(x, path) { standardGeneric("write.json") }) @@ -751,6 +760,10 @@ setGeneric("bin", function(x) { standardGeneric("bin") }) #' @export setGeneric("bitwiseNOT", function(x) { standardGeneric("bitwiseNOT") }) +#' @rdname bround +#' @export +setGeneric("bround", function(x, ...) { standardGeneric("bround") }) + #' @rdname cbrt #' @export setGeneric("cbrt", function(x) { standardGeneric("cbrt") }) diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 31bca16580451..922a9b13dbfe6 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -101,12 +101,55 @@ setMethod("summary", signature(object = "GeneralizedLinearRegressionModel"), jobj <- object@jobj features <- callJMethod(jobj, "rFeatures") coefficients <- callJMethod(jobj, "rCoefficients") - coefficients <- as.matrix(unlist(coefficients)) - colnames(coefficients) <- c("Estimate") + deviance.resid <- callJMethod(jobj, "rDevianceResiduals") + dispersion <- callJMethod(jobj, "rDispersion") + null.deviance <- callJMethod(jobj, "rNullDeviance") + deviance <- callJMethod(jobj, "rDeviance") + df.null <- callJMethod(jobj, "rResidualDegreeOfFreedomNull") + df.residual <- callJMethod(jobj, "rResidualDegreeOfFreedom") + aic <- callJMethod(jobj, "rAic") + iter <- callJMethod(jobj, "rNumIterations") + family <- callJMethod(jobj, "rFamily") + + deviance.resid <- dataFrame(deviance.resid) + coefficients <- matrix(coefficients, ncol = 4) + colnames(coefficients) <- c("Estimate", "Std. Error", "t value", "Pr(>|t|)") rownames(coefficients) <- unlist(features) - return(list(coefficients = coefficients)) + ans <- list(deviance.resid = deviance.resid, coefficients = coefficients, + dispersion = dispersion, null.deviance = null.deviance, + deviance = deviance, df.null = df.null, df.residual = df.residual, + aic = aic, iter = iter, family = family) + class(ans) <- "summary.GeneralizedLinearRegressionModel" + return(ans) }) +#' Print the summary of GeneralizedLinearRegressionModel +#' +#' @rdname print +#' @name print.summary.GeneralizedLinearRegressionModel +#' @export +print.summary.GeneralizedLinearRegressionModel <- function(x, ...) { + x$deviance.resid <- setNames(unlist(approxQuantile(x$deviance.resid, "devianceResiduals", + c(0.0, 0.25, 0.5, 0.75, 1.0), 0.01)), c("Min", "1Q", "Median", "3Q", "Max")) + x$deviance.resid <- zapsmall(x$deviance.resid, 5L) + cat("\nDeviance Residuals: \n") + cat("(Note: These are approximate quantiles with relative error <= 0.01)\n") + print.default(x$deviance.resid, digits = 5L, na.print = "", print.gap = 2L) + + cat("\nCoefficients:\n") + print.default(x$coefficients, digits = 5L, na.print = "", print.gap = 2L) + + cat("\n(Dispersion parameter for ", x$family, " family taken to be ", format(x$dispersion), + ")\n\n", apply(cbind(paste(format(c("Null", "Residual"), justify = "right"), "deviance:"), + format(unlist(x[c("null.deviance", "deviance")]), digits = 5L), + " on", format(unlist(x[c("df.null", "df.residual")])), " degrees of freedom\n"), + 1L, paste, collapse = " "), sep = "") + cat("AIC: ", format(x$aic, digits = 4L), "\n\n", + "Number of Fisher Scoring iterations: ", x$iter, "\n", sep = "") + cat("\n") + invisible(x) + } + #' Make predictions from a generalized linear model #' #' Makes predictions from a generalized linear model produced by glm(), similarly to R's predict(). diff --git a/R/pkg/R/utils.R b/R/pkg/R/utils.R index fb6575cb42907..b425ccf6e7a36 100644 --- a/R/pkg/R/utils.R +++ b/R/pkg/R/utils.R @@ -650,3 +650,14 @@ convertToJSaveMode <- function(mode) { jmode <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "saveMode", mode) jmode } + +varargsToJProperties <- function(...) { + pairs <- list(...) + props <- newJObject("java.util.Properties") + if (length(pairs) > 0) { + lapply(ls(pairs), function(k) { + callJMethod(props, "setProperty", as.character(k), as.character(pairs[[k]])) + }) + } + props +} diff --git a/R/pkg/inst/tests/testthat/test_context.R b/R/pkg/inst/tests/testthat/test_context.R index 6e06c974c291f..9f51161230e1a 100644 --- a/R/pkg/inst/tests/testthat/test_context.R +++ b/R/pkg/inst/tests/testthat/test_context.R @@ -26,7 +26,7 @@ test_that("Check masked functions", { maskedBySparkR <- masked[funcSparkROrEmpty] namesOfMasked <- c("describe", "cov", "filter", "lag", "na.omit", "predict", "sd", "var", "colnames", "colnames<-", "intersect", "rank", "rbind", "sample", "subset", - "summary", "transform", "drop", "window") + "summary", "transform", "drop", "window", "as.data.frame") expect_equal(length(maskedBySparkR), length(namesOfMasked)) expect_equal(sort(maskedBySparkR), sort(namesOfMasked)) # above are those reported as masked when `library(SparkR)` diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index a9dbd2bdc4cc0..47bbf7e5bd2b5 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -77,6 +77,55 @@ test_that("glm and predict", { expect_equal(length(predict(lm(y ~ x))), 15) }) +test_that("glm summary", { + # gaussian family + training <- suppressWarnings(createDataFrame(sqlContext, iris)) + stats <- summary(glm(Sepal_Width ~ Sepal_Length + Species, data = training)) + + rStats <- summary(glm(Sepal.Width ~ Sepal.Length + Species, data = iris)) + + coefs <- unlist(stats$coefficients) + rCoefs <- unlist(rStats$coefficients) + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Species_versicolor", "Species_virginica"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + # binomial family + df <- suppressWarnings(createDataFrame(sqlContext, iris)) + training <- df[df$Species %in% c("versicolor", "virginica"), ] + stats <- summary(glm(Species ~ Sepal_Length + Sepal_Width, data = training, + family = binomial(link = "logit"))) + + rTraining <- iris[iris$Species %in% c("versicolor", "virginica"), ] + rStats <- summary(glm(Species ~ Sepal.Length + Sepal.Width, data = rTraining, + family = binomial(link = "logit"))) + + coefs <- unlist(stats$coefficients) + rCoefs <- unlist(rStats$coefficients) + expect_true(all(abs(rCoefs - coefs) < 1e-4)) + expect_true(all( + rownames(stats$coefficients) == + c("(Intercept)", "Sepal_Length", "Sepal_Width"))) + expect_equal(stats$dispersion, rStats$dispersion) + expect_equal(stats$null.deviance, rStats$null.deviance) + expect_equal(stats$deviance, rStats$deviance) + expect_equal(stats$df.null, rStats$df.null) + expect_equal(stats$df.residual, rStats$df.residual) + expect_equal(stats$aic, rStats$aic) + + # Test summary works on base GLM models + baseModel <- stats::glm(Sepal.Width ~ Sepal.Length + Species, data = iris) + baseSummary <- summary(baseModel) + expect_true(abs(baseSummary$deviance - 12.19313) < 1e-4) +}) + test_that("kmeans", { newIris <- iris newIris$Species <- NULL diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index d747d4f83f24b..b923ccf6bb1ae 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1087,6 +1087,11 @@ test_that("column functions", { expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19) expect_equal(collect(select(df, last("age")))[[1]], 19) expect_equal(collect(select(df, last("age", TRUE)))[[1]], 19) + + # Test bround() + df <- createDataFrame(sqlContext, data.frame(x = c(2.5, 3.5))) + expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) + expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) }) test_that("column binary mathfunctions", { @@ -1863,6 +1868,9 @@ test_that("Method as.data.frame as a synonym for collect()", { expect_equal(as.data.frame(irisDF), collect(irisDF)) irisDF2 <- irisDF[irisDF$Species == "setosa", ] expect_equal(as.data.frame(irisDF2), collect(irisDF2)) + + # Make sure as.data.frame in the R base package is not covered + expect_that(as.data.frame(c(1, 2)), not(throws_error())) }) test_that("attach() on a DataFrame", { diff --git a/R/pkg/inst/tests/testthat/test_utils.R b/R/pkg/inst/tests/testthat/test_utils.R index 4218138f641d1..01694ab5c4f61 100644 --- a/R/pkg/inst/tests/testthat/test_utils.R +++ b/R/pkg/inst/tests/testthat/test_utils.R @@ -140,3 +140,27 @@ test_that("cleanClosure on R functions", { expect_equal(ls(env), "aBroadcast") expect_equal(get("aBroadcast", envir = env, inherits = FALSE), aBroadcast) }) + +test_that("varargsToJProperties", { + jprops <- newJObject("java.util.Properties") + expect_true(class(jprops) == "jobj") + + jprops <- varargsToJProperties(abc = "123") + expect_true(class(jprops) == "jobj") + expect_equal(callJMethod(jprops, "getProperty", "abc"), "123") + + jprops <- varargsToJProperties(abc = "abc", b = 1) + expect_equal(callJMethod(jprops, "getProperty", "abc"), "abc") + expect_equal(callJMethod(jprops, "getProperty", "b"), "1") + + jprops <- varargsToJProperties() + expect_equal(callJMethod(jprops, "size"), 0L) +}) + +test_that("convertToJSaveMode", { + s <- convertToJSaveMode("error") + expect_true(class(s) == "jobj") + expect_match(capture.output(print.jobj(s)), "Java ref type org.apache.spark.sql.SaveMode id ") + expect_error(convertToJSaveMode("foo"), + 'mode should be one of "append", "overwrite", "error", "ignore"') #nolint +}) diff --git a/bin/spark-class2.cmd b/bin/spark-class2.cmd index 579efff909535..db680218dc964 100644 --- a/bin/spark-class2.cmd +++ b/bin/spark-class2.cmd @@ -36,7 +36,7 @@ if exist "%SPARK_HOME%\RELEASE" ( ) if not exist "%SPARK_JARS_DIR%"\ ( - echo Failed to find Spark assembly JAR. + echo Failed to find Spark jars directory. echo You need to build Spark before running this program. exit /b 1 ) diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java index ce5c68e85375e..30712012669bd 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolver.java @@ -49,7 +49,7 @@ * Manages converting shuffle BlockIds into physical segments of local files, from a process outside * of Executors. Each Executor must register its own configuration about where it stores its files * (local dirs) and how (shuffle manager). The logic for retrieval of individual files is replicated - * from Spark's FileShuffleBlockResolver and IndexShuffleBlockResolver. + * from Spark's IndexShuffleBlockResolver. */ public class ExternalShuffleBlockResolver { private static final Logger logger = LoggerFactory.getLogger(ExternalShuffleBlockResolver.class); @@ -185,8 +185,6 @@ public ManagedBuffer getBlockData(String appId, String execId, String blockId) { if ("sort".equals(executor.shuffleManager) || "tungsten-sort".equals(executor.shuffleManager)) { return getSortBasedShuffleBlockData(executor, shuffleId, mapId, reduceId); - } else if ("hash".equals(executor.shuffleManager)) { - return getHashBasedShuffleBlockData(executor, blockId); } else { throw new UnsupportedOperationException( "Unsupported shuffle manager: " + executor.shuffleManager); @@ -250,15 +248,6 @@ private void deleteExecutorDirs(String[] dirs) { } } - /** - * Hash-based shuffle data is simply stored as one file per block. - * This logic is from FileShuffleBlockResolver. - */ - private ManagedBuffer getHashBasedShuffleBlockData(ExecutorShuffleInfo executor, String blockId) { - File shuffleFile = getFile(executor.localDirs, executor.subDirsPerLocalDir, blockId); - return new FileSegmentManagedBuffer(conf, shuffleFile, 0, shuffleFile.length()); - } - /** * Sort-based shuffle data uses an index called "shuffle_ShuffleId_MapId_0.index" into a data file * called "shuffle_ShuffleId_MapId_0.data". This logic is from IndexShuffleBlockResolver, diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java index 102d4efb8bf3b..93758bdc58fb0 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/ExecutorShuffleInfo.java @@ -33,7 +33,7 @@ public class ExecutorShuffleInfo implements Encodable { public final String[] localDirs; /** Number of subdirectories created within each localDir. */ public final int subDirsPerLocalDir; - /** Shuffle manager (SortShuffleManager or HashShuffleManager) that the executor is using. */ + /** Shuffle manager (SortShuffleManager) that the executor is using. */ public final String shuffleManager; @JsonCreator diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java index d9b5f0261aaba..de4840a5880c2 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleBlockResolverSuite.java @@ -38,9 +38,6 @@ public class ExternalShuffleBlockResolverSuite { private static final String sortBlock0 = "Hello!"; private static final String sortBlock1 = "World!"; - private static final String hashBlock0 = "Elementary"; - private static final String hashBlock1 = "Tabular"; - private static TestShuffleDataContext dataContext; private static final TransportConf conf = @@ -51,13 +48,10 @@ public static void beforeAll() throws IOException { dataContext = new TestShuffleDataContext(2, 5); dataContext.create(); - // Write some sort and hash data. + // Write some sort data. dataContext.insertSortShuffleData(0, 0, new byte[][] { sortBlock0.getBytes(StandardCharsets.UTF_8), sortBlock1.getBytes(StandardCharsets.UTF_8)}); - dataContext.insertHashShuffleData(1, 0, new byte[][] { - hashBlock0.getBytes(StandardCharsets.UTF_8), - hashBlock1.getBytes(StandardCharsets.UTF_8)}); } @AfterClass @@ -117,27 +111,6 @@ public void testSortShuffleBlocks() throws IOException { assertEquals(sortBlock1, block1); } - @Test - public void testHashShuffleBlocks() throws IOException { - ExternalShuffleBlockResolver resolver = new ExternalShuffleBlockResolver(conf, null); - resolver.registerExecutor("app0", "exec0", - dataContext.createExecutorInfo("hash")); - - InputStream block0Stream = - resolver.getBlockData("app0", "exec0", "shuffle_1_0_0").createInputStream(); - String block0 = CharStreams.toString( - new InputStreamReader(block0Stream, StandardCharsets.UTF_8)); - block0Stream.close(); - assertEquals(hashBlock0, block0); - - InputStream block1Stream = - resolver.getBlockData("app0", "exec0", "shuffle_1_0_1").createInputStream(); - String block1 = CharStreams.toString( - new InputStreamReader(block1Stream, StandardCharsets.UTF_8)); - block1Stream.close(); - assertEquals(hashBlock1, block1); - } - @Test public void jsonSerializationOfExecutorRegistration() throws IOException { ObjectMapper mapper = new ObjectMapper(); @@ -147,7 +120,7 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { assertEquals(parsedAppId, appId); ExecutorShuffleInfo shuffleInfo = - new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "hash"); + new ExecutorShuffleInfo(new String[]{"/bippy", "/flippy"}, 7, "sort"); String shuffleJson = mapper.writeValueAsString(shuffleInfo); ExecutorShuffleInfo parsedShuffleInfo = mapper.readValue(shuffleJson, ExecutorShuffleInfo.class); @@ -158,7 +131,7 @@ public void jsonSerializationOfExecutorRegistration() throws IOException { String legacyAppIdJson = "{\"appId\":\"foo\", \"execId\":\"bar\"}"; assertEquals(appId, mapper.readValue(legacyAppIdJson, AppExecId.class)); String legacyShuffleJson = "{\"localDirs\": [\"/bippy\", \"/flippy\"], " + - "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"hash\"}"; + "\"subDirsPerLocalDir\": 7, \"shuffleManager\": \"sort\"}"; assertEquals(shuffleInfo, mapper.readValue(legacyShuffleJson, ExecutorShuffleInfo.class)); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java index 43d0201405872..fa5cd1398aa0f 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleCleanupSuite.java @@ -144,9 +144,6 @@ private static TestShuffleDataContext createSomeData() throws IOException { dataContext.insertSortShuffleData(rand.nextInt(1000), rand.nextInt(1000), new byte[][] { "ABC".getBytes(StandardCharsets.UTF_8), "DEF".getBytes(StandardCharsets.UTF_8)}); - dataContext.insertHashShuffleData(rand.nextInt(1000), rand.nextInt(1000) + 1000, new byte[][] { - "GHI".getBytes(StandardCharsets.UTF_8), - "JKLMNOPQRSTUVWXYZ".getBytes(StandardCharsets.UTF_8)}); return dataContext; } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java index ecbbe7bfa3b11..067c815c30a51 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalShuffleIntegrationSuite.java @@ -50,12 +50,9 @@ public class ExternalShuffleIntegrationSuite { static String APP_ID = "app-id"; static String SORT_MANAGER = "sort"; - static String HASH_MANAGER = "hash"; // Executor 0 is sort-based static TestShuffleDataContext dataContext0; - // Executor 1 is hash-based - static TestShuffleDataContext dataContext1; static ExternalShuffleBlockHandler handler; static TransportServer server; @@ -87,10 +84,6 @@ public static void beforeAll() throws IOException { dataContext0.create(); dataContext0.insertSortShuffleData(0, 0, exec0Blocks); - dataContext1 = new TestShuffleDataContext(6, 2); - dataContext1.create(); - dataContext1.insertHashShuffleData(1, 0, exec1Blocks); - conf = new TransportConf("shuffle", new SystemPropertyConfigProvider()); handler = new ExternalShuffleBlockHandler(conf, null); TransportContext transportContext = new TransportContext(conf, handler); @@ -100,7 +93,6 @@ public static void beforeAll() throws IOException { @AfterClass public static void afterAll() { dataContext0.cleanup(); - dataContext1.cleanup(); server.close(); } @@ -192,40 +184,18 @@ public void testFetchThreeSort() throws Exception { exec0Fetch.releaseBuffers(); } - @Test - public void testFetchHash() throws Exception { - registerExecutor("exec-1", dataContext1.createExecutorInfo(HASH_MANAGER)); - FetchResult execFetch = fetchBlocks("exec-1", - new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.successBlocks); - assertTrue(execFetch.failedBlocks.isEmpty()); - assertBufferListsEqual(execFetch.buffers, Lists.newArrayList(exec1Blocks)); - execFetch.releaseBuffers(); - } - - @Test - public void testFetchWrongShuffle() throws Exception { - registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); - FetchResult execFetch = fetchBlocks("exec-1", - new String[] { "shuffle_1_0_0", "shuffle_1_0_1" }); - assertTrue(execFetch.successBlocks.isEmpty()); - assertEquals(Sets.newHashSet("shuffle_1_0_0", "shuffle_1_0_1"), execFetch.failedBlocks); - } - @Test public void testFetchInvalidShuffle() throws Exception { - registerExecutor("exec-1", dataContext1.createExecutorInfo("unknown sort manager")); - FetchResult execFetch = fetchBlocks("exec-1", - new String[] { "shuffle_1_0_0" }); + registerExecutor("exec-1", dataContext0.createExecutorInfo("unknown sort manager")); + FetchResult execFetch = fetchBlocks("exec-1", new String[] { "shuffle_1_0_0" }); assertTrue(execFetch.successBlocks.isEmpty()); assertEquals(Sets.newHashSet("shuffle_1_0_0"), execFetch.failedBlocks); } @Test public void testFetchWrongBlockId() throws Exception { - registerExecutor("exec-1", dataContext1.createExecutorInfo(SORT_MANAGER /* wrong manager */)); - FetchResult execFetch = fetchBlocks("exec-1", - new String[] { "rdd_1_0_0" }); + registerExecutor("exec-1", dataContext0.createExecutorInfo(SORT_MANAGER)); + FetchResult execFetch = fetchBlocks("exec-1", new String[] { "rdd_1_0_0" }); assertTrue(execFetch.successBlocks.isEmpty()); assertEquals(Sets.newHashSet("rdd_1_0_0"), execFetch.failedBlocks); } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java index 7ac1ca128aed0..62a1fb42b023d 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/TestShuffleDataContext.java @@ -29,7 +29,7 @@ import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; /** - * Manages some sort- and hash-based shuffle data, including the creation + * Manages some sort-shuffle data, including the creation * and cleanup of directories that can be read by the {@link ExternalShuffleBlockResolver}. */ public class TestShuffleDataContext { @@ -85,15 +85,6 @@ public void insertSortShuffleData(int shuffleId, int mapId, byte[][] blocks) thr } } - /** Creates reducer blocks in a hash-based data format within our local dirs. */ - public void insertHashShuffleData(int shuffleId, int mapId, byte[][] blocks) throws IOException { - for (int i = 0; i < blocks.length; i ++) { - String blockId = "shuffle_" + shuffleId + "_" + mapId + "_" + i; - Files.write(blocks[i], - ExternalShuffleBlockResolver.getFile(localDirs, subDirsPerLocalDir, blockId)); - } - } - /** * Creates an ExecutorShuffleInfo object based on the given shuffle manager which targets this * context's directories. diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 7a60c3eb35740..0e9defe5b4a51 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -114,7 +114,7 @@ final class BypassMergeSortShuffleWriter extends ShuffleWriter { this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); - this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics(); + this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); this.serializer = dep.serializer(); this.shuffleBlockResolver = shuffleBlockResolver; } diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index 0c5fb883a8326..daa63d47e6aed 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -118,7 +118,7 @@ public UnsafeShuffleWriter( this.shuffleId = dep.shuffleId(); this.serializer = dep.serializer().newInstance(); this.partitioner = dep.partitioner(); - this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics(); + this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); this.taskContext = taskContext; this.sparkConf = sparkConf; this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", true); diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java index ef79b49083479..3e32dd9d63e31 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java @@ -129,7 +129,7 @@ private UnsafeExternalSorter( // Use getSizeAsKb (not bytes) to maintain backwards compatibility for units // this.fileBufferSizeBytes = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", "32k") * 1024; this.fileBufferSizeBytes = 32 * 1024; - this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics(); + this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics(); if (existingInMemorySorter == null) { this.inMemSorter = new UnsafeInMemorySorter( diff --git a/core/src/main/resources/org/apache/spark/ui/static/log-view.js b/core/src/main/resources/org/apache/spark/ui/static/log-view.js new file mode 100644 index 0000000000000..1782b4f209c09 --- /dev/null +++ b/core/src/main/resources/org/apache/spark/ui/static/log-view.js @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +var baseParams; + +var curLogLength; +var startByte; +var endByte; +var totalLogLength; + +var byteLength; + +function setLogScroll(oldHeight) { + var logContent = $(".log-content"); + logContent.scrollTop(logContent[0].scrollHeight - oldHeight); +} + +function tailLog() { + var logContent = $(".log-content"); + logContent.scrollTop(logContent[0].scrollHeight); +} + +function setLogData() { + $('#log-data').html("Showing " + curLogLength + " Bytes: " + startByte + + " - " + endByte + " of " + totalLogLength); +} + +function disableMoreButton() { + var moreBtn = $(".log-more-btn"); + moreBtn.attr("disabled", "disabled"); + moreBtn.html("Top of Log"); +} + +function noNewAlert() { + var alert = $(".no-new-alert"); + alert.css("display", "block"); + window.setTimeout(function () {alert.css("display", "none");}, 4000); +} + +function loadMore() { + var offset = Math.max(startByte - byteLength, 0); + var moreByteLength = Math.min(byteLength, startByte); + + $.ajax({ + type: "GET", + url: "/log" + baseParams + "&offset=" + offset + "&byteLength=" + moreByteLength, + success: function (data) { + var oldHeight = $(".log-content")[0].scrollHeight; + var newlineIndex = data.indexOf('\n'); + var dataInfo = data.substring(0, newlineIndex).match(/\d+/g); + var retStartByte = dataInfo[0]; + var retLogLength = dataInfo[2]; + + var cleanData = data.substring(newlineIndex + 1); + if (retStartByte == 0) { + disableMoreButton(); + } + $("pre", ".log-content").prepend(cleanData); + + curLogLength = curLogLength + (startByte - retStartByte); + startByte = retStartByte; + totalLogLength = retLogLength; + setLogScroll(oldHeight); + setLogData(); + } + }); +} + +function loadNew() { + $.ajax({ + type: "GET", + url: "/log" + baseParams + "&byteLength=0", + success: function (data) { + var dataInfo = data.substring(0, data.indexOf('\n')).match(/\d+/g); + var newDataLen = dataInfo[2] - totalLogLength; + if (newDataLen != 0) { + $.ajax({ + type: "GET", + url: "/log" + baseParams + "&byteLength=" + newDataLen, + success: function (data) { + var newlineIndex = data.indexOf('\n'); + var dataInfo = data.substring(0, newlineIndex).match(/\d+/g); + var retStartByte = dataInfo[0]; + var retEndByte = dataInfo[1]; + var retLogLength = dataInfo[2]; + + var cleanData = data.substring(newlineIndex + 1); + $("pre", ".log-content").append(cleanData); + + curLogLength = curLogLength + (retEndByte - retStartByte); + endByte = retEndByte; + totalLogLength = retLogLength; + tailLog(); + setLogData(); + } + }); + } else { + noNewAlert(); + } + } + }); +} + +function initLogPage(params, logLen, start, end, totLogLen, defaultLen) { + baseParams = params; + curLogLength = logLen; + startByte = start; + endByte = end; + totalLogLength = totLogLen; + byteLength = defaultLen; + tailLog(); + if (startByte == 0) { + disableMoreButton(); + } +} \ No newline at end of file diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css index 47dd9162a1bfa..595e80ab5e3ad 100644 --- a/core/src/main/resources/org/apache/spark/ui/static/webui.css +++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css @@ -237,3 +237,13 @@ a.expandbutton { color: #333; text-decoration: none; } + +.log-more-btn, .log-new-btn { + width: 100% +} + +.no-new-alert { + text-align: center; + margin: 0; + padding: 4px 0; +} \ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 339266a5d48b2..a50600f1488c9 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -28,6 +28,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaFutureAction import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.JobWaiter +import org.apache.spark.util.ThreadUtils /** @@ -45,6 +46,7 @@ trait FutureAction[T] extends Future[T] { /** * Blocks until this action completes. + * * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf * for unbounded waiting, or a finite positive duration * @return this FutureAction @@ -53,6 +55,7 @@ trait FutureAction[T] extends Future[T] { /** * Awaits and returns the result (of type T) of this action. + * * @param atMost maximum wait time, which may be negative (no waiting is done), Duration.Inf * for unbounded waiting, or a finite positive duration * @throws Exception exception during action execution @@ -89,8 +92,8 @@ trait FutureAction[T] extends Future[T] { /** * Blocks and returns the result of this job. */ - @throws(classOf[Exception]) - def get(): T = Await.result(this, Duration.Inf) + @throws(classOf[SparkException]) + def get(): T = ThreadUtils.awaitResult(this, Duration.Inf) /** * Returns the job IDs run by the underlying async operation. diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala deleted file mode 100644 index 982b6d6b61732..0000000000000 --- a/core/src/main/scala/org/apache/spark/HttpServer.scala +++ /dev/null @@ -1,181 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.File - -import org.eclipse.jetty.security.{ConstraintMapping, ConstraintSecurityHandler, HashLoginService} -import org.eclipse.jetty.security.authentication.DigestAuthenticator -import org.eclipse.jetty.server.Server -import org.eclipse.jetty.server.bio.SocketConnector -import org.eclipse.jetty.server.ssl.SslSocketConnector -import org.eclipse.jetty.servlet.{DefaultServlet, ServletContextHandler, ServletHolder} -import org.eclipse.jetty.util.component.LifeCycle -import org.eclipse.jetty.util.security.{Constraint, Password} -import org.eclipse.jetty.util.thread.QueuedThreadPool - -import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils - -/** - * Exception type thrown by HttpServer when it is in the wrong state for an operation. - */ -private[spark] class ServerStateException(message: String) extends Exception(message) - -/** - * An HTTP server for static content used to allow worker nodes to access JARs added to SparkContext - * as well as classes created by the interpreter when the user types in code. This is just a wrapper - * around a Jetty server. - */ -private[spark] class HttpServer( - conf: SparkConf, - resourceBase: File, - securityManager: SecurityManager, - requestedPort: Int = 0, - serverName: String = "HTTP server") - extends Logging { - - private var server: Server = null - private var port: Int = requestedPort - private val servlets = { - val handler = new ServletContextHandler() - handler.setContextPath("/") - handler - } - - def start() { - if (server != null) { - throw new ServerStateException("Server is already started") - } else { - logInfo("Starting HTTP Server") - val (actualServer, actualPort) = - Utils.startServiceOnPort[Server](requestedPort, doStart, conf, serverName) - server = actualServer - port = actualPort - } - } - - def addDirectory(contextPath: String, resourceBase: String): Unit = { - val holder = new ServletHolder() - holder.setInitParameter("resourceBase", resourceBase) - holder.setInitParameter("pathInfoOnly", "true") - holder.setServlet(new DefaultServlet()) - servlets.addServlet(holder, contextPath.stripSuffix("/") + "/*") - } - - /** - * Actually start the HTTP server on the given port. - * - * Note that this is only best effort in the sense that we may end up binding to a nearby port - * in the event of port collision. Return the bound server and the actual port used. - */ - private def doStart(startPort: Int): (Server, Int) = { - val server = new Server() - - val connector = securityManager.fileServerSSLOptions.createJettySslContextFactory() - .map(new SslSocketConnector(_)).getOrElse(new SocketConnector) - - connector.setMaxIdleTime(60 * 1000) - connector.setSoLingerTime(-1) - connector.setPort(startPort) - server.addConnector(connector) - - val threadPool = new QueuedThreadPool - threadPool.setDaemon(true) - server.setThreadPool(threadPool) - addDirectory("/", resourceBase.getAbsolutePath) - - if (securityManager.isAuthenticationEnabled()) { - logDebug("HttpServer is using security") - val sh = setupSecurityHandler(securityManager) - // make sure we go through security handler to get resources - sh.setHandler(servlets) - server.setHandler(sh) - } else { - logDebug("HttpServer is not using security") - server.setHandler(servlets) - } - - server.start() - val actualPort = server.getConnectors()(0).getLocalPort - - (server, actualPort) - } - - /** - * Setup Jetty to the HashLoginService using a single user with our - * shared secret. Configure it to use DIGEST-MD5 authentication so that the password - * isn't passed in plaintext. - */ - private def setupSecurityHandler(securityMgr: SecurityManager): ConstraintSecurityHandler = { - val constraint = new Constraint() - // use DIGEST-MD5 as the authentication mechanism - constraint.setName(Constraint.__DIGEST_AUTH) - constraint.setRoles(Array("user")) - constraint.setAuthenticate(true) - constraint.setDataConstraint(Constraint.DC_NONE) - - val cm = new ConstraintMapping() - cm.setConstraint(constraint) - cm.setPathSpec("/*") - val sh = new ConstraintSecurityHandler() - - // the hashLoginService lets us do a single user and - // secret right now. This could be changed to use the - // JAASLoginService for other options. - val hashLogin = new HashLoginService() - - val userCred = new Password(securityMgr.getSecretKey()) - if (userCred == null) { - throw new Exception("Error: secret key is null with authentication on") - } - hashLogin.putUser(securityMgr.getHttpUser(), userCred, Array("user")) - sh.setLoginService(hashLogin) - sh.setAuthenticator(new DigestAuthenticator()); - sh.setConstraintMappings(Array(cm)) - sh - } - - def stop() { - if (server == null) { - throw new ServerStateException("Server is already stopped") - } else { - server.stop() - // Stop the ThreadPool if it supports stop() method (through LifeCycle). - // It is needed because stopping the Server won't stop the ThreadPool it uses. - val threadPool = server.getThreadPool - if (threadPool != null && threadPool.isInstanceOf[LifeCycle]) { - threadPool.asInstanceOf[LifeCycle].stop - } - port = -1 - server = null - } - } - - /** - * Get the URI of this HTTP server (http://host:port or https://host:port) - */ - def uri: String = { - if (server == null) { - throw new ServerStateException("Server is not started") - } else { - val scheme = if (securityManager.fileServerSSLOptions.enabled) "https" else "http" - s"$scheme://${Utils.localHostNameForURI()}:$port" - } - } -} diff --git a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala index 0dd4ec656f5a2..0b494c146fa1b 100644 --- a/core/src/main/scala/org/apache/spark/InternalAccumulator.scala +++ b/core/src/main/scala/org/apache/spark/InternalAccumulator.scala @@ -17,17 +17,11 @@ package org.apache.spark -import org.apache.spark.storage.{BlockId, BlockStatus} - - /** * A collection of fields and methods concerned with internal accumulators that represent * task level metrics. */ private[spark] object InternalAccumulator { - - import AccumulatorParam._ - // Prefixes used in names of internal task level metrics val METRICS_PREFIX = "internal.metrics." val SHUFFLE_READ_METRICS_PREFIX = METRICS_PREFIX + "shuffle.read." @@ -68,142 +62,15 @@ private[spark] object InternalAccumulator { // Names of output metrics object output { - val WRITE_METHOD = OUTPUT_METRICS_PREFIX + "writeMethod" val BYTES_WRITTEN = OUTPUT_METRICS_PREFIX + "bytesWritten" val RECORDS_WRITTEN = OUTPUT_METRICS_PREFIX + "recordsWritten" } // Names of input metrics object input { - val READ_METHOD = INPUT_METRICS_PREFIX + "readMethod" val BYTES_READ = INPUT_METRICS_PREFIX + "bytesRead" val RECORDS_READ = INPUT_METRICS_PREFIX + "recordsRead" } // scalastyle:on - - /** - * Create an internal [[Accumulator]] by name, which must begin with [[METRICS_PREFIX]]. - */ - def create(name: String): Accumulator[_] = { - require(name.startsWith(METRICS_PREFIX), - s"internal accumulator name must start with '$METRICS_PREFIX': $name") - getParam(name) match { - case p @ LongAccumulatorParam => newMetric[Long](0L, name, p) - case p @ IntAccumulatorParam => newMetric[Int](0, name, p) - case p @ StringAccumulatorParam => newMetric[String]("", name, p) - case p @ UpdatedBlockStatusesAccumulatorParam => - newMetric[Seq[(BlockId, BlockStatus)]](Seq(), name, p) - case p => throw new IllegalArgumentException( - s"unsupported accumulator param '${p.getClass.getSimpleName}' for metric '$name'.") - } - } - - /** - * Get the [[AccumulatorParam]] associated with the internal metric name, - * which must begin with [[METRICS_PREFIX]]. - */ - def getParam(name: String): AccumulatorParam[_] = { - require(name.startsWith(METRICS_PREFIX), - s"internal accumulator name must start with '$METRICS_PREFIX': $name") - name match { - case UPDATED_BLOCK_STATUSES => UpdatedBlockStatusesAccumulatorParam - case shuffleRead.LOCAL_BLOCKS_FETCHED => IntAccumulatorParam - case shuffleRead.REMOTE_BLOCKS_FETCHED => IntAccumulatorParam - case input.READ_METHOD => StringAccumulatorParam - case output.WRITE_METHOD => StringAccumulatorParam - case _ => LongAccumulatorParam - } - } - - /** - * Accumulators for tracking internal metrics. - */ - def createAll(): Seq[Accumulator[_]] = { - Seq[String]( - EXECUTOR_DESERIALIZE_TIME, - EXECUTOR_RUN_TIME, - RESULT_SIZE, - JVM_GC_TIME, - RESULT_SERIALIZATION_TIME, - MEMORY_BYTES_SPILLED, - DISK_BYTES_SPILLED, - PEAK_EXECUTION_MEMORY, - UPDATED_BLOCK_STATUSES).map(create) ++ - createShuffleReadAccums() ++ - createShuffleWriteAccums() ++ - createInputAccums() ++ - createOutputAccums() ++ - sys.props.get("spark.testing").map(_ => create(TEST_ACCUM)).toSeq - } - - /** - * Accumulators for tracking shuffle read metrics. - */ - def createShuffleReadAccums(): Seq[Accumulator[_]] = { - Seq[String]( - shuffleRead.REMOTE_BLOCKS_FETCHED, - shuffleRead.LOCAL_BLOCKS_FETCHED, - shuffleRead.REMOTE_BYTES_READ, - shuffleRead.LOCAL_BYTES_READ, - shuffleRead.FETCH_WAIT_TIME, - shuffleRead.RECORDS_READ).map(create) - } - - /** - * Accumulators for tracking shuffle write metrics. - */ - def createShuffleWriteAccums(): Seq[Accumulator[_]] = { - Seq[String]( - shuffleWrite.BYTES_WRITTEN, - shuffleWrite.RECORDS_WRITTEN, - shuffleWrite.WRITE_TIME).map(create) - } - - /** - * Accumulators for tracking input metrics. - */ - def createInputAccums(): Seq[Accumulator[_]] = { - Seq[String]( - input.READ_METHOD, - input.BYTES_READ, - input.RECORDS_READ).map(create) - } - - /** - * Accumulators for tracking output metrics. - */ - def createOutputAccums(): Seq[Accumulator[_]] = { - Seq[String]( - output.WRITE_METHOD, - output.BYTES_WRITTEN, - output.RECORDS_WRITTEN).map(create) - } - - /** - * Accumulators for tracking internal metrics. - * - * These accumulators are created with the stage such that all tasks in the stage will - * add to the same set of accumulators. We do this to report the distribution of accumulator - * values across all tasks within each stage. - */ - def createAll(sc: SparkContext): Seq[Accumulator[_]] = { - val accums = createAll() - accums.foreach { accum => - Accumulators.register(accum) - sc.cleaner.foreach(_.registerAccumulatorForCleanup(accum)) - } - accums - } - - /** - * Create a new accumulator representing an internal task metric. - */ - private def newMetric[T]( - initialValue: T, - name: String, - param: AccumulatorParam[T]): Accumulator[T] = { - new Accumulator[T](initialValue, param, Some(name), internal = true, countFailedValues = true) - } - } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index e41088f7c8f69..e7eabd289699c 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -20,7 +20,7 @@ package org.apache.spark import java.io._ import java.lang.reflect.Constructor import java.net.URI -import java.util.{Arrays, Properties, UUID} +import java.util.{Arrays, Properties, ServiceLoader, UUID} import java.util.concurrent.ConcurrentMap import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} @@ -2453,9 +2453,32 @@ object SparkContext extends Logging { "in the form mesos://zk://host:port. Current Master URL will stop working in Spark 2.0.") createTaskScheduler(sc, "mesos://" + zkUrl, deployMode) - case _ => - throw new SparkException("Could not parse Master URL: '" + master + "'") + case masterUrl => + val cm = getClusterManager(masterUrl) match { + case Some(clusterMgr) => clusterMgr + case None => throw new SparkException("Could not parse Master URL: '" + master + "'") + } + try { + val scheduler = cm.createTaskScheduler(sc, masterUrl) + val backend = cm.createSchedulerBackend(sc, masterUrl, scheduler) + cm.initialize(scheduler, backend) + (backend, scheduler) + } catch { + case NonFatal(e) => + throw new SparkException("External scheduler cannot be instantiated", e) + } + } + } + + private def getClusterManager(url: String): Option[ExternalClusterManager] = { + val loader = Utils.getContextOrSparkClassLoader + val serviceLoaders = + ServiceLoader.load(classOf[ExternalClusterManager], loader).asScala.filter(_.canCreate(url)) + if (serviceLoaders.size > 1) { + throw new SparkException(s"Multiple Cluster Managers ($serviceLoaders) registered " + + s"for the url $url:") } + serviceLoaders.headOption } } diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 3d11db7461c05..27497e21b829d 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -298,9 +298,8 @@ object SparkEnv extends Logging { // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( - "hash" -> "org.apache.spark.shuffle.hash.HashShuffleManager", - "sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager", - "tungsten-sort" -> "org.apache.spark.shuffle.sort.SortShuffleManager") + "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName, + "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName) val shuffleMgrName = conf.get("spark.shuffle.manager", "sort") val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase, shuffleMgrName) val shuffleManager = instantiateClass[ShuffleManager](shuffleMgrClass) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 757c1b5116f3c..e7940bd9eddcd 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -62,12 +62,11 @@ object TaskContext { protected[spark] def unset(): Unit = taskContext.remove() /** - * An empty task context that does not represent an actual task. + * An empty task context that does not represent an actual task. This is only used in tests. */ private[spark] def empty(): TaskContextImpl = { new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) } - } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index fa0b2d3d28293..e8f83c6d14b37 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -36,15 +36,10 @@ private[spark] class TaskContextImpl( override val taskMemoryManager: TaskMemoryManager, localProperties: Properties, @transient private val metricsSystem: MetricsSystem, - initialAccumulators: Seq[Accumulator[_]] = InternalAccumulator.createAll()) + override val taskMetrics: TaskMetrics = new TaskMetrics) extends TaskContext with Logging { - /** - * Metrics associated with this task. - */ - override val taskMetrics: TaskMetrics = new TaskMetrics(initialAccumulators) - /** List of callback functions to execute when the task completes. */ @transient private val onCompleteCallbacks = new ArrayBuffer[TaskCompletionListener] diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index 83af226bfd6f1..7487cfe9c5509 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -149,9 +149,7 @@ case class ExceptionFailure( this(e, accumUpdates, preserveCause = true) } - def exception: Option[Throwable] = exceptionWrapper.flatMap { - (w: ThrowableSerializationWrapper) => Option(w.exception) - } + def exception: Option[Throwable] = exceptionWrapper.flatMap(w => Option(w.exception)) override def toErrorString: String = if (fullStackTrace == null) { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 4212027122544..6f3b8faf03b04 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -105,7 +105,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to all elements of this RDD. */ def mapToDouble[R](f: DoubleFunction[T]): JavaDoubleRDD = { - new JavaDoubleRDD(rdd.map(x => f.call(x).doubleValue())) + new JavaDoubleRDD(rdd.map(f.call(_).doubleValue())) } /** @@ -131,7 +131,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { */ def flatMapToDouble(f: DoubleFlatMapFunction[T]): JavaDoubleRDD = { def fn: (T) => Iterator[jl.Double] = (x: T) => f.call(x).asScala - new JavaDoubleRDD(rdd.flatMap(fn).map((x: jl.Double) => x.doubleValue())) + new JavaDoubleRDD(rdd.flatMap(fn).map(_.doubleValue())) } /** @@ -173,7 +173,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def fn: (Iterator[T]) => Iterator[jl.Double] = { (x: Iterator[T]) => f.call(x.asJava).asScala } - new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: jl.Double) => x.doubleValue())) + new JavaDoubleRDD(rdd.mapPartitions(fn).map(_.doubleValue())) } /** @@ -196,7 +196,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { (x: Iterator[T]) => f.call(x.asJava).asScala } new JavaDoubleRDD(rdd.mapPartitions(fn, preservesPartitioning) - .map(x => x.doubleValue())) + .map(_.doubleValue())) } /** @@ -215,7 +215,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Applies a function f to each partition of this RDD. */ def foreachPartition(f: VoidFunction[java.util.Iterator[T]]) { - rdd.foreachPartition((x => f.call(x.asJava))) + rdd.foreachPartition(x => f.call(x.asJava)) } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala index 48df5bedd6e41..8e4e80a24acee 100644 --- a/core/src/main/scala/org/apache/spark/api/r/SerDe.scala +++ b/core/src/main/scala/org/apache/spark/api/r/SerDe.scala @@ -356,6 +356,13 @@ private[spark] object SerDe { writeInt(dos, v.length) v.foreach(elem => writeObject(dos, elem)) + // Handle Properties + // This must be above the case java.util.Map below. + // (Properties implements Map and will be serialized as map otherwise) + case v: java.util.Properties => + writeType(dos, "jobj") + writeJObj(dos, value) + // Handle map case v: java.util.Map[_, _] => writeType(dos, "map") diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index abb98f95a1ee8..79f4d06c8460e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -23,7 +23,7 @@ import java.nio.charset.StandardCharsets import java.util.concurrent.TimeoutException import scala.collection.mutable.ListBuffer -import scala.concurrent.{Await, Future, Promise} +import scala.concurrent.{Future, Promise} import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.language.postfixOps @@ -35,7 +35,7 @@ import org.json4s.jackson.JsonMethods import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.deploy.master.RecoveryState import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * This suite tests the fault tolerance of the Spark standalone scheduler, mainly the Master. @@ -265,7 +265,7 @@ private object FaultToleranceTest extends App with Logging { } // Avoid waiting indefinitely (e.g., we could register but get no executors). - assertTrue(Await.result(f, 120 seconds)) + assertTrue(ThreadUtils.awaitResult(f, 120 seconds)) } /** @@ -318,7 +318,7 @@ private object FaultToleranceTest extends App with Logging { } try { - assertTrue(Await.result(f, 120 seconds)) + assertTrue(ThreadUtils.awaitResult(f, 120 seconds)) } catch { case e: TimeoutException => logError("Master states: " + masters.map(_.state)) @@ -422,7 +422,7 @@ private object SparkDocker { } dockerCmd.run(ProcessLogger(findIpAndLog _)) - val ip = Await.result(ipPromise.future, 30 seconds) + val ip = ThreadUtils.awaitResult(ipPromise.future, 30 seconds) val dockerId = Docker.getLastProcessId (ip, dockerId, outFile) } diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index c0a9e3f280ba1..6227a30dc949c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -62,7 +62,7 @@ object PythonRunner { // ready to serve connections. thread.join() - // Build up a PYTHONPATH that includes the Spark assembly JAR (where this class is), the + // Build up a PYTHONPATH that includes the Spark assembly (where this class is), the // python directories in SPARK_HOME (if set), and any files in the pyFiles argument val pathElements = new ArrayBuffer[String] pathElements ++= formattedPyFiles diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index ec6d48485f110..78da1b70c54a5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -478,7 +478,8 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S val command = sys.env.get("_SPARK_CMD_USAGE").getOrElse( """Usage: spark-submit [options] [app arguments] |Usage: spark-submit --kill [submission ID] --master [spark://...] - |Usage: spark-submit --status [submission ID] --master [spark://...]""".stripMargin) + |Usage: spark-submit --status [submission ID] --master [spark://...] + |Usage: spark-submit run-example [options] example-class [example args]""".stripMargin) outStream.println(command) val mem_mb = Utils.DEFAULT_DRIVER_MEM_MB diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala index d5afb33c7118a..2bd4a46e16fc9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala @@ -353,7 +353,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) * the name of the file being compressed. */ def zipFileToStream(file: Path, entryName: String, outputStream: ZipOutputStream): Unit = { - val fs = FileSystem.get(hadoopConf) + val fs = file.getFileSystem(hadoopConf) val inputStream = fs.open(file, 1 * 1024 * 1024) // 1MB Buffer try { outputStream.putNextEntry(new ZipEntry(entryName)) @@ -372,7 +372,7 @@ private[history] class FsHistoryProvider(conf: SparkConf, clock: Clock) attempt.attemptId.isEmpty || attemptId.isEmpty || attempt.attemptId.get == attemptId.get }.foreach { attempt => val logPath = new Path(logDir, attempt.logPath) - zipFileToStream(new Path(logDir, attempt.logPath), attempt.logPath, zipStream) + zipFileToStream(logPath, attempt.logPath, zipStream) } } finally { zipStream.close() diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index b443e8f0519f4..edc9be2a8a8cb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -24,7 +24,7 @@ import java.util.Date import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, TimeUnit} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import scala.language.postfixOps import scala.util.Random @@ -959,7 +959,7 @@ private[deploy] class Master( */ private[master] def rebuildSparkUI(app: ApplicationInfo): Option[SparkUI] = { val futureUI = asyncRebuildSparkUI(app) - Await.result(futureUI, Duration.Inf) + ThreadUtils.awaitResult(futureUI, Duration.Inf) } /** Rebuild a new SparkUI asynchronously to not block RPC event loop */ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 1b18cf0ded69d..96274958d1422 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -35,9 +35,8 @@ private[ui] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") val state = master.askWithRetry[MasterStateResponse](RequestMasterState) - val app = state.activeApps.find(_.id == appId).getOrElse({ - state.completedApps.find(_.id == appId).getOrElse(null) - }) + val app = state.activeApps.find(_.id == appId) + .getOrElse(state.completedApps.find(_.id == appId).orNull) if (app == null) { val msg =
No running application with ID {appId}
return UIUtils.basicSparkPage(msg, "Not Found") diff --git a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala index c5a5876a896cc..21cb94142b15b 100644 --- a/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/rest/RestSubmissionClient.scala @@ -27,10 +27,11 @@ import scala.collection.mutable import scala.concurrent.{Await, Future} import scala.concurrent.duration._ import scala.io.Source +import scala.util.control.NonFatal import com.fasterxml.jackson.core.JsonProcessingException -import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf} +import org.apache.spark.{SPARK_VERSION => sparkVersion, SparkConf, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.util.Utils @@ -258,13 +259,17 @@ private[spark] class RestSubmissionClient(master: String) extends Logging { } } + // scalastyle:off awaitresult try { Await.result(responseFuture, 10.seconds) } catch { + // scalastyle:on awaitresult case unreachable @ (_: FileNotFoundException | _: SocketException) => throw new SubmitRestConnectionException("Unable to connect to server", unreachable) case malformed @ (_: JsonProcessingException | _: SubmitRestProtocolException) => throw new SubmitRestProtocolException("Malformed response received from server", malformed) case timeout: TimeoutException => throw new SubmitRestConnectionException("No response from server", timeout) + case NonFatal(t) => + throw new SparkException("Exception while waiting for response", t) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index aad2e91b25554..f4376dedea725 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -68,7 +68,10 @@ private[deploy] class DriverRunner( private var clock: Clock = new SystemClock() private var sleeper = new Sleeper { - def sleep(seconds: Int): Unit = (0 until seconds).takeWhile(f => {Thread.sleep(1000); !killed}) + def sleep(seconds: Int): Unit = (0 until seconds).takeWhile { _ => + Thread.sleep(1000) + !killed + } } /** Starts a thread to run and manage the driver. */ @@ -116,7 +119,7 @@ private[deploy] class DriverRunner( /** Terminate this driver (or prevent it from ever starting if not yet started) */ private[worker] def kill() { synchronized { - process.foreach(p => p.destroy()) + process.foreach(_.destroy()) killed = true } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala index e75c0cec4acc7..3473c41b935fd 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/LogPage.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker.ui import java.io.File import javax.servlet.http.HttpServletRequest -import scala.xml.Node +import scala.xml.{Node, Unparsed} import org.apache.spark.internal.Logging import org.apache.spark.ui.{UIUtils, WebUIPage} @@ -31,10 +31,9 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with private val worker = parent.worker private val workDir = new File(parent.workDir.toURI.normalize().getPath) private val supportedLogTypes = Set("stderr", "stdout") + private val defaultBytes = 100 * 1024 def renderLog(request: HttpServletRequest): String = { - val defaultBytes = 100 * 1024 - val appId = Option(request.getParameter("appId")) val executorId = Option(request.getParameter("executorId")) val driverId = Option(request.getParameter("driverId")) @@ -44,9 +43,9 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with val logDir = (appId, executorId, driverId) match { case (Some(a), Some(e), None) => - s"${workDir.getPath}/$appId/$executorId/" + s"${workDir.getPath}/$a/$e/" case (None, None, Some(d)) => - s"${workDir.getPath}/$driverId/" + s"${workDir.getPath}/$d/" case _ => throw new Exception("Request must specify either application or driver identifiers") } @@ -57,7 +56,6 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with } def render(request: HttpServletRequest): Seq[Node] = { - val defaultBytes = 100 * 1024 val appId = Option(request.getParameter("appId")) val executorId = Option(request.getParameter("executorId")) val driverId = Option(request.getParameter("driverId")) @@ -76,49 +74,44 @@ private[ui] class LogPage(parent: WorkerWebUI) extends WebUIPage("logPage") with val (logText, startByte, endByte, logLength) = getLog(logDir, logType, offset, byteLength) val linkToMaster =

Back to Master

- val range = Bytes {startByte.toString} - {endByte.toString} of {logLength} - - val backButton = - if (startByte > 0) { - - - - } else { - - } + val curLogLength = endByte - startByte + val range = + + Showing {curLogLength} Bytes: {startByte.toString} - {endByte.toString} of {logLength} + + + val moreButton = + + + val newButton = + + + val alert = + - val nextButton = - if (endByte < logLength) { - - - - } else { - - } + val logParams = "?%s&logType=%s".format(params, logType) + val jsOnload = "window.onload = " + + s"initLogPage('$logParams', $curLogLength, $startByte, $endByte, $logLength, $byteLength);" val content =
{linkToMaster} -
-
{backButton}
-
{range}
-
{nextButton}
-
-
-
+ {range} +
+
{moreButton}
{logText}
+ {alert} +
{newButton}
+
+ UIUtils.basicSparkPage(content, logType + " log page for " + pageName) } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 71b4ad160d679..e08729510926b 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -64,7 +64,7 @@ private[spark] class CoarseGrainedExecutorBackend( // Always receive `true`. Just ignore it case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) - System.exit(1) + exitExecutor(1) }(ThreadUtils.sameThread) } @@ -81,12 +81,12 @@ private[spark] class CoarseGrainedExecutorBackend( case RegisterExecutorFailed(message) => logError("Slave registration failed: " + message) - System.exit(1) + exitExecutor(1) case LaunchTask(data) => if (executor == null) { logError("Received LaunchTask command but executor was null") - System.exit(1) + exitExecutor(1) } else { val taskDesc = ser.deserialize[TaskDescription](data.value) logInfo("Got assigned task " + taskDesc.taskId) @@ -97,7 +97,7 @@ private[spark] class CoarseGrainedExecutorBackend( case KillTask(taskId, _, interruptThread) => if (executor == null) { logError("Received KillTask command but executor was null") - System.exit(1) + exitExecutor(1) } else { executor.killTask(taskId, interruptThread) } @@ -127,7 +127,7 @@ private[spark] class CoarseGrainedExecutorBackend( logInfo(s"Driver from $remoteAddress disconnected during shutdown") } else if (driver.exists(_.address == remoteAddress)) { logError(s"Driver $remoteAddress disassociated! Shutting down.") - System.exit(1) + exitExecutor(1) } else { logWarning(s"An unknown ($remoteAddress) driver disconnected.") } @@ -140,6 +140,13 @@ private[spark] class CoarseGrainedExecutorBackend( case None => logWarning(s"Drop $msg because has not yet connected to driver") } } + + /** + * This function can be overloaded by other child classes to handle + * executor exits differently. For e.g. when an executor goes down, + * back-end may not want to take the parent process down. + */ + protected def exitExecutor(code: Int): Unit = System.exit(code) } private[spark] object CoarseGrainedExecutorBackend extends Logging { diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 9f94fdef24ebe..650f05c309d20 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -153,6 +153,21 @@ private[spark] class Executor( } } + /** + * Function to kill the running tasks in an executor. + * This can be called by executor back-ends to kill the + * tasks instead of taking the JVM down. + * @param interruptThread whether to interrupt the task thread + */ + def killAllTasks(interruptThread: Boolean) : Unit = { + // kill all the running tasks + for (taskRunner <- runningTasks.values().asScala) { + if (taskRunner != null) { + taskRunner.kill(interruptThread) + } + } + } + def stop(): Unit = { env.metricsSystem.report() heartbeater.shutdown() @@ -278,16 +293,14 @@ private[spark] class Executor( val valueBytes = resultSer.serialize(value) val afterSerialization = System.currentTimeMillis() - for (m <- task.metrics) { - // Deserialization happens in two parts: first, we deserialize a Task object, which - // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. - m.setExecutorDeserializeTime( - (taskStart - deserializeStartTime) + task.executorDeserializeTime) - // We need to subtract Task.run()'s deserialization time to avoid double-counting - m.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) - m.setJvmGCTime(computeTotalGcTime() - startGCTime) - m.setResultSerializationTime(afterSerialization - beforeSerialization) - } + // Deserialization happens in two parts: first, we deserialize a Task object, which + // includes the Partition. Second, Task.run() deserializes the RDD and function to be run. + task.metrics.setExecutorDeserializeTime( + (taskStart - deserializeStartTime) + task.executorDeserializeTime) + // We need to subtract Task.run()'s deserialization time to avoid double-counting + task.metrics.setExecutorRunTime((taskFinish - taskStart) - task.executorDeserializeTime) + task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) + task.metrics.setResultSerializationTime(afterSerialization - beforeSerialization) // Note: accumulator updates must be collected after TaskMetrics is updated val accumUpdates = task.collectAccumulatorUpdates() @@ -342,10 +355,8 @@ private[spark] class Executor( // Collect latest accumulator values to report back to the driver val accumulatorUpdates: Seq[AccumulableInfo] = if (task != null) { - task.metrics.foreach { m => - m.setExecutorRunTime(System.currentTimeMillis() - taskStart) - m.setJvmGCTime(computeTotalGcTime() - startGCTime) - } + task.metrics.setExecutorRunTime(System.currentTimeMillis() - taskStart) + task.metrics.setJvmGCTime(computeTotalGcTime() - startGCTime) task.collectAccumulatorUpdates(taskFailed = true) } else { Seq.empty[AccumulableInfo] @@ -470,11 +481,9 @@ private[spark] class Executor( for (taskRunner <- runningTasks.values().asScala) { if (taskRunner.task != null) { - taskRunner.task.metrics.foreach { metrics => - metrics.mergeShuffleReadMetrics() - metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) - accumUpdates += ((taskRunner.taskId, metrics.accumulatorUpdates())) - } + taskRunner.task.metrics.mergeShuffleReadMetrics() + taskRunner.task.metrics.setJvmGCTime(curGCTime - taskRunner.startGCTime) + accumUpdates += ((taskRunner.taskId, taskRunner.task.metrics.accumulatorUpdates())) } } diff --git a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala index 83e11c5e236d4..535352e7dd7a1 100644 --- a/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/InputMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import org.apache.spark.{Accumulator, InternalAccumulator} +import org.apache.spark.InternalAccumulator import org.apache.spark.annotation.DeveloperApi @@ -39,32 +39,11 @@ object DataReadMethod extends Enumeration with Serializable { * A collection of accumulators that represents metrics about reading data from external systems. */ @DeveloperApi -class InputMetrics private ( - _bytesRead: Accumulator[Long], - _recordsRead: Accumulator[Long], - _readMethod: Accumulator[String]) - extends Serializable { +class InputMetrics private[spark] () extends Serializable { + import InternalAccumulator._ - private[executor] def this(accumMap: Map[String, Accumulator[_]]) { - this( - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.input.BYTES_READ), - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.input.RECORDS_READ), - TaskMetrics.getAccum[String](accumMap, InternalAccumulator.input.READ_METHOD)) - } - - /** - * Create a new [[InputMetrics]] that is not associated with any particular task. - * - * This mainly exists because of SPARK-5225, where we are forced to use a dummy [[InputMetrics]] - * because we want to ignore metrics from a second read method. In the future, we should revisit - * whether this is needed. - * - * A better alternative is [[TaskMetrics.registerInputMetrics]]. - */ - private[executor] def this() { - this(InternalAccumulator.createInputAccums() - .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]]) - } + private[executor] val _bytesRead = TaskMetrics.createLongAccum(input.BYTES_READ) + private[executor] val _recordsRead = TaskMetrics.createLongAccum(input.RECORDS_READ) /** * Total number of bytes read. @@ -76,14 +55,7 @@ class InputMetrics private ( */ def recordsRead: Long = _recordsRead.localValue - /** - * The source from which this task reads its input. - */ - def readMethod: DataReadMethod.Value = DataReadMethod.withName(_readMethod.localValue) - private[spark] def incBytesRead(v: Long): Unit = _bytesRead.add(v) private[spark] def incRecordsRead(v: Long): Unit = _recordsRead.add(v) private[spark] def setBytesRead(v: Long): Unit = _bytesRead.setValue(v) - private[spark] def setReadMethod(v: DataReadMethod.Value): Unit = _readMethod.setValue(v.toString) - } diff --git a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala index 93f953846fe26..586c98b15637b 100644 --- a/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/OutputMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import org.apache.spark.{Accumulator, InternalAccumulator} +import org.apache.spark.InternalAccumulator import org.apache.spark.annotation.DeveloperApi @@ -38,18 +38,11 @@ object DataWriteMethod extends Enumeration with Serializable { * A collection of accumulators that represents metrics about writing data to external systems. */ @DeveloperApi -class OutputMetrics private ( - _bytesWritten: Accumulator[Long], - _recordsWritten: Accumulator[Long], - _writeMethod: Accumulator[String]) - extends Serializable { +class OutputMetrics private[spark] () extends Serializable { + import InternalAccumulator._ - private[executor] def this(accumMap: Map[String, Accumulator[_]]) { - this( - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.output.BYTES_WRITTEN), - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.output.RECORDS_WRITTEN), - TaskMetrics.getAccum[String](accumMap, InternalAccumulator.output.WRITE_METHOD)) - } + private[executor] val _bytesWritten = TaskMetrics.createLongAccum(output.BYTES_WRITTEN) + private[executor] val _recordsWritten = TaskMetrics.createLongAccum(output.RECORDS_WRITTEN) /** * Total number of bytes written. @@ -61,14 +54,6 @@ class OutputMetrics private ( */ def recordsWritten: Long = _recordsWritten.localValue - /** - * The source to which this task writes its output. - */ - def writeMethod: DataWriteMethod.Value = DataWriteMethod.withName(_writeMethod.localValue) - private[spark] def setBytesWritten(v: Long): Unit = _bytesWritten.setValue(v) private[spark] def setRecordsWritten(v: Long): Unit = _recordsWritten.setValue(v) - private[spark] def setWriteMethod(v: DataWriteMethod.Value): Unit = - _writeMethod.setValue(v.toString) - } diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala index 71a24770b50ae..8e9a332b7c556 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleReadMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import org.apache.spark.{Accumulator, InternalAccumulator} +import org.apache.spark.InternalAccumulator import org.apache.spark.annotation.DeveloperApi @@ -27,38 +27,21 @@ import org.apache.spark.annotation.DeveloperApi * Operations are not thread-safe. */ @DeveloperApi -class ShuffleReadMetrics private ( - _remoteBlocksFetched: Accumulator[Int], - _localBlocksFetched: Accumulator[Int], - _remoteBytesRead: Accumulator[Long], - _localBytesRead: Accumulator[Long], - _fetchWaitTime: Accumulator[Long], - _recordsRead: Accumulator[Long]) - extends Serializable { - - private[executor] def this(accumMap: Map[String, Accumulator[_]]) { - this( - TaskMetrics.getAccum[Int](accumMap, InternalAccumulator.shuffleRead.REMOTE_BLOCKS_FETCHED), - TaskMetrics.getAccum[Int](accumMap, InternalAccumulator.shuffleRead.LOCAL_BLOCKS_FETCHED), - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.REMOTE_BYTES_READ), - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.LOCAL_BYTES_READ), - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.FETCH_WAIT_TIME), - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleRead.RECORDS_READ)) - } - - /** - * Create a new [[ShuffleReadMetrics]] that is not associated with any particular task. - * - * This mainly exists for legacy reasons, because we use dummy [[ShuffleReadMetrics]] in - * many places only to merge their values together later. In the future, we should revisit - * whether this is needed. - * - * A better alternative is [[TaskMetrics.registerTempShuffleReadMetrics]] followed by - * [[TaskMetrics.mergeShuffleReadMetrics]]. - */ - private[spark] def this() { - this(InternalAccumulator.createShuffleReadAccums().map { a => (a.name.get, a) }.toMap) - } +class ShuffleReadMetrics private[spark] () extends Serializable { + import InternalAccumulator._ + + private[executor] val _remoteBlocksFetched = + TaskMetrics.createIntAccum(shuffleRead.REMOTE_BLOCKS_FETCHED) + private[executor] val _localBlocksFetched = + TaskMetrics.createIntAccum(shuffleRead.LOCAL_BLOCKS_FETCHED) + private[executor] val _remoteBytesRead = + TaskMetrics.createLongAccum(shuffleRead.REMOTE_BYTES_READ) + private[executor] val _localBytesRead = + TaskMetrics.createLongAccum(shuffleRead.LOCAL_BYTES_READ) + private[executor] val _fetchWaitTime = + TaskMetrics.createLongAccum(shuffleRead.FETCH_WAIT_TIME) + private[executor] val _recordsRead = + TaskMetrics.createLongAccum(shuffleRead.RECORDS_READ) /** * Number of remote blocks fetched in this shuffle by this task. diff --git a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala index c7aaabb561bba..7326fba841587 100644 --- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala @@ -17,7 +17,7 @@ package org.apache.spark.executor -import org.apache.spark.{Accumulator, InternalAccumulator} +import org.apache.spark.InternalAccumulator import org.apache.spark.annotation.DeveloperApi @@ -27,31 +27,15 @@ import org.apache.spark.annotation.DeveloperApi * Operations are not thread-safe. */ @DeveloperApi -class ShuffleWriteMetrics private ( - _bytesWritten: Accumulator[Long], - _recordsWritten: Accumulator[Long], - _writeTime: Accumulator[Long]) - extends Serializable { +class ShuffleWriteMetrics private[spark] () extends Serializable { + import InternalAccumulator._ - private[executor] def this(accumMap: Map[String, Accumulator[_]]) { - this( - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.BYTES_WRITTEN), - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.RECORDS_WRITTEN), - TaskMetrics.getAccum[Long](accumMap, InternalAccumulator.shuffleWrite.WRITE_TIME)) - } - - /** - * Create a new [[ShuffleWriteMetrics]] that is not associated with any particular task. - * - * This mainly exists for legacy reasons, because we use dummy [[ShuffleWriteMetrics]] in - * many places only to merge their values together later. In the future, we should revisit - * whether this is needed. - * - * A better alternative is [[TaskMetrics.registerShuffleWriteMetrics]]. - */ - private[spark] def this() { - this(InternalAccumulator.createShuffleWriteAccums().map { a => (a.name.get, a) }.toMap) - } + private[executor] val _bytesWritten = + TaskMetrics.createLongAccum(shuffleWrite.BYTES_WRITTEN) + private[executor] val _recordsWritten = + TaskMetrics.createLongAccum(shuffleWrite.RECORDS_WRITTEN) + private[executor] val _writeTime = + TaskMetrics.createLongAccum(shuffleWrite.WRITE_TIME) /** * Number of bytes written for the shuffle by this task. diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index bda2a91d9d2ca..4558fbb4d95d8 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -17,10 +17,10 @@ package org.apache.spark.executor -import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import org.apache.spark._ +import org.apache.spark.AccumulatorParam.{IntAccumulatorParam, LongAccumulatorParam, UpdatedBlockStatusesAccumulatorParam} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.scheduler.AccumulableInfo @@ -39,57 +39,21 @@ import org.apache.spark.storage.{BlockId, BlockStatus} * The accumulator updates are also sent to the driver periodically (on executor heartbeat) * and when the task failed with an exception. The [[TaskMetrics]] object itself should never * be sent to the driver. - * - * @param initialAccums the initial set of accumulators that this [[TaskMetrics]] depends on. - * Each accumulator in this initial set must be uniquely named and marked - * as internal. Additional accumulators registered later need not satisfy - * these requirements. */ @DeveloperApi -class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Serializable { +class TaskMetrics private[spark] () extends Serializable { import InternalAccumulator._ - // Needed for Java tests - def this() { - this(InternalAccumulator.createAll()) - } - - /** - * All accumulators registered with this task. - */ - private val accums = new ArrayBuffer[Accumulable[_, _]] - accums ++= initialAccums - - /** - * A map for quickly accessing the initial set of accumulators by name. - */ - private val initialAccumsMap: Map[String, Accumulator[_]] = { - val map = new mutable.HashMap[String, Accumulator[_]] - initialAccums.foreach { a => - val name = a.name.getOrElse { - throw new IllegalArgumentException( - "initial accumulators passed to TaskMetrics must be named") - } - require(a.isInternal, - s"initial accumulator '$name' passed to TaskMetrics must be marked as internal") - require(!map.contains(name), - s"detected duplicate accumulator name '$name' when constructing TaskMetrics") - map(name) = a - } - map.toMap - } - // Each metric is internally represented as an accumulator - private val _executorDeserializeTime = getAccum(EXECUTOR_DESERIALIZE_TIME) - private val _executorRunTime = getAccum(EXECUTOR_RUN_TIME) - private val _resultSize = getAccum(RESULT_SIZE) - private val _jvmGCTime = getAccum(JVM_GC_TIME) - private val _resultSerializationTime = getAccum(RESULT_SERIALIZATION_TIME) - private val _memoryBytesSpilled = getAccum(MEMORY_BYTES_SPILLED) - private val _diskBytesSpilled = getAccum(DISK_BYTES_SPILLED) - private val _peakExecutionMemory = getAccum(PEAK_EXECUTION_MEMORY) - private val _updatedBlockStatuses = - TaskMetrics.getAccum[Seq[(BlockId, BlockStatus)]](initialAccumsMap, UPDATED_BLOCK_STATUSES) + private val _executorDeserializeTime = TaskMetrics.createLongAccum(EXECUTOR_DESERIALIZE_TIME) + private val _executorRunTime = TaskMetrics.createLongAccum(EXECUTOR_RUN_TIME) + private val _resultSize = TaskMetrics.createLongAccum(RESULT_SIZE) + private val _jvmGCTime = TaskMetrics.createLongAccum(JVM_GC_TIME) + private val _resultSerializationTime = TaskMetrics.createLongAccum(RESULT_SERIALIZATION_TIME) + private val _memoryBytesSpilled = TaskMetrics.createLongAccum(MEMORY_BYTES_SPILLED) + private val _diskBytesSpilled = TaskMetrics.createLongAccum(DISK_BYTES_SPILLED) + private val _peakExecutionMemory = TaskMetrics.createLongAccum(PEAK_EXECUTION_MEMORY) + private val _updatedBlockStatuses = TaskMetrics.createBlocksAccum(UPDATED_BLOCK_STATUSES) /** * Time taken on the executor to deserialize this task. @@ -155,91 +119,28 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se private[spark] def setUpdatedBlockStatuses(v: Seq[(BlockId, BlockStatus)]): Unit = _updatedBlockStatuses.setValue(v) - /** - * Get a Long accumulator from the given map by name, assuming it exists. - * Note: this only searches the initial set of accumulators passed into the constructor. - */ - private[spark] def getAccum(name: String): Accumulator[Long] = { - TaskMetrics.getAccum[Long](initialAccumsMap, name) - } - - - /* ========================== * - | INPUT METRICS | - * ========================== */ - - private var _inputMetrics: Option[InputMetrics] = None - /** * Metrics related to reading data from a [[org.apache.spark.rdd.HadoopRDD]] or from persisted * data, defined only in tasks with input. */ - def inputMetrics: Option[InputMetrics] = _inputMetrics - - /** - * Get or create a new [[InputMetrics]] associated with this task. - */ - private[spark] def registerInputMetrics(readMethod: DataReadMethod.Value): InputMetrics = { - synchronized { - val metrics = _inputMetrics.getOrElse { - val metrics = new InputMetrics(initialAccumsMap) - metrics.setReadMethod(readMethod) - _inputMetrics = Some(metrics) - metrics - } - // If there already exists an InputMetric with the same read method, we can just return - // that one. Otherwise, if the read method is different from the one previously seen by - // this task, we return a new dummy one to avoid clobbering the values of the old metrics. - // In the future we should try to store input metrics from all different read methods at - // the same time (SPARK-5225). - if (metrics.readMethod == readMethod) { - metrics - } else { - val m = new InputMetrics - m.setReadMethod(readMethod) - m - } - } - } - - - /* ============================ * - | OUTPUT METRICS | - * ============================ */ - - private var _outputMetrics: Option[OutputMetrics] = None + val inputMetrics: InputMetrics = new InputMetrics() /** * Metrics related to writing data externally (e.g. to a distributed filesystem), * defined only in tasks with output. */ - def outputMetrics: Option[OutputMetrics] = _outputMetrics + val outputMetrics: OutputMetrics = new OutputMetrics() /** - * Get or create a new [[OutputMetrics]] associated with this task. + * Metrics related to shuffle read aggregated across all shuffle dependencies. + * This is defined only if there are shuffle dependencies in this task. */ - private[spark] def registerOutputMetrics( - writeMethod: DataWriteMethod.Value): OutputMetrics = synchronized { - _outputMetrics.getOrElse { - val metrics = new OutputMetrics(initialAccumsMap) - metrics.setWriteMethod(writeMethod) - _outputMetrics = Some(metrics) - metrics - } - } - - - /* ================================== * - | SHUFFLE READ METRICS | - * ================================== */ - - private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None + val shuffleReadMetrics: ShuffleReadMetrics = new ShuffleReadMetrics() /** - * Metrics related to shuffle read aggregated across all shuffle dependencies. - * This is defined only if there are shuffle dependencies in this task. + * Metrics related to shuffle write, defined only in shuffle map stages. */ - def shuffleReadMetrics: Option[ShuffleReadMetrics] = _shuffleReadMetrics + val shuffleWriteMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics() /** * Temporary list of [[ShuffleReadMetrics]], one per shuffle dependency. @@ -257,7 +158,7 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se * merges the temporary values synchronously. Otherwise, all temporary data collected will * be lost. */ - private[spark] def registerTempShuffleReadMetrics(): ShuffleReadMetrics = synchronized { + private[spark] def createTempShuffleReadMetrics(): ShuffleReadMetrics = synchronized { val readMetrics = new ShuffleReadMetrics tempShuffleReadMetrics += readMetrics readMetrics @@ -269,41 +170,45 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se */ private[spark] def mergeShuffleReadMetrics(): Unit = synchronized { if (tempShuffleReadMetrics.nonEmpty) { - val metrics = new ShuffleReadMetrics(initialAccumsMap) - metrics.setMergeValues(tempShuffleReadMetrics) - _shuffleReadMetrics = Some(metrics) + shuffleReadMetrics.setMergeValues(tempShuffleReadMetrics) } } - /* =================================== * - | SHUFFLE WRITE METRICS | - * =================================== */ - - private var _shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None - - /** - * Metrics related to shuffle write, defined only in shuffle map stages. - */ - def shuffleWriteMetrics: Option[ShuffleWriteMetrics] = _shuffleWriteMetrics - - /** - * Get or create a new [[ShuffleWriteMetrics]] associated with this task. - */ - private[spark] def registerShuffleWriteMetrics(): ShuffleWriteMetrics = synchronized { - _shuffleWriteMetrics.getOrElse { - val metrics = new ShuffleWriteMetrics(initialAccumsMap) - _shuffleWriteMetrics = Some(metrics) - metrics - } + // Only used for test + private[spark] val testAccum = + sys.props.get("spark.testing").map(_ => TaskMetrics.createLongAccum(TEST_ACCUM)) + + @transient private[spark] lazy val internalAccums: Seq[Accumulable[_, _]] = { + val in = inputMetrics + val out = outputMetrics + val sr = shuffleReadMetrics + val sw = shuffleWriteMetrics + Seq(_executorDeserializeTime, _executorRunTime, _resultSize, _jvmGCTime, + _resultSerializationTime, _memoryBytesSpilled, _diskBytesSpilled, _peakExecutionMemory, + _updatedBlockStatuses, sr._remoteBlocksFetched, sr._localBlocksFetched, sr._remoteBytesRead, + sr._localBytesRead, sr._fetchWaitTime, sr._recordsRead, sw._bytesWritten, sw._recordsWritten, + sw._writeTime, in._bytesRead, in._recordsRead, out._bytesWritten, out._recordsWritten) ++ + testAccum } - /* ========================== * | OTHER THINGS | * ========================== */ + private[spark] def registerAccums(sc: SparkContext): Unit = { + internalAccums.foreach { accum => + Accumulators.register(accum) + sc.cleaner.foreach(_.registerAccumulatorForCleanup(accum)) + } + } + + /** + * External accumulators registered with this task. + */ + @transient private lazy val externalAccums = new ArrayBuffer[Accumulable[_, _]] + private[spark] def registerAccumulator(a: Accumulable[_, _]): Unit = { - accums += a + externalAccums += a } /** @@ -314,30 +219,8 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se * not the aggregated value across multiple tasks. */ def accumulatorUpdates(): Seq[AccumulableInfo] = { - accums.map { a => a.toInfo(Some(a.localValue), None) } + (internalAccums ++ externalAccums).map { a => a.toInfo(Some(a.localValue), None) } } - - // If we are reconstructing this TaskMetrics on the driver, some metrics may already be set. - // If so, initialize all relevant metrics classes so listeners can access them downstream. - { - var (hasShuffleRead, hasShuffleWrite, hasInput, hasOutput) = (false, false, false, false) - initialAccums - .filter { a => a.localValue != a.zero } - .foreach { a => - a.name.get match { - case sr if sr.startsWith(SHUFFLE_READ_METRICS_PREFIX) => hasShuffleRead = true - case sw if sw.startsWith(SHUFFLE_WRITE_METRICS_PREFIX) => hasShuffleWrite = true - case in if in.startsWith(INPUT_METRICS_PREFIX) => hasInput = true - case out if out.startsWith(OUTPUT_METRICS_PREFIX) => hasOutput = true - case _ => - } - } - if (hasShuffleRead) { _shuffleReadMetrics = Some(new ShuffleReadMetrics(initialAccumsMap)) } - if (hasShuffleWrite) { _shuffleWriteMetrics = Some(new ShuffleWriteMetrics(initialAccumsMap)) } - if (hasInput) { _inputMetrics = Some(new InputMetrics(initialAccumsMap)) } - if (hasOutput) { _outputMetrics = Some(new OutputMetrics(initialAccumsMap)) } - } - } /** @@ -350,9 +233,7 @@ class TaskMetrics private[spark] (initialAccums: Seq[Accumulator[_]]) extends Se * UnsupportedOperationException, we choose not to do so because the overrides would quickly become * out-of-date when new metrics are added. */ -private[spark] class ListenerTaskMetrics( - initialAccums: Seq[Accumulator[_]], - accumUpdates: Seq[AccumulableInfo]) extends TaskMetrics(initialAccums) { +private[spark] class ListenerTaskMetrics(accumUpdates: Seq[AccumulableInfo]) extends TaskMetrics { override def accumulatorUpdates(): Seq[AccumulableInfo] = accumUpdates @@ -366,18 +247,25 @@ private[spark] object TaskMetrics extends Logging { def empty: TaskMetrics = new TaskMetrics /** - * Get an accumulator from the given map by name, assuming it exists. + * Create a new accumulator representing an internal task metric. */ - def getAccum[T](accumMap: Map[String, Accumulator[_]], name: String): Accumulator[T] = { - require(accumMap.contains(name), s"metric '$name' is missing") - val accum = accumMap(name) - try { - // Note: we can't do pattern matching here because types are erased by compile time - accum.asInstanceOf[Accumulator[T]] - } catch { - case e: ClassCastException => - throw new SparkException(s"accumulator $name was of unexpected type", e) - } + private def newMetric[T]( + initialValue: T, + name: String, + param: AccumulatorParam[T]): Accumulator[T] = { + new Accumulator[T](initialValue, param, Some(name), internal = true, countFailedValues = true) + } + + def createLongAccum(name: String): Accumulator[Long] = { + newMetric(0L, name, LongAccumulatorParam) + } + + def createIntAccum(name: String): Accumulator[Int] = { + newMetric(0, name, IntAccumulatorParam) + } + + def createBlocksAccum(name: String): Accumulator[Seq[(BlockId, BlockStatus)]] = { + newMetric(Nil, name, UpdatedBlockStatusesAccumulatorParam) } /** @@ -391,18 +279,11 @@ private[spark] object TaskMetrics extends Logging { * internal task level metrics. */ def fromAccumulatorUpdates(accumUpdates: Seq[AccumulableInfo]): TaskMetrics = { - // Initial accumulators are passed into the TaskMetrics constructor first because these - // are required to be uniquely named. The rest of the accumulators from this task are - // registered later because they need not satisfy this requirement. - val definedAccumUpdates = accumUpdates.filter { info => info.update.isDefined } - val initialAccums = definedAccumUpdates - .filter { info => info.name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX)) } - .map { info => - val accum = InternalAccumulator.create(info.name.get) - accum.setValueAny(info.update.get) - accum - } - new ListenerTaskMetrics(initialAccums, definedAccumUpdates) + val definedAccumUpdates = accumUpdates.filter(_.update.isDefined) + val metrics = new ListenerTaskMetrics(definedAccumUpdates) + definedAccumUpdates.filter(_.internal).foreach { accum => + metrics.internalAccums.find(_.name == accum.name).foreach(_.setValueAny(accum.update.get)) + } + metrics } - } diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 94b50ee06520c..2c1e0b71e3613 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -89,4 +89,11 @@ package object config { .stringConf .toSequence .createWithDefault(Nil) + + // Note: This is a SQL config but needs to be in core because the REPL depends on it + private[spark] val CATALOG_IMPLEMENTATION = ConfigBuilder("spark.sql.catalogImplementation") + .internal() + .stringConf + .checkValues(Set("hive", "in-memory")) + .createWithDefault("in-memory") } diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala index fa9c021f70376..82023b533d660 100644 --- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala +++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala @@ -206,7 +206,7 @@ object UnifiedMemoryManager { val systemMemory = conf.getLong("spark.testing.memory", Runtime.getRuntime.maxMemory) val reservedMemory = conf.getLong("spark.testing.reservedMemory", if (conf.contains("spark.testing")) 0 else RESERVED_SYSTEM_MEMORY_BYTES) - val minSystemMemory = reservedMemory * 1.5 + val minSystemMemory = (reservedMemory * 1.5).ceil.toLong if (systemMemory < minSystemMemory) { throw new IllegalArgumentException(s"System memory $systemMemory must " + s"be at least $minSystemMemory. Please increase heap size using the --driver-memory " + diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index 09ce012e4e692..cb9d389dd7ea6 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -20,7 +20,7 @@ package org.apache.spark.network import java.io.Closeable import java.nio.ByteBuffer -import scala.concurrent.{Await, Future, Promise} +import scala.concurrent.{Future, Promise} import scala.concurrent.duration.Duration import scala.reflect.ClassTag @@ -28,6 +28,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient} import org.apache.spark.storage.{BlockId, StorageLevel} +import org.apache.spark.util.ThreadUtils private[spark] abstract class BlockTransferService extends ShuffleClient with Closeable with Logging { @@ -100,8 +101,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo result.success(new NioManagedBuffer(ret)) } }) - - Await.result(result.future, Duration.Inf) + ThreadUtils.awaitResult(result.future, Duration.Inf) } /** @@ -119,6 +119,6 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo level: StorageLevel, classTag: ClassTag[_]): Unit = { val future = uploadBlock(hostname, port, execId, blockId, blockData, level, classTag) - Await.result(future, Duration.Inf) + ThreadUtils.awaitResult(future, Duration.Inf) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index 90d9735cb3f69..e75f1dbf8107a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -70,23 +70,23 @@ private[spark] case class CoalescedRDDPartition( * parent partitions * @param prev RDD to be coalesced * @param maxPartitions number of desired partitions in the coalesced RDD (must be positive) - * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance + * @param partitionCoalescer [[PartitionCoalescer]] implementation to use for coalescing */ private[spark] class CoalescedRDD[T: ClassTag]( @transient var prev: RDD[T], maxPartitions: Int, - balanceSlack: Double = 0.10) + partitionCoalescer: Option[PartitionCoalescer] = None) extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies require(maxPartitions > 0 || maxPartitions == prev.partitions.length, s"Number of partitions ($maxPartitions) must be positive.") override def getPartitions: Array[Partition] = { - val pc = new PartitionCoalescer(maxPartitions, prev, balanceSlack) + val pc = partitionCoalescer.getOrElse(new DefaultPartitionCoalescer()) - pc.run().zipWithIndex.map { + pc.coalesce(maxPartitions, prev).zipWithIndex.map { case (pg, i) => - val ids = pg.arr.map(_.index).toArray + val ids = pg.partitions.map(_.index).toArray new CoalescedRDDPartition(i, prev, ids, pg.prefLoc) } } @@ -144,15 +144,15 @@ private[spark] class CoalescedRDD[T: ClassTag]( * desired partitions is greater than the number of preferred machines (can happen), it needs to * start picking duplicate preferred machines. This is determined using coupon collector estimation * (2n log(n)). The load balancing is done using power-of-two randomized bins-balls with one twist: - * it tries to also achieve locality. This is done by allowing a slack (balanceSlack) between two - * bins. If two bins are within the slack in terms of balance, the algorithm will assign partitions - * according to locality. (contact alig for questions) - * + * it tries to also achieve locality. This is done by allowing a slack (balanceSlack, where + * 1.0 is all locality, 0 is all balance) between two bins. If two bins are within the slack + * in terms of balance, the algorithm will assign partitions according to locality. + * (contact alig for questions) */ -private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) { - - def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size +private class DefaultPartitionCoalescer(val balanceSlack: Double = 0.10) + extends PartitionCoalescer { + def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.numPartitions < o2.numPartitions def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean = if (o1 == None) false else if (o2 == None) true else compare(o1.get, o2.get) @@ -167,14 +167,10 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: // hash used for the first maxPartitions (to avoid duplicates) val initialHash = mutable.Set[Partition]() - // determines the tradeoff between load-balancing the partitions sizes and their locality - // e.g. balanceSlack=0.10 means that it allows up to 10% imbalance in favor of locality - val slack = (balanceSlack * prev.partitions.length).toInt - var noLocality = true // if true if no preferredLocations exists for parent RDD // gets the *current* preferred locations from the DAGScheduler (as opposed to the static ones) - def currPrefLocs(part: Partition): Seq[String] = { + def currPrefLocs(part: Partition, prev: RDD[_]): Seq[String] = { prev.context.getPreferredLocs(prev, part.index).map(tl => tl.host) } @@ -190,11 +186,11 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: // initializes/resets to start iterating from the beginning def resetIterator(): Iterator[(String, Partition)] = { - val iterators = (0 to 2).map( x => - prev.partitions.iterator.flatMap(p => { - if (currPrefLocs(p).size > x) Some((currPrefLocs(p)(x), p)) else None - } ) - ) + val iterators = (0 to 2).map { x => + prev.partitions.iterator.flatMap { p => + if (currPrefLocs(p, prev).size > x) Some((currPrefLocs(p, prev)(x), p)) else None + } + } iterators.reduceLeft((x, y) => x ++ y) } @@ -215,8 +211,9 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: /** * Sorts and gets the least element of the list associated with key in groupHash * The returned PartitionGroup is the least loaded of all groups that represent the machine "key" + * * @param key string representing a partitioned group on preferred machine key - * @return Option of PartitionGroup that has least elements for key + * @return Option of [[PartitionGroup]] that has least elements for key */ def getLeastGroupHash(key: String): Option[PartitionGroup] = { groupHash.get(key).map(_.sortWith(compare).head) @@ -224,7 +221,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: def addPartToPGroup(part: Partition, pgroup: PartitionGroup): Boolean = { if (!initialHash.contains(part)) { - pgroup.arr += part // already assign this element + pgroup.partitions += part // already assign this element initialHash += part // needed to avoid assigning partitions to multiple buckets true } else { false } @@ -236,12 +233,12 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: * until it has seen most of the preferred locations (2 * n log(n)) * @param targetLen */ - def setupGroups(targetLen: Int) { + def setupGroups(targetLen: Int, prev: RDD[_]) { val rotIt = new LocationIterator(prev) // deal with empty case, just create targetLen partition groups with no preferred location if (!rotIt.hasNext) { - (1 to targetLen).foreach(x => groupArr += PartitionGroup()) + (1 to targetLen).foreach(x => groupArr += new PartitionGroup()) return } @@ -259,7 +256,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: tries += 1 val (nxt_replica, nxt_part) = rotIt.next() if (!groupHash.contains(nxt_replica)) { - val pgroup = PartitionGroup(nxt_replica) + val pgroup = new PartitionGroup(Some(nxt_replica)) groupArr += pgroup addPartToPGroup(nxt_part, pgroup) groupHash.put(nxt_replica, ArrayBuffer(pgroup)) // list in case we have multiple @@ -269,7 +266,7 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: while (numCreated < targetLen) { // if we don't have enough partition groups, create duplicates var (nxt_replica, nxt_part) = rotIt.next() - val pgroup = PartitionGroup(nxt_replica) + val pgroup = new PartitionGroup(Some(nxt_replica)) groupArr += pgroup groupHash.getOrElseUpdate(nxt_replica, ArrayBuffer()) += pgroup var tries = 0 @@ -285,17 +282,29 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: /** * Takes a parent RDD partition and decides which of the partition groups to put it in * Takes locality into account, but also uses power of 2 choices to load balance - * It strikes a balance between the two use the balanceSlack variable + * It strikes a balance between the two using the balanceSlack variable * @param p partition (ball to be thrown) + * @param balanceSlack determines the trade-off between load-balancing the partitions sizes and + * their locality. e.g., balanceSlack=0.10 means that it allows up to 10% + * imbalance in favor of locality * @return partition group (bin to be put in) */ - def pickBin(p: Partition): PartitionGroup = { - val pref = currPrefLocs(p).map(getLeastGroupHash(_)).sortWith(compare) // least loaded pref locs + def pickBin(p: Partition, prev: RDD[_], balanceSlack: Double): PartitionGroup = { + val slack = (balanceSlack * prev.partitions.length).toInt + // least loaded pref locs + val pref = currPrefLocs(p, prev).map(getLeastGroupHash(_)).sortWith(compare) val prefPart = if (pref == Nil) None else pref.head val r1 = rnd.nextInt(groupArr.size) val r2 = rnd.nextInt(groupArr.size) - val minPowerOfTwo = if (groupArr(r1).size < groupArr(r2).size) groupArr(r1) else groupArr(r2) + val minPowerOfTwo = { + if (groupArr(r1).numPartitions < groupArr(r2).numPartitions) { + groupArr(r1) + } + else { + groupArr(r2) + } + } if (prefPart.isEmpty) { // if no preferred locations, just use basic power of two return minPowerOfTwo @@ -303,55 +312,45 @@ private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: val prefPartActual = prefPart.get - if (minPowerOfTwo.size + slack <= prefPartActual.size) { // more imbalance than the slack allows + // more imbalance than the slack allows + if (minPowerOfTwo.numPartitions + slack <= prefPartActual.numPartitions) { minPowerOfTwo // prefer balance over locality } else { prefPartActual // prefer locality over balance } } - def throwBalls() { + def throwBalls(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) { if (noLocality) { // no preferredLocations in parent RDD, no randomization needed if (maxPartitions > groupArr.size) { // just return prev.partitions for ((p, i) <- prev.partitions.zipWithIndex) { - groupArr(i).arr += p + groupArr(i).partitions += p } } else { // no locality available, then simply split partitions based on positions in array for (i <- 0 until maxPartitions) { val rangeStart = ((i.toLong * prev.partitions.length) / maxPartitions).toInt val rangeEnd = (((i.toLong + 1) * prev.partitions.length) / maxPartitions).toInt - (rangeStart until rangeEnd).foreach{ j => groupArr(i).arr += prev.partitions(j) } + (rangeStart until rangeEnd).foreach{ j => groupArr(i).partitions += prev.partitions(j) } } } } else { for (p <- prev.partitions if (!initialHash.contains(p))) { // throw every partition into group - pickBin(p).arr += p + pickBin(p, prev, balanceSlack).partitions += p } } } - def getPartitions: Array[PartitionGroup] = groupArr.filter( pg => pg.size > 0).toArray + def getPartitions: Array[PartitionGroup] = groupArr.filter( pg => pg.numPartitions > 0).toArray /** * Runs the packing algorithm and returns an array of PartitionGroups that if possible are * load balanced and grouped by locality - * @return array of partition groups + * + * @return array of partition groups */ - def run(): Array[PartitionGroup] = { - setupGroups(math.min(prev.partitions.length, maxPartitions)) // setup the groups (bins) - throwBalls() // assign partitions (balls) to each group (bins) + def coalesce(maxPartitions: Int, prev: RDD[_]): Array[PartitionGroup] = { + setupGroups(math.min(prev.partitions.length, maxPartitions), prev) // setup the groups (bins) + throwBalls(maxPartitions, prev, balanceSlack) // assign partitions (balls) to each group (bins) getPartitions } } - -private case class PartitionGroup(prefLoc: Option[String] = None) { - var arr = mutable.ArrayBuffer[Partition]() - def size: Int = arr.size -} - -private object PartitionGroup { - def apply(prefLoc: String): PartitionGroup = { - require(prefLoc != "", "Preferred location must not be empty") - PartitionGroup(Some(prefLoc)) - } -} diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 35d190b464ff4..6b1e15572c03a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -213,7 +213,7 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics().inputMetrics val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index 526138093d3ea..5426bf80bafc5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -65,11 +65,11 @@ class JdbcRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { // bounds are inclusive, hence the + 1 here and - 1 on end val length = BigInt(1) + upperBound - lowerBound - (0 until numPartitions).map(i => { + (0 until numPartitions).map { i => val start = lowerBound + ((i * length) / numPartitions) val end = lowerBound + (((i + 1) * length) / numPartitions) - 1 new JdbcPartition(i, start.toLong, end.toLong) - }).toArray + }.toArray } override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T] diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala index 3ccd616cbfd57..a71c191b318ea 100644 --- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala @@ -130,7 +130,7 @@ class NewHadoopRDD[K, V]( logInfo("Input split: " + split.serializableHadoopSplit) val conf = getConf - val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) + val inputMetrics = context.taskMetrics().inputMetrics val existingBytesRead = inputMetrics.bytesRead // Find a function that will return the FileSystem bytes read by this thread. Do this before diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 085829af6eee7..7936d8e1d45a2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -1218,7 +1218,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)]) context: TaskContext): Option[(OutputMetrics, () => Long)] = { val bytesWrittenCallback = SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback() bytesWrittenCallback.map { b => - (context.taskMetrics().registerOutputMetrics(DataWriteMethod.Hadoop), b) + (context.taskMetrics().outputMetrics, b) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index bb84e4af15b15..34a1c112cbcd0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -129,7 +129,7 @@ private object ParallelCollectionRDD { } seq match { case r: Range => - positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) => + positions(r.length, numSlices).zipWithIndex.map { case ((start, end), index) => // If the range is inclusive, use inclusive range for the last slice if (r.isInclusive && index == numSlices - 1) { new Range.Inclusive(r.start + start * r.step, r.end, r.step) @@ -137,7 +137,7 @@ private object ParallelCollectionRDD { else { new Range(r.start + start * r.step, r.start + end * r.step, r.step) } - }).toSeq.asInstanceOf[Seq[Seq[T]]] + }.toSeq.asInstanceOf[Seq[Seq[T]]] case nr: NumericRange[_] => // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) @@ -150,10 +150,9 @@ private object ParallelCollectionRDD { slices case _ => val array = seq.toArray // To prevent O(n^2) operations for List etc - positions(array.length, numSlices).map({ - case (start, end) => + positions(array.length, numSlices).map { case (start, end) => array.slice(start, end).toSeq - }).toSeq + }.toSeq } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 36ff3bcaaec62..499a8b9aa1a89 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -332,7 +332,7 @@ abstract class RDD[T: ClassTag]( }) match { case Left(blockResult) => if (readCachedBlock) { - val existingMetrics = context.taskMetrics().registerInputMetrics(blockResult.readMethod) + val existingMetrics = context.taskMetrics().inputMetrics existingMetrics.incBytesRead(blockResult.bytes) new InterruptibleIterator[T](context, blockResult.data.asInstanceOf[Iterator[T]]) { override def next(): T = { @@ -433,7 +433,9 @@ abstract class RDD[T: ClassTag]( * coalesce(1000, shuffle = true) will result in 1000 partitions with the * data distributed using a hash partitioner. */ - def coalesce(numPartitions: Int, shuffle: Boolean = false)(implicit ord: Ordering[T] = null) + def coalesce(numPartitions: Int, shuffle: Boolean = false, + partitionCoalescer: Option[PartitionCoalescer] = Option.empty) + (implicit ord: Ordering[T] = null) : RDD[T] = withScope { if (shuffle) { /** Distributes elements evenly across output partitions, starting from a random partition. */ @@ -451,9 +453,10 @@ abstract class RDD[T: ClassTag]( new CoalescedRDD( new ShuffledRDD[Int, T, T](mapPartitionsWithIndex(distributePartition), new HashPartitioner(numPartitions)), - numPartitions).values + numPartitions, + partitionCoalescer).values } else { - new CoalescedRDD(this, numPartitions) + new CoalescedRDD(this, numPartitions, partitionCoalescer) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala new file mode 100644 index 0000000000000..d8a80aa5aeb15 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/coalesce-public.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.collection.mutable + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.Partition + +/** + * ::DeveloperApi:: + * A PartitionCoalescer defines how to coalesce the partitions of a given RDD. + */ +@DeveloperApi +trait PartitionCoalescer { + + /** + * Coalesce the partitions of the given RDD. + * + * @param maxPartitions the maximum number of partitions to have after coalescing + * @param parent the parent RDD whose partitions to coalesce + * @return an array of [[PartitionGroup]]s, where each element is itself an array of + * [[Partition]]s and represents a partition after coalescing is performed. + */ + def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] +} + +/** + * ::DeveloperApi:: + * A group of [[Partition]]s + * @param prefLoc preferred location for the partition group + */ +@DeveloperApi +class PartitionGroup(val prefLoc: Option[String] = None) { + val partitions = mutable.ArrayBuffer[Partition]() + def numPartitions: Int = partitions.size +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala index 2950df62bf285..2761d39e37029 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcTimeout.scala @@ -19,10 +19,11 @@ package org.apache.spark.rpc import java.util.concurrent.TimeoutException -import scala.concurrent.{Await, Awaitable} +import scala.concurrent.{Await, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal -import org.apache.spark.SparkConf +import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.util.Utils /** @@ -65,14 +66,21 @@ private[spark] class RpcTimeout(val duration: FiniteDuration, val timeoutProp: S /** * Wait for the completed result and return it. If the result is not available within this * timeout, throw a [[RpcTimeoutException]] to indicate which configuration controls the timeout. - * @param awaitable the `Awaitable` to be awaited - * @throws RpcTimeoutException if after waiting for the specified time `awaitable` + * + * @param future the `Future` to be awaited + * @throws RpcTimeoutException if after waiting for the specified time `future` * is still not ready */ - def awaitResult[T](awaitable: Awaitable[T]): T = { + def awaitResult[T](future: Future[T]): T = { + val wrapAndRethrow: PartialFunction[Throwable, T] = { + case NonFatal(t) => + throw new SparkException("Exception thrown in awaitResult", t) + } try { - Await.result(awaitable, duration) - } catch addMessageIfTimeout + // scalastyle:off awaitresult + Await.result(future, duration) + // scalastyle:on awaitresult + } catch addMessageIfTimeout.orElse(wrapAndRethrow) } } @@ -82,6 +90,7 @@ private[spark] object RpcTimeout { /** * Lookup the timeout property in the configuration and create * a RpcTimeout with the property key in the description. + * * @param conf configuration properties containing the timeout * @param timeoutProp property key for the timeout in seconds * @throws NoSuchElementException if property is not set @@ -95,6 +104,7 @@ private[spark] object RpcTimeout { * Lookup the timeout property in the configuration and create * a RpcTimeout with the property key in the description. * Uses the given default value if property is not set + * * @param conf configuration properties containing the timeout * @param timeoutProp property key for the timeout in seconds * @param defaultValue default timeout value in seconds if property not found @@ -109,6 +119,7 @@ private[spark] object RpcTimeout { * and create a RpcTimeout with the first set property key in the * description. * Uses the given default value if property is not set + * * @param conf configuration properties containing the timeout * @param timeoutPropList prioritized list of property keys for the timeout in seconds * @param defaultValue default timeout value in seconds if no properties found diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index c27aad268d32a..b7fb608ea5064 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1029,7 +1029,7 @@ class DAGScheduler( val locs = taskIdToLocations(id) val part = stage.rdd.partitions(id) new ShuffleMapTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, stage.latestInfo.internalAccumulators, properties) + taskBinary, part, locs, stage.latestInfo.taskMetrics, properties) } case stage: ResultStage => @@ -1039,7 +1039,7 @@ class DAGScheduler( val part = stage.rdd.partitions(p) val locs = taskIdToLocations(id) new ResultTask(stage.id, stage.latestInfo.attemptId, - taskBinary, part, locs, id, properties, stage.latestInfo.internalAccumulators) + taskBinary, part, locs, id, properties, stage.latestInfo.taskMetrics) } } } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala new file mode 100644 index 0000000000000..d1ac7131baba5 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/scheduler/ExternalClusterManager.scala @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.SparkContext + +/** + * A cluster manager interface to plugin external scheduler. + */ +private[spark] trait ExternalClusterManager { + + /** + * Check if this cluster manager instance can create scheduler components + * for a certain master URL. + * @param masterURL the master URL + * @return True if the cluster manager can create scheduler backend/ + */ + def canCreate(masterURL: String): Boolean + + /** + * Create a task scheduler instance for the given SparkContext + * @param sc SparkContext + * @param masterURL the master URL + * @return TaskScheduler that will be responsible for task handling + */ + def createTaskScheduler(sc: SparkContext, masterURL: String): TaskScheduler + + /** + * Create a scheduler backend for the given SparkContext and scheduler. This is + * called after task scheduler is created using [[ExternalClusterManager.createTaskScheduler()]]. + * @param sc SparkContext + * @param masterURL the master URL + * @param scheduler TaskScheduler that will be used with the scheduler backend. + * @return SchedulerBackend that works with a TaskScheduler + */ + def createSchedulerBackend(sc: SparkContext, + masterURL: String, + scheduler: TaskScheduler): SchedulerBackend + + /** + * Initialize task scheduler and backend scheduler. This is called after the + * scheduler components are created + * @param scheduler TaskScheduler that will be responsible for task handling + * @param backend SchedulerBackend that works with a TaskScheduler + */ + def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index db6276f75d781..75c6018e214d8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -23,6 +23,7 @@ import java.util.Properties import org.apache.spark._ import org.apache.spark.broadcast.Broadcast +import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD /** @@ -40,9 +41,7 @@ import org.apache.spark.rdd.RDD * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). * @param localProperties copy of thread-local properties set by the user on the driver side. - * @param _initialAccums initial set of accumulators to be used in this task for tracking - * internal metrics. Other accumulators will be registered later when - * they are deserialized on the executors. + * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. */ private[spark] class ResultTask[T, U]( stageId: Int, @@ -52,8 +51,8 @@ private[spark] class ResultTask[T, U]( locs: Seq[TaskLocation], val outputId: Int, localProperties: Properties, - _initialAccums: Seq[Accumulator[_]] = InternalAccumulator.createAll()) - extends Task[U](stageId, stageAttemptId, partition.index, _initialAccums, localProperties) + metrics: TaskMetrics) + extends Task[U](stageId, stageAttemptId, partition.index, metrics, localProperties) with Serializable { @transient private[this] val preferredLocs: Seq[TaskLocation] = { @@ -68,7 +67,6 @@ private[spark] class ResultTask[T, U]( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime - metrics = Some(context.taskMetrics) func(context, rdd.iterator(partition, context)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index b7cab7013ef6f..84b3e5ba6c1f3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -24,6 +24,7 @@ import scala.language.existentials import org.apache.spark._ import org.apache.spark.broadcast.Broadcast +import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.shuffle.ShuffleWriter @@ -40,9 +41,7 @@ import org.apache.spark.shuffle.ShuffleWriter * the type should be (RDD[_], ShuffleDependency[_, _, _]). * @param partition partition of the RDD this task is associated with * @param locs preferred task execution locations for locality scheduling - * @param _initialAccums initial set of accumulators to be used in this task for tracking - * internal metrics. Other accumulators will be registered later when - * they are deserialized on the executors. + * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. */ private[spark] class ShuffleMapTask( @@ -51,9 +50,9 @@ private[spark] class ShuffleMapTask( taskBinary: Broadcast[Array[Byte]], partition: Partition, @transient private var locs: Seq[TaskLocation], - _initialAccums: Seq[Accumulator[_]], + metrics: TaskMetrics, localProperties: Properties) - extends Task[MapStatus](stageId, stageAttemptId, partition.index, _initialAccums, localProperties) + extends Task[MapStatus](stageId, stageAttemptId, partition.index, metrics, localProperties) with Logging { /** A constructor used only in test suites. This does not require passing in an RDD. */ @@ -73,7 +72,6 @@ private[spark] class ShuffleMapTask( ByteBuffer.wrap(taskBinary.value), Thread.currentThread.getContextClassLoader) _executorDeserializeTime = System.currentTimeMillis() - deserializeStartTime - metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager diff --git a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala index b6d4e39fe532a..d5cf6b82e86f0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Stage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Stage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashSet import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.util.CallSite @@ -110,9 +111,10 @@ private[scheduler] abstract class Stage( def makeNewStageAttempt( numPartitionsToCompute: Int, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty): Unit = { + val metrics = new TaskMetrics + metrics.registerAccums(rdd.sparkContext) _latestInfo = StageInfo.fromStage( - this, nextAttemptId, Some(numPartitionsToCompute), - InternalAccumulator.createAll(rdd.sparkContext), taskLocalityPreferences) + this, nextAttemptId, Some(numPartitionsToCompute), metrics, taskLocalityPreferences) nextAttemptId += 1 } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 0fd58c41cdceb..58349fe250887 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -19,8 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable.HashMap -import org.apache.spark.Accumulator import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.executor.TaskMetrics import org.apache.spark.storage.RDDInfo /** @@ -36,7 +36,7 @@ class StageInfo( val rddInfos: Seq[RDDInfo], val parentIds: Seq[Int], val details: String, - val internalAccumulators: Seq[Accumulator[_]] = Seq.empty, + val taskMetrics: TaskMetrics = new TaskMetrics, private[spark] val taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None @@ -81,7 +81,7 @@ private[spark] object StageInfo { stage: Stage, attemptId: Int, numTasks: Option[Int] = None, - internalAccumulators: Seq[Accumulator[_]] = Seq.empty, + taskMetrics: TaskMetrics = new TaskMetrics, taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty ): StageInfo = { val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd) @@ -94,7 +94,7 @@ private[spark] object StageInfo { rddInfos, stage.parents.map(_.id), stage.details, - internalAccumulators, + taskMetrics, taskLocalityPreferences) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala index 309f4b806bf70..3c8cab7504c17 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StatsReportListener.scala @@ -47,19 +47,19 @@ class StatsReportListener extends SparkListener with Logging { override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) { implicit val sc = stageCompleted this.logInfo(s"Finished stage: ${getStatusDetail(stageCompleted.stageInfo)}") - showMillisDistribution("task runtime:", (info, _) => Some(info.duration), taskInfoMetrics) + showMillisDistribution("task runtime:", (info, _) => info.duration, taskInfoMetrics) // Shuffle write showBytesDistribution("shuffle bytes written:", - (_, metric) => metric.shuffleWriteMetrics.map(_.bytesWritten), taskInfoMetrics) + (_, metric) => metric.shuffleWriteMetrics.bytesWritten, taskInfoMetrics) // Fetch & I/O showMillisDistribution("fetch wait time:", - (_, metric) => metric.shuffleReadMetrics.map(_.fetchWaitTime), taskInfoMetrics) + (_, metric) => metric.shuffleReadMetrics.fetchWaitTime, taskInfoMetrics) showBytesDistribution("remote bytes read:", - (_, metric) => metric.shuffleReadMetrics.map(_.remoteBytesRead), taskInfoMetrics) + (_, metric) => metric.shuffleReadMetrics.remoteBytesRead, taskInfoMetrics) showBytesDistribution("task result size:", - (_, metric) => Some(metric.resultSize), taskInfoMetrics) + (_, metric) => metric.resultSize, taskInfoMetrics) // Runtime breakdown val runtimePcts = taskInfoMetrics.map { case (info, metrics) => @@ -95,17 +95,17 @@ private[spark] object StatsReportListener extends Logging { def extractDoubleDistribution( taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)], - getMetric: (TaskInfo, TaskMetrics) => Option[Double]): Option[Distribution] = { - Distribution(taskInfoMetrics.flatMap { case (info, metric) => getMetric(info, metric) }) + getMetric: (TaskInfo, TaskMetrics) => Double): Option[Distribution] = { + Distribution(taskInfoMetrics.map { case (info, metric) => getMetric(info, metric) }) } // Is there some way to setup the types that I can get rid of this completely? def extractLongDistribution( taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)], - getMetric: (TaskInfo, TaskMetrics) => Option[Long]): Option[Distribution] = { + getMetric: (TaskInfo, TaskMetrics) => Long): Option[Distribution] = { extractDoubleDistribution( taskInfoMetrics, - (info, metric) => { getMetric(info, metric).map(_.toDouble) }) + (info, metric) => { getMetric(info, metric).toDouble }) } def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { @@ -117,9 +117,9 @@ private[spark] object StatsReportListener extends Logging { } def showDistribution( - heading: String, - dOpt: Option[Distribution], - formatNumber: Double => String) { + heading: String, + dOpt: Option[Distribution], + formatNumber: Double => String) { dOpt.foreach { d => showDistribution(heading, d, formatNumber)} } @@ -129,17 +129,17 @@ private[spark] object StatsReportListener extends Logging { } def showDistribution( - heading: String, - format: String, - getMetric: (TaskInfo, TaskMetrics) => Option[Double], - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { + heading: String, + format: String, + getMetric: (TaskInfo, TaskMetrics) => Double, + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { showDistribution(heading, extractDoubleDistribution(taskInfoMetrics, getMetric), format) } def showBytesDistribution( - heading: String, - getMetric: (TaskInfo, TaskMetrics) => Option[Long], - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { + heading: String, + getMetric: (TaskInfo, TaskMetrics) => Long, + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { showBytesDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) } @@ -157,9 +157,9 @@ private[spark] object StatsReportListener extends Logging { } def showMillisDistribution( - heading: String, - getMetric: (TaskInfo, TaskMetrics) => Option[Long], - taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { + heading: String, + getMetric: (TaskInfo, TaskMetrics) => Long, + taskInfoMetrics: Seq[(TaskInfo, TaskMetrics)]) { showMillisDistribution(heading, extractLongDistribution(taskInfoMetrics, getMetric)) } @@ -190,7 +190,7 @@ private case class RuntimePercentage(executorPct: Double, fetchPct: Option[Doubl private object RuntimePercentage { def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { val denom = totalTime.toDouble - val fetchTime = metrics.shuffleReadMetrics.map(_.fetchWaitTime) + val fetchTime = Some(metrics.shuffleReadMetrics.fetchWaitTime) val fetch = fetchTime.map(_ / denom) val exec = (metrics.executorRunTime - fetchTime.getOrElse(0L)) / denom val other = 1.0 - (exec + fetch.getOrElse(0d)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 1ff9d7795f42e..9f2fa02c69ab1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -23,7 +23,7 @@ import java.util.Properties import scala.collection.mutable.HashMap -import org.apache.spark.{Accumulator, SparkEnv, TaskContext, TaskContextImpl} +import org.apache.spark.{SparkEnv, TaskContext, TaskContextImpl} import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem @@ -44,17 +44,17 @@ import org.apache.spark.util.{ByteBufferInputStream, ByteBufferOutputStream, Uti * @param stageId id of the stage this task belongs to * @param stageAttemptId attempt id of the stage this task belongs to * @param partitionId index of the number in the RDD - * @param initialAccumulators initial set of accumulators to be used in this task for tracking - * internal metrics. Other accumulators will be registered later when - * they are deserialized on the executors. + * @param metrics a [[TaskMetrics]] that is created at driver side and sent to executor side. * @param localProperties copy of thread-local properties set by the user on the driver side. + * + * The default values for `metrics` and `localProperties` are used by tests only. */ private[spark] abstract class Task[T]( val stageId: Int, val stageAttemptId: Int, val partitionId: Int, - val initialAccumulators: Seq[Accumulator[_]], - @transient var localProperties: Properties) extends Serializable { + val metrics: TaskMetrics = new TaskMetrics, + @transient var localProperties: Properties = new Properties) extends Serializable { /** * Called by [[org.apache.spark.executor.Executor]] to run this task. @@ -76,7 +76,7 @@ private[spark] abstract class Task[T]( taskMemoryManager, localProperties, metricsSystem, - initialAccumulators) + metrics) TaskContext.setTaskContext(context) taskThread = Thread.currentThread() if (_killed) { @@ -128,8 +128,6 @@ private[spark] abstract class Task[T]( // Map output tracker epoch. Will be set by TaskScheduler. var epoch: Long = -1 - var metrics: Option[TaskMetrics] = None - // Task context, to be initialized in run(). @transient protected var context: TaskContextImpl = _ diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 876cdfaa87601..5794f542b7564 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -67,7 +67,7 @@ private[spark] class BlockStoreShuffleReader[K, C]( } // Update the context task metrics for each record read. - val readMetrics = context.taskMetrics.registerTempShuffleReadMetrics() + val readMetrics = context.taskMetrics.createTempShuffleReadMetrics() val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( recordIter.map { record => readMetrics.incRecordsRead(1) diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala deleted file mode 100644 index be1e84a2ba938..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle - -import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue} - -import scala.collection.JavaConverters._ - -import org.apache.spark.{SparkConf, SparkEnv} -import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.internal.Logging -import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} -import org.apache.spark.network.netty.SparkTransportConf -import org.apache.spark.serializer.Serializer -import org.apache.spark.storage._ -import org.apache.spark.util.Utils - -/** A group of writers for a ShuffleMapTask, one writer per reducer. */ -private[spark] trait ShuffleWriterGroup { - val writers: Array[DiskBlockObjectWriter] - - /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ - def releaseWriters(success: Boolean): Unit -} - -/** - * Manages assigning disk-based block writers to shuffle tasks. Each shuffle task gets one file - * per reducer. - */ -// Note: Changes to the format in this file should be kept in sync with -// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getHashBasedShuffleBlockData(). -private[spark] class FileShuffleBlockResolver(conf: SparkConf) - extends ShuffleBlockResolver with Logging { - - private val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle") - - private lazy val blockManager = SparkEnv.get.blockManager - - // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided - private val bufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - - /** - * Contains all the state related to a particular shuffle. - */ - private class ShuffleState(val numReducers: Int) { - /** - * The mapIds of all map tasks completed on this Executor for this shuffle. - */ - val completedMapTasks = new ConcurrentLinkedQueue[Int]() - } - - private val shuffleStates = new ConcurrentHashMap[ShuffleId, ShuffleState] - - /** - * Get a ShuffleWriterGroup for the given map task, which will register it as complete - * when the writers are closed successfully - */ - def forMapTask(shuffleId: Int, mapId: Int, numReducers: Int, serializer: Serializer, - writeMetrics: ShuffleWriteMetrics): ShuffleWriterGroup = { - new ShuffleWriterGroup { - private val shuffleState: ShuffleState = { - // Note: we do _not_ want to just wrap this java ConcurrentHashMap into a Scala map and use - // .getOrElseUpdate() because that's actually NOT atomic. - shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numReducers)) - shuffleStates.get(shuffleId) - } - val openStartTime = System.nanoTime - val serializerInstance = serializer.newInstance() - val writers: Array[DiskBlockObjectWriter] = { - Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId => - val blockId = ShuffleBlockId(shuffleId, mapId, bucketId) - val blockFile = blockManager.diskBlockManager.getFile(blockId) - val tmp = Utils.tempFileWith(blockFile) - blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics) - } - } - // Creating the file to write to and creating a disk writer both involve interacting with - // the disk, so should be included in the shuffle write time. - writeMetrics.incWriteTime(System.nanoTime - openStartTime) - - override def releaseWriters(success: Boolean) { - shuffleState.completedMapTasks.add(mapId) - } - } - } - - override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = { - val file = blockManager.diskBlockManager.getFile(blockId) - new FileSegmentManagedBuffer(transportConf, file, 0, file.length) - } - - /** Remove all the blocks / files and metadata related to a particular shuffle. */ - def removeShuffle(shuffleId: ShuffleId): Boolean = { - // Do not change the ordering of this, if shuffleStates should be removed only - // after the corresponding shuffle blocks have been removed - val cleaned = removeShuffleBlocks(shuffleId) - shuffleStates.remove(shuffleId) - cleaned - } - - /** Remove all the blocks / files related to a particular shuffle. */ - private def removeShuffleBlocks(shuffleId: ShuffleId): Boolean = { - Option(shuffleStates.get(shuffleId)) match { - case Some(state) => - for (mapId <- state.completedMapTasks.asScala; reduceId <- 0 until state.numReducers) { - val blockId = new ShuffleBlockId(shuffleId, mapId, reduceId) - val file = blockManager.diskBlockManager.getFile(blockId) - if (!file.delete()) { - logWarning(s"Error deleting ${file.getPath()}") - } - } - logInfo("Deleted all files for shuffle " + shuffleId) - true - case None => - logInfo("Could not find files for shuffle " + shuffleId + " for deleting") - false - } - } - - override def stop(): Unit = {} -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala deleted file mode 100644 index 6bb4ff94b546d..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleManager.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.shuffle._ - -/** - * A ShuffleManager using hashing, that creates one output file per reduce partition on each - * mapper (possibly reusing these across waves of tasks). - */ -private[spark] class HashShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { - - if (!conf.getBoolean("spark.shuffle.spill", true)) { - logWarning( - "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + - " Shuffle will continue to spill to disk when necessary.") - } - - private val fileShuffleBlockResolver = new FileShuffleBlockResolver(conf) - - override val shortName: String = "hash" - - /* Register a shuffle with the manager and obtain a handle for it to pass to tasks. */ - override def registerShuffle[K, V, C]( - shuffleId: Int, - numMaps: Int, - dependency: ShuffleDependency[K, V, C]): ShuffleHandle = { - new BaseShuffleHandle(shuffleId, numMaps, dependency) - } - - /** - * Get a reader for a range of reduce partitions (startPartition to endPartition-1, inclusive). - * Called on executors by reduce tasks. - */ - override def getReader[K, C]( - handle: ShuffleHandle, - startPartition: Int, - endPartition: Int, - context: TaskContext): ShuffleReader[K, C] = { - new BlockStoreShuffleReader( - handle.asInstanceOf[BaseShuffleHandle[K, _, C]], startPartition, endPartition, context) - } - - /** Get a writer for a given partition. Called on executors by map tasks. */ - override def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: TaskContext) - : ShuffleWriter[K, V] = { - new HashShuffleWriter( - shuffleBlockResolver, handle.asInstanceOf[BaseShuffleHandle[K, V, _]], mapId, context) - } - - /** Remove a shuffle's metadata from the ShuffleManager. */ - override def unregisterShuffle(shuffleId: Int): Boolean = { - shuffleBlockResolver.removeShuffle(shuffleId) - } - - override def shuffleBlockResolver: FileShuffleBlockResolver = { - fileShuffleBlockResolver - } - - /** Shut down this ShuffleManager. */ - override def stop(): Unit = { - shuffleBlockResolver.stop() - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala deleted file mode 100644 index 9276d95012f2f..0000000000000 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.shuffle.hash - -import java.io.IOException - -import org.apache.spark._ -import org.apache.spark.internal.Logging -import org.apache.spark.scheduler.MapStatus -import org.apache.spark.shuffle._ -import org.apache.spark.storage.DiskBlockObjectWriter - -private[spark] class HashShuffleWriter[K, V]( - shuffleBlockResolver: FileShuffleBlockResolver, - handle: BaseShuffleHandle[K, V, _], - mapId: Int, - context: TaskContext) - extends ShuffleWriter[K, V] with Logging { - - private val dep = handle.dependency - private val numOutputSplits = dep.partitioner.numPartitions - private val metrics = context.taskMetrics - - // Are we in the process of stopping? Because map tasks can call stop() with success = true - // and then call stop() with success = false if they get an exception, we want to make sure - // we don't try deleting files, etc twice. - private var stopping = false - - private val writeMetrics = metrics.registerShuffleWriteMetrics() - - private val blockManager = SparkEnv.get.blockManager - private val shuffle = shuffleBlockResolver.forMapTask(dep.shuffleId, mapId, numOutputSplits, - dep.serializer, writeMetrics) - - /** Write a bunch of records to this task's output */ - override def write(records: Iterator[Product2[K, V]]): Unit = { - val iter = if (dep.aggregator.isDefined) { - if (dep.mapSideCombine) { - dep.aggregator.get.combineValuesByKey(records, context) - } else { - records - } - } else { - require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!") - records - } - - for (elem <- iter) { - val bucketId = dep.partitioner.getPartition(elem._1) - shuffle.writers(bucketId).write(elem._1, elem._2) - } - } - - /** Close this writer, passing along whether the map completed */ - override def stop(initiallySuccess: Boolean): Option[MapStatus] = { - var success = initiallySuccess - try { - if (stopping) { - return None - } - stopping = true - if (success) { - try { - Some(commitWritesAndBuildStatus()) - } catch { - case e: Exception => - success = false - revertWrites() - throw e - } - } else { - revertWrites() - None - } - } finally { - // Release the writers back to the shuffle block manager. - if (shuffle != null && shuffle.writers != null) { - try { - shuffle.releaseWriters(success) - } catch { - case e: Exception => logError("Failed to release shuffle writers", e) - } - } - } - } - - private def commitWritesAndBuildStatus(): MapStatus = { - // Commit the writes. Get the size of each bucket block (total block size). - val sizes: Array[Long] = shuffle.writers.map { writer: DiskBlockObjectWriter => - writer.commitAndClose() - writer.fileSegment().length - } - // rename all shuffle files to final paths - // Note: there is only one ShuffleBlockResolver in executor - shuffleBlockResolver.synchronized { - shuffle.writers.zipWithIndex.foreach { case (writer, i) => - val output = blockManager.diskBlockManager.getFile(writer.blockId) - if (sizes(i) > 0) { - if (output.exists()) { - // Use length of existing file and delete our own temporary one - sizes(i) = output.length() - writer.file.delete() - } else { - // Commit by renaming our temporary file to something the fetcher expects - if (!writer.file.renameTo(output)) { - throw new IOException(s"fail to rename ${writer.file} to $output") - } - } - } else { - if (output.exists()) { - output.delete() - } - } - } - } - MapStatus(blockManager.shuffleServerId, sizes) - } - - private def revertWrites(): Unit = { - if (shuffle != null && shuffle.writers != null) { - for (writer <- shuffle.writers) { - writer.revertPartialWritesAndClose() - } - } - } -} diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 8ab1cee2e842d..1adacabc86c05 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -45,7 +45,7 @@ private[spark] class SortShuffleWriter[K, V, C]( private var mapStatus: MapStatus = null - private val writeMetrics = context.taskMetrics().registerShuffleWriteMetrics() + private val writeMetrics = context.taskMetrics().shuffleWriteMetrics /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala index f8d6e9fbbb90d..eddc36edc9611 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/AllStagesResource.scala @@ -167,35 +167,32 @@ private[v1] object AllStagesResource { // to make it a little easier to deal w/ all of the nested options. Mostly it lets us just // implement one "build" method, which just builds the quantiles for each field. - val inputMetrics: Option[InputMetricDistributions] = + val inputMetrics: InputMetricDistributions = new MetricHelper[InternalInputMetrics, InputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): Option[InternalInputMetrics] = { - raw.inputMetrics - } + def getSubmetrics(raw: InternalTaskMetrics): InternalInputMetrics = raw.inputMetrics def build: InputMetricDistributions = new InputMetricDistributions( bytesRead = submetricQuantiles(_.bytesRead), recordsRead = submetricQuantiles(_.recordsRead) ) - }.metricOption + }.build - val outputMetrics: Option[OutputMetricDistributions] = + val outputMetrics: OutputMetricDistributions = new MetricHelper[InternalOutputMetrics, OutputMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): Option[InternalOutputMetrics] = { - raw.outputMetrics - } + def getSubmetrics(raw: InternalTaskMetrics): InternalOutputMetrics = raw.outputMetrics + def build: OutputMetricDistributions = new OutputMetricDistributions( bytesWritten = submetricQuantiles(_.bytesWritten), recordsWritten = submetricQuantiles(_.recordsWritten) ) - }.metricOption + }.build - val shuffleReadMetrics: Option[ShuffleReadMetricDistributions] = + val shuffleReadMetrics: ShuffleReadMetricDistributions = new MetricHelper[InternalShuffleReadMetrics, ShuffleReadMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): Option[InternalShuffleReadMetrics] = { + def getSubmetrics(raw: InternalTaskMetrics): InternalShuffleReadMetrics = raw.shuffleReadMetrics - } + def build: ShuffleReadMetricDistributions = new ShuffleReadMetricDistributions( readBytes = submetricQuantiles(_.totalBytesRead), readRecords = submetricQuantiles(_.recordsRead), @@ -205,20 +202,20 @@ private[v1] object AllStagesResource { totalBlocksFetched = submetricQuantiles(_.totalBlocksFetched), fetchWaitTime = submetricQuantiles(_.fetchWaitTime) ) - }.metricOption + }.build - val shuffleWriteMetrics: Option[ShuffleWriteMetricDistributions] = + val shuffleWriteMetrics: ShuffleWriteMetricDistributions = new MetricHelper[InternalShuffleWriteMetrics, ShuffleWriteMetricDistributions](rawMetrics, quantiles) { - def getSubmetrics(raw: InternalTaskMetrics): Option[InternalShuffleWriteMetrics] = { + def getSubmetrics(raw: InternalTaskMetrics): InternalShuffleWriteMetrics = raw.shuffleWriteMetrics - } + def build: ShuffleWriteMetricDistributions = new ShuffleWriteMetricDistributions( writeBytes = submetricQuantiles(_.bytesWritten), writeRecords = submetricQuantiles(_.recordsWritten), writeTime = submetricQuantiles(_.writeTime) ) - }.metricOption + }.build new TaskMetricDistributions( quantiles = quantiles, @@ -250,10 +247,10 @@ private[v1] object AllStagesResource { resultSerializationTime = internal.resultSerializationTime, memoryBytesSpilled = internal.memoryBytesSpilled, diskBytesSpilled = internal.diskBytesSpilled, - inputMetrics = internal.inputMetrics.map { convertInputMetrics }, - outputMetrics = Option(internal.outputMetrics).flatten.map { convertOutputMetrics }, - shuffleReadMetrics = internal.shuffleReadMetrics.map { convertShuffleReadMetrics }, - shuffleWriteMetrics = internal.shuffleWriteMetrics.map { convertShuffleWriteMetrics } + inputMetrics = convertInputMetrics(internal.inputMetrics), + outputMetrics = convertOutputMetrics(internal.outputMetrics), + shuffleReadMetrics = convertShuffleReadMetrics(internal.shuffleReadMetrics), + shuffleWriteMetrics = convertShuffleWriteMetrics(internal.shuffleWriteMetrics) ) } @@ -277,7 +274,7 @@ private[v1] object AllStagesResource { localBlocksFetched = internal.localBlocksFetched, fetchWaitTime = internal.fetchWaitTime, remoteBytesRead = internal.remoteBytesRead, - totalBlocksFetched = internal.totalBlocksFetched, + localBytesRead = internal.localBytesRead, recordsRead = internal.recordsRead ) } @@ -292,31 +289,20 @@ private[v1] object AllStagesResource { } /** - * Helper for getting distributions from nested metric types. Many of the metrics we want are - * contained in options inside TaskMetrics (eg., ShuffleWriteMetrics). This makes it easy to handle - * the options (returning None if the metrics are all empty), and extract the quantiles for each - * metric. After creating an instance, call metricOption to get the result type. + * Helper for getting distributions from nested metric types. */ private[v1] abstract class MetricHelper[I, O]( rawMetrics: Seq[InternalTaskMetrics], quantiles: Array[Double]) { - def getSubmetrics(raw: InternalTaskMetrics): Option[I] + def getSubmetrics(raw: InternalTaskMetrics): I def build: O - val data: Seq[I] = rawMetrics.flatMap(getSubmetrics) + val data: Seq[I] = rawMetrics.map(getSubmetrics) /** applies the given function to all input metrics, and returns the quantiles */ def submetricQuantiles(f: I => Double): IndexedSeq[Double] = { Distribution(data.map { d => f(d) }).get.getQuantiles(quantiles) } - - def metricOption: Option[O] = { - if (data.isEmpty) { - None - } else { - Some(build) - } - } } diff --git a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala index ebbbf4814880f..ff28796a60f67 100644 --- a/core/src/main/scala/org/apache/spark/status/api/v1/api.scala +++ b/core/src/main/scala/org/apache/spark/status/api/v1/api.scala @@ -172,10 +172,10 @@ class TaskMetrics private[spark]( val resultSerializationTime: Long, val memoryBytesSpilled: Long, val diskBytesSpilled: Long, - val inputMetrics: Option[InputMetrics], - val outputMetrics: Option[OutputMetrics], - val shuffleReadMetrics: Option[ShuffleReadMetrics], - val shuffleWriteMetrics: Option[ShuffleWriteMetrics]) + val inputMetrics: InputMetrics, + val outputMetrics: OutputMetrics, + val shuffleReadMetrics: ShuffleReadMetrics, + val shuffleWriteMetrics: ShuffleWriteMetrics) class InputMetrics private[spark]( val bytesRead: Long, @@ -190,7 +190,7 @@ class ShuffleReadMetrics private[spark]( val localBlocksFetched: Int, val fetchWaitTime: Long, val remoteBytesRead: Long, - val totalBlocksFetched: Int, + val localBytesRead: Long, val recordsRead: Long) class ShuffleWriteMetrics private[spark]( @@ -209,10 +209,10 @@ class TaskMetricDistributions private[spark]( val memoryBytesSpilled: IndexedSeq[Double], val diskBytesSpilled: IndexedSeq[Double], - val inputMetrics: Option[InputMetricDistributions], - val outputMetrics: Option[OutputMetricDistributions], - val shuffleReadMetrics: Option[ShuffleReadMetricDistributions], - val shuffleWriteMetrics: Option[ShuffleWriteMetricDistributions]) + val inputMetrics: InputMetricDistributions, + val outputMetrics: OutputMetricDistributions, + val shuffleReadMetrics: ShuffleReadMetricDistributions, + val shuffleWriteMetrics: ShuffleWriteMetricDistributions) class InputMetricDistributions private[spark]( val bytesRead: IndexedSeq[Double], diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 35a6c63ad193e..22bc76b143516 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -260,7 +260,12 @@ private[spark] class BlockManager( def waitForAsyncReregister(): Unit = { val task = asyncReregisterTask if (task != null) { - Await.ready(task, Duration.Inf) + try { + Await.ready(task, Duration.Inf) + } catch { + case NonFatal(t) => + throw new Exception("Error occurred while waiting for async. reregistration", t) + } } } @@ -802,7 +807,12 @@ private[spark] class BlockManager( logDebug("Put block %s locally took %s".format(blockId, Utils.getUsedTimeMs(startTimeMs))) if (level.replication > 1) { // Wait for asynchronous replication to finish - Await.ready(replicationFuture, Duration.Inf) + try { + Await.ready(replicationFuture, Duration.Inf) + } catch { + case NonFatal(t) => + throw new Exception("Error occurred while waiting for replication to finish", t) + } } if (blockWasSuccessfullyStored) { None diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 4ec5b4bbb07cb..4dc2f362329a0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -108,7 +108,7 @@ final class ShuffleBlockFetcherIterator( /** Current number of requests in flight */ private[this] var reqsInFlight = 0 - private[this] val shuffleMetrics = context.taskMetrics().registerTempShuffleReadMetrics() + private[this] val shuffleMetrics = context.taskMetrics().createTempShuffleReadMetrics() /** * Whether the iterator is still active. If isZombie is true, the callback interface will no diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala index 119165f724f59..db24f0319ba05 100644 --- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala @@ -84,9 +84,7 @@ private[spark] object JettyUtils extends Logging { val result = servletParams.responder(request) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") response.setHeader("X-Frame-Options", xFrameOptionsValue) - // scalastyle:off println - response.getWriter.println(servletParams.extractFn(result)) - // scalastyle:on println + response.getWriter.print(servletParams.extractFn(result)) } else { response.setStatus(HttpServletResponse.SC_UNAUTHORIZED) response.setHeader("Cache-Control", "no-cache, no-store, must-revalidate") diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala index 28d277df4ae12..6241593bba32f 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala @@ -168,6 +168,7 @@ private[spark] object UIUtils extends Logging { + } def vizHeaderNodes: Seq[Node] = { diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala index 3fd0efd3a1e74..676f4457510c2 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala @@ -119,26 +119,19 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener, conf: Spar // Update shuffle read/write val metrics = taskEnd.taskMetrics if (metrics != null) { - metrics.inputMetrics.foreach { inputMetrics => - executorToInputBytes(eid) = - executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead - executorToInputRecords(eid) = - executorToInputRecords.getOrElse(eid, 0L) + inputMetrics.recordsRead - } - metrics.outputMetrics.foreach { outputMetrics => - executorToOutputBytes(eid) = - executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten - executorToOutputRecords(eid) = - executorToOutputRecords.getOrElse(eid, 0L) + outputMetrics.recordsWritten - } - metrics.shuffleReadMetrics.foreach { shuffleRead => - executorToShuffleRead(eid) = - executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead - } - metrics.shuffleWriteMetrics.foreach { shuffleWrite => - executorToShuffleWrite(eid) = - executorToShuffleWrite.getOrElse(eid, 0L) + shuffleWrite.bytesWritten - } + executorToInputBytes(eid) = + executorToInputBytes.getOrElse(eid, 0L) + metrics.inputMetrics.bytesRead + executorToInputRecords(eid) = + executorToInputRecords.getOrElse(eid, 0L) + metrics.inputMetrics.recordsRead + executorToOutputBytes(eid) = + executorToOutputBytes.getOrElse(eid, 0L) + metrics.outputMetrics.bytesWritten + executorToOutputRecords(eid) = + executorToOutputRecords.getOrElse(eid, 0L) + metrics.outputMetrics.recordsWritten + + executorToShuffleRead(eid) = + executorToShuffleRead.getOrElse(eid, 0L) + metrics.shuffleReadMetrics.remoteBytesRead + executorToShuffleWrite(eid) = + executorToShuffleWrite.getOrElse(eid, 0L) + metrics.shuffleWriteMetrics.bytesWritten executorToJvmGCTime(eid) = executorToJvmGCTime.getOrElse(eid, 0L) + metrics.jvmGCTime } } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala index bd4797ae8e0c5..645e2d2e360bb 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala @@ -203,7 +203,7 @@ private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") { // This could be empty if the JobProgressListener hasn't received information about the // stage or if the stage information has been garbage collected listener.stageIdToInfo.getOrElse(stageId, - new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown", Seq.empty)) + new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, Seq.empty, "Unknown")) } val activeStages = Buffer[StageInfo]() diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index 13f5f84d06feb..9e4771ce4ac51 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -434,50 +434,50 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging { val execSummary = stageData.executorSummary.getOrElseUpdate(execId, new ExecutorSummary) val shuffleWriteDelta = - (taskMetrics.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.bytesWritten).getOrElse(0L)) + taskMetrics.shuffleWriteMetrics.bytesWritten - + oldMetrics.map(_.shuffleWriteMetrics.bytesWritten).getOrElse(0L) stageData.shuffleWriteBytes += shuffleWriteDelta execSummary.shuffleWrite += shuffleWriteDelta val shuffleWriteRecordsDelta = - (taskMetrics.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleWriteMetrics).map(_.recordsWritten).getOrElse(0L)) + taskMetrics.shuffleWriteMetrics.recordsWritten - + oldMetrics.map(_.shuffleWriteMetrics.recordsWritten).getOrElse(0L) stageData.shuffleWriteRecords += shuffleWriteRecordsDelta execSummary.shuffleWriteRecords += shuffleWriteRecordsDelta val shuffleReadDelta = - (taskMetrics.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.totalBytesRead).getOrElse(0L)) + taskMetrics.shuffleReadMetrics.totalBytesRead - + oldMetrics.map(_.shuffleReadMetrics.totalBytesRead).getOrElse(0L) stageData.shuffleReadTotalBytes += shuffleReadDelta execSummary.shuffleRead += shuffleReadDelta val shuffleReadRecordsDelta = - (taskMetrics.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L) - - oldMetrics.flatMap(_.shuffleReadMetrics).map(_.recordsRead).getOrElse(0L)) + taskMetrics.shuffleReadMetrics.recordsRead - + oldMetrics.map(_.shuffleReadMetrics.recordsRead).getOrElse(0L) stageData.shuffleReadRecords += shuffleReadRecordsDelta execSummary.shuffleReadRecords += shuffleReadRecordsDelta val inputBytesDelta = - (taskMetrics.inputMetrics.map(_.bytesRead).getOrElse(0L) - - oldMetrics.flatMap(_.inputMetrics).map(_.bytesRead).getOrElse(0L)) + taskMetrics.inputMetrics.bytesRead - + oldMetrics.map(_.inputMetrics.bytesRead).getOrElse(0L) stageData.inputBytes += inputBytesDelta execSummary.inputBytes += inputBytesDelta val inputRecordsDelta = - (taskMetrics.inputMetrics.map(_.recordsRead).getOrElse(0L) - - oldMetrics.flatMap(_.inputMetrics).map(_.recordsRead).getOrElse(0L)) + taskMetrics.inputMetrics.recordsRead - + oldMetrics.map(_.inputMetrics.recordsRead).getOrElse(0L) stageData.inputRecords += inputRecordsDelta execSummary.inputRecords += inputRecordsDelta val outputBytesDelta = - (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L) - - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L)) + taskMetrics.outputMetrics.bytesWritten - + oldMetrics.map(_.outputMetrics.bytesWritten).getOrElse(0L) stageData.outputBytes += outputBytesDelta execSummary.outputBytes += outputBytesDelta val outputRecordsDelta = - (taskMetrics.outputMetrics.map(_.recordsWritten).getOrElse(0L) - - oldMetrics.flatMap(_.outputMetrics).map(_.recordsWritten).getOrElse(0L)) + taskMetrics.outputMetrics.recordsWritten - + oldMetrics.map(_.outputMetrics.recordsWritten).getOrElse(0L) stageData.outputRecords += outputRecordsDelta execSummary.outputRecords += outputRecordsDelta diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index 8a44bbd9fcd57..5d1928ac6b2ca 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -428,29 +428,29 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { } val inputSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.inputMetrics.map(_.bytesRead).getOrElse(0L).toDouble + taskUIData.metrics.get.inputMetrics.bytesRead.toDouble } val inputRecords = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.inputMetrics.map(_.recordsRead).getOrElse(0L).toDouble + taskUIData.metrics.get.inputMetrics.recordsRead.toDouble } val inputQuantiles = Input Size / Records +: getFormattedSizeQuantilesWithRecords(inputSizes, inputRecords) val outputSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.outputMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + taskUIData.metrics.get.outputMetrics.bytesWritten.toDouble } val outputRecords = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.outputMetrics.map(_.recordsWritten).getOrElse(0L).toDouble + taskUIData.metrics.get.outputMetrics.recordsWritten.toDouble } val outputQuantiles = Output Size / Records +: getFormattedSizeQuantilesWithRecords(outputSizes, outputRecords) val shuffleReadBlockedTimes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.map(_.fetchWaitTime).getOrElse(0L).toDouble + taskUIData.metrics.get.shuffleReadMetrics.fetchWaitTime.toDouble } val shuffleReadBlockedQuantiles = @@ -462,10 +462,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { getFormattedTimeQuantiles(shuffleReadBlockedTimes) val shuffleReadTotalSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.map(_.totalBytesRead).getOrElse(0L).toDouble + taskUIData.metrics.get.shuffleReadMetrics.totalBytesRead.toDouble } val shuffleReadTotalRecords = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.map(_.recordsRead).getOrElse(0L).toDouble + taskUIData.metrics.get.shuffleReadMetrics.recordsRead.toDouble } val shuffleReadTotalQuantiles = @@ -477,7 +477,7 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { getFormattedSizeQuantilesWithRecords(shuffleReadTotalSizes, shuffleReadTotalRecords) val shuffleReadRemoteSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleReadMetrics.map(_.remoteBytesRead).getOrElse(0L).toDouble + taskUIData.metrics.get.shuffleReadMetrics.remoteBytesRead.toDouble } val shuffleReadRemoteQuantiles = @@ -489,11 +489,11 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { getFormattedSizeQuantiles(shuffleReadRemoteSizes) val shuffleWriteSizes = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleWriteMetrics.map(_.bytesWritten).getOrElse(0L).toDouble + taskUIData.metrics.get.shuffleWriteMetrics.bytesWritten.toDouble } val shuffleWriteRecords = validTasks.map { taskUIData: TaskUIData => - taskUIData.metrics.get.shuffleWriteMetrics.map(_.recordsWritten).getOrElse(0L).toDouble + taskUIData.metrics.get.shuffleWriteMetrics.recordsWritten.toDouble } val shuffleWriteQuantiles = Shuffle Write Size / Records +: @@ -603,11 +603,10 @@ private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") { val metricsOpt = taskUIData.metrics val shuffleReadTime = - metricsOpt.flatMap(_.shuffleReadMetrics.map(_.fetchWaitTime)).getOrElse(0L) + metricsOpt.map(_.shuffleReadMetrics.fetchWaitTime).getOrElse(0L) val shuffleReadTimeProportion = toProportion(shuffleReadTime) val shuffleWriteTime = - (metricsOpt.flatMap(_.shuffleWriteMetrics - .map(_.writeTime)).getOrElse(0L) / 1e6).toLong + (metricsOpt.map(_.shuffleWriteMetrics.writeTime).getOrElse(0L) / 1e6).toLong val shuffleWriteTimeProportion = toProportion(shuffleWriteTime) val serializationTime = metricsOpt.map(_.resultSerializationTime).getOrElse(0L) @@ -890,21 +889,21 @@ private[ui] class TaskDataSource( } val peakExecutionMemoryUsed = metrics.map(_.peakExecutionMemory).getOrElse(0L) - val maybeInput = metrics.flatMap(_.inputMetrics) + val maybeInput = metrics.map(_.inputMetrics) val inputSortable = maybeInput.map(_.bytesRead).getOrElse(0L) val inputReadable = maybeInput - .map(m => s"${Utils.bytesToString(m.bytesRead)} (${m.readMethod.toString.toLowerCase()})") + .map(m => s"${Utils.bytesToString(m.bytesRead)}") .getOrElse("") val inputRecords = maybeInput.map(_.recordsRead.toString).getOrElse("") - val maybeOutput = metrics.flatMap(_.outputMetrics) + val maybeOutput = metrics.map(_.outputMetrics) val outputSortable = maybeOutput.map(_.bytesWritten).getOrElse(0L) val outputReadable = maybeOutput .map(m => s"${Utils.bytesToString(m.bytesWritten)}") .getOrElse("") val outputRecords = maybeOutput.map(_.recordsWritten.toString).getOrElse("") - val maybeShuffleRead = metrics.flatMap(_.shuffleReadMetrics) + val maybeShuffleRead = metrics.map(_.shuffleReadMetrics) val shuffleReadBlockedTimeSortable = maybeShuffleRead.map(_.fetchWaitTime).getOrElse(0L) val shuffleReadBlockedTimeReadable = maybeShuffleRead.map(ms => UIUtils.formatDuration(ms.fetchWaitTime)).getOrElse("") @@ -918,14 +917,14 @@ private[ui] class TaskDataSource( val shuffleReadRemoteSortable = remoteShuffleBytes.getOrElse(0L) val shuffleReadRemoteReadable = remoteShuffleBytes.map(Utils.bytesToString).getOrElse("") - val maybeShuffleWrite = metrics.flatMap(_.shuffleWriteMetrics) + val maybeShuffleWrite = metrics.map(_.shuffleWriteMetrics) val shuffleWriteSortable = maybeShuffleWrite.map(_.bytesWritten).getOrElse(0L) val shuffleWriteReadable = maybeShuffleWrite .map(m => s"${Utils.bytesToString(m.bytesWritten)}").getOrElse("") val shuffleWriteRecords = maybeShuffleWrite .map(_.recordsWritten.toString).getOrElse("") - val maybeWriteTime = metrics.flatMap(_.shuffleWriteMetrics).map(_.writeTime) + val maybeWriteTime = metrics.map(_.shuffleWriteMetrics.writeTime) val writeTimeSortable = maybeWriteTime.getOrElse(0L) val writeTimeReadable = maybeWriteTime.map(t => t / (1000 * 1000)).map { ms => if (ms == 0) "" else UIUtils.formatDuration(ms) diff --git a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala index bb6b663f1ead3..84ca750e1a96a 100644 --- a/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala +++ b/core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraph.scala @@ -17,6 +17,8 @@ package org.apache.spark.ui.scope +import java.util.Objects + import scala.collection.mutable import scala.collection.mutable.{ListBuffer, StringBuilder} @@ -72,6 +74,22 @@ private[ui] class RDDOperationCluster(val id: String, private var _name: String) def getCachedNodes: Seq[RDDOperationNode] = { _childNodes.filter(_.cached) ++ _childClusters.flatMap(_.getCachedNodes) } + + def canEqual(other: Any): Boolean = other.isInstanceOf[RDDOperationCluster] + + override def equals(other: Any): Boolean = other match { + case that: RDDOperationCluster => + (that canEqual this) && + _childClusters == that._childClusters && + id == that.id && + _name == that._name + case _ => false + } + + override def hashCode(): Int = { + val state = Seq(_childClusters, id, _name) + state.map(Objects.hashCode).foldLeft(0)((a, b) => 31 * a + b) + } } private[ui] object RDDOperationGraph extends Logging { diff --git a/core/src/main/scala/org/apache/spark/util/Benchmark.scala b/core/src/main/scala/org/apache/spark/util/Benchmark.scala index 9e40bafd521d7..1fc0ad7a4d6d3 100644 --- a/core/src/main/scala/org/apache/spark/util/Benchmark.scala +++ b/core/src/main/scala/org/apache/spark/util/Benchmark.scala @@ -42,7 +42,24 @@ private[spark] class Benchmark( outputPerIteration: Boolean = false) { val benchmarks = mutable.ArrayBuffer.empty[Benchmark.Case] + /** + * Adds a case to run when run() is called. The given function will be run for several + * iterations to collect timing statistics. + */ def addCase(name: String)(f: Int => Unit): Unit = { + addTimerCase(name) { timer => + timer.startTiming() + f(timer.iteration) + timer.stopTiming() + } + } + + /** + * Adds a case with manual timing control. When the function is run, timing does not start + * until timer.startTiming() is called within the given function. The corresponding + * timer.stopTiming() method must be called before the function returns. + */ + def addTimerCase(name: String)(f: Benchmark.Timer => Unit): Unit = { benchmarks += Benchmark.Case(name, f) } @@ -84,7 +101,34 @@ private[spark] class Benchmark( } private[spark] object Benchmark { - case class Case(name: String, fn: Int => Unit) + + /** + * Object available to benchmark code to control timing e.g. to exclude set-up time. + * + * @param iteration specifies this is the nth iteration of running the benchmark case + */ + class Timer(val iteration: Int) { + private var accumulatedTime: Long = 0L + private var timeStart: Long = 0L + + def startTiming(): Unit = { + assert(timeStart == 0L, "Already started timing.") + timeStart = System.nanoTime + } + + def stopTiming(): Unit = { + assert(timeStart != 0L, "Have not started timing.") + accumulatedTime += System.nanoTime - timeStart + timeStart = 0L + } + + def totalTime(): Long = { + assert(timeStart == 0L, "Have not stopped timing.") + accumulatedTime + } + } + + case class Case(name: String, fn: Timer => Unit) case class Result(avgMs: Double, bestRate: Double, bestMs: Double) /** @@ -96,9 +140,9 @@ private[spark] object Benchmark { Utils.executeAndGetOutput(Seq("/usr/sbin/sysctl", "-n", "machdep.cpu.brand_string")) } else if (SystemUtils.IS_OS_LINUX) { Try { - val grepPath = Utils.executeAndGetOutput(Seq("which", "grep")) + val grepPath = Utils.executeAndGetOutput(Seq("which", "grep")).stripLineEnd Utils.executeAndGetOutput(Seq(grepPath, "-m", "1", "model name", "/proc/cpuinfo")) - .replaceFirst("model name[\\s*]:[\\s*]", "") + .stripLineEnd.replaceFirst("model name[\\s*]:[\\s*]", "") }.getOrElse("Unknown processor") } else { System.getenv("PROCESSOR_IDENTIFIER") @@ -123,15 +167,12 @@ private[spark] object Benchmark { * Runs a single function `f` for iters, returning the average time the function took and * the rate of the function. */ - def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Int => Unit): Result = { + def measure(num: Long, iters: Int, outputPerIteration: Boolean)(f: Timer => Unit): Result = { val runTimes = ArrayBuffer[Long]() for (i <- 0 until iters + 1) { - val start = System.nanoTime() - - f(i) - - val end = System.nanoTime() - val runTime = end - start + val timer = new Benchmark.Timer(i) + f(timer) + val runTime = timer.totalTime() if (i > 0) { runTimes += runTime } diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala index 558767e36f7da..6c50c72a91ef2 100644 --- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala +++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala @@ -304,20 +304,17 @@ private[spark] object JsonProtocol { * The behavior here must match that of [[accumValueFromJson]]. Exposed for testing. */ private[util] def accumValueToJson(name: Option[String], value: Any): JValue = { - import AccumulatorParam._ if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) { - (value, InternalAccumulator.getParam(name.get)) match { - case (v: Int, IntAccumulatorParam) => JInt(v) - case (v: Long, LongAccumulatorParam) => JInt(v) - case (v: String, StringAccumulatorParam) => JString(v) - case (v, UpdatedBlockStatusesAccumulatorParam) => + value match { + case v: Int => JInt(v) + case v: Long => JInt(v) + // We only have 3 kind of internal accumulator types, so if it's not int or long, it must be + // the blocks accumulator, whose type is `Seq[(BlockId, BlockStatus)]` + case v => JArray(v.asInstanceOf[Seq[(BlockId, BlockStatus)]].toList.map { case (id, status) => ("Block ID" -> id.toString) ~ ("Status" -> blockStatusToJson(status)) }) - case (v, p) => - throw new IllegalArgumentException(s"unexpected combination of accumulator value " + - s"type (${v.getClass.getName}) and param (${p.getClass.getName}) in '${name.get}'") } } else { // For all external accumulators, just use strings @@ -327,36 +324,26 @@ private[spark] object JsonProtocol { def taskMetricsToJson(taskMetrics: TaskMetrics): JValue = { val shuffleReadMetrics: JValue = - taskMetrics.shuffleReadMetrics.map { rm => - ("Remote Blocks Fetched" -> rm.remoteBlocksFetched) ~ - ("Local Blocks Fetched" -> rm.localBlocksFetched) ~ - ("Fetch Wait Time" -> rm.fetchWaitTime) ~ - ("Remote Bytes Read" -> rm.remoteBytesRead) ~ - ("Local Bytes Read" -> rm.localBytesRead) ~ - ("Total Records Read" -> rm.recordsRead) - }.getOrElse(JNothing) + ("Remote Blocks Fetched" -> taskMetrics.shuffleReadMetrics.remoteBlocksFetched) ~ + ("Local Blocks Fetched" -> taskMetrics.shuffleReadMetrics.localBlocksFetched) ~ + ("Fetch Wait Time" -> taskMetrics.shuffleReadMetrics.fetchWaitTime) ~ + ("Remote Bytes Read" -> taskMetrics.shuffleReadMetrics.remoteBytesRead) ~ + ("Local Bytes Read" -> taskMetrics.shuffleReadMetrics.localBytesRead) ~ + ("Total Records Read" -> taskMetrics.shuffleReadMetrics.recordsRead) val shuffleWriteMetrics: JValue = - taskMetrics.shuffleWriteMetrics.map { wm => - ("Shuffle Bytes Written" -> wm.bytesWritten) ~ - ("Shuffle Write Time" -> wm.writeTime) ~ - ("Shuffle Records Written" -> wm.recordsWritten) - }.getOrElse(JNothing) + ("Shuffle Bytes Written" -> taskMetrics.shuffleWriteMetrics.bytesWritten) ~ + ("Shuffle Write Time" -> taskMetrics.shuffleWriteMetrics.writeTime) ~ + ("Shuffle Records Written" -> taskMetrics.shuffleWriteMetrics.recordsWritten) val inputMetrics: JValue = - taskMetrics.inputMetrics.map { im => - ("Data Read Method" -> im.readMethod.toString) ~ - ("Bytes Read" -> im.bytesRead) ~ - ("Records Read" -> im.recordsRead) - }.getOrElse(JNothing) + ("Bytes Read" -> taskMetrics.inputMetrics.bytesRead) ~ + ("Records Read" -> taskMetrics.inputMetrics.recordsRead) val outputMetrics: JValue = - taskMetrics.outputMetrics.map { om => - ("Data Write Method" -> om.writeMethod.toString) ~ - ("Bytes Written" -> om.bytesWritten) ~ - ("Records Written" -> om.recordsWritten) - }.getOrElse(JNothing) + ("Bytes Written" -> taskMetrics.outputMetrics.bytesWritten) ~ + ("Records Written" -> taskMetrics.outputMetrics.recordsWritten) val updatedBlocks = JArray(taskMetrics.updatedBlockStatuses.toList.map { case (id, status) => ("Block ID" -> id.toString) ~ - ("Status" -> blockStatusToJson(status)) + ("Status" -> blockStatusToJson(status)) }) ("Executor Deserialize Time" -> taskMetrics.executorDeserializeTime) ~ ("Executor Run Time" -> taskMetrics.executorRunTime) ~ @@ -579,7 +566,7 @@ private[spark] object JsonProtocol { val stageInfos = Utils.jsonOption(json \ "Stage Infos") .map(_.extract[Seq[JValue]].map(stageInfoFromJson)).getOrElse { stageIds.map { id => - new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown", Seq.empty) + new StageInfo(id, 0, "unknown", 0, Seq.empty, Seq.empty, "unknown") } } SparkListenerJobStart(jobId, submissionTime, stageInfos, properties) @@ -688,7 +675,7 @@ private[spark] object JsonProtocol { } val stageInfo = new StageInfo( - stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details, Seq.empty) + stageId, attemptId, stageName, numTasks, rddInfos, parentIds, details) stageInfo.submissionTime = submissionTime stageInfo.completionTime = completionTime stageInfo.failureReason = failureReason @@ -745,25 +732,21 @@ private[spark] object JsonProtocol { * The behavior here must match that of [[accumValueToJson]]. Exposed for testing. */ private[util] def accumValueFromJson(name: Option[String], value: JValue): Any = { - import AccumulatorParam._ if (name.exists(_.startsWith(InternalAccumulator.METRICS_PREFIX))) { - (value, InternalAccumulator.getParam(name.get)) match { - case (JInt(v), IntAccumulatorParam) => v.toInt - case (JInt(v), LongAccumulatorParam) => v.toLong - case (JString(v), StringAccumulatorParam) => v - case (JArray(v), UpdatedBlockStatusesAccumulatorParam) => + value match { + case JInt(v) => v.toLong + case JArray(v) => v.map { blockJson => val id = BlockId((blockJson \ "Block ID").extract[String]) val status = blockStatusFromJson(blockJson \ "Status") (id, status) } - case (v, p) => - throw new IllegalArgumentException(s"unexpected combination of accumulator " + - s"value in JSON ($v) and accumulator param (${p.getClass.getName}) in '${name.get}'") - } - } else { - value.extract[String] - } + case _ => throw new IllegalArgumentException(s"unexpected json value $value for " + + "accumulator " + name.get) + } + } else { + value.extract[String] + } } def taskMetricsFromJson(json: JValue): TaskMetrics = { @@ -781,7 +764,7 @@ private[spark] object JsonProtocol { // Shuffle read metrics Utils.jsonOption(json \ "Shuffle Read Metrics").foreach { readJson => - val readMetrics = metrics.registerTempShuffleReadMetrics() + val readMetrics = metrics.createTempShuffleReadMetrics() readMetrics.incRemoteBlocksFetched((readJson \ "Remote Blocks Fetched").extract[Int]) readMetrics.incLocalBlocksFetched((readJson \ "Local Blocks Fetched").extract[Int]) readMetrics.incRemoteBytesRead((readJson \ "Remote Bytes Read").extract[Long]) @@ -794,7 +777,7 @@ private[spark] object JsonProtocol { // Shuffle write metrics // TODO: Drop the redundant "Shuffle" since it's inconsistent with related classes. Utils.jsonOption(json \ "Shuffle Write Metrics").foreach { writeJson => - val writeMetrics = metrics.registerShuffleWriteMetrics() + val writeMetrics = metrics.shuffleWriteMetrics writeMetrics.incBytesWritten((writeJson \ "Shuffle Bytes Written").extract[Long]) writeMetrics.incRecordsWritten((writeJson \ "Shuffle Records Written") .extractOpt[Long].getOrElse(0L)) @@ -803,16 +786,14 @@ private[spark] object JsonProtocol { // Output metrics Utils.jsonOption(json \ "Output Metrics").foreach { outJson => - val writeMethod = DataWriteMethod.withName((outJson \ "Data Write Method").extract[String]) - val outputMetrics = metrics.registerOutputMetrics(writeMethod) + val outputMetrics = metrics.outputMetrics outputMetrics.setBytesWritten((outJson \ "Bytes Written").extract[Long]) outputMetrics.setRecordsWritten((outJson \ "Records Written").extractOpt[Long].getOrElse(0L)) } // Input metrics Utils.jsonOption(json \ "Input Metrics").foreach { inJson => - val readMethod = DataReadMethod.withName((inJson \ "Data Read Method").extract[String]) - val inputMetrics = metrics.registerInputMetrics(readMethod) + val inputMetrics = metrics.inputMetrics inputMetrics.incBytesRead((inJson \ "Bytes Read").extract[Long]) inputMetrics.incRecordsRead((inJson \ "Records Read").extractOpt[Long].getOrElse(0L)) } diff --git a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala index 9abbf4a7a3971..5a6dbc830448a 100644 --- a/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/ThreadUtils.scala @@ -19,12 +19,15 @@ package org.apache.spark.util import java.util.concurrent._ -import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.{Await, Awaitable, ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.duration.Duration import scala.concurrent.forkjoin.{ForkJoinPool => SForkJoinPool, ForkJoinWorkerThread => SForkJoinWorkerThread} import scala.util.control.NonFatal import com.google.common.util.concurrent.{MoreExecutors, ThreadFactoryBuilder} +import org.apache.spark.SparkException + private[spark] object ThreadUtils { private val sameThreadExecutionContext = @@ -174,4 +177,21 @@ private[spark] object ThreadUtils { false // asyncMode ) } + + // scalastyle:off awaitresult + /** + * Preferred alternative to [[Await.result()]]. This method wraps and re-throws any exceptions + * thrown by the underlying [[Await]] call, ensuring that this thread's stack trace appears in + * logs. + */ + @throws(classOf[SparkException]) + def awaitResult[T](awaitable: Awaitable[T], atMost: Duration): T = { + try { + Await.result(awaitable, atMost) + // scalastyle:on awaitresult + } catch { + case NonFatal(t) => + throw new SparkException("Exception thrown in awaitResult: ", t) + } + } } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 78e164cff7738..848f7d7adbc7e 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -1598,6 +1598,7 @@ private[spark] object Utils extends Logging { /** * Timing method based on iterations that permit JVM JIT optimization. + * * @param numIters number of iterations * @param f function to be executed. If prepare is not None, the running time of each call to f * must be an order of magnitude longer than one millisecond for accurate timing. @@ -1639,6 +1640,7 @@ private[spark] object Utils extends Logging { /** * Creates a symlink. + * * @param src absolute path to the source * @param dst relative path for the destination */ diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 561ba22df557f..916053f42d072 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -645,7 +645,7 @@ private[spark] class ExternalSorter[K, V, C]( blockId: BlockId, outputFile: File): Array[Long] = { - val writeMetrics = context.taskMetrics().registerShuffleWriteMetrics() + val writeMetrics = context.taskMetrics().shuffleWriteMetrics // Track location of each range in the output file val lengths = new Array[Long](numPartitions) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index 22d7a4988bb56..10ab0b3f89964 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -25,6 +25,9 @@ import scala.reflect.ClassTag * space overhead. * * Under the hood, it uses our OpenHashSet implementation. + * + * NOTE: when using numeric type as the value type, the user of this class should be careful to + * distinguish between the 0/0.0/0L and non-exist value */ private[spark] class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 30750b1bf1980..fbaaa1cf49982 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -249,8 +249,8 @@ public void writeEmptyIterator() throws Exception { assertTrue(mapStatus.isDefined()); assertTrue(mergedOutputFile.exists()); assertArrayEquals(new long[NUM_PARTITITONS], partitionSizesInMergedFile); - assertEquals(0, taskMetrics.shuffleWriteMetrics().get().recordsWritten()); - assertEquals(0, taskMetrics.shuffleWriteMetrics().get().bytesWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().recordsWritten()); + assertEquals(0, taskMetrics.shuffleWriteMetrics().bytesWritten()); assertEquals(0, taskMetrics.diskBytesSpilled()); assertEquals(0, taskMetrics.memoryBytesSpilled()); } @@ -279,7 +279,7 @@ public void writeWithoutSpilling() throws Exception { HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); - ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertEquals(0, taskMetrics.diskBytesSpilled()); assertEquals(0, taskMetrics.memoryBytesSpilled()); @@ -321,7 +321,7 @@ private void testMergingSpills( assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); assertSpillFilesWereCleanedUp(); - ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); @@ -383,7 +383,7 @@ public void writeEnoughDataToTriggerSpill() throws Exception { writer.stop(true); readRecordsFromFile(); assertSpillFilesWereCleanedUp(); - ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); @@ -404,7 +404,7 @@ public void writeEnoughRecordsToTriggerSortBufferExpansionAndSpill() throws Exce writer.stop(true); readRecordsFromFile(); assertSpillFilesWereCleanedUp(); - ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics().get(); + ShuffleWriteMetrics shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics(); assertEquals(dataToWrite.size(), shuffleWriteMetrics.recordsWritten()); assertThat(taskMetrics.diskBytesSpilled(), greaterThan(0L)); assertThat(taskMetrics.diskBytesSpilled(), lessThan(mergedOutputFile.length())); diff --git a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json index 1a13233133b1e..cba44c848e012 100644 --- a/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/application_list_json_expectation.json @@ -2,108 +2,108 @@ "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1430917380893, - "endTimeEpoch" : 1430917391398, - "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", "lastUpdated" : "", "duration" : 10505, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1430917381535", "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", - "startTimeEpoch" : 1430917380893, - "endTimeEpoch" : 1430917380950, - "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", "lastUpdated" : "", "duration" : 57, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", - "startTimeEpoch" : 1430917380880, - "endTimeEpoch" : 1430917380890, - "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", "lastUpdated" : "", "duration" : 10, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", - "startTimeEpoch" : 1426633910242, - "endTimeEpoch" : 1426633945177, - "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", "lastUpdated" : "", "duration" : 34935, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", - "startTimeEpoch" : 1426533910242, - "endTimeEpoch" : 1426533945177, - "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", "lastUpdated" : "", "duration" : 34935, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1425081759269", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1425081758277, - "endTimeEpoch" : 1425081766912, - "lastUpdatedEpoch" : 0, "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", "lastUpdated" : "", "duration" : 8635, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1425081758277, + "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1422981779720, - "endTimeEpoch" : 1422981788731, - "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", "lastUpdated" : "", "duration" : 9011, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1422981758277, - "endTimeEpoch" : 1422981766912, - "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", "lastUpdated" : "", "duration" : 8635, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json index 1a13233133b1e..cba44c848e012 100644 --- a/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/completed_app_list_json_expectation.json @@ -2,108 +2,108 @@ "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1430917380893, - "endTimeEpoch" : 1430917391398, - "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", "lastUpdated" : "", "duration" : 10505, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1430917381535", "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", - "startTimeEpoch" : 1430917380893, - "endTimeEpoch" : 1430917380950, - "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", "lastUpdated" : "", "duration" : 57, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", - "startTimeEpoch" : 1430917380880, - "endTimeEpoch" : 1430917380890, - "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", "lastUpdated" : "", "duration" : 10, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", - "startTimeEpoch" : 1426633910242, - "endTimeEpoch" : 1426633945177, - "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", "lastUpdated" : "", "duration" : 34935, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", - "startTimeEpoch" : 1426533910242, - "endTimeEpoch" : 1426533945177, - "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", "lastUpdated" : "", "duration" : 34935, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1425081759269", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1425081758277, - "endTimeEpoch" : 1425081766912, - "lastUpdatedEpoch" : 0, "startTime" : "2015-02-28T00:02:38.277GMT", "endTime" : "2015-02-28T00:02:46.912GMT", "lastUpdated" : "", "duration" : 8635, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1425081758277, + "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1422981779720, - "endTimeEpoch" : 1422981788731, - "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", "lastUpdated" : "", "duration" : 9011, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1422981758277, - "endTimeEpoch" : 1422981766912, - "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", "lastUpdated" : "", "duration" : 8635, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json index efc865919b0d7..e7db6742c25e1 100644 --- a/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/executor_list_json_expectation.json @@ -18,4 +18,4 @@ "totalShuffleWrite" : 13180, "maxMemory" : 278302556, "executorLogs" : { } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json index 2e92e1fa0ec23..bb6bf434be90b 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_1__expectation.json @@ -12,4 +12,4 @@ "numCompletedStages" : 1, "numSkippedStages" : 0, "numFailedStages" : 0 -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json index 2e92e1fa0ec23..bb6bf434be90b 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_from_multi_attempt_app_json_2__expectation.json @@ -12,4 +12,4 @@ "numCompletedStages" : 1, "numSkippedStages" : 0, "numFailedStages" : 0 -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json index cab4750270dfa..1583e5ddef565 100644 --- a/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/job_list_json_expectation.json @@ -40,4 +40,4 @@ "numCompletedStages" : 1, "numSkippedStages" : 0, "numFailedStages" : 0 -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json index eacf04b9016ac..a525d61543a88 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate2_app_list_json_expectation.json @@ -2,14 +2,14 @@ "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1422981758277, - "endTimeEpoch" : 1422981766912, - "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", "lastUpdated" : "", "duration" : 8635, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json index adad25bf17fd5..cc567f66f02e8 100644 --- a/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/maxDate_app_list_json_expectation.json @@ -2,28 +2,28 @@ "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1422981779720, - "endTimeEpoch" : 1422981788731, - "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", "lastUpdated" : "", "duration" : 9011, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1422981759269", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1422981758277, - "endTimeEpoch" : 1422981766912, - "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:38.277GMT", "endTime" : "2015-02-03T16:42:46.912GMT", "lastUpdated" : "", "duration" : 8635, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1422981758277, + "endTimeEpoch" : 1422981766912, + "lastUpdatedEpoch" : 0 } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json index a658909088a4a..c934a871724b5 100644 --- a/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/minDate_app_list_json_expectation.json @@ -2,82 +2,80 @@ "id" : "local-1430917381534", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1430917380893, - "endTimeEpoch" : 1430917391398, - "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:11.398GMT", "lastUpdated" : "", "duration" : 10505, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917391398, + "lastUpdatedEpoch" : 0 } ] -}, { +}, { "id" : "local-1430917381535", "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", - "startTimeEpoch" : 1430917380893, - "endTimeEpoch" : 1430917380950, - "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.893GMT", "endTime" : "2015-05-06T13:03:00.950GMT", "lastUpdated" : "", "duration" : 57, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1430917380893, + "endTimeEpoch" : 1430917380950, + "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", - "startTimeEpoch" : 1430917380880, - "endTimeEpoch" : 1430917380890, - "lastUpdatedEpoch" : 0, "startTime" : "2015-05-06T13:03:00.880GMT", "endTime" : "2015-05-06T13:03:00.890GMT", "lastUpdated" : "", "duration" : 10, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1430917380880, + "endTimeEpoch" : 1430917380890, + "lastUpdatedEpoch" : 0 } ] }, { "id" : "local-1426533911241", "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", - "startTimeEpoch" : 1426633910242, - "endTimeEpoch" : 1426633945177, - "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", "lastUpdated" : "", "duration" : 34935, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", - "startTimeEpoch" : 1426533910242, - "endTimeEpoch" : 1426533945177, - "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", "lastUpdated" : "", "duration" : 34935, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0 } ] }, { - "id": "local-1425081759269", - "name": "Spark shell", - "attempts": [ - { - "startTimeEpoch" : 1425081758277, - "endTimeEpoch" : 1425081766912, - "lastUpdatedEpoch" : 0, - "startTime": "2015-02-28T00:02:38.277GMT", - "endTime": "2015-02-28T00:02:46.912GMT", - "lastUpdated" : "", - "duration" : 8635, - "sparkUser": "irashid", - "completed": true - } - ] + "id" : "local-1425081759269", + "name" : "Spark shell", + "attempts" : [ { + "startTime" : "2015-02-28T00:02:38.277GMT", + "endTime" : "2015-02-28T00:02:46.912GMT", + "lastUpdated" : "", + "duration" : 8635, + "sparkUser" : "irashid", + "completed" : true, + "startTimeEpoch" : 1425081758277, + "endTimeEpoch" : 1425081766912, + "lastUpdatedEpoch" : 0 + } ] } ] diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json index 0217facad9ded..f486d46313d8b 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_json_expectation.json @@ -2,14 +2,14 @@ "id" : "local-1422981780767", "name" : "Spark shell", "attempts" : [ { - "startTimeEpoch" : 1422981779720, - "endTimeEpoch" : 1422981788731, - "lastUpdatedEpoch" : 0, "startTime" : "2015-02-03T16:42:59.720GMT", "endTime" : "2015-02-03T16:43:08.731GMT", "lastUpdated" : "", "duration" : 9011, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1422981779720, + "endTimeEpoch" : 1422981788731, + "lastUpdatedEpoch" : 0 } ] } diff --git a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json index b20a26648e430..e63039f6a17fc 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_app_multi_attempt_json_expectation.json @@ -3,25 +3,25 @@ "name" : "Spark shell", "attempts" : [ { "attemptId" : "2", - "startTimeEpoch" : 1426633910242, - "endTimeEpoch" : 1426633945177, - "lastUpdatedEpoch" : 0, "startTime" : "2015-03-17T23:11:50.242GMT", "endTime" : "2015-03-17T23:12:25.177GMT", "lastUpdated" : "", "duration" : 34935, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1426633910242, + "endTimeEpoch" : 1426633945177, + "lastUpdatedEpoch" : 0 }, { "attemptId" : "1", - "startTimeEpoch" : 1426533910242, - "endTimeEpoch" : 1426533945177, - "lastUpdatedEpoch" : 0, "startTime" : "2015-03-16T19:25:10.242GMT", "endTime" : "2015-03-16T19:25:45.177GMT", "lastUpdated" : "", "duration" : 34935, "sparkUser" : "irashid", - "completed" : true + "completed" : true, + "startTimeEpoch" : 1426533910242, + "endTimeEpoch" : 1426533945177, + "lastUpdatedEpoch" : 0 } ] } diff --git a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json index 4a29072bdb6e4..f1f0ec885587b 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_job_json_expectation.json @@ -12,4 +12,4 @@ "numCompletedStages" : 1, "numSkippedStages" : 0, "numFailedStages" : 0 -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/one_rdd_storage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_rdd_storage_json_expectation.json deleted file mode 100644 index 38b5328ffbb03..0000000000000 --- a/core/src/test/resources/HistoryServerExpectations/one_rdd_storage_json_expectation.json +++ /dev/null @@ -1,64 +0,0 @@ -{ - "id" : 0, - "name" : "0", - "numPartitions" : 8, - "numCachedPartitions" : 8, - "storageLevel" : "Memory Deserialized 1x Replicated", - "memoryUsed" : 28000128, - "diskUsed" : 0, - "dataDistribution" : [ { - "address" : "localhost:57971", - "memoryUsed" : 28000128, - "memoryRemaining" : 250302428, - "diskUsed" : 0 - } ], - "partitions" : [ { - "blockName" : "rdd_0_0", - "storageLevel" : "Memory Deserialized 1x Replicated", - "memoryUsed" : 3500016, - "diskUsed" : 0, - "executors" : [ "localhost:57971" ] - }, { - "blockName" : "rdd_0_1", - "storageLevel" : "Memory Deserialized 1x Replicated", - "memoryUsed" : 3500016, - "diskUsed" : 0, - "executors" : [ "localhost:57971" ] - }, { - "blockName" : "rdd_0_2", - "storageLevel" : "Memory Deserialized 1x Replicated", - "memoryUsed" : 3500016, - "diskUsed" : 0, - "executors" : [ "localhost:57971" ] - }, { - "blockName" : "rdd_0_3", - "storageLevel" : "Memory Deserialized 1x Replicated", - "memoryUsed" : 3500016, - "diskUsed" : 0, - "executors" : [ "localhost:57971" ] - }, { - "blockName" : "rdd_0_4", - "storageLevel" : "Memory Deserialized 1x Replicated", - "memoryUsed" : 3500016, - "diskUsed" : 0, - "executors" : [ "localhost:57971" ] - }, { - "blockName" : "rdd_0_5", - "storageLevel" : "Memory Deserialized 1x Replicated", - "memoryUsed" : 3500016, - "diskUsed" : 0, - "executors" : [ "localhost:57971" ] - }, { - "blockName" : "rdd_0_6", - "storageLevel" : "Memory Deserialized 1x Replicated", - "memoryUsed" : 3500016, - "diskUsed" : 0, - "executors" : [ "localhost:57971" ] - }, { - "blockName" : "rdd_0_7", - "storageLevel" : "Memory Deserialized 1x Replicated", - "memoryUsed" : 3500016, - "diskUsed" : 0, - "executors" : [ "localhost:57971" ] - } ] -} \ No newline at end of file diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json index b07011d4f113f..477a2fec8b69b 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_attempt_json_expectation.json @@ -46,6 +46,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 94000, @@ -75,6 +87,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1647, "writeTime" : 83000, @@ -104,6 +128,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 88000, @@ -133,6 +169,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 73000, @@ -162,6 +210,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 76000, @@ -191,6 +251,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 98000, @@ -220,6 +292,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1645, "writeTime" : 101000, @@ -249,6 +333,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 79000, diff --git a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json index 2f71520549e1f..388e51f77a24d 100644 --- a/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/one_stage_json_expectation.json @@ -46,6 +46,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 94000, @@ -75,6 +87,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1647, "writeTime" : 83000, @@ -104,6 +128,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 88000, @@ -133,6 +169,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 73000, @@ -162,6 +210,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 76000, @@ -191,6 +251,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 98000, @@ -220,6 +292,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1645, "writeTime" : 101000, @@ -249,6 +333,18 @@ "bytesRead" : 3500016, "recordsRead" : 0 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1648, "writeTime" : 79000, diff --git a/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json index 8878e547a7984..1e3ec7217afba 100644 --- a/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/rdd_list_storage_json_expectation.json @@ -1 +1 @@ -[ ] \ No newline at end of file +[ ] diff --git a/core/src/test/resources/HistoryServerExpectations/running_app_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/running_app_list_json_expectation.json index 8878e547a7984..1e3ec7217afba 100644 --- a/core/src/test/resources/HistoryServerExpectations/running_app_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/running_app_list_json_expectation.json @@ -1 +1 @@ -[ ] \ No newline at end of file +[ ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json index f2cb29b31c85f..8e09aabbad7c9 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_expectation.json @@ -20,6 +20,18 @@ "bytesRead" : 49294, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 3842811, @@ -48,6 +60,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 3934399, @@ -76,6 +100,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 89885, @@ -104,6 +140,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 1311694, @@ -132,6 +180,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 83022, @@ -160,6 +220,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 3675510, @@ -188,6 +260,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 4016617, @@ -216,6 +300,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 2579051, @@ -244,6 +340,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 121551, @@ -272,6 +380,18 @@ "bytesRead" : 60489, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 101664, @@ -300,6 +420,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 94709, @@ -328,6 +460,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 94507, @@ -356,6 +500,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 102476, @@ -384,6 +540,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95004, @@ -412,6 +580,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95646, @@ -440,6 +620,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 602780, @@ -468,6 +660,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 108320, @@ -496,6 +700,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 99944, @@ -524,6 +740,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 100836, @@ -552,10 +780,22 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95788, "recordsWritten" : 10 } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json index c3febc5fc9447..1dbf72b42a926 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_1__expectation.json @@ -20,7 +20,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 1, @@ -44,7 +65,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 2, @@ -68,7 +110,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 3, @@ -92,7 +155,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 4, @@ -116,7 +200,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 5, @@ -140,7 +245,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 6, @@ -164,7 +290,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 7, @@ -188,6 +335,27 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json index 56d667d88917c..483492282dd64 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_from_multi_attempt_app_json_2__expectation.json @@ -20,7 +20,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 1, @@ -44,7 +65,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 2, @@ -68,7 +110,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 3, @@ -92,7 +155,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 4, @@ -116,7 +200,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 5, @@ -140,7 +245,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 6, @@ -164,7 +290,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, { "taskId" : 7, @@ -188,6 +335,27 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json index e5ec3bc4c7126..624f2bb16df48 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__offset___length_expectation.json @@ -20,6 +20,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 94709, @@ -48,6 +60,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 94507, @@ -76,6 +100,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 102476, @@ -104,6 +140,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95004, @@ -132,6 +180,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95646, @@ -160,6 +220,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 602780, @@ -188,6 +260,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 108320, @@ -216,6 +300,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 99944, @@ -244,6 +340,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 100836, @@ -272,6 +380,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95788, @@ -300,6 +420,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 97716, @@ -328,6 +460,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 100270, @@ -356,6 +500,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 143427, @@ -384,6 +540,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 91844, @@ -412,6 +580,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 157194, @@ -440,6 +620,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 94134, @@ -468,6 +660,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 108213, @@ -496,6 +700,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 102019, @@ -524,6 +740,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 104299, @@ -552,6 +780,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 114938, @@ -580,6 +820,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 119770, @@ -608,6 +860,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 92619, @@ -636,6 +900,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 89603, @@ -664,6 +940,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 118329, @@ -692,6 +980,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 127746, @@ -720,6 +1020,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 160963, @@ -748,6 +1060,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 123855, @@ -776,6 +1100,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 111869, @@ -804,6 +1140,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 131158, @@ -832,6 +1180,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 98748, @@ -860,6 +1220,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 94792, @@ -888,6 +1260,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 90765, @@ -916,6 +1300,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 103713, @@ -944,6 +1340,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 171516, @@ -972,6 +1380,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 98293, @@ -1000,6 +1420,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 92985, @@ -1028,6 +1460,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 113322, @@ -1056,6 +1500,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 103015, @@ -1084,6 +1540,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 139844, @@ -1112,6 +1580,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 94984, @@ -1140,6 +1620,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 90836, @@ -1168,6 +1660,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 96013, @@ -1196,6 +1700,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 89664, @@ -1224,6 +1740,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 92835, @@ -1252,6 +1780,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 90506, @@ -1280,6 +1820,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 108309, @@ -1308,6 +1860,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 90329, @@ -1336,6 +1900,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 96849, @@ -1364,6 +1940,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 97521, @@ -1392,10 +1980,22 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 100753, "recordsWritten" : 10 } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json index 5657123a2db15..11eec0b49c40b 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_expectation.json @@ -20,6 +20,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 4016617, @@ -48,6 +60,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 3675510, @@ -76,6 +100,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 3934399, @@ -104,6 +140,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 83022, @@ -132,6 +180,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 2579051, @@ -160,6 +220,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 1311694, @@ -188,6 +260,18 @@ "bytesRead" : 49294, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 3842811, @@ -216,6 +300,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 89885, @@ -244,6 +340,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 143427, @@ -272,6 +380,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 100836, @@ -300,6 +420,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 99944, @@ -328,6 +460,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 100270, @@ -356,6 +500,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 108320, @@ -384,6 +540,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95788, @@ -412,6 +580,18 @@ "bytesRead" : 60489, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 101664, @@ -440,6 +620,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 97716, @@ -468,6 +660,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95646, @@ -496,6 +700,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 121551, @@ -524,6 +740,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 102476, @@ -552,10 +780,22 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95004, "recordsWritten" : 10 } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json index 5657123a2db15..11eec0b49c40b 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names___runtime_expectation.json @@ -20,6 +20,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 4016617, @@ -48,6 +60,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 3675510, @@ -76,6 +100,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 3934399, @@ -104,6 +140,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 83022, @@ -132,6 +180,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 2579051, @@ -160,6 +220,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 1311694, @@ -188,6 +260,18 @@ "bytesRead" : 49294, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 3842811, @@ -216,6 +300,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 89885, @@ -244,6 +340,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 143427, @@ -272,6 +380,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 100836, @@ -300,6 +420,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 99944, @@ -328,6 +460,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 100270, @@ -356,6 +500,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 108320, @@ -384,6 +540,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95788, @@ -412,6 +580,18 @@ "bytesRead" : 60489, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 101664, @@ -440,6 +620,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 97716, @@ -468,6 +660,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95646, @@ -496,6 +700,18 @@ "bytesRead" : 60488, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 121551, @@ -524,6 +740,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 102476, @@ -552,10 +780,22 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95004, "recordsWritten" : 10 } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json index 72fe017e9f85d..9528d872ef731 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_list_w__sortBy_short_names__runtime_expectation.json @@ -20,6 +20,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 94792, @@ -48,6 +60,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 95848, @@ -76,6 +100,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 90765, @@ -104,6 +140,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 101750, @@ -132,6 +180,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 97521, @@ -160,6 +220,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 171516, @@ -188,6 +260,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 96849, @@ -216,6 +300,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 100753, @@ -244,6 +340,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 89603, @@ -272,6 +380,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 102159, @@ -300,6 +420,18 @@ "bytesRead" : 70565, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 133964, @@ -328,6 +460,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 102779, @@ -356,6 +500,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 98472, @@ -384,6 +540,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 98748, @@ -412,6 +580,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 103713, @@ -440,6 +620,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 96013, @@ -468,6 +660,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 90836, @@ -496,6 +700,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 92835, @@ -524,6 +740,18 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 98293, @@ -552,10 +780,22 @@ "bytesRead" : 70564, "recordsRead" : 10000 }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, "shuffleWriteMetrics" : { "bytesWritten" : 1710, "writeTime" : 98069, "recordsWritten" : 10 } } -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json index bc3c302813de2..76d1553bc8f77 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w__custom_quantiles_expectation.json @@ -11,9 +11,22 @@ "bytesRead" : [ 60488.0, 70564.0, 70565.0 ], "recordsRead" : [ 10000.0, 10000.0, 10000.0 ] }, + "outputMetrics" : { + "bytesWritten" : [ 0.0, 0.0, 0.0 ], + "recordsWritten" : [ 0.0, 0.0, 0.0 ] + }, + "shuffleReadMetrics" : { + "readBytes" : [ 0.0, 0.0, 0.0 ], + "readRecords" : [ 0.0, 0.0, 0.0 ], + "remoteBlocksFetched" : [ 0.0, 0.0, 0.0 ], + "localBlocksFetched" : [ 0.0, 0.0, 0.0 ], + "fetchWaitTime" : [ 0.0, 0.0, 0.0 ], + "remoteBytesRead" : [ 0.0, 0.0, 0.0 ], + "totalBlocksFetched" : [ 0.0, 0.0, 0.0 ] + }, "shuffleWriteMetrics" : { "writeBytes" : [ 1710.0, 1710.0, 1710.0 ], "writeRecords" : [ 10.0, 10.0, 10.0 ], "writeTime" : [ 89437.0, 102159.0, 4016617.0 ] } -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json index e084c839f1d5a..7baffc5df0b0f 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_read_expectation.json @@ -7,6 +7,14 @@ "resultSerializationTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "memoryBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "diskBytesSpilled" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "inputMetrics" : { + "bytesRead" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "recordsRead" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ] + }, + "outputMetrics" : { + "bytesWritten" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "recordsWritten" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ] + }, "shuffleReadMetrics" : { "readBytes" : [ 17100.0, 17100.0, 17100.0, 17100.0, 17100.0 ], "readRecords" : [ 100.0, 100.0, 100.0, 100.0, 100.0 ], @@ -15,5 +23,10 @@ "fetchWaitTime" : [ 0.0, 0.0, 0.0, 1.0, 1.0 ], "remoteBytesRead" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], "totalBlocksFetched" : [ 100.0, 100.0, 100.0, 100.0, 100.0 ] + }, + "shuffleWriteMetrics" : { + "writeBytes" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "writeRecords" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "writeTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ] } -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json index 6ac7811ce691b..f8c4b7c128733 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_task_summary_w_shuffle_write_expectation.json @@ -11,9 +11,22 @@ "bytesRead" : [ 60488.0, 70564.0, 70564.0, 70564.0, 70564.0 ], "recordsRead" : [ 10000.0, 10000.0, 10000.0, 10000.0, 10000.0 ] }, + "outputMetrics" : { + "bytesWritten" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "recordsWritten" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ] + }, + "shuffleReadMetrics" : { + "readBytes" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "readRecords" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "remoteBlocksFetched" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "localBlocksFetched" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "fetchWaitTime" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "remoteBytesRead" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ], + "totalBlocksFetched" : [ 0.0, 0.0, 0.0, 0.0, 0.0 ] + }, "shuffleWriteMetrics" : { "writeBytes" : [ 1710.0, 1710.0, 1710.0, 1710.0, 1710.0 ], "writeRecords" : [ 10.0, 10.0, 10.0, 10.0, 10.0 ], "writeTime" : [ 90329.0, 95848.0, 102159.0, 121551.0, 2579051.0 ] } -} \ No newline at end of file +} diff --git a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json index 12665a152c9ec..ce008bf40967d 100644 --- a/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/stage_with_accumulable_json_expectation.json @@ -50,7 +50,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, "5" : { @@ -75,7 +96,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, "4" : { @@ -100,7 +142,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 1, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, "7" : { @@ -125,7 +188,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, "1" : { @@ -150,7 +234,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, "3" : { @@ -175,7 +280,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, "6" : { @@ -200,7 +326,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } }, "0" : { @@ -225,7 +372,28 @@ "jvmGcTime" : 0, "resultSerializationTime" : 2, "memoryBytesSpilled" : 0, - "diskBytesSpilled" : 0 + "diskBytesSpilled" : 0, + "inputMetrics" : { + "bytesRead" : 0, + "recordsRead" : 0 + }, + "outputMetrics" : { + "bytesWritten" : 0, + "recordsWritten" : 0 + }, + "shuffleReadMetrics" : { + "remoteBlocksFetched" : 0, + "localBlocksFetched" : 0, + "fetchWaitTime" : 0, + "remoteBytesRead" : 0, + "localBytesRead" : 0, + "recordsRead" : 0 + }, + "shuffleWriteMetrics" : { + "bytesWritten" : 0, + "writeTime" : 0, + "recordsWritten" : 0 + } } } }, diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json index cab4750270dfa..1583e5ddef565 100644 --- a/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/succeeded_failed_job_list_json_expectation.json @@ -40,4 +40,4 @@ "numCompletedStages" : 1, "numSkippedStages" : 0, "numFailedStages" : 0 -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json index 6fd25befbf7e8..c232c98323755 100644 --- a/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json +++ b/core/src/test/resources/HistoryServerExpectations/succeeded_job_list_json_expectation.json @@ -26,4 +26,4 @@ "numCompletedStages" : 1, "numSkippedStages" : 0, "numFailedStages" : 0 -} ] \ No newline at end of file +} ] diff --git a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager new file mode 100644 index 0000000000000..3c570ffd8f566 --- /dev/null +++ b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager @@ -0,0 +1 @@ +org.apache.spark.scheduler.DummyExternalClusterManager \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 37879d11caec4..454c42517ca1b 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -17,7 +17,6 @@ package org.apache.spark -import java.util.Properties import java.util.concurrent.Semaphore import javax.annotation.concurrent.GuardedBy @@ -29,6 +28,7 @@ import scala.util.control.NonFatal import org.scalatest.Matchers import org.scalatest.exceptions.TestFailedException +import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler._ import org.apache.spark.serializer.JavaSerializer @@ -278,16 +278,13 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex val acc1 = new Accumulator(0, IntAccumulatorParam, Some("thing"), internal = false) val acc2 = new Accumulator(0L, LongAccumulatorParam, Some("thing2"), internal = false) val externalAccums = Seq(acc1, acc2) - val internalAccums = InternalAccumulator.createAll() + val taskMetrics = new TaskMetrics // Set some values; these should not be observed later on the "executors" acc1.setValue(10) acc2.setValue(20L) - internalAccums - .find(_.name == Some(InternalAccumulator.TEST_ACCUM)) - .get.asInstanceOf[Accumulator[Long]] - .setValue(30L) + taskMetrics.testAccum.get.asInstanceOf[Accumulator[Long]].setValue(30L) // Simulate the task being serialized and sent to the executors. - val dummyTask = new DummyTask(internalAccums, externalAccums) + val dummyTask = new DummyTask(taskMetrics, externalAccums) val serInstance = new JavaSerializer(new SparkConf).newInstance() val taskSer = Task.serializeWithDependencies( dummyTask, mutable.HashMap(), mutable.HashMap(), serInstance) @@ -298,7 +295,7 @@ class AccumulatorSuite extends SparkFunSuite with Matchers with LocalSparkContex taskBytes, Thread.currentThread.getContextClassLoader) // Assert that executors see only zeros taskDeser.externalAccums.foreach { a => assert(a.localValue == a.zero) } - taskDeser.internalAccums.foreach { a => assert(a.localValue == a.zero) } + taskDeser.metrics.internalAccums.foreach { a => assert(a.localValue == a.zero) } } } @@ -402,8 +399,7 @@ private class SaveInfoListener extends SparkListener { * A dummy [[Task]] that contains internal and external [[Accumulator]]s. */ private[spark] class DummyTask( - val internalAccums: Seq[Accumulator[_]], - val externalAccums: Seq[Accumulator[_]]) - extends Task[Int](0, 0, 0, internalAccums, new Properties) { + metrics: TaskMetrics, + val externalAccums: Seq[Accumulator[_]]) extends Task[Int](0, 0, 0, metrics) { override def runTask(c: TaskContext): Int = 1 } diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index f98150536d8a8..69ff6c7c28ee9 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -30,7 +30,6 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.internal.Logging import org.apache.spark.rdd.{RDD, ReliableRDDCheckpointData} -import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage._ @@ -39,7 +38,7 @@ import org.apache.spark.storage._ * suitable for cleaner tests and provides some utility functions. Subclasses can use different * config options, in particular, a different shuffle manager class */ -abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[HashShuffleManager]) +abstract class ContextCleanerSuiteBase(val shuffleManager: Class[_] = classOf[SortShuffleManager]) extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { implicit val defaultTimeout = timeout(10000 millis) @@ -353,84 +352,6 @@ class ContextCleanerSuite extends ContextCleanerSuiteBase { } -/** - * A copy of the shuffle tests for sort-based shuffle - */ -class SortShuffleContextCleanerSuite extends ContextCleanerSuiteBase(classOf[SortShuffleManager]) { - test("cleanup shuffle") { - val (rdd, shuffleDeps) = newRDDWithShuffleDependencies() - val collected = rdd.collect().toList - val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) - - // Explicit cleanup - shuffleDeps.foreach(s => cleaner.doCleanupShuffle(s.shuffleId, blocking = true)) - tester.assertCleanup() - - // Verify that shuffles can be re-executed after cleaning up - assert(rdd.collect().toList.equals(collected)) - } - - test("automatically cleanup shuffle") { - var rdd = newShuffleRDD() - rdd.count() - - // Test that GC does not cause shuffle cleanup due to a strong reference - val preGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) - runGC() - intercept[Exception] { - preGCTester.assertCleanup()(timeout(1000 millis)) - } - rdd.count() // Defeat early collection by the JVM - - // Test that GC causes shuffle cleanup after dereferencing the RDD - val postGCTester = new CleanerTester(sc, shuffleIds = Seq(0)) - rdd = null // Make RDD out of scope, so that corresponding shuffle goes out of scope - runGC() - postGCTester.assertCleanup() - } - - test("automatically cleanup RDD + shuffle + broadcast in distributed mode") { - sc.stop() - - val conf2 = new SparkConf() - .setMaster("local-cluster[2, 1, 1024]") - .setAppName("ContextCleanerSuite") - .set("spark.cleaner.referenceTracking.blocking", "true") - .set("spark.cleaner.referenceTracking.blocking.shuffle", "true") - .set("spark.shuffle.manager", shuffleManager.getName) - sc = new SparkContext(conf2) - - val numRdds = 10 - val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => newBroadcast).toBuffer - val rddIds = sc.persistentRdds.keys.toSeq - val shuffleIds = 0 until sc.newShuffleId() - val broadcastIds = broadcastBuffer.map(_.id) - - val preGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) - runGC() - intercept[Exception] { - preGCTester.assertCleanup()(timeout(1000 millis)) - } - - // Test that GC triggers the cleanup of all variables after the dereferencing them - val postGCTester = new CleanerTester(sc, rddIds, shuffleIds, broadcastIds) - broadcastBuffer.clear() - rddBuffer.clear() - runGC() - postGCTester.assertCleanup() - - // Make sure the broadcasted task closure no longer exists after GC. - val taskClosureBroadcastId = broadcastIds.max + 1 - assert(sc.env.blockManager.master.getMatchingBlockIds({ - case BroadcastBlockId(`taskClosureBroadcastId`, _) => true - case _ => false - }, askSlaves = true).isEmpty) - } -} - - /** * Class to test whether RDDs, shuffles, etc. have been successfully cleaned. * The checkpoint here refers only to normal (reliable) checkpoints, not local checkpoints. diff --git a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala index ee6b991461902..c130649830416 100644 --- a/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ExecutorAllocationManagerSuite.scala @@ -929,7 +929,7 @@ private object ExecutorAllocationManagerSuite extends PrivateMethodTester { taskLocalityPreferences: Seq[Seq[TaskLocation]] = Seq.empty ): StageInfo = { new StageInfo(stageId, 0, "name", numTasks, Seq.empty, Seq.empty, "no details", - Seq.empty, taskLocalityPreferences) + taskLocalityPreferences = taskLocalityPreferences) } private def createTaskInfo(taskId: Int, taskIndex: Int, executorId: String): TaskInfo = { diff --git a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala index 1102aea96b548..70b6309be7d53 100644 --- a/core/src/test/scala/org/apache/spark/FutureActionSuite.scala +++ b/core/src/test/scala/org/apache/spark/FutureActionSuite.scala @@ -17,11 +17,12 @@ package org.apache.spark -import scala.concurrent.Await import scala.concurrent.duration.Duration import org.scalatest.{BeforeAndAfter, Matchers} +import org.apache.spark.util.ThreadUtils + class FutureActionSuite extends SparkFunSuite @@ -36,7 +37,7 @@ class FutureActionSuite test("simple async action") { val rdd = sc.parallelize(1 to 10, 2) val job = rdd.countAsync() - val res = Await.result(job, Duration.Inf) + val res = ThreadUtils.awaitResult(job, Duration.Inf) res should be (10) job.jobIds.size should be (1) } @@ -44,7 +45,7 @@ class FutureActionSuite test("complex async action") { val rdd = sc.parallelize(1 to 15, 3) val job = rdd.takeAsync(10) - val res = Await.result(job, Duration.Inf) + val res = ThreadUtils.awaitResult(job, Duration.Inf) res should be (1 to 10) job.jobIds.size should be (2) } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala index 713d5e58b4ffc..4d2b3e7f3b14b 100644 --- a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -21,7 +21,6 @@ import java.util.concurrent.{ExecutorService, TimeUnit} import scala.collection.Map import scala.collection.mutable -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -36,7 +35,7 @@ import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.ManualClock +import org.apache.spark.util.{ManualClock, ThreadUtils} /** * A test suite for the heartbeating behavior between the driver and the executors. @@ -231,14 +230,14 @@ class HeartbeatReceiverSuite private def addExecutorAndVerify(executorId: String): Unit = { assert( heartbeatReceiver.addExecutor(executorId).map { f => - Await.result(f, 10.seconds) + ThreadUtils.awaitResult(f, 10.seconds) } === Some(true)) } private def removeExecutorAndVerify(executorId: String): Unit = { assert( heartbeatReceiver.removeExecutor(executorId).map { f => - Await.result(f, 10.seconds) + ThreadUtils.awaitResult(f, 10.seconds) } === Some(true)) } diff --git a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala index 474550608ba2f..b074b95424731 100644 --- a/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/InternalAccumulatorSuite.scala @@ -19,14 +19,13 @@ package org.apache.spark import scala.collection.mutable.ArrayBuffer +import org.apache.spark.executor.TaskMetrics import org.apache.spark.scheduler.AccumulableInfo import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockId, BlockStatus} class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { import InternalAccumulator._ - import AccumulatorParam._ override def afterEach(): Unit = { try { @@ -36,125 +35,12 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { } } - test("get param") { - assert(getParam(EXECUTOR_DESERIALIZE_TIME) === LongAccumulatorParam) - assert(getParam(EXECUTOR_RUN_TIME) === LongAccumulatorParam) - assert(getParam(RESULT_SIZE) === LongAccumulatorParam) - assert(getParam(JVM_GC_TIME) === LongAccumulatorParam) - assert(getParam(RESULT_SERIALIZATION_TIME) === LongAccumulatorParam) - assert(getParam(MEMORY_BYTES_SPILLED) === LongAccumulatorParam) - assert(getParam(DISK_BYTES_SPILLED) === LongAccumulatorParam) - assert(getParam(PEAK_EXECUTION_MEMORY) === LongAccumulatorParam) - assert(getParam(UPDATED_BLOCK_STATUSES) === UpdatedBlockStatusesAccumulatorParam) - assert(getParam(TEST_ACCUM) === LongAccumulatorParam) - // shuffle read - assert(getParam(shuffleRead.REMOTE_BLOCKS_FETCHED) === IntAccumulatorParam) - assert(getParam(shuffleRead.LOCAL_BLOCKS_FETCHED) === IntAccumulatorParam) - assert(getParam(shuffleRead.REMOTE_BYTES_READ) === LongAccumulatorParam) - assert(getParam(shuffleRead.LOCAL_BYTES_READ) === LongAccumulatorParam) - assert(getParam(shuffleRead.FETCH_WAIT_TIME) === LongAccumulatorParam) - assert(getParam(shuffleRead.RECORDS_READ) === LongAccumulatorParam) - // shuffle write - assert(getParam(shuffleWrite.BYTES_WRITTEN) === LongAccumulatorParam) - assert(getParam(shuffleWrite.RECORDS_WRITTEN) === LongAccumulatorParam) - assert(getParam(shuffleWrite.WRITE_TIME) === LongAccumulatorParam) - // input - assert(getParam(input.READ_METHOD) === StringAccumulatorParam) - assert(getParam(input.RECORDS_READ) === LongAccumulatorParam) - assert(getParam(input.BYTES_READ) === LongAccumulatorParam) - // output - assert(getParam(output.WRITE_METHOD) === StringAccumulatorParam) - assert(getParam(output.RECORDS_WRITTEN) === LongAccumulatorParam) - assert(getParam(output.BYTES_WRITTEN) === LongAccumulatorParam) - // default to Long - assert(getParam(METRICS_PREFIX + "anything") === LongAccumulatorParam) - intercept[IllegalArgumentException] { - getParam("something that does not start with the right prefix") - } - } - - test("create by name") { - val executorRunTime = create(EXECUTOR_RUN_TIME) - val updatedBlockStatuses = create(UPDATED_BLOCK_STATUSES) - val shuffleRemoteBlocksRead = create(shuffleRead.REMOTE_BLOCKS_FETCHED) - val inputReadMethod = create(input.READ_METHOD) - assert(executorRunTime.name === Some(EXECUTOR_RUN_TIME)) - assert(updatedBlockStatuses.name === Some(UPDATED_BLOCK_STATUSES)) - assert(shuffleRemoteBlocksRead.name === Some(shuffleRead.REMOTE_BLOCKS_FETCHED)) - assert(inputReadMethod.name === Some(input.READ_METHOD)) - assert(executorRunTime.value.isInstanceOf[Long]) - assert(updatedBlockStatuses.value.isInstanceOf[Seq[_]]) - // We cannot assert the type of the value directly since the type parameter is erased. - // Instead, try casting a `Seq` of expected type and see if it fails in run time. - updatedBlockStatuses.setValueAny(Seq.empty[(BlockId, BlockStatus)]) - assert(shuffleRemoteBlocksRead.value.isInstanceOf[Int]) - assert(inputReadMethod.value.isInstanceOf[String]) - // default to Long - val anything = create(METRICS_PREFIX + "anything") - assert(anything.value.isInstanceOf[Long]) - } - - test("create") { - val accums = createAll() - val shuffleReadAccums = createShuffleReadAccums() - val shuffleWriteAccums = createShuffleWriteAccums() - val inputAccums = createInputAccums() - val outputAccums = createOutputAccums() - // assert they're all internal - assert(accums.forall(_.isInternal)) - assert(shuffleReadAccums.forall(_.isInternal)) - assert(shuffleWriteAccums.forall(_.isInternal)) - assert(inputAccums.forall(_.isInternal)) - assert(outputAccums.forall(_.isInternal)) - // assert they all count on failures - assert(accums.forall(_.countFailedValues)) - assert(shuffleReadAccums.forall(_.countFailedValues)) - assert(shuffleWriteAccums.forall(_.countFailedValues)) - assert(inputAccums.forall(_.countFailedValues)) - assert(outputAccums.forall(_.countFailedValues)) - // assert they all have names - assert(accums.forall(_.name.isDefined)) - assert(shuffleReadAccums.forall(_.name.isDefined)) - assert(shuffleWriteAccums.forall(_.name.isDefined)) - assert(inputAccums.forall(_.name.isDefined)) - assert(outputAccums.forall(_.name.isDefined)) - // assert `accums` is a strict superset of the others - val accumNames = accums.map(_.name.get).toSet - val shuffleReadAccumNames = shuffleReadAccums.map(_.name.get).toSet - val shuffleWriteAccumNames = shuffleWriteAccums.map(_.name.get).toSet - val inputAccumNames = inputAccums.map(_.name.get).toSet - val outputAccumNames = outputAccums.map(_.name.get).toSet - assert(shuffleReadAccumNames.subsetOf(accumNames)) - assert(shuffleWriteAccumNames.subsetOf(accumNames)) - assert(inputAccumNames.subsetOf(accumNames)) - assert(outputAccumNames.subsetOf(accumNames)) - } - - test("naming") { - val accums = createAll() - val shuffleReadAccums = createShuffleReadAccums() - val shuffleWriteAccums = createShuffleWriteAccums() - val inputAccums = createInputAccums() - val outputAccums = createOutputAccums() - // assert that prefixes are properly namespaced - assert(SHUFFLE_READ_METRICS_PREFIX.startsWith(METRICS_PREFIX)) - assert(SHUFFLE_WRITE_METRICS_PREFIX.startsWith(METRICS_PREFIX)) - assert(INPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX)) - assert(OUTPUT_METRICS_PREFIX.startsWith(METRICS_PREFIX)) - assert(accums.forall(_.name.get.startsWith(METRICS_PREFIX))) - // assert they all start with the expected prefixes - assert(shuffleReadAccums.forall(_.name.get.startsWith(SHUFFLE_READ_METRICS_PREFIX))) - assert(shuffleWriteAccums.forall(_.name.get.startsWith(SHUFFLE_WRITE_METRICS_PREFIX))) - assert(inputAccums.forall(_.name.get.startsWith(INPUT_METRICS_PREFIX))) - assert(outputAccums.forall(_.name.get.startsWith(OUTPUT_METRICS_PREFIX))) - } - test("internal accumulators in TaskContext") { val taskContext = TaskContext.empty() val accumUpdates = taskContext.taskMetrics.accumulatorUpdates() assert(accumUpdates.size > 0) assert(accumUpdates.forall(_.internal)) - val testAccum = taskContext.taskMetrics.getAccum(TEST_ACCUM) + val testAccum = taskContext.taskMetrics.testAccum.get assert(accumUpdates.exists(_.id == testAccum.id)) } @@ -165,7 +51,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { sc.addSparkListener(listener) // Have each task add 1 to the internal accumulator val rdd = sc.parallelize(1 to 100, numPartitions).mapPartitions { iter => - TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1 + TaskContext.get().taskMetrics().testAccum.get += 1 iter } // Register asserts in job completion callback to avoid flakiness @@ -201,17 +87,17 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { val rdd = sc.parallelize(1 to 100, numPartitions) .map { i => (i, i) } .mapPartitions { iter => - TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1 + TaskContext.get().taskMetrics().testAccum.get += 1 iter } .reduceByKey { case (x, y) => x + y } .mapPartitions { iter => - TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 10 + TaskContext.get().taskMetrics().testAccum.get += 10 iter } .repartition(numPartitions * 2) .mapPartitions { iter => - TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 100 + TaskContext.get().taskMetrics().testAccum.get += 100 iter } // Register asserts in job completion callback to avoid flakiness @@ -241,7 +127,7 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { // This should retry both stages in the scheduler. Note that we only want to fail the // first stage attempt because we want the stage to eventually succeed. val x = sc.parallelize(1 to 100, numPartitions) - .mapPartitions { iter => TaskContext.get().taskMetrics().getAccum(TEST_ACCUM) += 1; iter } + .mapPartitions { iter => TaskContext.get().taskMetrics().testAccum.get += 1; iter } .groupBy(identity) val sid = x.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]].shuffleHandle.shuffleId val rdd = x.mapPartitionsWithIndex { case (i, iter) => @@ -299,15 +185,15 @@ class InternalAccumulatorSuite extends SparkFunSuite with LocalSparkContext { } assert(Accumulators.originals.isEmpty) sc.parallelize(1 to 100).map { i => (i, i) }.reduceByKey { _ + _ }.count() - val internalAccums = InternalAccumulator.createAll() + val numInternalAccums = TaskMetrics.empty.internalAccums.length // We ran 2 stages, so we should have 2 sets of internal accumulators, 1 for each stage - assert(Accumulators.originals.size === internalAccums.size * 2) + assert(Accumulators.originals.size === numInternalAccums * 2) val accumsRegistered = sc.cleaner match { case Some(cleaner: SaveAccumContextCleaner) => cleaner.accumsRegisteredForCleanup case _ => Seq.empty[Long] } // Make sure the same set of accumulators is registered for cleanup - assert(accumsRegistered.size === internalAccums.size * 2) + assert(accumsRegistered.size === numInternalAccums * 2) assert(accumsRegistered.toSet === Accumulators.originals.keys.toSet) } diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index c347ab8dc8020..a3490fc79e458 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark import java.util.concurrent.Semaphore -import scala.concurrent.Await import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.concurrent.Future @@ -28,6 +27,7 @@ import org.scalatest.BeforeAndAfter import org.scalatest.Matchers import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskStart} +import org.apache.spark.util.ThreadUtils /** * Test suite for cancelling running jobs. We run the cancellation tasks for single job action @@ -137,7 +137,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sc.clearJobGroup() val jobB = sc.parallelize(1 to 100, 2).countAsync() sc.cancelJobGroup("jobA") - val e = intercept[SparkException] { Await.result(jobA, Duration.Inf) } + val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, Duration.Inf) }.getCause assert(e.getMessage contains "cancel") // Once A is cancelled, job B should finish fairly quickly. @@ -202,7 +202,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sc.clearJobGroup() val jobB = sc.parallelize(1 to 100, 2).countAsync() sc.cancelJobGroup("jobA") - val e = intercept[SparkException] { Await.result(jobA, 5.seconds) } + val e = intercept[SparkException] { ThreadUtils.awaitResult(jobA, 5.seconds) }.getCause assert(e.getMessage contains "cancel") // Once A is cancelled, job B should finish fairly quickly. @@ -248,7 +248,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft { val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.countAsync() Future { f.cancel() } - val e = intercept[SparkException] { f.get() } + val e = intercept[SparkException] { f.get() }.getCause assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } @@ -268,7 +268,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sem.acquire() f.cancel() } - val e = intercept[SparkException] { f.get() } + val e = intercept[SparkException] { f.get() }.getCause assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } } @@ -278,7 +278,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft { val f = sc.parallelize(1 to 10000, 2).map { i => Thread.sleep(10); i }.takeAsync(5000) Future { f.cancel() } - val e = intercept[SparkException] { f.get() } + val e = intercept[SparkException] { f.get() }.getCause assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } @@ -296,7 +296,7 @@ class JobCancellationSuite extends SparkFunSuite with Matchers with BeforeAndAft sem.acquire() f.cancel() } - val e = intercept[SparkException] { f.get() } + val e = intercept[SparkException] { f.get() }.getCause assert(e.getMessage.contains("cancelled") || e.getMessage.contains("killed")) } } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index cd7d2e15700d3..a854f5bb9b7ce 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -336,16 +336,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem, - InternalAccumulator.createAll(sc))) + new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem, - InternalAccumulator.createAll(sc))) + new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -373,8 +371,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem, - InternalAccumulator.createAll(sc))) + new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) @@ -450,14 +447,10 @@ object ShuffleSuite { @volatile var bytesRead: Long = 0 val listener = new SparkListener { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - taskEnd.taskMetrics.shuffleWriteMetrics.foreach { m => - recordsWritten += m.recordsWritten - bytesWritten += m.bytesWritten - } - taskEnd.taskMetrics.shuffleReadMetrics.foreach { m => - recordsRead += m.recordsRead - bytesRead += m.totalBytesRead - } + recordsWritten += taskEnd.taskMetrics.shuffleWriteMetrics.recordsWritten + bytesWritten += taskEnd.taskMetrics.shuffleWriteMetrics.bytesWritten + recordsRead += taskEnd.taskMetrics.shuffleReadMetrics.recordsRead + bytesRead += taskEnd.taskMetrics.shuffleReadMetrics.totalBytesRead } } sc.addSparkListener(listener) diff --git a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala index 2a013aca7b895..631a7cd9d5d7a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/history/HistoryServerSuite.scala @@ -153,37 +153,39 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers code should be (HttpServletResponse.SC_OK) jsonOpt should be ('defined) errOpt should be (None) - val jsonOrg = jsonOpt.get - - // SPARK-10873 added the lastUpdated field for each application's attempt, - // the REST API returns the last modified time of EVENT LOG file for this field. - // It is not applicable to hard-code this dynamic field in a static expected file, - // so here we skip checking the lastUpdated field's value (setting it as ""). - val json = if (jsonOrg.indexOf("lastUpdated") >= 0) { - val subStrings = jsonOrg.split(",") - for (i <- subStrings.indices) { - if (subStrings(i).indexOf("lastUpdatedEpoch") >= 0) { - subStrings(i) = subStrings(i).replaceAll("(\\d+)", "0") - } else if (subStrings(i).indexOf("lastUpdated") >= 0) { - subStrings(i) = "\"lastUpdated\":\"\"" - } - } - subStrings.mkString(",") - } else { - jsonOrg - } val exp = IOUtils.toString(new FileInputStream( new File(expRoot, HistoryServerSuite.sanitizePath(name) + "_expectation.json"))) // compare the ASTs so formatting differences don't cause failures import org.json4s._ import org.json4s.jackson.JsonMethods._ - val jsonAst = parse(json) + val jsonAst = parse(clearLastUpdated(jsonOpt.get)) val expAst = parse(exp) assertValidDataInJson(jsonAst, expAst) } } + // SPARK-10873 added the lastUpdated field for each application's attempt, + // the REST API returns the last modified time of EVENT LOG file for this field. + // It is not applicable to hard-code this dynamic field in a static expected file, + // so here we skip checking the lastUpdated field's value (setting it as ""). + private def clearLastUpdated(json: String): String = { + if (json.indexOf("lastUpdated") >= 0) { + val subStrings = json.split(",") + for (i <- subStrings.indices) { + if (subStrings(i).indexOf("lastUpdatedEpoch") >= 0) { + subStrings(i) = subStrings(i).replaceAll("(\\d+)", "0") + } else if (subStrings(i).indexOf("lastUpdated") >= 0) { + val regex = "\"lastUpdated\"\\s*:\\s*\".*\"".r + subStrings(i) = regex.replaceAllIn(subStrings(i), "\"lastUpdated\" : \"\"") + } + } + subStrings.mkString(",") + } else { + json + } + } + test("download all logs for app with multiple attempts") { doDownloadTest("local-1430917381535", None) } @@ -486,7 +488,8 @@ class HistoryServerSuite extends SparkFunSuite with BeforeAndAfter with Matchers val json = getUrl(path) val file = new File(expRoot, HistoryServerSuite.sanitizePath(name) + "_expectation.json") val out = new FileWriter(file) - out.write(json) + out.write(clearLastUpdated(json)) + out.write('\n') out.close() } diff --git a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala index d91f50f18f431..fbc2fae08df24 100644 --- a/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/TaskMetricsSuite.scala @@ -26,124 +26,20 @@ import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel, TestBlockId class TaskMetricsSuite extends SparkFunSuite { import AccumulatorParam._ - import InternalAccumulator._ import StorageLevel._ import TaskMetricsSuite._ - test("create") { - val internalAccums = InternalAccumulator.createAll() - val tm1 = new TaskMetrics - val tm2 = new TaskMetrics(internalAccums) - assert(tm1.accumulatorUpdates().size === internalAccums.size) - assert(tm1.shuffleReadMetrics.isEmpty) - assert(tm1.shuffleWriteMetrics.isEmpty) - assert(tm1.inputMetrics.isEmpty) - assert(tm1.outputMetrics.isEmpty) - assert(tm2.accumulatorUpdates().size === internalAccums.size) - assert(tm2.shuffleReadMetrics.isEmpty) - assert(tm2.shuffleWriteMetrics.isEmpty) - assert(tm2.inputMetrics.isEmpty) - assert(tm2.outputMetrics.isEmpty) - // TaskMetrics constructor expects minimal set of initial accumulators - intercept[IllegalArgumentException] { new TaskMetrics(Seq.empty[Accumulator[_]]) } - } - - test("create with unnamed accum") { - intercept[IllegalArgumentException] { - new TaskMetrics( - InternalAccumulator.createAll() ++ Seq( - new Accumulator(0, IntAccumulatorParam, None, internal = true))) - } - } - - test("create with duplicate name accum") { - intercept[IllegalArgumentException] { - new TaskMetrics( - InternalAccumulator.createAll() ++ Seq( - new Accumulator(0, IntAccumulatorParam, Some(RESULT_SIZE), internal = true))) - } - } - - test("create with external accum") { - intercept[IllegalArgumentException] { - new TaskMetrics( - InternalAccumulator.createAll() ++ Seq( - new Accumulator(0, IntAccumulatorParam, Some("x")))) - } - } - - test("create shuffle read metrics") { - import shuffleRead._ - val accums = InternalAccumulator.createShuffleReadAccums() - .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] - accums(REMOTE_BLOCKS_FETCHED).setValueAny(1) - accums(LOCAL_BLOCKS_FETCHED).setValueAny(2) - accums(REMOTE_BYTES_READ).setValueAny(3L) - accums(LOCAL_BYTES_READ).setValueAny(4L) - accums(FETCH_WAIT_TIME).setValueAny(5L) - accums(RECORDS_READ).setValueAny(6L) - val sr = new ShuffleReadMetrics(accums) - assert(sr.remoteBlocksFetched === 1) - assert(sr.localBlocksFetched === 2) - assert(sr.remoteBytesRead === 3L) - assert(sr.localBytesRead === 4L) - assert(sr.fetchWaitTime === 5L) - assert(sr.recordsRead === 6L) - } - - test("create shuffle write metrics") { - import shuffleWrite._ - val accums = InternalAccumulator.createShuffleWriteAccums() - .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] - accums(BYTES_WRITTEN).setValueAny(1L) - accums(RECORDS_WRITTEN).setValueAny(2L) - accums(WRITE_TIME).setValueAny(3L) - val sw = new ShuffleWriteMetrics(accums) - assert(sw.bytesWritten === 1L) - assert(sw.recordsWritten === 2L) - assert(sw.writeTime === 3L) - } - - test("create input metrics") { - import input._ - val accums = InternalAccumulator.createInputAccums() - .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] - accums(BYTES_READ).setValueAny(1L) - accums(RECORDS_READ).setValueAny(2L) - accums(READ_METHOD).setValueAny(DataReadMethod.Hadoop.toString) - val im = new InputMetrics(accums) - assert(im.bytesRead === 1L) - assert(im.recordsRead === 2L) - assert(im.readMethod === DataReadMethod.Hadoop) - } - - test("create output metrics") { - import output._ - val accums = InternalAccumulator.createOutputAccums() - .map { a => (a.name.get, a) }.toMap[String, Accumulator[_]] - accums(BYTES_WRITTEN).setValueAny(1L) - accums(RECORDS_WRITTEN).setValueAny(2L) - accums(WRITE_METHOD).setValueAny(DataWriteMethod.Hadoop.toString) - val om = new OutputMetrics(accums) - assert(om.bytesWritten === 1L) - assert(om.recordsWritten === 2L) - assert(om.writeMethod === DataWriteMethod.Hadoop) - } - test("mutating values") { - val accums = InternalAccumulator.createAll() - val tm = new TaskMetrics(accums) - // initial values - assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 0L) - assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 0L) - assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 0L) - assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 0L) - assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 0L) - assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 0L) - assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 0L) - assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 0L) - assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES, - Seq.empty[(BlockId, BlockStatus)]) + val tm = new TaskMetrics + assert(tm.executorDeserializeTime == 0L) + assert(tm.executorRunTime == 0L) + assert(tm.resultSize == 0L) + assert(tm.jvmGCTime == 0L) + assert(tm.resultSerializationTime == 0L) + assert(tm.memoryBytesSpilled == 0L) + assert(tm.diskBytesSpilled == 0L) + assert(tm.peakExecutionMemory == 0L) + assert(tm.updatedBlockStatuses.isEmpty) // set or increment values tm.setExecutorDeserializeTime(100L) tm.setExecutorDeserializeTime(1L) // overwrite @@ -166,38 +62,27 @@ class TaskMetricsSuite extends SparkFunSuite { tm.incUpdatedBlockStatuses(Seq(block1)) tm.incUpdatedBlockStatuses(Seq(block2)) // assert new values exist - assertValueEquals(tm, _.executorDeserializeTime, accums, EXECUTOR_DESERIALIZE_TIME, 1L) - assertValueEquals(tm, _.executorRunTime, accums, EXECUTOR_RUN_TIME, 2L) - assertValueEquals(tm, _.resultSize, accums, RESULT_SIZE, 3L) - assertValueEquals(tm, _.jvmGCTime, accums, JVM_GC_TIME, 4L) - assertValueEquals(tm, _.resultSerializationTime, accums, RESULT_SERIALIZATION_TIME, 5L) - assertValueEquals(tm, _.memoryBytesSpilled, accums, MEMORY_BYTES_SPILLED, 606L) - assertValueEquals(tm, _.diskBytesSpilled, accums, DISK_BYTES_SPILLED, 707L) - assertValueEquals(tm, _.peakExecutionMemory, accums, PEAK_EXECUTION_MEMORY, 808L) - assertValueEquals(tm, _.updatedBlockStatuses, accums, UPDATED_BLOCK_STATUSES, - Seq(block1, block2)) + assert(tm.executorDeserializeTime == 1L) + assert(tm.executorRunTime == 2L) + assert(tm.resultSize == 3L) + assert(tm.jvmGCTime == 4L) + assert(tm.resultSerializationTime == 5L) + assert(tm.memoryBytesSpilled == 606L) + assert(tm.diskBytesSpilled == 707L) + assert(tm.peakExecutionMemory == 808L) + assert(tm.updatedBlockStatuses == Seq(block1, block2)) } test("mutating shuffle read metrics values") { - import shuffleRead._ - val accums = InternalAccumulator.createAll() - val tm = new TaskMetrics(accums) - def assertValEquals[T](tmValue: ShuffleReadMetrics => T, name: String, value: T): Unit = { - assertValueEquals(tm, tm => tmValue(tm.shuffleReadMetrics.get), accums, name, value) - } - // create shuffle read metrics - assert(tm.shuffleReadMetrics.isEmpty) - tm.registerTempShuffleReadMetrics() - tm.mergeShuffleReadMetrics() - assert(tm.shuffleReadMetrics.isDefined) - val sr = tm.shuffleReadMetrics.get + val tm = new TaskMetrics + val sr = tm.shuffleReadMetrics // initial values - assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 0) - assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 0) - assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 0L) - assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 0L) - assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 0L) - assertValEquals(_.recordsRead, RECORDS_READ, 0L) + assert(sr.remoteBlocksFetched == 0) + assert(sr.localBlocksFetched == 0) + assert(sr.remoteBytesRead == 0L) + assert(sr.localBytesRead == 0L) + assert(sr.fetchWaitTime == 0L) + assert(sr.recordsRead == 0L) // set and increment values sr.setRemoteBlocksFetched(100) sr.setRemoteBlocksFetched(10) @@ -224,30 +109,21 @@ class TaskMetricsSuite extends SparkFunSuite { sr.incRecordsRead(6L) sr.incRecordsRead(6L) // assert new values exist - assertValEquals(_.remoteBlocksFetched, REMOTE_BLOCKS_FETCHED, 12) - assertValEquals(_.localBlocksFetched, LOCAL_BLOCKS_FETCHED, 24) - assertValEquals(_.remoteBytesRead, REMOTE_BYTES_READ, 36L) - assertValEquals(_.localBytesRead, LOCAL_BYTES_READ, 48L) - assertValEquals(_.fetchWaitTime, FETCH_WAIT_TIME, 60L) - assertValEquals(_.recordsRead, RECORDS_READ, 72L) + assert(sr.remoteBlocksFetched == 12) + assert(sr.localBlocksFetched == 24) + assert(sr.remoteBytesRead == 36L) + assert(sr.localBytesRead == 48L) + assert(sr.fetchWaitTime == 60L) + assert(sr.recordsRead == 72L) } test("mutating shuffle write metrics values") { - import shuffleWrite._ - val accums = InternalAccumulator.createAll() - val tm = new TaskMetrics(accums) - def assertValEquals[T](tmValue: ShuffleWriteMetrics => T, name: String, value: T): Unit = { - assertValueEquals(tm, tm => tmValue(tm.shuffleWriteMetrics.get), accums, name, value) - } - // create shuffle write metrics - assert(tm.shuffleWriteMetrics.isEmpty) - tm.registerShuffleWriteMetrics() - assert(tm.shuffleWriteMetrics.isDefined) - val sw = tm.shuffleWriteMetrics.get + val tm = new TaskMetrics + val sw = tm.shuffleWriteMetrics // initial values - assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L) - assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L) - assertValEquals(_.writeTime, WRITE_TIME, 0L) + assert(sw.bytesWritten == 0L) + assert(sw.recordsWritten == 0L) + assert(sw.writeTime == 0L) // increment and decrement values sw.incBytesWritten(100L) sw.incBytesWritten(10L) // 100 + 10 @@ -260,130 +136,65 @@ class TaskMetricsSuite extends SparkFunSuite { sw.incWriteTime(300L) sw.incWriteTime(30L) // assert new values exist - assertValEquals(_.bytesWritten, BYTES_WRITTEN, 108L) - assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 216L) - assertValEquals(_.writeTime, WRITE_TIME, 330L) + assert(sw.bytesWritten == 108L) + assert(sw.recordsWritten == 216L) + assert(sw.writeTime == 330L) } test("mutating input metrics values") { - import input._ - val accums = InternalAccumulator.createAll() - val tm = new TaskMetrics(accums) - def assertValEquals(tmValue: InputMetrics => Any, name: String, value: Any): Unit = { - assertValueEquals(tm, tm => tmValue(tm.inputMetrics.get), accums, name, value, - (x: Any, y: Any) => assert(x.toString === y.toString)) - } - // create input metrics - assert(tm.inputMetrics.isEmpty) - tm.registerInputMetrics(DataReadMethod.Memory) - assert(tm.inputMetrics.isDefined) - val in = tm.inputMetrics.get + val tm = new TaskMetrics + val in = tm.inputMetrics // initial values - assertValEquals(_.bytesRead, BYTES_READ, 0L) - assertValEquals(_.recordsRead, RECORDS_READ, 0L) - assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Memory) + assert(in.bytesRead == 0L) + assert(in.recordsRead == 0L) // set and increment values in.setBytesRead(1L) in.setBytesRead(2L) in.incRecordsRead(1L) in.incRecordsRead(2L) - in.setReadMethod(DataReadMethod.Disk) // assert new values exist - assertValEquals(_.bytesRead, BYTES_READ, 2L) - assertValEquals(_.recordsRead, RECORDS_READ, 3L) - assertValEquals(_.readMethod, READ_METHOD, DataReadMethod.Disk) + assert(in.bytesRead == 2L) + assert(in.recordsRead == 3L) } test("mutating output metrics values") { - import output._ - val accums = InternalAccumulator.createAll() - val tm = new TaskMetrics(accums) - def assertValEquals(tmValue: OutputMetrics => Any, name: String, value: Any): Unit = { - assertValueEquals(tm, tm => tmValue(tm.outputMetrics.get), accums, name, value, - (x: Any, y: Any) => assert(x.toString === y.toString)) - } - // create input metrics - assert(tm.outputMetrics.isEmpty) - tm.registerOutputMetrics(DataWriteMethod.Hadoop) - assert(tm.outputMetrics.isDefined) - val out = tm.outputMetrics.get + val tm = new TaskMetrics + val out = tm.outputMetrics // initial values - assertValEquals(_.bytesWritten, BYTES_WRITTEN, 0L) - assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 0L) - assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop) + assert(out.bytesWritten == 0L) + assert(out.recordsWritten == 0L) // set values out.setBytesWritten(1L) out.setBytesWritten(2L) out.setRecordsWritten(3L) out.setRecordsWritten(4L) - out.setWriteMethod(DataWriteMethod.Hadoop) // assert new values exist - assertValEquals(_.bytesWritten, BYTES_WRITTEN, 2L) - assertValEquals(_.recordsWritten, RECORDS_WRITTEN, 4L) - // Note: this doesn't actually test anything, but there's only one DataWriteMethod - // so we can't set it to anything else - assertValEquals(_.writeMethod, WRITE_METHOD, DataWriteMethod.Hadoop) + assert(out.bytesWritten == 2L) + assert(out.recordsWritten == 4L) } test("merging multiple shuffle read metrics") { val tm = new TaskMetrics - assert(tm.shuffleReadMetrics.isEmpty) - val sr1 = tm.registerTempShuffleReadMetrics() - val sr2 = tm.registerTempShuffleReadMetrics() - val sr3 = tm.registerTempShuffleReadMetrics() - assert(tm.shuffleReadMetrics.isEmpty) + val sr1 = tm.createTempShuffleReadMetrics() + val sr2 = tm.createTempShuffleReadMetrics() + val sr3 = tm.createTempShuffleReadMetrics() sr1.setRecordsRead(10L) sr2.setRecordsRead(10L) sr1.setFetchWaitTime(1L) sr2.setFetchWaitTime(2L) sr3.setFetchWaitTime(3L) tm.mergeShuffleReadMetrics() - assert(tm.shuffleReadMetrics.isDefined) - val sr = tm.shuffleReadMetrics.get - assert(sr.remoteBlocksFetched === 0L) - assert(sr.recordsRead === 20L) - assert(sr.fetchWaitTime === 6L) + assert(tm.shuffleReadMetrics.remoteBlocksFetched === 0L) + assert(tm.shuffleReadMetrics.recordsRead === 20L) + assert(tm.shuffleReadMetrics.fetchWaitTime === 6L) // SPARK-5701: calling merge without any shuffle deps does nothing val tm2 = new TaskMetrics tm2.mergeShuffleReadMetrics() - assert(tm2.shuffleReadMetrics.isEmpty) - } - - test("register multiple shuffle write metrics") { - val tm = new TaskMetrics - val sw1 = tm.registerShuffleWriteMetrics() - val sw2 = tm.registerShuffleWriteMetrics() - assert(sw1 === sw2) - assert(tm.shuffleWriteMetrics === Some(sw1)) - } - - test("register multiple input metrics") { - val tm = new TaskMetrics - val im1 = tm.registerInputMetrics(DataReadMethod.Memory) - val im2 = tm.registerInputMetrics(DataReadMethod.Memory) - // input metrics with a different read method than the one already registered are ignored - val im3 = tm.registerInputMetrics(DataReadMethod.Hadoop) - assert(im1 === im2) - assert(im1 !== im3) - assert(tm.inputMetrics === Some(im1)) - im2.setBytesRead(50L) - im3.setBytesRead(100L) - assert(tm.inputMetrics.get.bytesRead === 50L) - } - - test("register multiple output metrics") { - val tm = new TaskMetrics - val om1 = tm.registerOutputMetrics(DataWriteMethod.Hadoop) - val om2 = tm.registerOutputMetrics(DataWriteMethod.Hadoop) - assert(om1 === om2) - assert(tm.outputMetrics === Some(om1)) } test("additional accumulables") { - val internalAccums = InternalAccumulator.createAll() - val tm = new TaskMetrics(internalAccums) - assert(tm.accumulatorUpdates().size === internalAccums.size) + val tm = new TaskMetrics val acc1 = new Accumulator(0, IntAccumulatorParam, Some("a")) val acc2 = new Accumulator(0, IntAccumulatorParam, Some("b")) val acc3 = new Accumulator(0, IntAccumulatorParam, Some("c")) @@ -414,63 +225,11 @@ class TaskMetricsSuite extends SparkFunSuite { assert(newUpdates(acc4.id).countFailedValues) assert(newUpdates.values.map(_.update).forall(_.isDefined)) assert(newUpdates.values.map(_.value).forall(_.isEmpty)) - assert(newUpdates.size === internalAccums.size + 4) - } - - test("existing values in shuffle read accums") { - // set shuffle read accum before passing it into TaskMetrics - val accums = InternalAccumulator.createAll() - val srAccum = accums.find(_.name === Some(shuffleRead.FETCH_WAIT_TIME)) - assert(srAccum.isDefined) - srAccum.get.asInstanceOf[Accumulator[Long]] += 10L - val tm = new TaskMetrics(accums) - assert(tm.shuffleReadMetrics.isDefined) - assert(tm.shuffleWriteMetrics.isEmpty) - assert(tm.inputMetrics.isEmpty) - assert(tm.outputMetrics.isEmpty) - } - - test("existing values in shuffle write accums") { - // set shuffle write accum before passing it into TaskMetrics - val accums = InternalAccumulator.createAll() - val swAccum = accums.find(_.name === Some(shuffleWrite.RECORDS_WRITTEN)) - assert(swAccum.isDefined) - swAccum.get.asInstanceOf[Accumulator[Long]] += 10L - val tm = new TaskMetrics(accums) - assert(tm.shuffleReadMetrics.isEmpty) - assert(tm.shuffleWriteMetrics.isDefined) - assert(tm.inputMetrics.isEmpty) - assert(tm.outputMetrics.isEmpty) - } - - test("existing values in input accums") { - // set input accum before passing it into TaskMetrics - val accums = InternalAccumulator.createAll() - val inAccum = accums.find(_.name === Some(input.RECORDS_READ)) - assert(inAccum.isDefined) - inAccum.get.asInstanceOf[Accumulator[Long]] += 10L - val tm = new TaskMetrics(accums) - assert(tm.shuffleReadMetrics.isEmpty) - assert(tm.shuffleWriteMetrics.isEmpty) - assert(tm.inputMetrics.isDefined) - assert(tm.outputMetrics.isEmpty) - } - - test("existing values in output accums") { - // set output accum before passing it into TaskMetrics - val accums = InternalAccumulator.createAll() - val outAccum = accums.find(_.name === Some(output.RECORDS_WRITTEN)) - assert(outAccum.isDefined) - outAccum.get.asInstanceOf[Accumulator[Long]] += 10L - val tm4 = new TaskMetrics(accums) - assert(tm4.shuffleReadMetrics.isEmpty) - assert(tm4.shuffleWriteMetrics.isEmpty) - assert(tm4.inputMetrics.isEmpty) - assert(tm4.outputMetrics.isDefined) + assert(newUpdates.size === tm.internalAccums.size + 4) } test("from accumulator updates") { - val accumUpdates1 = InternalAccumulator.createAll().map { a => + val accumUpdates1 = TaskMetrics.empty.internalAccums.map { a => AccumulableInfo(a.id, a.name, Some(3L), None, a.isInternal, a.countFailedValues) } val metrics1 = TaskMetrics.fromAccumulatorUpdates(accumUpdates1) @@ -504,29 +263,6 @@ class TaskMetricsSuite extends SparkFunSuite { private[spark] object TaskMetricsSuite extends Assertions { - /** - * Assert that the following three things are equal to `value`: - * (1) TaskMetrics value - * (2) TaskMetrics accumulator update value - * (3) Original accumulator value - */ - def assertValueEquals( - tm: TaskMetrics, - tmValue: TaskMetrics => Any, - accums: Seq[Accumulator[_]], - metricName: String, - value: Any, - assertEquals: (Any, Any) => Unit = (x: Any, y: Any) => assert(x === y)): Unit = { - assertEquals(tmValue(tm), value) - val accum = accums.find(_.name == Some(metricName)) - assert(accum.isDefined) - assertEquals(accum.get.value, value) - val accumUpdate = tm.accumulatorUpdates().find(_.name == Some(metricName)) - assert(accumUpdate.isDefined) - assert(accumUpdate.get.value === None) - assertEquals(accumUpdate.get.update, Some(value)) - } - /** * Assert that two lists of accumulator updates are equal. * Note: this does NOT check accumulator ID equality. @@ -550,5 +286,4 @@ private[spark] object TaskMetricsSuite extends Assertions { * info as an accumulator update. */ def makeInfo(a: Accumulable[_, _]): AccumulableInfo = a.toInfo(Some(a.value), None) - } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 99d5b496bcd2e..a1286523a235d 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.memory import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration.Duration import org.mockito.Matchers.{any, anyLong} @@ -33,6 +33,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark.SparkFunSuite import org.apache.spark.storage.{BlockId, BlockStatus, StorageLevel} import org.apache.spark.storage.memory.MemoryStore +import org.apache.spark.util.ThreadUtils /** @@ -172,15 +173,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft // Have both tasks request 500 bytes, then wait until both requests have been granted: val t1Result1 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result1, futureTimeout) === 500L) - assert(Await.result(t2Result1, futureTimeout) === 500L) + assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 500L) + assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 500L) // Have both tasks each request 500 bytes more; both should immediately return 0 as they are // both now at 1 / N val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result2, 200.millis) === 0L) - assert(Await.result(t2Result2, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(t1Result2, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(t2Result2, 200.millis) === 0L) } test("two tasks cannot grow past 1 / N of on-heap execution memory") { @@ -192,15 +193,15 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft // Have both tasks request 250 bytes, then wait until both requests have been granted: val t1Result1 = Future { t1MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result1, futureTimeout) === 250L) - assert(Await.result(t2Result1, futureTimeout) === 250L) + assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 250L) + assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 250L) // Have both tasks each request 500 bytes more. // We should only grant 250 bytes to each of them on this second request val t1Result2 = Future { t1MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result2, futureTimeout) === 250L) - assert(Await.result(t2Result2, futureTimeout) === 250L) + assert(ThreadUtils.awaitResult(t1Result2, futureTimeout) === 250L) + assert(ThreadUtils.awaitResult(t2Result2, futureTimeout) === 250L) } test("tasks can block to get at least 1 / 2N of on-heap execution memory") { @@ -211,17 +212,17 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result1, futureTimeout) === 1000L) + assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 1000L) val t2Result1 = Future { t2MemManager.acquireExecutionMemory(250L, MemoryMode.ON_HEAP, null) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) t1MemManager.releaseExecutionMemory(250L, MemoryMode.ON_HEAP, null) // The memory freed from t1 should now be granted to t2. - assert(Await.result(t2Result1, futureTimeout) === 250L) + assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 250L) // Further requests by t2 should be denied immediately because it now has 1 / 2N of the memory. val t2Result2 = Future { t2MemManager.acquireExecutionMemory(100L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t2Result2, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(t2Result2, 200.millis) === 0L) } test("TaskMemoryManager.cleanUpAllAllocatedMemory") { @@ -232,18 +233,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft // t1 grabs 1000 bytes and then waits until t2 is ready to make a request. val t1Result1 = Future { t1MemManager.acquireExecutionMemory(1000L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result1, futureTimeout) === 1000L) + assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 1000L) val t2Result1 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } // Make sure that t2 didn't grab the memory right away. This is hacky but it would be difficult // to make sure the other thread blocks for some time otherwise. Thread.sleep(300) // t1 releases all of its memory, so t2 should be able to grab all of the memory t1MemManager.cleanUpAllAllocatedMemory() - assert(Await.result(t2Result1, futureTimeout) === 500L) + assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 500L) val t2Result2 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t2Result2, futureTimeout) === 500L) + assert(ThreadUtils.awaitResult(t2Result2, futureTimeout) === 500L) val t2Result3 = Future { t2MemManager.acquireExecutionMemory(500L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t2Result3, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(t2Result3, 200.millis) === 0L) } test("tasks should not be granted a negative amount of execution memory") { @@ -254,13 +255,13 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft val futureTimeout: Duration = 20.seconds val t1Result1 = Future { t1MemManager.acquireExecutionMemory(700L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result1, futureTimeout) === 700L) + assert(ThreadUtils.awaitResult(t1Result1, futureTimeout) === 700L) val t2Result1 = Future { t2MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t2Result1, futureTimeout) === 300L) + assert(ThreadUtils.awaitResult(t2Result1, futureTimeout) === 300L) val t1Result2 = Future { t1MemManager.acquireExecutionMemory(300L, MemoryMode.ON_HEAP, null) } - assert(Await.result(t1Result2, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(t1Result2, 200.millis) === 0L) } test("off-heap execution allocations cannot exceed limit") { @@ -270,11 +271,11 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft val tMemManager = new TaskMemoryManager(memoryManager, 1) val result1 = Future { tMemManager.acquireExecutionMemory(1000L, MemoryMode.OFF_HEAP, null) } - assert(Await.result(result1, 200.millis) === 1000L) + assert(ThreadUtils.awaitResult(result1, 200.millis) === 1000L) assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) val result2 = Future { tMemManager.acquireExecutionMemory(300L, MemoryMode.OFF_HEAP, null) } - assert(Await.result(result2, 200.millis) === 0L) + assert(ThreadUtils.awaitResult(result2, 200.millis) === 0L) assert(tMemManager.getMemoryConsumptionForThisTask === 1000L) tMemManager.releaseExecutionMemory(500L, MemoryMode.OFF_HEAP, null) diff --git a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala index 056e5463a0abf..f8054f5fd7701 100644 --- a/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala +++ b/core/src/test/scala/org/apache/spark/metrics/InputOutputMetricsSuite.scala @@ -25,16 +25,10 @@ import org.apache.commons.lang3.RandomUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.io.{LongWritable, Text} -import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, - JobConf, LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, - Reporter, TextInputFormat => OldTextInputFormat} -import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, - CombineFileRecordReader => OldCombineFileRecordReader, CombineFileSplit => OldCombineFileSplit} -import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader, - TaskAttemptContext} -import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, - CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, - FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} +import org.apache.hadoop.mapred.{FileSplit => OldFileSplit, InputSplit => OldInputSplit, JobConf, LineRecordReader => OldLineRecordReader, RecordReader => OldRecordReader, Reporter, TextInputFormat => OldTextInputFormat} +import org.apache.hadoop.mapred.lib.{CombineFileInputFormat => OldCombineFileInputFormat, CombineFileRecordReader => OldCombineFileRecordReader, CombineFileSplit => OldCombineFileSplit} +import org.apache.hadoop.mapreduce.{InputSplit => NewInputSplit, RecordReader => NewRecordReader, TaskAttemptContext} +import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat => NewCombineFileInputFormat, CombineFileRecordReader => NewCombineFileRecordReader, CombineFileSplit => NewCombineFileSplit, FileSplit => NewFileSplit, TextInputFormat => NewTextInputFormat} import org.apache.hadoop.mapreduce.lib.output.{TextOutputFormat => NewTextOutputFormat} import org.scalatest.BeforeAndAfter @@ -103,40 +97,6 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext assert(bytesRead2 == bytesRead) } - /** - * This checks the situation where we have interleaved reads from - * different sources. Currently, we only accumulate from the first - * read method we find in the task. This test uses cartesian to create - * the interleaved reads. - * - * Once https://issues.apache.org/jira/browse/SPARK-5225 is fixed - * this test should break. - */ - test("input metrics with mixed read method") { - // prime the cache manager - val numPartitions = 2 - val rdd = sc.parallelize(1 to 100, numPartitions).cache() - rdd.collect() - - val rdd2 = sc.textFile(tmpFilePath, numPartitions) - - val bytesRead = runAndReturnBytesRead { - rdd.count() - } - val bytesRead2 = runAndReturnBytesRead { - rdd2.count() - } - - val cartRead = runAndReturnBytesRead { - rdd.cartesian(rdd2).count() - } - - assert(cartRead != 0) - assert(bytesRead != 0) - // We read from the first rdd of the cartesian once per partition. - assert(cartRead == bytesRead * numPartitions) - } - test("input metrics for new Hadoop API with coalesce") { val bytesRead = runAndReturnBytesRead { sc.newAPIHadoopFile(tmpFilePath, classOf[NewTextInputFormat], classOf[LongWritable], @@ -209,10 +169,10 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { val metrics = taskEnd.taskMetrics - metrics.inputMetrics.foreach(inputRead += _.recordsRead) - metrics.outputMetrics.foreach(outputWritten += _.recordsWritten) - metrics.shuffleReadMetrics.foreach(shuffleRead += _.recordsRead) - metrics.shuffleWriteMetrics.foreach(shuffleWritten += _.recordsWritten) + inputRead += metrics.inputMetrics.recordsRead + outputWritten += metrics.outputMetrics.recordsWritten + shuffleRead += metrics.shuffleReadMetrics.recordsRead + shuffleWritten += metrics.shuffleWriteMetrics.recordsWritten } }) @@ -272,19 +232,18 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext } private def runAndReturnBytesRead(job: => Unit): Long = { - runAndReturnMetrics(job, _.taskMetrics.inputMetrics.map(_.bytesRead)) + runAndReturnMetrics(job, _.taskMetrics.inputMetrics.bytesRead) } private def runAndReturnRecordsRead(job: => Unit): Long = { - runAndReturnMetrics(job, _.taskMetrics.inputMetrics.map(_.recordsRead)) + runAndReturnMetrics(job, _.taskMetrics.inputMetrics.recordsRead) } private def runAndReturnRecordsWritten(job: => Unit): Long = { - runAndReturnMetrics(job, _.taskMetrics.outputMetrics.map(_.recordsWritten)) + runAndReturnMetrics(job, _.taskMetrics.outputMetrics.recordsWritten) } - private def runAndReturnMetrics(job: => Unit, - collector: (SparkListenerTaskEnd) => Option[Long]): Long = { + private def runAndReturnMetrics(job: => Unit, collector: (SparkListenerTaskEnd) => Long): Long = { val taskMetrics = new ArrayBuffer[Long]() // Avoid receiving earlier taskEnd events @@ -292,7 +251,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - collector(taskEnd).foreach(taskMetrics += _) + taskMetrics += collector(taskEnd) } }) @@ -337,7 +296,7 @@ class InputOutputMetricsSuite extends SparkFunSuite with SharedSparkContext val taskBytesWritten = new ArrayBuffer[Long]() sc.addSparkListener(new SparkListener() { override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - taskBytesWritten += taskEnd.taskMetrics.outputMetrics.get.bytesWritten + taskBytesWritten += taskEnd.taskMetrics.outputMetrics.bytesWritten } }) diff --git a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala index f3c156e4f709d..e7df7cb419339 100644 --- a/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala +++ b/core/src/test/scala/org/apache/spark/network/netty/NettyBlockTransferServiceSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.network.netty +import scala.util.Random + import org.mockito.Mockito.mock import org.scalatest._ @@ -59,19 +61,26 @@ class NettyBlockTransferServiceSuite } test("can bind to a specific port") { - val port = 17634 + val port = 17634 + Random.nextInt(10000) + logInfo("random port for test: " + port) service0 = createService(port) - service0.port should be >= port - service0.port should be <= (port + 10) // avoid testing equality in case of simultaneous tests + verifyServicePort(expectedPort = port, actualPort = service0.port) } test("can bind to a specific port twice and the second increments") { - val port = 17634 + val port = 17634 + Random.nextInt(10000) + logInfo("random port for test: " + port) service0 = createService(port) - service1 = createService(port) - service0.port should be >= port - service0.port should be <= (port + 10) - service1.port should be (service0.port + 1) + verifyServicePort(expectedPort = port, actualPort = service0.port) + service1 = createService(service0.port) + // `service0.port` is occupied, so `service1.port` should not be `service0.port` + verifyServicePort(expectedPort = service0.port + 1, actualPort = service1.port) + } + + private def verifyServicePort(expectedPort: Int, actualPort: Int): Unit = { + actualPort should be >= expectedPort + // avoid testing equality in case of simultaneous tests + actualPort should be <= (expectedPort + 10) } private def createService(port: Int): NettyBlockTransferService = { diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index d18bde790b40a..8cb0a295b0773 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -28,6 +28,7 @@ import org.scalatest.concurrent.Timeouts import org.scalatest.time.SpanSugar._ import org.apache.spark._ +import org.apache.spark.util.ThreadUtils class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Timeouts { @@ -185,22 +186,23 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim test("FutureAction result, infinite wait") { val f = sc.parallelize(1 to 100, 4) .countAsync() - assert(Await.result(f, Duration.Inf) === 100) + assert(ThreadUtils.awaitResult(f, Duration.Inf) === 100) } test("FutureAction result, finite wait") { val f = sc.parallelize(1 to 100, 4) .countAsync() - assert(Await.result(f, Duration(30, "seconds")) === 100) + assert(ThreadUtils.awaitResult(f, Duration(30, "seconds")) === 100) } test("FutureAction result, timeout") { val f = sc.parallelize(1 to 100, 4) .mapPartitions(itr => { Thread.sleep(20); itr }) .countAsync() - intercept[TimeoutException] { - Await.result(f, Duration(20, "milliseconds")) + val e = intercept[SparkException] { + ThreadUtils.awaitResult(f, Duration(20, "milliseconds")) } + assert(e.getCause.isInstanceOf[TimeoutException]) } private def testAsyncAction[R](action: RDD[Int] => FutureAction[R]): Unit = { @@ -221,7 +223,7 @@ class AsyncRDDActionsSuite extends SparkFunSuite with BeforeAndAfterAll with Tim // Now allow the executors to proceed with task processing. starter.release(rdd.partitions.length) // Waiting for the result verifies that the tasks were successfully processed. - Await.result(executionContextInvoked.future, atMost = 15.seconds) + ThreadUtils.awaitResult(executionContextInvoked.future, atMost = 15.seconds) } test("SimpleFutureAction callback must not consume a thread while waiting") { diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 24daedab2090f..8dc463d56d182 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -17,13 +17,15 @@ package org.apache.spark.rdd -import java.io.{IOException, ObjectInputStream, ObjectOutputStream} +import java.io.{File, IOException, ObjectInputStream, ObjectOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.reflect.ClassTag import com.esotericsoftware.kryo.KryoException +import org.apache.hadoop.io.{LongWritable, Text} +import org.apache.hadoop.mapred.{FileSplit, TextInputFormat} import org.apache.spark._ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext} @@ -31,6 +33,20 @@ import org.apache.spark.rdd.RDDSuiteUtils._ import org.apache.spark.util.Utils class RDDSuite extends SparkFunSuite with SharedSparkContext { + var tempDir: File = _ + + override def beforeAll(): Unit = { + super.beforeAll() + tempDir = Utils.createTempDir() + } + + override def afterAll(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterAll() + } + } test("basic operations") { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) @@ -951,6 +967,32 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(thrown.getMessage.contains("SPARK-5063")) } + test("custom RDD coalescer") { + val maxSplitSize = 512 + val outDir = new File(tempDir, "output").getAbsolutePath + sc.makeRDD(1 to 1000, 10).saveAsTextFile(outDir) + val hadoopRDD = + sc.hadoopFile(outDir, classOf[TextInputFormat], classOf[LongWritable], classOf[Text]) + val coalescedHadoopRDD = + hadoopRDD.coalesce(2, partitionCoalescer = Option(new SizeBasedCoalescer(maxSplitSize))) + assert(coalescedHadoopRDD.partitions.size <= 10) + var totalPartitionCount = 0L + coalescedHadoopRDD.partitions.foreach(partition => { + var splitSizeSum = 0L + partition.asInstanceOf[CoalescedRDDPartition].parents.foreach(partition => { + val split = partition.asInstanceOf[HadoopPartition].inputSplit.value.asInstanceOf[FileSplit] + splitSizeSum += split.getLength + totalPartitionCount += 1 + }) + assert(splitSizeSum <= maxSplitSize) + }) + assert(totalPartitionCount == 10) + } + + // NOTE + // Below tests calling sc.stop() have to be the last tests in this suite. If there are tests + // running after them and if they access sc those tests will fail as sc is already closed, because + // sc is shared (this suite mixins SharedSparkContext) test("cannot run actions after SparkContext has been stopped (SPARK-5063)") { val existingRDD = sc.parallelize(1 to 100) sc.stop() @@ -971,5 +1013,60 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assertFails { sc.parallelize(1 to 100) } assertFails { sc.textFile("/nonexistent-path") } } +} +/** + * Coalesces partitions based on their size assuming that the parent RDD is a [[HadoopRDD]]. + * Took this class out of the test suite to prevent "Task not serializable" exceptions. + */ +class SizeBasedCoalescer(val maxSize: Int) extends PartitionCoalescer with Serializable { + override def coalesce(maxPartitions: Int, parent: RDD[_]): Array[PartitionGroup] = { + val partitions: Array[Partition] = parent.asInstanceOf[HadoopRDD[Any, Any]].getPartitions + val groups = ArrayBuffer[PartitionGroup]() + var currentGroup = new PartitionGroup() + var currentSum = 0L + var totalSum = 0L + var index = 0 + + // sort partitions based on the size of the corresponding input splits + partitions.sortWith((partition1, partition2) => { + val partition1Size = partition1.asInstanceOf[HadoopPartition].inputSplit.value.getLength + val partition2Size = partition2.asInstanceOf[HadoopPartition].inputSplit.value.getLength + partition1Size < partition2Size + }) + + def updateGroups(): Unit = { + groups += currentGroup + currentGroup = new PartitionGroup() + currentSum = 0 + } + + def addPartition(partition: Partition, splitSize: Long): Unit = { + currentGroup.partitions += partition + currentSum += splitSize + totalSum += splitSize + } + + while (index < partitions.size) { + val partition = partitions(index) + val fileSplit = + partition.asInstanceOf[HadoopPartition].inputSplit.value.asInstanceOf[FileSplit] + val splitSize = fileSplit.getLength + if (currentSum + splitSize < maxSize) { + addPartition(partition, splitSize) + index += 1 + if (index == partitions.size) { + updateGroups + } + } else { + if (currentGroup.partitions.size == 0) { + addPartition(partition, splitSize) + index += 1 + } else { + updateGroups + } + } + } + groups.toArray + } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index cebac2097f380..73803ec21a567 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -35,7 +35,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Common tests for an RpcEnv implementation. @@ -415,7 +415,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { }) val f = endpointRef.ask[String]("Hi") - val ack = Await.result(f, 5 seconds) + val ack = ThreadUtils.awaitResult(f, 5 seconds) assert("ack" === ack) env.stop(endpointRef) @@ -435,7 +435,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val rpcEndpointRef = anotherEnv.setupEndpointRef(env.address, "sendWithReply-remotely") try { val f = rpcEndpointRef.ask[String]("hello") - val ack = Await.result(f, 5 seconds) + val ack = ThreadUtils.awaitResult(f, 5 seconds) assert("ack" === ack) } finally { anotherEnv.shutdown() @@ -454,9 +454,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { val f = endpointRef.ask[String]("Hi") val e = intercept[SparkException] { - Await.result(f, 5 seconds) + ThreadUtils.awaitResult(f, 5 seconds) } - assert("Oops" === e.getMessage) + assert("Oops" === e.getCause.getMessage) env.stop(endpointRef) } @@ -476,9 +476,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { try { val f = rpcEndpointRef.ask[String]("hello") val e = intercept[SparkException] { - Await.result(f, 5 seconds) + ThreadUtils.awaitResult(f, 5 seconds) } - assert("Oops" === e.getMessage) + assert("Oops" === e.getCause.getMessage) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -487,6 +487,7 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { /** * Setup an [[RpcEndpoint]] to collect all network events. + * * @return the [[RpcEndpointRef]] and an `ConcurrentLinkedQueue` that contains network events. */ private def setupNetworkEndpoint( @@ -620,10 +621,10 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { anotherEnv.setupEndpointRef(env.address, "sendWithReply-unserializable-error") try { val f = rpcEndpointRef.ask[String]("hello") - val e = intercept[Exception] { - Await.result(f, 1 seconds) + val e = intercept[SparkException] { + ThreadUtils.awaitResult(f, 1 seconds) } - assert(e.isInstanceOf[NotSerializableException]) + assert(e.getCause.isInstanceOf[NotSerializableException]) } finally { anotherEnv.shutdown() anotherEnv.awaitTermination() @@ -754,15 +755,17 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { // RpcTimeout.awaitResult should have added the property to the TimeoutException message assert(reply2.contains(shortTimeout.timeoutProp)) - // Ask with delayed response and allow the Future to timeout before Await.result + // Ask with delayed response and allow the Future to timeout before ThreadUtils.awaitResult val fut3 = rpcEndpointRef.ask[String](NeverReply("goodbye"), shortTimeout) + // scalastyle:off awaitresult // Allow future to complete with failure using plain Await.result, this will return // once the future is complete to verify addMessageIfTimeout was invoked val reply3 = intercept[RpcTimeoutException] { Await.result(fut3, 2000 millis) }.getMessage + // scalastyle:on awaitresult // When the future timed out, the recover callback should have used // RpcTimeout.addMessageIfTimeout to add the property to the TimeoutException message diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala index 994a58836bd0d..2d6543d328618 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/NettyRpcEnvSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.rpc.netty -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark._ import org.apache.spark.rpc._ class NettyRpcEnvSuite extends RpcEnvSuite { @@ -34,10 +34,11 @@ class NettyRpcEnvSuite extends RpcEnvSuite { test("non-existent endpoint") { val uri = RpcEndpointAddress(env.address, "nonexist-endpoint").toString - val e = intercept[RpcEndpointNotFoundException] { + val e = intercept[SparkException] { env.setupEndpointRef(env.address, "nonexist-endpoint") } - assert(e.getMessage.contains(uri)) + assert(e.getCause.isInstanceOf[RpcEndpointNotFoundException]) + assert(e.getCause.getMessage.contains(uri)) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index fd96fb04f8b29..b76c0a4bd1dde 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -1144,7 +1144,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou // SPARK-9809 -- this stage is submitted without a task for each partition (because some of // the shuffle map output is still available from stage 0); make sure we've still got internal // accumulators setup - assert(scheduler.stageIdToStage(2).latestInfo.internalAccumulators.nonEmpty) + assert(scheduler.stageIdToStage(2).latestInfo.taskMetrics != null) completeShuffleMapStageSuccessfully(2, 0, 2) completeNextResultStageWithSuccess(3, 1, idx => idx + 1234) assert(results === Map(0 -> 1234, 1 -> 1235)) @@ -2010,7 +2010,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou extraAccumUpdates: Seq[AccumulableInfo] = Seq.empty[AccumulableInfo], taskInfo: TaskInfo = createFakeTaskInfo()): CompletionEvent = { val accumUpdates = reason match { - case Success => task.initialAccumulators.map { a => a.toInfo(Some(a.zero), None) } + case Success => task.metrics.accumulatorUpdates() case ef: ExceptionFailure => ef.accumUpdates case _ => Seq.empty[AccumulableInfo] } diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala new file mode 100644 index 0000000000000..9971d48a52ce7 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.storage.BlockManagerId + +class ExternalClusterManagerSuite extends SparkFunSuite with LocalSparkContext +{ + test("launch of backend and scheduler") { + val conf = new SparkConf().setMaster("myclusterManager"). + setAppName("testcm").set("spark.driver.allowMultipleContexts", "true") + sc = new SparkContext(conf) + // check if the scheduler components are created + assert(sc.schedulerBackend.isInstanceOf[DummySchedulerBackend]) + assert(sc.taskScheduler.isInstanceOf[DummyTaskScheduler]) + } +} + +private class DummyExternalClusterManager extends ExternalClusterManager { + + def canCreate(masterURL: String): Boolean = masterURL == "myclusterManager" + + def createTaskScheduler(sc: SparkContext, + masterURL: String): TaskScheduler = new DummyTaskScheduler + + def createSchedulerBackend(sc: SparkContext, + masterURL: String, + scheduler: TaskScheduler): SchedulerBackend = new DummySchedulerBackend() + + def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = {} + +} + +private class DummySchedulerBackend extends SchedulerBackend { + def start() {} + def stop() {} + def reviveOffers() {} + def defaultParallelism(): Int = 1 +} + +private class DummyTaskScheduler extends TaskScheduler { + override def rootPool: Pool = null + override def schedulingMode: SchedulingMode = SchedulingMode.NONE + override def start(): Unit = {} + override def stop(): Unit = {} + override def submitTasks(taskSet: TaskSet): Unit = {} + override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = {} + override def setDAGScheduler(dagScheduler: DAGScheduler): Unit = {} + override def defaultParallelism(): Int = 2 + override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} + override def applicationAttemptId(): Option[String] = None + def executorHeartbeatReceived( + execId: String, + accumUpdates: Array[(Long, Seq[AccumulableInfo])], + blockManagerId: BlockManagerId): Boolean = true +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala index e3e6df6831def..4fe705b201ec8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/FakeTask.scala @@ -17,14 +17,11 @@ package org.apache.spark.scheduler -import java.util.Properties - import org.apache.spark.TaskContext class FakeTask( stageId: Int, - prefLocs: Seq[TaskLocation] = Nil) - extends Task[Int](stageId, 0, 0, Seq.empty, new Properties) { + prefLocs: Seq[TaskLocation] = Nil) extends Task[Int](stageId, 0, 0) { override def runTask(context: TaskContext): Int = 0 override def preferredLocations: Seq[TaskLocation] = prefLocs } diff --git a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala index 76a7087645961..255be6f46b06b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/NotSerializableFakeTask.scala @@ -18,7 +18,6 @@ package org.apache.spark.scheduler import java.io.{IOException, ObjectInputStream, ObjectOutputStream} -import java.util.Properties import org.apache.spark.TaskContext @@ -26,7 +25,7 @@ import org.apache.spark.TaskContext * A Task implementation that fails to serialize. */ private[spark] class NotSerializableFakeTask(myId: Int, stageId: Int) - extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) { + extends Task[Array[Byte]](stageId, 0, 0) { override def runTask(context: TaskContext): Array[Byte] = Array.empty[Byte] override def preferredLocations: Seq[TaskLocation] = Seq[TaskLocation]() diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 8e509de7677c3..83288db92bb43 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import java.io.File import java.util.concurrent.TimeoutException -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps @@ -33,7 +32,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark._ import org.apache.spark.rdd.{FakeOutputCommitter, RDD} -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * Unit tests for the output commit coordination functionality. @@ -159,9 +158,10 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { 0 until rdd.partitions.size, resultHandler, () => Unit) // It's an error if the job completes successfully even though no committer was authorized, // so throw an exception if the job was allowed to complete. - intercept[TimeoutException] { - Await.result(futureAction, 5 seconds) + val e = intercept[SparkException] { + ThreadUtils.awaitResult(futureAction, 5 seconds) } + assert(e.getCause.isInstanceOf[TimeoutException]) assert(tempDir.list().size === 0) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index b854d742b5bdd..5ba67afc0cd62 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -266,18 +266,13 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match taskInfoMetrics.foreach { case (taskInfo, taskMetrics) => taskMetrics.resultSize should be > (0L) if (stageInfo.rddInfos.exists(info => info.name == d2.name || info.name == d3.name)) { - taskMetrics.inputMetrics should not be ('defined) - taskMetrics.outputMetrics should not be ('defined) - taskMetrics.shuffleWriteMetrics should be ('defined) - taskMetrics.shuffleWriteMetrics.get.bytesWritten should be > (0L) + assert(taskMetrics.shuffleWriteMetrics.bytesWritten > 0L) } if (stageInfo.rddInfos.exists(_.name == d4.name)) { - taskMetrics.shuffleReadMetrics should be ('defined) - val sm = taskMetrics.shuffleReadMetrics.get - sm.totalBlocksFetched should be (2*numSlices) - sm.localBlocksFetched should be (2*numSlices) - sm.remoteBlocksFetched should be (0) - sm.remoteBytesRead should be (0L) + assert(taskMetrics.shuffleReadMetrics.totalBlocksFetched == 2 * numSlices) + assert(taskMetrics.shuffleReadMetrics.localBlocksFetched == 2 * numSlices) + assert(taskMetrics.shuffleReadMetrics.remoteBlocksFetched == 0) + assert(taskMetrics.shuffleReadMetrics.remoteBytesRead == 0L) } } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index 86911d2211a3a..bda4c996b27df 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -24,7 +24,7 @@ import org.mockito.Mockito._ import org.scalatest.BeforeAndAfter import org.apache.spark._ -import org.apache.spark.executor.{Executor, TaskMetricsSuite} +import org.apache.spark.executor.{Executor, TaskMetrics, TaskMetricsSuite} import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils @@ -62,7 +62,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties) + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, new TaskMetrics) intercept[RuntimeException] { task.run(0, 0, null) } @@ -83,7 +83,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val func = (c: TaskContext, i: Iterator[String]) => i.next() val taskBinary = sc.broadcast(JavaUtils.bufferToArray(closureSerializer.serialize((rdd, func)))) val task = new ResultTask[String, String]( - 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties) + 0, 0, taskBinary, rdd.partitions(0), Seq.empty, 0, new Properties, new TaskMetrics) intercept[RuntimeException] { task.run(0, 0, null) } @@ -171,26 +171,27 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark val param = AccumulatorParam.LongAccumulatorParam val acc1 = new Accumulator(0L, param, Some("x"), internal = false, countFailedValues = true) val acc2 = new Accumulator(0L, param, Some("y"), internal = false, countFailedValues = false) - val initialAccums = InternalAccumulator.createAll() // Create a dummy task. We won't end up running this; we just want to collect // accumulator updates from it. - val task = new Task[Int](0, 0, 0, Seq.empty[Accumulator[_]], new Properties) { + val taskMetrics = new TaskMetrics + val task = new Task[Int](0, 0, 0) { context = new TaskContextImpl(0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, - initialAccums) - context.taskMetrics.registerAccumulator(acc1) - context.taskMetrics.registerAccumulator(acc2) + taskMetrics) + taskMetrics.registerAccumulator(acc1) + taskMetrics.registerAccumulator(acc2) override def runTask(tc: TaskContext): Int = 0 } // First, simulate task success. This should give us all the accumulators. val accumUpdates1 = task.collectAccumulatorUpdates(taskFailed = false) - val accumUpdates2 = (initialAccums ++ Seq(acc1, acc2)).map(TaskMetricsSuite.makeInfo) + val accumUpdates2 = (taskMetrics.internalAccums ++ Seq(acc1, acc2)) + .map(TaskMetricsSuite.makeInfo) TaskMetricsSuite.assertUpdatesEquals(accumUpdates1, accumUpdates2) // Now, simulate task failures. This should give us only the accums that count failed values. val accumUpdates3 = task.collectAccumulatorUpdates(taskFailed = true) - val accumUpdates4 = (initialAccums ++ Seq(acc1)).map(TaskMetricsSuite.makeInfo) + val accumUpdates4 = (taskMetrics.internalAccums ++ Seq(acc1)).map(TaskMetricsSuite.makeInfo) TaskMetricsSuite.assertUpdatesEquals(accumUpdates3, accumUpdates4) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ade8e84d848f0..ecf4b76da5586 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,9 +17,8 @@ package org.apache.spark.scheduler -import java.util.{Properties, Random} +import java.util.Random -import scala.collection.Map import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -138,7 +137,8 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex /** * A Task implementation that results in a large serialized task. */ -class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0, Seq.empty, new Properties) { +class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0, 0) { + val randomBuffer = new Array[Byte](TaskSetManager.TASK_SIZE_TO_WARN_KB * 1024) val random = new Random(0) random.nextBytes(randomBuffer) @@ -166,7 +166,8 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(1) val clock = new ManualClock val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) - val accumUpdates = taskSet.tasks.head.initialAccumulators.map { a => a.toInfo(Some(0L), None) } + val accumUpdates = + taskSet.tasks.head.metrics.internalAccums.map { a => a.toInfo(Some(0L), None) } // Offer a host with NO_PREF as the constraint, // we should get a nopref task immediately since that's what we only have @@ -185,7 +186,7 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(3) val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) val accumUpdatesByTask: Array[Seq[AccumulableInfo]] = taskSet.tasks.map { task => - task.initialAccumulators.map { a => a.toInfo(Some(0L), None) } + task.metrics.internalAccums.map { a => a.toInfo(Some(0L), None) } } // First three offers should all find tasks diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 16418f855bbe1..5132384a5ed7d 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -144,7 +144,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(outputFile.exists()) assert(outputFile.length() === 0) assert(temporaryFilesCreated.isEmpty) - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics assert(shuffleWriteMetrics.bytesWritten === 0) assert(shuffleWriteMetrics.recordsWritten === 0) assert(taskMetrics.diskBytesSpilled === 0) @@ -168,7 +168,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(writer.getPartitionLengths.sum === outputFile.length()) assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics.get + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) assert(shuffleWriteMetrics.recordsWritten === records.length) assert(taskMetrics.diskBytesSpilled === 0) diff --git a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala index 88817dccf3497..d223af1496a4b 100644 --- a/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala +++ b/core/src/test/scala/org/apache/spark/status/api/v1/AllStagesResourceSuite.scala @@ -38,7 +38,7 @@ class AllStagesResourceSuite extends SparkFunSuite { stageUiData.taskData = tasks val status = StageStatus.ACTIVE val stageInfo = new StageInfo( - 1, 1, "stage 1", 10, Seq.empty, Seq.empty, "details abc", Seq.empty) + 1, 1, "stage 1", 10, Seq.empty, Seq.empty, "details abc") val stageData = AllStagesResource.stageUiToStageData(status, stageInfo, stageUiData, false) stageData.firstTaskLaunchedTime diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 9d1bd7ec89bc7..9ee83b76e71dc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.BeforeAndAfterEach import org.scalatest.time.SpanSugar._ import org.apache.spark.{SparkException, SparkFunSuite, TaskContext, TaskContextImpl} +import org.apache.spark.util.ThreadUtils class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -124,8 +125,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { } // After downgrading to a read lock, both threads should wake up and acquire the shared // read lock. - assert(!Await.result(lock1Future, 1.seconds)) - assert(!Await.result(lock2Future, 1.seconds)) + assert(!ThreadUtils.awaitResult(lock1Future, 1.seconds)) + assert(!ThreadUtils.awaitResult(lock2Future, 1.seconds)) assert(blockInfoManager.get("block").get.readerCount === 3) } @@ -161,7 +162,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { withTaskId(winningTID) { blockInfoManager.unlock("block") } - assert(!Await.result(losingFuture, 1.seconds)) + assert(!ThreadUtils.awaitResult(losingFuture, 1.seconds)) assert(blockInfoManager.get("block").get.readerCount === 1) } @@ -262,8 +263,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { withTaskId(0) { blockInfoManager.unlock("block") } - assert(Await.result(get1Future, 1.seconds).isDefined) - assert(Await.result(get2Future, 1.seconds).isDefined) + assert(ThreadUtils.awaitResult(get1Future, 1.seconds).isDefined) + assert(ThreadUtils.awaitResult(get2Future, 1.seconds).isDefined) assert(blockInfoManager.get("block").get.readerCount === 2) } @@ -288,13 +289,14 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { blockInfoManager.unlock("block") } assert( - Await.result(Future.firstCompletedOf(Seq(write1Future, write2Future)), 1.seconds).isDefined) + ThreadUtils.awaitResult( + Future.firstCompletedOf(Seq(write1Future, write2Future)), 1.seconds).isDefined) val firstWriteWinner = if (write1Future.isCompleted) 1 else 2 withTaskId(firstWriteWinner) { blockInfoManager.unlock("block") } - assert(Await.result(write1Future, 1.seconds).isDefined) - assert(Await.result(write2Future, 1.seconds).isDefined) + assert(ThreadUtils.awaitResult(write1Future, 1.seconds).isDefined) + assert(ThreadUtils.awaitResult(write2Future, 1.seconds).isDefined) } test("removing a non-existent block throws IllegalArgumentException") { @@ -344,8 +346,8 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { withTaskId(0) { blockInfoManager.removeBlock("block") } - assert(Await.result(getFuture, 1.seconds).isEmpty) - assert(Await.result(writeFuture, 1.seconds).isEmpty) + assert(ThreadUtils.awaitResult(getFuture, 1.seconds).isEmpty) + assert(ThreadUtils.awaitResult(writeFuture, 1.seconds).isEmpty) } test("releaseAllLocksForTask releases write locks") { diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index d26df7e760cea..d14728cb50555 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{KryoSerializer, SerializerManager} -import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.StorageLevel._ /** Testsuite that tests block replication in BlockManager */ @@ -44,7 +44,7 @@ class BlockManagerReplicationSuite extends SparkFunSuite with Matchers with Befo private var master: BlockManagerMaster = null private val securityMgr = new SecurityManager(conf) private val mapOutputTracker = new MapOutputTrackerMaster(conf) - private val shuffleManager = new HashShuffleManager(conf) + private val shuffleManager = new SortShuffleManager(conf) // List of block manager created during an unit test, so that all of the them can be stopped // after the unit test. diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index a1c2933584acc..db1efaf2a20b8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -42,7 +42,7 @@ import org.apache.spark.network.shuffle.BlockFetchingListener import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer, SerializerManager} -import org.apache.spark.shuffle.hash.HashShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat import org.apache.spark.util._ import org.apache.spark.util.io.ChunkedByteBuffer @@ -60,7 +60,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with BeforeAndAfterE var master: BlockManagerMaster = null val securityMgr = new SecurityManager(new SparkConf(false)) val mapOutputTracker = new MapOutputTrackerMaster(new SparkConf(false)) - val shuffleManager = new HashShuffleManager(new SparkConf(false)) + val shuffleManager = new SortShuffleManager(new SparkConf(false)) // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test val serializer = new KryoSerializer(new SparkConf(false).set("spark.kryoserializer.buffer", "1m")) diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala index 7d4c0863bc963..221124829fc54 100644 --- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala @@ -184,7 +184,7 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val conf = new SparkConf() val listener = new JobProgressListener(conf) val taskMetrics = new TaskMetrics() - val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics() + val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() assert(listener.stageIdToData.size === 0) // finish this task, should get updated shuffleRead @@ -269,13 +269,11 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with val execId = "exe-1" def makeTaskMetrics(base: Int): TaskMetrics = { - val accums = InternalAccumulator.createAll() - accums.foreach(Accumulators.register) - val taskMetrics = new TaskMetrics(accums) - val shuffleReadMetrics = taskMetrics.registerTempShuffleReadMetrics() - val shuffleWriteMetrics = taskMetrics.registerShuffleWriteMetrics() - val inputMetrics = taskMetrics.registerInputMetrics(DataReadMethod.Hadoop) - val outputMetrics = taskMetrics.registerOutputMetrics(DataWriteMethod.Hadoop) + val taskMetrics = new TaskMetrics + val shuffleReadMetrics = taskMetrics.createTempShuffleReadMetrics() + val shuffleWriteMetrics = taskMetrics.shuffleWriteMetrics + val inputMetrics = taskMetrics.inputMetrics + val outputMetrics = taskMetrics.outputMetrics shuffleReadMetrics.incRemoteBytesRead(base + 1) shuffleReadMetrics.incLocalBytesRead(base + 9) shuffleReadMetrics.incRemoteBlocksFetched(base + 2) @@ -322,12 +320,13 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with assert(stage1Data.inputBytes == 207) assert(stage0Data.outputBytes == 116) assert(stage1Data.outputBytes == 208) - assert(stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.get - .totalBlocksFetched == 2) - assert(stage0Data.taskData.get(1235L).get.metrics.get.shuffleReadMetrics.get - .totalBlocksFetched == 102) - assert(stage1Data.taskData.get(1236L).get.metrics.get.shuffleReadMetrics.get - .totalBlocksFetched == 202) + + assert( + stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 2) + assert( + stage0Data.taskData.get(1235L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 102) + assert( + stage1Data.taskData.get(1236L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 202) // task that was included in a heartbeat listener.onTaskEnd(SparkListenerTaskEnd(0, 0, taskType, Success, makeTaskInfo(1234L, 1), @@ -355,9 +354,9 @@ class JobProgressListenerSuite extends SparkFunSuite with LocalSparkContext with assert(stage1Data.inputBytes == 614) assert(stage0Data.outputBytes == 416) assert(stage1Data.outputBytes == 616) - assert(stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.get - .totalBlocksFetched == 302) - assert(stage1Data.taskData.get(1237L).get.metrics.get.shuffleReadMetrics.get - .totalBlocksFetched == 402) + assert( + stage0Data.taskData.get(1234L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 302) + assert( + stage1Data.taskData.get(1237L).get.metrics.get.shuffleReadMetrics.totalBlocksFetched == 402) } } diff --git a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphSuite.scala similarity index 55% rename from core/src/test/scala/org/apache/spark/HashShuffleSuite.scala rename to core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphSuite.scala index 10794235ed392..6ddcb5aba1678 100644 --- a/core/src/test/scala/org/apache/spark/HashShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphSuite.scala @@ -15,16 +15,23 @@ * limitations under the License. */ -package org.apache.spark +package org.apache.spark.ui.scope -import org.scalatest.BeforeAndAfterAll +import org.apache.spark.SparkFunSuite -class HashShuffleSuite extends ShuffleSuite with BeforeAndAfterAll { +class RDDOperationGraphSuite extends SparkFunSuite { + test("Test simple cluster equals") { + // create a 2-cluster chain with a child + val c1 = new RDDOperationCluster("1", "Bender") + val c2 = new RDDOperationCluster("2", "Hal") + c1.attachChildCluster(c2) + c1.attachChildNode(new RDDOperationNode(3, "Marvin", false, "collect!")) - // This test suite should run all tests in ShuffleSuite with hash-based shuffle. + // create an equal cluster, but without the child node + val c1copy = new RDDOperationCluster("1", "Bender") + val c2copy = new RDDOperationCluster("2", "Hal") + c1copy.attachChildCluster(c2copy) - override def beforeAll() { - super.beforeAll() - conf.set("spark.shuffle.manager", "hash") + assert(c1 == c1copy) } } diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala index de6f408fa82be..d3b6cdfe86eec 100644 --- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala @@ -197,49 +197,41 @@ class JsonProtocolSuite extends SparkFunSuite { test("InputMetrics backward compatibility") { // InputMetrics were added after 1.0.1. val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = true, hasOutput = false) - assert(metrics.inputMetrics.nonEmpty) val newJson = JsonProtocol.taskMetricsToJson(metrics) val oldJson = newJson.removeField { case (field, _) => field == "Input Metrics" } val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) - assert(newMetrics.inputMetrics.isEmpty) } test("Input/Output records backwards compatibility") { // records read were added after 1.2 val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = true, hasOutput = true, hasRecords = false) - assert(metrics.inputMetrics.nonEmpty) - assert(metrics.outputMetrics.nonEmpty) val newJson = JsonProtocol.taskMetricsToJson(metrics) val oldJson = newJson.removeField { case (field, _) => field == "Records Read" } .removeField { case (field, _) => field == "Records Written" } val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) - assert(newMetrics.inputMetrics.get.recordsRead == 0) - assert(newMetrics.outputMetrics.get.recordsWritten == 0) + assert(newMetrics.inputMetrics.recordsRead == 0) + assert(newMetrics.outputMetrics.recordsWritten == 0) } test("Shuffle Read/Write records backwards compatibility") { // records read were added after 1.2 val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = false, hasOutput = false, hasRecords = false) - assert(metrics.shuffleReadMetrics.nonEmpty) - assert(metrics.shuffleWriteMetrics.nonEmpty) val newJson = JsonProtocol.taskMetricsToJson(metrics) val oldJson = newJson.removeField { case (field, _) => field == "Total Records Read" } .removeField { case (field, _) => field == "Shuffle Records Written" } val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) - assert(newMetrics.shuffleReadMetrics.get.recordsRead == 0) - assert(newMetrics.shuffleWriteMetrics.get.recordsWritten == 0) + assert(newMetrics.shuffleReadMetrics.recordsRead == 0) + assert(newMetrics.shuffleWriteMetrics.recordsWritten == 0) } test("OutputMetrics backward compatibility") { // OutputMetrics were added after 1.1 val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = false, hasOutput = true) - assert(metrics.outputMetrics.nonEmpty) val newJson = JsonProtocol.taskMetricsToJson(metrics) val oldJson = newJson.removeField { case (field, _) => field == "Output Metrics" } val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) - assert(newMetrics.outputMetrics.isEmpty) } test("BlockManager events backward compatibility") { @@ -279,11 +271,10 @@ class JsonProtocolSuite extends SparkFunSuite { // Metrics about local shuffle bytes read were added in 1.3.1. val metrics = makeTaskMetrics(1L, 2L, 3L, 4L, 5, 6, hasHadoopInput = false, hasOutput = false, hasRecords = false) - assert(metrics.shuffleReadMetrics.nonEmpty) val newJson = JsonProtocol.taskMetricsToJson(metrics) val oldJson = newJson.removeField { case (field, _) => field == "Local Bytes Read" } val newMetrics = JsonProtocol.taskMetricsFromJson(oldJson) - assert(newMetrics.shuffleReadMetrics.get.localBytesRead == 0) + assert(newMetrics.shuffleReadMetrics.localBytesRead == 0) } test("SparkListenerApplicationStart backwards compatibility") { @@ -423,7 +414,6 @@ class JsonProtocolSuite extends SparkFunSuite { }) testAccumValue(Some(RESULT_SIZE), 3L, JInt(3)) testAccumValue(Some(shuffleRead.REMOTE_BLOCKS_FETCHED), 2, JInt(2)) - testAccumValue(Some(input.READ_METHOD), "aka", JString("aka")) testAccumValue(Some(UPDATED_BLOCK_STATUSES), blocks, blocksJson) // For anything else, we just cast the value to a string testAccumValue(Some("anything"), blocks, JString(blocks.toString)) @@ -619,12 +609,9 @@ private[spark] object JsonProtocolSuite extends Assertions { assert(metrics1.resultSerializationTime === metrics2.resultSerializationTime) assert(metrics1.memoryBytesSpilled === metrics2.memoryBytesSpilled) assert(metrics1.diskBytesSpilled === metrics2.diskBytesSpilled) - assertOptionEquals( - metrics1.shuffleReadMetrics, metrics2.shuffleReadMetrics, assertShuffleReadEquals) - assertOptionEquals( - metrics1.shuffleWriteMetrics, metrics2.shuffleWriteMetrics, assertShuffleWriteEquals) - assertOptionEquals( - metrics1.inputMetrics, metrics2.inputMetrics, assertInputMetricsEquals) + assertEquals(metrics1.shuffleReadMetrics, metrics2.shuffleReadMetrics) + assertEquals(metrics1.shuffleWriteMetrics, metrics2.shuffleWriteMetrics) + assertEquals(metrics1.inputMetrics, metrics2.inputMetrics) assertBlocksEquals(metrics1.updatedBlockStatuses, metrics2.updatedBlockStatuses) } @@ -641,7 +628,6 @@ private[spark] object JsonProtocolSuite extends Assertions { } private def assertEquals(metrics1: InputMetrics, metrics2: InputMetrics) { - assert(metrics1.readMethod === metrics2.readMethod) assert(metrics1.bytesRead === metrics2.bytesRead) } @@ -706,12 +692,13 @@ private[spark] object JsonProtocolSuite extends Assertions { } private def assertJsonStringEquals(expected: String, actual: String, metadata: String) { - val formatJsonString = (json: String) => json.replaceAll("[\\s|]", "") - if (formatJsonString(expected) != formatJsonString(actual)) { + val expectedJson = pretty(parse(expected)) + val actualJson = pretty(parse(actual)) + if (expectedJson != actualJson) { // scalastyle:off // This prints something useful if the JSON strings don't match - println("=== EXPECTED ===\n" + pretty(parse(expected)) + "\n") - println("=== ACTUAL ===\n" + pretty(parse(actual)) + "\n") + println("=== EXPECTED ===\n" + expectedJson + "\n") + println("=== ACTUAL ===\n" + actualJson + "\n") // scalastyle:on throw new TestFailedException(s"$metadata JSON did not equal", 1) } @@ -740,22 +727,6 @@ private[spark] object JsonProtocolSuite extends Assertions { * Use different names for methods we pass in to assertSeqEquals or assertOptionEquals */ - private def assertShuffleReadEquals(r1: ShuffleReadMetrics, r2: ShuffleReadMetrics) { - assertEquals(r1, r2) - } - - private def assertShuffleWriteEquals(w1: ShuffleWriteMetrics, w2: ShuffleWriteMetrics) { - assertEquals(w1, w2) - } - - private def assertInputMetricsEquals(i1: InputMetrics, i2: InputMetrics) { - assertEquals(i1, i2) - } - - private def assertTaskMetricsEquals(t1: TaskMetrics, t2: TaskMetrics) { - assertEquals(t1, t2) - } - private def assertBlocksEquals( blocks1: Seq[(BlockId, BlockStatus)], blocks2: Seq[(BlockId, BlockStatus)]) = { @@ -851,11 +822,11 @@ private[spark] object JsonProtocolSuite extends Assertions { t.incMemoryBytesSpilled(a + c) if (hasHadoopInput) { - val inputMetrics = t.registerInputMetrics(DataReadMethod.Hadoop) + val inputMetrics = t.inputMetrics inputMetrics.setBytesRead(d + e + f) inputMetrics.incRecordsRead(if (hasRecords) (d + e + f) / 100 else -1) } else { - val sr = t.registerTempShuffleReadMetrics() + val sr = t.createTempShuffleReadMetrics() sr.incRemoteBytesRead(b + d) sr.incLocalBlocksFetched(e) sr.incFetchWaitTime(a + d) @@ -865,11 +836,10 @@ private[spark] object JsonProtocolSuite extends Assertions { t.mergeShuffleReadMetrics() } if (hasOutput) { - val outputMetrics = t.registerOutputMetrics(DataWriteMethod.Hadoop) - outputMetrics.setBytesWritten(a + b + c) - outputMetrics.setRecordsWritten(if (hasRecords) (a + b + c)/100 else -1) + t.outputMetrics.setBytesWritten(a + b + c) + t.outputMetrics.setRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) } else { - val sw = t.registerShuffleWriteMetrics() + val sw = t.shuffleWriteMetrics sw.incBytesWritten(a + b + c) sw.incWriteTime(b + c + d) sw.incRecordsWritten(if (hasRecords) (a + b + c) / 100 else -1) @@ -896,7 +866,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Stage Name": "greetings", | "Number of Tasks": 200, | "RDD Info": [], - | "ParentIDs" : [100, 200, 300], + | "Parent IDs" : [100, 200, 300], | "Details": "details", | "Accumulables": [ | { @@ -924,7 +894,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Ukraine": "Kiev" | } |} - """ + """.stripMargin private val stageCompletedJsonString = """ @@ -953,7 +923,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Disk Size": 501 | } | ], - | "ParentIDs" : [100, 200, 300], + | "Parent IDs" : [100, 200, 300], | "Details": "details", | "Accumulables": [ | { @@ -975,7 +945,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | ] | } |} - """ + """.stripMargin private val taskStartJsonString = """ @@ -1141,6 +1111,14 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Shuffle Write Time": 1500, | "Shuffle Records Written": 12 | }, + | "Input Metrics" : { + | "Bytes Read" : 0, + | "Records Read" : 0 + | }, + | "Output Metrics" : { + | "Bytes Written" : 0, + | "Records Written" : 0 + | }, | "Updated Blocks": [ | { | "Block ID": "rdd_0_0", @@ -1217,16 +1195,27 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Result Serialization Time": 700, | "Memory Bytes Spilled": 800, | "Disk Bytes Spilled": 0, + | "Shuffle Read Metrics" : { + | "Remote Blocks Fetched" : 0, + | "Local Blocks Fetched" : 0, + | "Fetch Wait Time" : 0, + | "Remote Bytes Read" : 0, + | "Local Bytes Read" : 0, + | "Total Records Read" : 0 + | }, | "Shuffle Write Metrics": { | "Shuffle Bytes Written": 1200, | "Shuffle Write Time": 1500, | "Shuffle Records Written": 12 | }, | "Input Metrics": { - | "Data Read Method": "Hadoop", | "Bytes Read": 2100, | "Records Read": 21 | }, + | "Output Metrics" : { + | "Bytes Written" : 0, + | "Records Written" : 0 + | }, | "Updated Blocks": [ | { | "Block ID": "rdd_0_0", @@ -1244,7 +1233,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | ] | } |} - """ + """.stripMargin private val taskEndWithOutputJsonString = """ @@ -1303,13 +1292,24 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Result Serialization Time": 700, | "Memory Bytes Spilled": 800, | "Disk Bytes Spilled": 0, + | "Shuffle Read Metrics" : { + | "Remote Blocks Fetched" : 0, + | "Local Blocks Fetched" : 0, + | "Fetch Wait Time" : 0, + | "Remote Bytes Read" : 0, + | "Local Bytes Read" : 0, + | "Total Records Read" : 0 + | }, + | "Shuffle Write Metrics" : { + | "Shuffle Bytes Written" : 0, + | "Shuffle Write Time" : 0, + | "Shuffle Records Written" : 0 + | }, | "Input Metrics": { - | "Data Read Method": "Hadoop", | "Bytes Read": 2100, | "Records Read": 21 | }, | "Output Metrics": { - | "Data Write Method": "Hadoop", | "Bytes Written": 1200, | "Records Written": 12 | }, @@ -1330,7 +1330,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | ] | } |} - """ + """.stripMargin private val jobStartJsonString = """ @@ -1422,7 +1422,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Disk Size": 1001 | } | ], - | "ParentIDs" : [100, 200, 300], + | "Parent IDs" : [100, 200, 300], | "Details": "details", | "Accumulables": [ | { @@ -1498,7 +1498,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Disk Size": 1502 | } | ], - | "ParentIDs" : [100, 200, 300], + | "Parent IDs" : [100, 200, 300], | "Details": "details", | "Accumulables": [ | { @@ -1590,7 +1590,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Disk Size": 2003 | } | ], - | "ParentIDs" : [100, 200, 300], + | "Parent IDs" : [100, 200, 300], | "Details": "details", | "Accumulables": [ | { @@ -1625,7 +1625,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Ukraine": "Kiev" | } |} - """ + """.stripMargin private val jobEndJsonString = """ @@ -1637,7 +1637,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Result": "JobSucceeded" | } |} - """ + """.stripMargin private val environmentUpdateJsonString = """ @@ -1658,7 +1658,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Super library": "/tmp/super_library" | } |} - """ + """.stripMargin private val blockManagerAddedJsonString = """ @@ -1672,7 +1672,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Maximum Memory": 500, | "Timestamp": 1 |} - """ + """.stripMargin private val blockManagerRemovedJsonString = """ @@ -1685,7 +1685,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | "Timestamp": 2 |} - """ + """.stripMargin private val unpersistRDDJsonString = """ @@ -1693,7 +1693,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Event": "SparkListenerUnpersistRDD", | "RDD ID": 12345 |} - """ + """.stripMargin private val applicationStartJsonString = """ @@ -1705,7 +1705,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "User": "Garfield", | "App Attempt ID": "appAttempt" |} - """ + """.stripMargin private val applicationStartJsonWithLogUrlsString = """ @@ -1721,7 +1721,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "stdout" : "mystdout" | } |} - """ + """.stripMargin private val applicationEndJsonString = """ @@ -1729,7 +1729,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Event": "SparkListenerApplicationEnd", | "Timestamp": 42 |} - """ + """.stripMargin private val executorAddedJsonString = s""" @@ -1746,7 +1746,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | } | } |} - """ + """.stripMargin private val executorRemovedJsonString = s""" @@ -1756,7 +1756,7 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Executor ID": "exec2", | "Removed Reason": "test reason" |} - """ + """.stripMargin private val executorMetricsUpdateJsonString = s""" @@ -1830,16 +1830,16 @@ private[spark] object JsonProtocolSuite extends Assertions { | "Name": "$UPDATED_BLOCK_STATUSES", | "Update": [ | { - | "BlockID": "rdd_0_0", + | "Block ID": "rdd_0_0", | "Status": { - | "StorageLevel": { - | "UseDisk": true, - | "UseMemory": true, + | "Storage Level": { + | "Use Disk": true, + | "Use Memory": true, | "Deserialized": false, | "Replication": 2 | }, - | "MemorySize": 0, - | "DiskSize": 0 + | "Memory Size": 0, + | "Disk Size": 0 | } | } | ], @@ -1911,48 +1911,34 @@ private[spark] object JsonProtocolSuite extends Assertions { | }, | { | "ID": 18, - | "Name": "${input.READ_METHOD}", - | "Update": "Hadoop", - | "Internal": true, - | "Count Failed Values": true - | }, - | { - | "ID": 19, | "Name": "${input.BYTES_READ}", | "Update": 2100, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 20, + | "ID": 19, | "Name": "${input.RECORDS_READ}", | "Update": 21, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 21, - | "Name": "${output.WRITE_METHOD}", - | "Update": "Hadoop", - | "Internal": true, - | "Count Failed Values": true - | }, - | { - | "ID": 22, + | "ID": 20, | "Name": "${output.BYTES_WRITTEN}", | "Update": 1200, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 23, + | "ID": 21, | "Name": "${output.RECORDS_WRITTEN}", | "Update": 12, | "Internal": true, | "Count Failed Values": true | }, | { - | "ID": 24, + | "ID": 22, | "Name": "$TEST_ACCUM", | "Update": 0, | "Internal": true, diff --git a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala index 6652a41b6990b..ae3b3d829f1bb 100644 --- a/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/ThreadUtilsSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.util import java.util.concurrent.{CountDownLatch, TimeUnit} -import scala.concurrent.{Await, Future} +import scala.concurrent.Future import scala.concurrent.duration._ import scala.util.Random @@ -109,7 +109,7 @@ class ThreadUtilsSuite extends SparkFunSuite { val f = Future { Thread.currentThread().getName() }(ThreadUtils.sameThread) - val futureThreadName = Await.result(f, 10.seconds) + val futureThreadName = ThreadUtils.awaitResult(f, 10.seconds) assert(futureThreadName === callerThreadName) } diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index dc3185a6d505a..2410118fb7172 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -237,7 +237,6 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { private def testSimpleSpilling(codec: Option[String] = None): Unit = { val size = 1000 val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home - conf.set("spark.shuffle.manager", "hash") // avoid using external sorter conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) @@ -401,7 +400,6 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { test("external aggregation updates peak execution memory") { val spillThreshold = 1000 val conf = createSparkConf(loadDefaults = false) - .set("spark.shuffle.manager", "hash") // make sure we're not also using ExternalSorter .set("spark.shuffle.spill.numElementsForceSpillThreshold", spillThreshold.toString) sc = new SparkContext("local", "test", conf) // No spilling diff --git a/dev/.rat-excludes b/dev/.rat-excludes index 8b5061415ff4c..a409a3cb3665c 100644 --- a/dev/.rat-excludes +++ b/dev/.rat-excludes @@ -98,3 +98,4 @@ LZ4BlockInputStream.java spark-deps-.* .*csv .*tsv +org.apache.spark.scheduler.ExternalClusterManager diff --git a/docs/building-spark.md b/docs/building-spark.md index 40661604af942..fec442af95e1b 100644 --- a/docs/building-spark.md +++ b/docs/building-spark.md @@ -192,7 +192,7 @@ If you have JDK 8 installed but it is not the system default, you can set JAVA_H # Packaging without Hadoop Dependencies for YARN -The assembly jar produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. +The assembly directory produced by `mvn package` will, by default, include all of Spark's dependencies, including Hadoop and some of its ecosystem projects. On YARN deployments, this causes multiple versions of these to appear on executor classpaths: the version packaged in the Spark assembly and the version on each node, included with `yarn.application.classpath`. The `hadoop-provided` profile builds the assembly without including Hadoop-ecosystem projects, like ZooKeeper and Hadoop itself. # Building with SBT diff --git a/docs/configuration.md b/docs/configuration.md index 16d5be62f9e89..6512e16faf4c1 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -455,15 +455,6 @@ Apart from these, the following properties are also available, and may be useful is 15 seconds by default, calculated as maxRetries * retryWait. - - spark.shuffle.manager - sort - - Implementation to use for shuffling data. There are two implementations available: - sort and hash. - Sort-based shuffle is more memory-efficient and is the default option starting in 1.2. - - spark.shuffle.service.enabled false diff --git a/docs/ec2-scripts.md b/docs/ec2-scripts.md new file mode 100644 index 0000000000000..6cd39dbed055d --- /dev/null +++ b/docs/ec2-scripts.md @@ -0,0 +1,7 @@ +--- +layout: global +title: Running Spark on EC2 +redirect: https://github.com/amplab/spark-ec2#readme +--- + +This document has been superseded and replaced by documentation at https://github.com/amplab/spark-ec2#readme diff --git a/docs/ml-features.md b/docs/ml-features.md index 70812eb5e2292..11d5acbb10c30 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -22,10 +22,19 @@ This section covers algorithms for working with features, roughly divided into t [Term Frequency-Inverse Document Frequency (TF-IDF)](http://en.wikipedia.org/wiki/Tf%E2%80%93idf) is a common text pre-processing step. In Spark ML, TF-IDF is separate into two parts: TF (+hashing) and IDF. -**TF**: `HashingTF` is a `Transformer` which takes sets of terms and converts those sets into fixed-length feature vectors. In text processing, a "set of terms" might be a bag of words. -The algorithm combines Term Frequency (TF) counts with the [hashing trick](http://en.wikipedia.org/wiki/Feature_hashing) for dimensionality reduction. +**TF**: Both `HashingTF` and `CountVectorizer` can be used to generate the term frequency vectors. -**IDF**: `IDF` is an `Estimator` which fits on a dataset and produces an `IDFModel`. The `IDFModel` takes feature vectors (generally created from `HashingTF`) and scales each column. Intuitively, it down-weights columns which appear frequently in a corpus. +`HashingTF` is a `Transformer` which takes sets of terms and converts those sets into +fixed-length feature vectors. In text processing, a "set of terms" might be a bag of words. +The algorithm combines Term Frequency (TF) counts with the +[hashing trick](http://en.wikipedia.org/wiki/Feature_hashing) for dimensionality reduction. + +`CountVectorizer` converts text documents to vectors of term counts. Refer to [CountVectorizer +](ml-features.html#countvectorizer) for more details. + +**IDF**: `IDF` is an `Estimator` which is fit on a dataset and produces an `IDFModel`. The +`IDFModel` takes feature vectors (generally created from `HashingTF` or `CountVectorizer`) and scales each column. +Intuitively, it down-weights columns which appear frequently in a corpus. Please refer to the [MLlib user guide on TF-IDF](mllib-feature-extraction.html#tf-idf) for more details on Term Frequency and Inverse Document Frequency. @@ -1303,4 +1312,12 @@ for more details on the API. {% include_example java/org/apache/spark/examples/ml/JavaChiSqSelectorExample.java %}
+ +
+ +Refer to the [ChiSqSelector Python docs](api/python/pyspark.ml.html#pyspark.ml.feature.ChiSqSelector) +for more details on the API. + +{% include_example python/ml/chisq_selector_example.py %} +
diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2d9849d0328e3..77887f4ca36be 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1651,7 +1651,7 @@ SELECT * FROM jsonTable Spark SQL also supports reading and writing data stored in [Apache Hive](http://hive.apache.org/). However, since Hive has a large number of dependencies, it is not included in the default Spark assembly. Hive support is enabled by adding the `-Phive` and `-Phive-thriftserver` flags to Spark's build. -This command builds a new assembly jar that includes Hive. Note that this Hive assembly jar must also be present +This command builds a new assembly directory that includes Hive. Note that this Hive assembly directory must also be present on all of the worker nodes, as they will need access to the Hive serialization and deserialization libraries (SerDes) in order to access data stored in Hive. @@ -1770,7 +1770,7 @@ The following options can be used to configure the version of Hive that is used property can be one of three options:
  1. builtin
  2. - Use Hive 1.2.1, which is bundled with the Spark assembly jar when -Phive is + Use Hive 1.2.1, which is bundled with the Spark assembly when -Phive is enabled. When this option is chosen, spark.sql.hive.metastore.version must be either 1.2.1 or not defined.
  3. maven
  4. diff --git a/examples/pom.xml b/examples/pom.xml index 4a20370f0668d..fcd60e3b776a2 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -297,6 +297,13 @@ true + + org.apache.maven.plugins + maven-jar-plugin + + ${jars.target.dir} + + diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java index 37a3d0d84dae2..107c835f2e01e 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaTfIdfExample.java @@ -63,6 +63,8 @@ public static void main(String[] args) { .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); Dataset featurizedData = hashingTF.transform(wordsData); + // alternatively, CountVectorizer can also be used to get term frequency vectors + IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); IDFModel idfModel = idf.fit(featurizedData); Dataset rescaledData = idfModel.transform(featurizedData); diff --git a/examples/src/main/python/ml/chisq_selector_example.py b/examples/src/main/python/ml/chisq_selector_example.py new file mode 100644 index 0000000000000..997a504735360 --- /dev/null +++ b/examples/src/main/python/ml/chisq_selector_example.py @@ -0,0 +1,44 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import print_function + +from pyspark import SparkContext +from pyspark.sql import SQLContext +# $example on$ +from pyspark.ml.feature import ChiSqSelector +from pyspark.mllib.linalg import Vectors +# $example off$ + +if __name__ == "__main__": + sc = SparkContext(appName="ChiSqSelectorExample") + sqlContext = SQLContext(sc) + + # $example on$ + df = sqlContext.createDataFrame([ + (7, Vectors.dense([0.0, 0.0, 18.0, 1.0]), 1.0,), + (8, Vectors.dense([0.0, 1.0, 12.0, 0.0]), 0.0,), + (9, Vectors.dense([1.0, 0.0, 15.0, 0.1]), 0.0,)], ["id", "features", "clicked"]) + + selector = ChiSqSelector(numTopFeatures=1, featuresCol="features", + outputCol="selectedFeatures", labelCol="clicked") + + result = selector.fit(df).transform(df) + result.show() + # $example off$ + + sc.stop() diff --git a/examples/src/main/python/ml/tf_idf_example.py b/examples/src/main/python/ml/tf_idf_example.py index c92313378eec7..141324d458530 100644 --- a/examples/src/main/python/ml/tf_idf_example.py +++ b/examples/src/main/python/ml/tf_idf_example.py @@ -37,6 +37,8 @@ wordsData = tokenizer.transform(sentenceData) hashingTF = HashingTF(inputCol="words", outputCol="rawFeatures", numFeatures=20) featurizedData = hashingTF.transform(wordsData) + # alternatively, CountVectorizer can also be used to get term frequency vectors + idf = IDF(inputCol="rawFeatures", outputCol="features") idfModel = idf.fit(featurizedData) rescaledData = idfModel.transform(featurizedData) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala deleted file mode 100644 index bca301d412f4c..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala +++ /dev/null @@ -1,114 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// scalastyle:off println -package org.apache.spark.examples.ml - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.ml.Pipeline -import org.apache.spark.ml.classification.LogisticRegression -import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator -import org.apache.spark.ml.feature.{HashingTF, Tokenizer} -import org.apache.spark.ml.tuning.{CrossValidator, ParamGridBuilder} -import org.apache.spark.mllib.linalg.Vector -import org.apache.spark.sql.{Row, SQLContext} - -/** - * A simple example demonstrating model selection using CrossValidator. - * This example also demonstrates how Pipelines are Estimators. - * - * This example uses the [[LabeledDocument]] and [[Document]] case classes from - * [[SimpleTextClassificationPipeline]]. - * - * Run with - * {{{ - * bin/run-example ml.CrossValidatorExample - * }}} - */ -object CrossValidatorExample { - - def main(args: Array[String]) { - val conf = new SparkConf().setAppName("CrossValidatorExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - import sqlContext.implicits._ - - // Prepare training documents, which are labeled. - val training = sc.parallelize(Seq( - LabeledDocument(0L, "a b c d e spark", 1.0), - LabeledDocument(1L, "b d", 0.0), - LabeledDocument(2L, "spark f g h", 1.0), - LabeledDocument(3L, "hadoop mapreduce", 0.0), - LabeledDocument(4L, "b spark who", 1.0), - LabeledDocument(5L, "g d a y", 0.0), - LabeledDocument(6L, "spark fly", 1.0), - LabeledDocument(7L, "was mapreduce", 0.0), - LabeledDocument(8L, "e spark program", 1.0), - LabeledDocument(9L, "a e c l", 0.0), - LabeledDocument(10L, "spark compile", 1.0), - LabeledDocument(11L, "hadoop software", 0.0))) - - // Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr. - val tokenizer = new Tokenizer() - .setInputCol("text") - .setOutputCol("words") - val hashingTF = new HashingTF() - .setInputCol(tokenizer.getOutputCol) - .setOutputCol("features") - val lr = new LogisticRegression() - .setMaxIter(10) - val pipeline = new Pipeline() - .setStages(Array(tokenizer, hashingTF, lr)) - - // We now treat the Pipeline as an Estimator, wrapping it in a CrossValidator instance. - // This will allow us to jointly choose parameters for all Pipeline stages. - // A CrossValidator requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. - val crossval = new CrossValidator() - .setEstimator(pipeline) - .setEvaluator(new BinaryClassificationEvaluator) - // We use a ParamGridBuilder to construct a grid of parameters to search over. - // With 3 values for hashingTF.numFeatures and 2 values for lr.regParam, - // this grid will have 3 x 2 = 6 parameter settings for CrossValidator to choose from. - val paramGrid = new ParamGridBuilder() - .addGrid(hashingTF.numFeatures, Array(10, 100, 1000)) - .addGrid(lr.regParam, Array(0.1, 0.01)) - .build() - crossval.setEstimatorParamMaps(paramGrid) - crossval.setNumFolds(2) // Use 3+ in practice - - // Run cross-validation, and choose the best set of parameters. - val cvModel = crossval.fit(training.toDF()) - - // Prepare test documents, which are unlabeled. - val test = sc.parallelize(Seq( - Document(4L, "spark i j k"), - Document(5L, "l m n"), - Document(6L, "mapreduce spark"), - Document(7L, "apache hadoop"))) - - // Make predictions on test documents. cvModel uses the best model found (lrModel). - cvModel.transform(test.toDF()) - .select("id", "text", "probability", "prediction") - .collect() - .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) => - println(s"($id, $text) --> prob=$prob, prediction=$prediction") - } - - sc.stop() - } -} -// scalastyle:on println diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala index 0331d6e7b35df..d1441b5497a86 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaCrossValidationExample.scala @@ -30,6 +30,15 @@ import org.apache.spark.sql.Row // $example off$ import org.apache.spark.sql.SQLContext +/** + * A simple example demonstrating model selection using CrossValidator. + * This example also demonstrates how Pipelines are Estimators. + * + * Run with + * {{{ + * bin/run-example ml.ModelSelectionViaCrossValidationExample + * }}} + */ object ModelSelectionViaCrossValidationExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala index 5a95344f223df..fcad17a817580 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ModelSelectionViaTrainValidationSplitExample.scala @@ -25,6 +25,14 @@ import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} // $example off$ import org.apache.spark.sql.SQLContext +/** + * A simple example demonstrating model selection using TrainValidationSplit. + * + * Run with + * {{{ + * bin/run-example ml.ModelSelectionViaTrainValidationSplitExample + * }}} + */ object ModelSelectionViaTrainValidationSplitExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala index 28115f939082e..396f073e6b322 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/TfIdfExample.scala @@ -43,6 +43,8 @@ object TfIdfExample { val hashingTF = new HashingTF() .setInputCol("words").setOutputCol("rawFeatures").setNumFeatures(20) val featurizedData = hashingTF.transform(wordsData) + // alternatively, CountVectorizer can also be used to get term frequency vectors + val idf = new IDF().setInputCol("rawFeatures").setOutputCol("features") val idfModel = idf.fit(featurizedData) val rescaledData = idfModel.transform(featurizedData) diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala deleted file mode 100644 index fbba17eba6a2f..0000000000000 --- a/examples/src/main/scala/org/apache/spark/examples/ml/TrainValidationSplitExample.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples.ml - -import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.ml.evaluation.RegressionEvaluator -import org.apache.spark.ml.regression.LinearRegression -import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit} -import org.apache.spark.sql.SQLContext - -/** - * A simple example demonstrating model selection using TrainValidationSplit. - * - * The example is based on [[SimpleParamsExample]] using linear regression. - * Run with - * {{{ - * bin/run-example ml.TrainValidationSplitExample - * }}} - */ -object TrainValidationSplitExample { - - def main(args: Array[String]): Unit = { - val conf = new SparkConf().setAppName("TrainValidationSplitExample") - val sc = new SparkContext(conf) - val sqlContext = new SQLContext(sc) - - // Prepare training and test data. - val data = sqlContext.read.format("libsvm").load("data/mllib/sample_libsvm_data.txt") - val Array(training, test) = data.randomSplit(Array(0.9, 0.1), seed = 12345) - - val lr = new LinearRegression() - - // We use a ParamGridBuilder to construct a grid of parameters to search over. - // TrainValidationSplit will try all combinations of values and determine best model using - // the evaluator. - val paramGrid = new ParamGridBuilder() - .addGrid(lr.regParam, Array(0.1, 0.01)) - .addGrid(lr.fitIntercept, Array(true, false)) - .addGrid(lr.elasticNetParam, Array(0.0, 0.5, 1.0)) - .build() - - // In this case the estimator is simply the linear regression. - // A TrainValidationSplit requires an Estimator, a set of Estimator ParamMaps, and an Evaluator. - val trainValidationSplit = new TrainValidationSplit() - .setEstimator(lr) - .setEvaluator(new RegressionEvaluator) - .setEstimatorParamMaps(paramGrid) - - // 80% of the data will be used for training and the remaining 20% for validation. - trainValidationSplit.setTrainRatio(0.8) - - // Run train validation split, and choose the best set of parameters. - val model = trainValidationSplit.fit(training) - - // Make predictions on test data. model is the model with combination of parameters - // that performed best. - model.transform(test) - .select("features", "label", "prediction") - .show() - - sc.stop() - } -} diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala index 49f5df39443e9..ae4dee24c6474 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StreamingTestExample.scala @@ -59,10 +59,10 @@ object StreamingTestExample { val conf = new SparkConf().setMaster("local").setAppName("StreamingTestExample") val ssc = new StreamingContext(conf, batchDuration) - ssc.checkpoint({ + ssc.checkpoint { val dir = Utils.createTempDir() dir.toString - }) + } // $example on$ val data = ssc.textFileStream(dataDir).map(line => line.split(",") match { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index bb2af9cd72e2a..1bcd85e1d533f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -115,7 +115,7 @@ object RecoverableNetworkWordCount { // words in input stream of \n delimited text (eg. generated by 'nc') val lines = ssc.socketTextStream(ip, port) val words = lines.flatMap(_.split(" ")) - val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) + val wordCounts = words.map((_, 1)).reduceByKey(_ + _) wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => // Get or register the blacklist Broadcast val blacklist = WordBlacklist.getInstance(rdd.sparkContext) @@ -158,9 +158,7 @@ object RecoverableNetworkWordCount { } val Array(ip, IntParam(port), checkpointDirectory, outputPath) = args val ssc = StreamingContext.getOrCreate(checkpointDirectory, - () => { - createContext(ip, port, outputPath, checkpointDirectory) - }) + () => createContext(ip, port, outputPath, checkpointDirectory)) ssc.start() ssc.awaitTermination() } diff --git a/external/docker-integration-tests/pom.xml b/external/docker-integration-tests/pom.xml index 17fd7d781c9ab..53a24f3e06e08 100644 --- a/external/docker-integration-tests/pom.xml +++ b/external/docker-integration-tests/pom.xml @@ -36,8 +36,12 @@ - db2 + db https://app.camunda.com/nexus/content/repositories/public/ + + true + warn + @@ -143,14 +147,13 @@ to use a an ojdbc jar for the testcase. The maven dependency here is commented because currently the maven repository does not contain the ojdbc jar mentioned. Once the jar is available in maven, this could be uncommented. --> - + + com.oracle + ojdbc6 + 11.2.0.1.0 + test + + JavaConversions diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 9cf2dd257e5c1..15f556e550260 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -121,6 +121,7 @@ statement | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases | SHOW TBLPROPERTIES table=tableIdentifier ('(' key=tablePropertyKey ')')? #showTblProperties + | SHOW CREATE TABLE tableIdentifier #showCreateTable | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? @@ -652,6 +653,7 @@ nonReserved | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION + | ASC | DESC | LIMIT | RENAME | SETS ; SELECT: 'SELECT'; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java index 1e4e5ede8cc11..110ed460cc8fa 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/types/SQLUserDefinedType.java @@ -24,11 +24,6 @@ /** * ::DeveloperApi:: * A user-defined type which can be automatically recognized by a SQLContext and registered. - *

    - * WARNING: This annotation will only work if both Java and Scala reflection return the same class - * names (after erasure) for the UDT. This will NOT be the case when, e.g., the UDT class - * is enclosed in an object (a singleton). - *

    * WARNING: UDTs are currently only supported from Scala. */ // TODO: Should I used @Documented ? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 2b98aacdd7264..0efe3c4d456ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -19,46 +19,34 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.catalyst.analysis._ -private[spark] trait CatalystConf { +/** + * Interface for configuration options used in the catalyst module. + */ +trait CatalystConf { def caseSensitiveAnalysis: Boolean def orderByOrdinal: Boolean def groupByOrdinal: Boolean + def optimizerMaxIterations: Int + def maxCaseBranchesForCodegen: Int + /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. */ def resolver: Resolver = { - if (caseSensitiveAnalysis) { - caseSensitiveResolution - } else { - caseInsensitiveResolution - } + if (caseSensitiveAnalysis) caseSensitiveResolution else caseInsensitiveResolution } } -/** - * A trivial conf that is empty. Used for testing when all - * relations are already filled in and the analyser needs only to resolve attribute references. - */ -object EmptyConf extends CatalystConf { - override def caseSensitiveAnalysis: Boolean = { - throw new UnsupportedOperationException - } - override def orderByOrdinal: Boolean = { - throw new UnsupportedOperationException - } - override def groupByOrdinal: Boolean = { - throw new UnsupportedOperationException - } -} /** A CatalystConf that can be used for local testing. */ case class SimpleCatalystConf( caseSensitiveAnalysis: Boolean, orderByOrdinal: Boolean = true, - groupByOrdinal: Boolean = true) - + groupByOrdinal: Boolean = true, + optimizerMaxIterations: Int = 100, + maxCaseBranchesForCodegen: Int = 20) extends CatalystConf { } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 4795fc25576aa..bd723135b510d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -374,10 +374,8 @@ object ScalaReflection extends ScalaReflection { newInstance } - case t if Utils.classIsLoadable(className) && - Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => - val udt = Utils.classForName(className) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), Nil, @@ -432,7 +430,6 @@ object ScalaReflection extends ScalaReflection { if (!inputObject.dataType.isInstanceOf[ObjectType]) { inputObject } else { - val className = getClassNameFromType(tpe) tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -589,9 +586,8 @@ object ScalaReflection extends ScalaReflection { case t if t <:< localTypeOf[java.lang.Boolean] => Invoke(inputObject, "booleanValue", BooleanType) - case t if Utils.classIsLoadable(className) && - Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => - val udt = Utils.classForName(className) + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t) .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() val obj = NewInstance( udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), @@ -637,24 +633,6 @@ object ScalaReflection extends ScalaReflection { * Retrieves the runtime class corresponding to the provided type. */ def getClassFromType(tpe: Type): Class[_] = mirror.runtimeClass(tpe.erasure.typeSymbol.asClass) -} - -/** - * Support for generating catalyst schemas for scala objects. Note that unlike its companion - * object, this trait able to work in both the runtime and the compile time (macro) universe. - */ -trait ScalaReflection { - /** The universe we work in (runtime or macro) */ - val universe: scala.reflect.api.Universe - - /** The mirror used to access types in the universe */ - def mirror: universe.Mirror - - import universe._ - - // The Predef.Map is scala.collection.immutable.Map. - // Since the map values can be mutable, we explicitly import scala.collection.Map at here. - import scala.collection.Map case class Schema(dataType: DataType, nullable: Boolean) @@ -668,36 +646,22 @@ trait ScalaReflection { def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) /** - * Return the Scala Type for `T` in the current classloader mirror. - * - * Use this method instead of the convenience method `universe.typeOf`, which - * assumes that all types can be found in the classloader that loaded scala-reflect classes. - * That's not necessarily the case when running using Eclipse launchers or even - * Sbt console or test (without `fork := true`). + * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. * - * @see SPARK-5281 + * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return + * `NullType` silently instead. */ - // SPARK-13640: Synchronize this because TypeTag.tpe is not thread-safe in Scala 2.10. - def localTypeOf[T: TypeTag]: `Type` = ScalaReflectionLock.synchronized { - val tag = implicitly[TypeTag[T]] - tag.in(mirror).tpe.normalize + def silentSchemaFor(tpe: `Type`): Schema = try { + schemaFor(tpe) + } catch { + case _: UnsupportedOperationException => Schema(NullType, nullable = true) } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { - val className = getClassNameFromType(tpe) - tpe match { - - case t if Utils.classIsLoadable(className) && - Utils.classForName(className).isAnnotationPresent(classOf[SQLUserDefinedType]) => - - // Note: We check for classIsLoadable above since Utils.classForName uses Java reflection, - // whereas className is from Scala reflection. This can make it hard to find classes - // in some cases, such as when a class is enclosed in an object (in which case - // Java appends a '$' to the object name but Scala does not). - val udt = Utils.classForName(className) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() Schema(udt, nullable = true) case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t @@ -748,17 +712,39 @@ trait ScalaReflection { throw new UnsupportedOperationException(s"Schema for type $other is not supported") } } +} + +/** + * Support for generating catalyst schemas for scala objects. Note that unlike its companion + * object, this trait able to work in both the runtime and the compile time (macro) universe. + */ +trait ScalaReflection { + /** The universe we work in (runtime or macro) */ + val universe: scala.reflect.api.Universe + + /** The mirror used to access types in the universe */ + def mirror: universe.Mirror + + import universe._ + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map /** - * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. + * Return the Scala Type for `T` in the current classloader mirror. * - * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return - * `NullType` silently instead. + * Use this method instead of the convenience method `universe.typeOf`, which + * assumes that all types can be found in the classloader that loaded scala-reflect classes. + * That's not necessarily the case when running using Eclipse launchers or even + * Sbt console or test (without `fork := true`). + * + * @see SPARK-5281 */ - def silentSchemaFor(tpe: `Type`): Schema = try { - schemaFor(tpe) - } catch { - case _: UnsupportedOperationException => Schema(NullType, nullable = true) + // SPARK-13640: Synchronize this because TypeTag.tpe is not thread-safe in Scala 2.10. + def localTypeOf[T: TypeTag]: `Type` = ScalaReflectionLock.synchronized { + val tag = implicitly[TypeTag[T]] + tag.in(mirror).tpe.normalize } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index de40ddde1bdd9..8595762988b4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -39,14 +39,13 @@ import org.apache.spark.sql.types._ * Used for testing when all relations are already filled in and the analyzer needs only * to resolve attribute references. */ -object SimpleAnalyzer - extends SimpleAnalyzer( - EmptyFunctionRegistry, +object SimpleAnalyzer extends Analyzer( + new SessionCatalog( + new InMemoryCatalog, + EmptyFunctionRegistry, + new SimpleCatalystConf(caseSensitiveAnalysis = true)), new SimpleCatalystConf(caseSensitiveAnalysis = true)) -class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf) - extends Analyzer(new SessionCatalog(new InMemoryCatalog, functionRegistry, conf), conf) - /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and * [[UnresolvedRelation]]s into fully typed objects using information in a @@ -55,9 +54,13 @@ class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf) class Analyzer( catalog: SessionCatalog, conf: CatalystConf, - maxIterations: Int = 100) + maxIterations: Int) extends RuleExecutor[LogicalPlan] with CheckAnalysis { + def this(catalog: SessionCatalog, conf: CatalystConf) = { + this(catalog, conf, conf.optimizerMaxIterations) + } + def resolver: Resolver = { if (conf.caseSensitiveAnalysis) { caseSensitiveResolution @@ -66,7 +69,7 @@ class Analyzer( } } - val fixedPoint = FixedPoint(maxIterations) + protected val fixedPoint = FixedPoint(maxIterations) /** * Override to provide additional rules for the "Resolution" batch. @@ -169,8 +172,8 @@ class Analyzer( private def assignAliases(exprs: Seq[NamedExpression]) = { exprs.zipWithIndex.map { case (expr, i) => - expr transformUp { - case u @ UnresolvedAlias(child, optionalAliasName) => child match { + expr.transformUp { case u @ UnresolvedAlias(child, optionalAliasName) => + child match { case ne: NamedExpression => ne case e if !e.resolved => u case g: Generator => MultiAlias(g, Nil) @@ -212,7 +215,7 @@ class Analyzer( * represented as the bit masks. */ def bitmasks(r: Rollup): Seq[Int] = { - Seq.tabulate(r.groupByExprs.length + 1)(idx => {(1 << idx) - 1}) + Seq.tabulate(r.groupByExprs.length + 1)(idx => (1 << idx) - 1) } /* @@ -293,10 +296,13 @@ class Analyzer( val nonNullBitmask = x.bitmasks.reduce(_ & _) - val groupByAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => + val expandedAttributes = groupByAliases.zipWithIndex.map { case (a, idx) => a.toAttribute.withNullability((nonNullBitmask & 1 << idx) == 0) } + val expand = Expand(x.bitmasks, groupByAliases, expandedAttributes, gid, x.child) + val groupingAttrs = expand.output.drop(x.child.output.length) + val aggregations: Seq[NamedExpression] = x.aggregations.map { case expr => // collect all the found AggregateExpression, so we can check an expression is part of // any AggregateExpression or not. @@ -318,15 +324,12 @@ class Analyzer( if (index == -1) { e } else { - groupByAttributes(index) + groupingAttrs(index) } }.asInstanceOf[NamedExpression] } - Aggregate( - groupByAttributes :+ gid, - aggregations, - Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child)) + Aggregate(groupingAttrs, aggregations, expand) case f @ Filter(cond, child) if hasGroupingFunction(cond) => val groupingExprs = findGroupingExprs(child) @@ -852,25 +855,35 @@ class Analyzer( } /** - * This rule resolve subqueries inside expressions. + * This rule resolves sub-queries inside expressions. * - * Note: CTE are handled in CTESubstitution. + * Note: CTEs are handled in CTESubstitution. */ object ResolveSubquery extends Rule[LogicalPlan] with PredicateHelper { - private def hasSubquery(e: Expression): Boolean = { - e.find(_.isInstanceOf[SubqueryExpression]).isDefined - } - - private def hasSubquery(q: LogicalPlan): Boolean = { - q.expressions.exists(hasSubquery) + /** + * Resolve the correlated predicates in the [[Filter]] clauses (e.g. WHERE or HAVING) of a + * sub-query by using the plan the predicates should be correlated to. + */ + private def resolveCorrelatedPredicates(q: LogicalPlan, p: LogicalPlan): LogicalPlan = { + q transformUp { + case f @ Filter(cond, child) if child.resolved && !f.resolved => + val newCond = resolveExpression(cond, p, throws = false) + if (!cond.fastEquals(newCond)) { + Filter(newCond, child) + } else { + f + } + } } def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case q: LogicalPlan if q.childrenResolved && hasSubquery(q) => + case q: LogicalPlan if q.childrenResolved => q transformExpressions { case e: SubqueryExpression if !e.query.resolved => - e.withNewPlan(execute(e.query)) + // First resolve as much of the sub-query as possible. After that we use the children of + // this plan to resolve the remaining correlated predicates. + e.withNewPlan(q.children.foldLeft(execute(e.query))(resolveCorrelatedPredicates)) } } } @@ -1669,9 +1682,9 @@ object CleanupAliases extends Rule[LogicalPlan] { // Operators that operate on objects should only have expressions from encoders, which should // never have extra aliases. - case o: ObjectOperator => o - case d: DeserializeToObject => d - case s: SerializeFromObject => s + case o: ObjectConsumer => o + case o: ObjectProducer => o + case a: AppendColumns => a case other => var stop = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index d6a8c3eec81aa..45e4d535c18cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.plans.UsingJoin +import org.apache.spark.sql.catalyst.plans.{Inner, RightOuter, UsingJoin} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ /** * Throws user facing errors when passed invalid queries that fail to analyze. */ -trait CheckAnalysis { +trait CheckAnalysis extends PredicateHelper { /** * Override to provide additional checks for correct analysis. @@ -110,6 +110,39 @@ trait CheckAnalysis { s"filter expression '${f.condition.sql}' " + s"of type ${f.condition.dataType.simpleString} is not a boolean.") + case f @ Filter(condition, child) => + // Make sure that no correlated reference is below Aggregates, Outer Joins and on the + // right hand side of Unions. + lazy val attributes = child.outputSet + def failOnCorrelatedReference( + p: LogicalPlan, + message: String): Unit = p.transformAllExpressions { + case e: NamedExpression if attributes.contains(e) => + failAnalysis(s"Accessing outer query column is not allowed in $message: $e") + } + def checkForCorrelatedReferences(p: PredicateSubquery): Unit = p.query.foreach { + case a @ Aggregate(_, _, source) => + failOnCorrelatedReference(source, "an AGGREATE") + case j @ Join(left, _, RightOuter, _) => + failOnCorrelatedReference(left, "a RIGHT OUTER JOIN") + case j @ Join(_, right, jt, _) if jt != Inner => + failOnCorrelatedReference(right, "a LEFT (OUTER) JOIN") + case Union(_ :: xs) => + xs.foreach(failOnCorrelatedReference(_, "a UNION")) + case s: SetOperation => + failOnCorrelatedReference(s.right, "an INTERSECT/EXCEPT") + case _ => + } + splitConjunctivePredicates(condition).foreach { + case p: PredicateSubquery => + checkForCorrelatedReferences(p) + case Not(p: PredicateSubquery) => + checkForCorrelatedReferences(p) + case e if PredicateSubquery.hasPredicateSubquery(e) => + failAnalysis(s"Predicate sub-queries cannot be used in nested conditions: $e") + case e => + } + case j @ Join(_, _, UsingJoin(_, cols), _) => val from = operator.inputSet.map(_.name).mkString(", ") failAnalysis( @@ -209,6 +242,9 @@ trait CheckAnalysis { | but one table has '${firstError.output.length}' columns and another table has | '${s.children.head.output.length}' columns""".stripMargin) + case p if p.expressions.exists(PredicateSubquery.hasPredicateSubquery) => + failAnalysis(s"Predicate sub-queries can only be used in a Filter: $p") + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f2abf136da685..a44430059dddd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -179,6 +179,7 @@ object FunctionRegistry { expression[Atan]("atan"), expression[Atan2]("atan2"), expression[Bin]("bin"), + expression[BRound]("bround"), expression[Cbrt]("cbrt"), expression[Ceil]("ceil"), expression[Ceil]("ceiling"), @@ -328,6 +329,7 @@ object FunctionRegistry { expression[SortArray]("sort_array"), // misc functions + expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), expression[Md5]("md5"), expression[Murmur3Hash]("hash"), @@ -337,6 +339,7 @@ object FunctionRegistry { expression[SparkPartitionID]("spark_partition_id"), expression[InputFileName]("input_file_name"), expression[MonotonicallyIncreasingID]("monotonically_increasing_id"), + expression[CurrentDatabase]("current_database"), // grouping sets expression[Cube]("cube"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 823d2495fad80..5323b79c57c4b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -584,10 +584,10 @@ object HiveTypeCoercion { val newRight = if (right.dataType == widestType) right else Cast(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. - // Convert If(null literal, _, _) into boolean type. - // In the optimizer, we should short-circuit this directly into false value. - case If(pred, left, right) if pred.dataType == NullType => + case If(Literal(null, NullType), left, right) => If(Literal.create(null, BooleanType), left, right) + case If(pred, left, right) if pred.dataType == NullType => + If(Cast(pred, BooleanType), left, right) } } diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala similarity index 79% rename from mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala index 6b3268cdfa25c..a4d387eae3c80 100644 --- a/mllib-local/src/main/scala/org/apache/spark/ml/DummyTesting.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/OutputMode.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.ml +package org.apache.spark.sql.catalyst.analysis -// This is a private class testing if the new build works. To be removed soon. -private[ml] object DummyTesting { - private[ml] def add10(input: Double): Double = input + 10 -} +sealed trait OutputMode + +case object Append extends OutputMode +case object Update extends OutputMode diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala new file mode 100644 index 0000000000000..aadc1d31bd4b2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationChecker.scala @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ + +/** + * Analyzes the presence of unsupported operations in a logical plan. + */ +object UnsupportedOperationChecker { + + def checkForBatch(plan: LogicalPlan): Unit = { + plan.foreachUp { + case p if p.isStreaming => + throwError( + "Queries with streaming sources must be executed with write.startStream()")(p) + + case _ => + } + } + + def checkForStreaming(plan: LogicalPlan, outputMode: OutputMode): Unit = { + + if (!plan.isStreaming) { + throwError( + "Queries without streaming sources cannot be executed with write.startStream()")(plan) + } + + plan.foreachUp { implicit plan => + + // Operations that cannot exists anywhere in a streaming plan + plan match { + + case _: Command => + throwError("Commands like CreateTable*, AlterTable*, Show* are not supported with " + + "streaming DataFrames/Datasets") + + case _: InsertIntoTable => + throwError("InsertIntoTable is not supported with streaming DataFrames/Datasets") + + case Aggregate(_, _, child) if child.isStreaming && outputMode == Append => + throwError( + "Aggregations are not supported on streaming DataFrames/Datasets in " + + "Append output mode. Consider changing output mode to Update.") + + case Join(left, right, joinType, _) => + + joinType match { + + case Inner => + if (left.isStreaming && right.isStreaming) { + throwError("Inner join between two streaming DataFrames/Datasets is not supported") + } + + case FullOuter => + if (left.isStreaming || right.isStreaming) { + throwError("Full outer joins with streaming DataFrames/Datasets are not supported") + } + + + case LeftOuter | LeftSemi | LeftAnti => + if (right.isStreaming) { + throwError("Left outer/semi/anti joins with a streaming DataFrame/Dataset " + + "on the right is not supported") + } + + case RightOuter => + if (left.isStreaming) { + throwError("Right outer join with a streaming DataFrame/Dataset on the left is " + + "not supported") + } + + case NaturalJoin(_) | UsingJoin(_, _) => + // They should not appear in an analyzed plan. + + case _ => + throwError(s"Join type $joinType is not supported with streaming DataFrame/Dataset") + } + + case c: CoGroup if c.children.exists(_.isStreaming) => + throwError("CoGrouping with a streaming DataFrame/Dataset is not supported") + + case u: Union if u.children.map(_.isStreaming).distinct.size == 2 => + throwError("Union between streaming and batch DataFrames/Datasets is not supported") + + case Except(left, right) if right.isStreaming => + throwError("Except with a streaming DataFrame/Dataset on the right is not supported") + + case Intersect(left, right) if left.isStreaming && right.isStreaming => + throwError("Intersect between two streaming DataFrames/Datasets is not supported") + + case GroupingSets(_, _, child, _) if child.isStreaming => + throwError("GroupingSets is not supported on streaming DataFrames/Datasets") + + case GlobalLimit(_, _) | LocalLimit(_, _) if plan.children.forall(_.isStreaming) => + throwError("Limits are not supported on streaming DataFrames/Datasets") + + case Sort(_, _, _) | SortPartitions(_, _) if plan.children.forall(_.isStreaming) => + throwError("Sorting is not supported on streaming DataFrames/Datasets") + + case Sample(_, _, _, _, child) if child.isStreaming => + throwError("Sampling is not supported on streaming DataFrames/Datasets") + + case Window(_, _, _, child) if child.isStreaming => + throwError("Non-time-based windows are not supported on streaming DataFrames/Datasets") + + case ReturnAnswer(child) if child.isStreaming => + throwError("Cannot return immediate result on streaming DataFrames/Dataset. Queries " + + "with streaming DataFrames/Datasets must be executed with write.startStream().") + + case _ => + } + } + } + + private def throwErrorIf( + condition: Boolean, + msg: String)(implicit operator: LogicalPlan): Unit = { + if (condition) { + throwError(msg) + } + } + + private def throwError(msg: String)(implicit operator: LogicalPlan): Nothing = { + throw new AnalysisException( + msg, operator.origin.line, operator.origin.startPosition, Some(operator)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 4ec43aba02d66..e83008e86ebef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -153,7 +153,7 @@ case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends override def eval(input: InternalRow = null): TraversableOnce[InternalRow] = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") override def terminate(): TraversableOnce[InternalRow] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index f8a6fb74cc87d..569614476f614 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -205,6 +205,11 @@ class InMemoryCatalog extends ExternalCatalog { StringUtils.filterPattern(listTables(db), pattern) } + override def showCreateTable(db: String, table: String): String = { + throw new AnalysisException( + "SHOW CREATE TABLE command is not supported for temporary tables created in SQLContext.") + } + // -------------------------------------------------------------------------- // Partitions // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 34e1cb7315a9c..980dda986338a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -187,6 +187,15 @@ class SessionCatalog( externalCatalog.getTableOption(db, table) } + /** + * Generate Create table DDL string for the specified tableIdentifier + */ + def showCreateTable(name: TableIdentifier): String = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + externalCatalog.showCreateTable(db, table) + } + // ------------------------------------------------------------- // | Methods that interact with temporary and metastore tables | // ------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index ad989a97e4afa..f20699a58cf7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -99,6 +99,8 @@ abstract class ExternalCatalog { def listTables(db: String, pattern: String): Seq[String] + def showCreateTable(db: String, table: String): String + // -------------------------------------------------------------------------- // Partitions // -------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 1e7296664bb25..085e95f542a16 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -245,6 +245,10 @@ package object dsl { def struct(attrs: AttributeReference*): AttributeReference = struct(StructType.fromAttributes(attrs)) + /** Creates a new AttributeReference of object type */ + def obj(cls: Class[_]): AttributeReference = + AttributeReference(s, ObjectType(cls), nullable = true)() + /** Create a function. */ def function(exprs: Expression*): UnresolvedFunction = UnresolvedFunction(s, exprs, isDistinct = false) @@ -297,6 +301,24 @@ package object dsl { condition: Option[Expression] = None): LogicalPlan = Join(logicalPlan, otherPlan, joinType, condition) + def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: Encoder]( + otherPlan: LogicalPlan, + func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute] + ): LogicalPlan = { + CoGroup.apply[Key, Left, Right, Result]( + func, + leftGroup, + rightGroup, + leftAttr, + rightAttr, + logicalPlan, + otherPlan) + } + def orderBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, true, logicalPlan) def sortBy(sortExprs: SortOrder*): LogicalPlan = Sort(sortExprs, false, logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index c1fd23f28d6b3..99f156a935b50 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -58,7 +58,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { @@ -67,17 +67,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.value = oev.value val code = oev.code oev.code = "" - code + ev.copy(code = code) } else if (nullable) { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); - """ + $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""") } else { - ev.isNull = "false" - s""" - $javaType ${ev.value} = $value; - """ + ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 0f8876a9e6881..b1e89b5de833f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -446,11 +446,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w protected override def nullSafeEval(input: Any): Any = cast(input) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval = child.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - eval.code + - castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast) + ev.copy(code = eval.code + + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } // three function arguments are: child.primitive, result.primitive and result.isNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index 8d8cc152ff29c..607c7c877cc14 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -69,8 +69,17 @@ class EquivalentExpressions { */ def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf - // the children of CodegenFallback will not be used to generate code (call eval() instead) - if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) { + // There are some special expressions that we should not recurse into children. + // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead) + // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination. + val shouldRecurse = root match { + // TODO: some expressions implements `CodegenFallback` but can still do codegen, + // e.g. `CaseWhen`, we should support them. + case _: CodegenFallback => false + case _: ReferenceToExpressions => false + case _ => true + } + if (!skip && !addExpr(root) && shouldRecurse) { root.children.foreach(addExprTree(_, ignoreLeaf)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a24a5db8d49cd..7dacdafb7141d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -86,24 +86,23 @@ abstract class Expression extends TreeNode[Expression] { def eval(input: InternalRow = null): Any /** - * Returns an [[ExprCode]], which contains Java source code that - * can be used to generate the result of evaluating the expression on an input row. + * Returns an [[ExprCode]], that contains the Java source code to generate the result of + * evaluating the expression on an input row. * * @param ctx a [[CodegenContext]] * @return [[ExprCode]] */ - def gen(ctx: CodegenContext): ExprCode = { + def genCode(ctx: CodegenContext): ExprCode = { ctx.subExprEliminationExprs.get(this).map { subExprState => - // This expression is repeated meaning the code to evaluated has already been added - // as a function and called in advance. Just use it. + // This expression is repeated which means that the code to evaluate it has already been added + // as a function before. In that case, we just re-use it. val code = s"/* ${toCommentSafeString(this.toString)} */" ExprCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val ve = ExprCode("", isNull, value) - ve.code = genCode(ctx, ve) - if (ve.code != "") { + val ve = doGenCode(ctx, ExprCode("", isNull, value)) + if (ve.code.nonEmpty) { // Add `this` in the comment. ve.copy(s"/* ${toCommentSafeString(this.toString)} */\n" + ve.code.trim) } else { @@ -119,9 +118,9 @@ abstract class Expression extends TreeNode[Expression] { * * @param ctx a [[CodegenContext]] * @param ev an [[ExprCode]] with unique terms. - * @return Java source code + * @return an [[ExprCode]] containing the Java source code to generate the given expression */ - protected def genCode(ctx: CodegenContext, ev: ExprCode): String + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode /** * Returns `true` if this expression and all its children have been resolved to a specific schema @@ -185,7 +184,7 @@ abstract class Expression extends TreeNode[Expression] { * Returns a user-facing string representation of this expression's name. * This should usually match the name of the function in SQL. */ - def prettyName: String = getClass.getSimpleName.toLowerCase + def prettyName: String = nodeName.toLowerCase private def flatArguments = productIterator.flatMap { case t: Traversable[_] => t @@ -216,7 +215,7 @@ trait Unevaluable extends Expression { final override def eval(input: InternalRow = null): Any = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - final override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = + final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } @@ -316,7 +315,7 @@ abstract class UnaryExpression extends Expression { protected def defineCodeGen( ctx: CodegenContext, ev: ExprCode, - f: String => String): String = { + f: String => String): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { s"${ev.value} = ${f(eval)};" }) @@ -332,25 +331,23 @@ abstract class UnaryExpression extends Expression { protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, - f: String => String): String = { - val childGen = child.gen(ctx) + f: String => String): ExprCode = { + val childGen = child.genCode(ctx) val resultCode = f(childGen.value) if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) - s""" + ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $nullSafeEval - """ + """) } else { - ev.isNull = "false" - s""" + ev.copy(code = s""" ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode - """ + $resultCode""", isNull = "false") } } } @@ -406,7 +403,7 @@ abstract class BinaryExpression extends Expression { protected def defineCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String) => String): String = { + f: (String, String) => String): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s"${ev.value} = ${f(eval1, eval2)};" }) @@ -423,9 +420,9 @@ abstract class BinaryExpression extends Expression { protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String) => String): String = { - val leftGen = left.gen(ctx) - val rightGen = right.gen(ctx) + f: (String, String) => String): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) val resultCode = f(leftGen.value, rightGen.value) if (nullable) { @@ -439,19 +436,17 @@ abstract class BinaryExpression extends Expression { } } - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $nullSafeEval - """ + """) } else { - ev.isNull = "false" - s""" + ev.copy(code = s""" ${leftGen.code} ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode - """ + $resultCode""", isNull = "false") } } } @@ -548,7 +543,7 @@ abstract class TernaryExpression extends Expression { protected def defineCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String, String) => String): String = { + f: (String, String, String) => String): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { s"${ev.value} = ${f(eval1, eval2, eval3)};" }) @@ -565,10 +560,10 @@ abstract class TernaryExpression extends Expression { protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String, String) => String): String = { - val leftGen = children(0).gen(ctx) - val midGen = children(1).gen(ctx) - val rightGen = children(2).gen(ctx) + f: (String, String, String) => String): ExprCode = { + val leftGen = children(0).genCode(ctx) + val midGen = children(1).genCode(ctx) + val rightGen = children(2).genCode(ctx) val resultCode = f(leftGen.value, midGen.value, rightGen.value) if (nullable) { @@ -584,20 +579,17 @@ abstract class TernaryExpression extends Expression { } } - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $nullSafeEval - """ + $nullSafeEval""") } else { - ev.isNull = "false" - s""" + ev.copy(code = s""" ${leftGen.code} ${midGen.code} ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode - """ + $resultCode""", isNull = "false") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 2ed6fc0d3824f..96929ecf56375 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -43,9 +43,8 @@ case class InputFileName() extends LeafExpression with Nondeterministic { InputFileNameHolder.getInputFileName() } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - ev.isNull = "false" - s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();" + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 5d28f8fbde8be..75c6bb2d84dfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -65,18 +65,16 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with partitionMask + currentCount } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;") - ev.isNull = "false" - s""" + ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++; - """ + $countTerm++;""", isNull = "false") } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 354311c5e7449..27ad8e4cf22ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -168,9 +168,7 @@ object FromUnsafeProjection { * Returns an UnsafeProjection for given Array of DataTypes. */ def apply(fields: Seq[DataType]): Projection = { - create(fields.zipWithIndex.map(x => { - new BoundReference(x._2, x._1, true) - })) + create(fields.zipWithIndex.map(x => new BoundReference(x._2, x._1, true))) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala new file mode 100644 index 0000000000000..c4cc6c39b0477 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -0,0 +1,75 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types.DataType + +/** + * A special expression that evaluates [[BoundReference]]s by given expressions instead of the + * input row. + * + * @param result The expression that contains [[BoundReference]] and produces the final output. + * @param children The expressions that used as input values for [[BoundReference]]. + */ +case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) + extends Expression { + + override def nullable: Boolean = result.nullable + override def dataType: DataType = result.dataType + + override def checkInputDataTypes(): TypeCheckResult = { + if (result.references.nonEmpty) { + return TypeCheckFailure("The result expression cannot reference to any attributes.") + } + + var maxOrdinal = -1 + result foreach { + case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal + } + if (maxOrdinal > children.length) { + return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + + s"there are only ${children.length} inputs.") + } + + TypeCheckSuccess + } + + private lazy val projection = UnsafeProjection.create(children) + + override def eval(input: InternalRow): Any = { + result.eval(projection(input)) + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childrenGen = children.map(_.genCode(ctx)) + val childrenVars = childrenGen.zip(children).map { + case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType) + } + + val resultGen = result.transform { + case b: BoundReference => childrenVars(b.ordinal) + }.genCode(ctx) + + ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code, + isNull = resultGen.isNull, value = resultGen.value) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 500ff447a9754..0038cf65e2993 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -989,9 +989,9 @@ case class ScalaUDF( converterTerm } - override def genCode( + override def doGenCode( ctx: CodegenContext, - ev: ExprCode): String = { + ev: ExprCode): ExprCode = { ctx.references += this @@ -1024,7 +1024,7 @@ case class ScalaUDF( s"[$funcExpressionIdx]).userDefinedFunc());") // codegen for children expressions - val evals = children.map(_.gen(ctx)) + val evals = children.map(_.genCode(ctx)) // Generate the codes for expressions and calling user-defined function // We need to get the boxedType of dataType's javaType here. Because for the dataType @@ -1042,7 +1042,7 @@ case class ScalaUDF( s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" - s""" + ev.copy(code = s""" $evalCode ${converters.mkString("\n")} $callFunc @@ -1051,8 +1051,7 @@ case class ScalaUDF( ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $resultTerm; - } - """ + }""") } private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index b739361937b6b..e0c3b22a3c389 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -70,8 +70,8 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val childCode = child.child.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childCode = child.child.genCode(ctx) val input = childCode.value val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName val DoublePrefixCmp = classOf[DoublePrefixComparator].getName @@ -104,14 +104,14 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { case _ => (0L, "0L") } - childCode.code + - s""" - |long ${ev.value} = ${nullValue}L; - |boolean ${ev.isNull} = false; - |if (!${childCode.isNull}) { - | ${ev.value} = $prefixCode; - |} - """.stripMargin + ev.copy(code = childCode.code + + s""" + |long ${ev.value} = ${nullValue}L; + |boolean ${ev.isNull} = false; + |if (!${childCode.isNull}) { + | ${ev.value} = $prefixCode; + |} + """.stripMargin) } override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 377f08eb105fa..71af59a7a8529 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -44,11 +44,10 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm override protected def evalInternal(input: InternalRow): Int = partitionId - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.freshName("partitionId") ctx.addMutableState(ctx.JAVA_INT, idTerm, s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") - ev.isNull = "false" - s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;" + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index daf3de95dd9ea..83fa447cf8c85 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -158,11 +158,11 @@ object TimeWindow { case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = LongType - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval = child.gen(ctx) - eval.code + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = child.genCode(ctx) + ev.copy(code = eval.code + s"""boolean ${ev.isNull} = ${eval.isNull}; |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; - """.stripMargin + """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index f3d42fc0b2164..b2df79a58884b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -36,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression private lazy val numeric = TypeUtils.getNumeric(dataType) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") @@ -70,7 +70,7 @@ case class UnaryPositive(child: Expression) override def dataType: DataType = child.dataType - override def genCode(ctx: CodegenContext, ev: ExprCode): String = + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = defineCodeGen(ctx, ev, c => c) protected override def nullSafeEval(input: Any): Any = input @@ -93,7 +93,7 @@ case class Abs(child: Expression) private lazy val numeric = TypeUtils.getNumeric(dataType) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => @@ -113,7 +113,7 @@ abstract class BinaryArithmetic extends BinaryOperator { def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") - override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide @@ -147,7 +147,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => @@ -179,7 +179,7 @@ case class Subtract(left: Expression, right: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => @@ -241,9 +241,9 @@ case class Divide(left: Expression, right: Expression) /** * Special case handling due to division by 0 => null. */ - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { s"${eval2.value}.isZero()" } else { @@ -256,7 +256,7 @@ case class Divide(left: Expression, right: Expression) s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - s""" + ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; @@ -265,10 +265,9 @@ case class Divide(left: Expression, right: Expression) } else { ${eval1.code} ${ev.value} = $divide; - } - """ + }""") } else { - s""" + ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; @@ -281,8 +280,7 @@ case class Divide(left: Expression, right: Expression) } else { ${ev.value} = $divide; } - } - """ + }""") } } } @@ -320,9 +318,9 @@ case class Remainder(left: Expression, right: Expression) /** * Special case handling for x % 0 ==> null. */ - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { s"${eval2.value}.isZero()" } else { @@ -335,7 +333,7 @@ case class Remainder(left: Expression, right: Expression) s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - s""" + ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; @@ -344,10 +342,9 @@ case class Remainder(left: Expression, right: Expression) } else { ${eval1.code} ${ev.value} = $remainder; - } - """ + }""") } else { - s""" + ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; @@ -360,8 +357,7 @@ case class Remainder(left: Expression, right: Expression) } else { ${ev.value} = $remainder; } - } - """ + }""") } } } @@ -393,12 +389,12 @@ case class MaxOf(left: Expression, right: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) val compCode = ctx.genComp(dataType, eval1.value, eval2.value) - eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; ${ctx.javaType(left.dataType)} ${ev.value} = ${ctx.defaultValue(left.dataType)}; @@ -415,8 +411,7 @@ case class MaxOf(left: Expression, right: Expression) } else { ${ev.value} = ${eval2.value}; } - } - """ + }""") } override def symbol: String = "max" @@ -449,12 +444,12 @@ case class MinOf(left: Expression, right: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) val compCode = ctx.genComp(dataType, eval1.value, eval2.value) - eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; ${ctx.javaType(left.dataType)} ${ev.value} = ${ctx.defaultValue(left.dataType)}; @@ -471,8 +466,7 @@ case class MinOf(left: Expression, right: Expression) } else { ${ev.value} = ${eval2.value}; } - } - """ + }""") } override def symbol: String = "min" @@ -503,7 +497,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic wi case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { dataType match { case dt: DecimalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index a7e1cd66f24aa..3a0a882e3876e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -130,7 +130,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index f43626ca814a0..d29c27c14b0c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -110,13 +110,17 @@ class CodegenContext { } def declareMutableStates(): String = { - mutableStates.map { case (javaType, variableName, _) => + // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in + // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. + mutableStates.distinct.map { case (javaType, variableName, _) => s"private $javaType $variableName;" }.mkString("\n") } def initMutableStates(): String = { - mutableStates.map(_._3).mkString("\n") + // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in + // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones. + mutableStates.distinct.map(_._3).mkString("\n") } /** @@ -526,7 +530,7 @@ class CodegenContext { val value = s"${fnName}Value" // Generate the code for this expression tree and wrap it in a function. - val code = expr.gen(this) + val code = expr.genCode(this) val fn = s""" |private void $fnName(InternalRow $INPUT_ROW) { @@ -572,7 +576,7 @@ class CodegenContext { def generateExpressions(expressions: Seq[Expression], doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { if (doSubexpressionElimination) subexpressionElimination(expressions) - expressions.map(e => e.gen(this)) + expressions.map(e => e.genCode(this)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 1365ee4b55634..2bd77c65c31cf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.toCommentSafeString */ trait CodegenFallback extends Expression { - protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { foreach { case n: Nondeterministic => n.setInitialValues() case _ => @@ -37,22 +37,20 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") if (nullable) { - s""" + ev.copy(code = s""" /* expression: ${toCommentSafeString(this.toString)} */ Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - } - """ + }""") } else { - ev.isNull = "false" - s""" + ev.copy(code = s""" /* expression: ${toCommentSafeString(this.toString)} */ Object $objectTerm = ((Expression) references[$idx]).eval($input); ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - """ + """, isNull = "false") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 7f840890f8ae5..f143b40443836 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -29,7 +29,7 @@ abstract class BaseMutableProjection extends MutableProjection * It exposes a `target` method, which is used to set the row that will be updated. * The internal [[MutableRow]] object created internally is used only when `target` is not used. */ -object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { +object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableProjection] { protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -40,17 +40,17 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu def generate( expressions: Seq[Expression], inputSchema: Seq[Attribute], - useSubexprElimination: Boolean): (() => MutableProjection) = { + useSubexprElimination: Boolean): MutableProjection = { create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) } - protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { + protected def create(expressions: Seq[Expression]): MutableProjection = { create(expressions, false) } private def create( expressions: Seq[Expression], - useSubexprElimination: Boolean): (() => MutableProjection) = { + useSubexprElimination: Boolean): MutableProjection = { val ctx = newCodeGenContext() val (validExpr, index) = expressions.zipWithIndex.filter { case (NoOp, _) => false @@ -136,8 +136,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = CodeGenerator.compile(code) - () => { - c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] - } + c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala index 908c32de4d896..5635c91830f4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateOrdering.scala @@ -70,7 +70,7 @@ object GenerateOrdering extends CodeGenerator[Seq[SortOrder], Ordering[InternalR */ def genComparisons(ctx: CodegenContext, ordering: Seq[SortOrder]): String = { val comparisons = ordering.map { order => - val eval = order.child.gen(ctx) + val eval = order.child.genCode(ctx) val asc = order.direction == Ascending val isNullA = ctx.freshName("isNullA") val primitiveA = ctx.freshName("primitiveA") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala index 58065d956f072..dd8e2a289a661 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratePredicate.scala @@ -39,7 +39,7 @@ object GeneratePredicate extends CodeGenerator[Expression, (InternalRow) => Bool protected def create(predicate: Expression): ((InternalRow) => Boolean) = { val ctx = newCodeGenContext() - val eval = predicate.gen(ctx) + val eval = predicate.genCode(ctx) val code = s""" public SpecificPredicate generate(Object[] references) { return new SpecificPredicate(references); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala index cf73e36d227c1..7be57aca333de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala @@ -141,7 +141,7 @@ object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] val expressionCodes = expressions.zipWithIndex.map { case (NoOp, _) => "" case (e, i) => - val evaluationCode = e.gen(ctx) + val evaluationCode = e.genCode(ctx) val converter = convertToSafe(ctx, evaluationCode.value, e.dataType) evaluationCode.code + s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index ab790cf372d9e..c71cb73d65bf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -27,7 +27,8 @@ import org.apache.spark.sql.types._ * Given an array or map, returns its size. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the size of an array or a map.") + usage = "_FUNC_(expr) - Returns the size of an array or a map.", + extended = " > SELECT _FUNC_(array('b', 'd', 'c', 'a'));\n 4") case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) @@ -37,7 +38,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType case _: MapType => value.asInstanceOf[MapData].numElements() } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).numElements();") } } @@ -48,8 +49,8 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements.", - extended = " > SELECT _FUNC_(array('b', 'd', 'c', 'a'));\n 'a', 'b', 'c', 'd'") + usage = "_FUNC_(array(obj1, obj2, ...), ascendingOrder) - Sorts the input array in ascending order according to the natural ordering of the array elements.", + extended = " > SELECT _FUNC_(array('b', 'd', 'c', 'a'), true);\n 'a', 'b', 'c', 'd'") // scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { @@ -133,7 +134,7 @@ case class SortArray(base: Expression, ascendingOrder: Expression) * Checks if the array (left) has the element (right) */ @ExpressionDescription( - usage = "_FUNC_(array, value) - Returns TRUE if the array contains value.", + usage = "_FUNC_(array, value) - Returns TRUE if the array contains the value.", extended = " > SELECT _FUNC_(array(1, 2, 3), 2);\n true") case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -180,7 +181,7 @@ case class ArrayContains(left: Expression, right: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") val getValue = ctx.getValue(arr, right.dataType, i) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 74de4a776de89..3d4819c55a2d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -48,15 +48,14 @@ case class CreateArray(children: Seq[Expression]) extends Expression { new GenericArrayData(children.map(_.eval(input)).toArray) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") - s""" + ev.copy(code = s""" final boolean ${ev.isNull} = false; - final Object[] $values = new Object[${children.size}]; - """ + + final Object[] $values = new Object[${children.size}];""" + children.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) + val eval = e.genCode(ctx) eval.code + s""" if (${eval.isNull}) { $values[$i] = null; @@ -65,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } """ }.mkString("\n") + - s"final ArrayData ${ev.value} = new $arrayClass($values);" + s"final ArrayData ${ev.value} = new $arrayClass($values);") } override def prettyName: String = "array" @@ -115,20 +114,19 @@ case class CreateMap(children: Seq[Expression]) extends Expression { new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName val mapClass = classOf[ArrayBasedMapData].getName val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") val keyData = s"new $arrayClass($keyArray)" val valueData = s"new $arrayClass($valueArray)" - s""" + ev.copy(code = s""" final boolean ${ev.isNull} = false; final Object[] $keyArray = new Object[${keys.size}]; - final Object[] $valueArray = new Object[${values.size}]; - """ + keys.zipWithIndex.map { - case (key, i) => - val eval = key.gen(ctx) + final Object[] $valueArray = new Object[${values.size}];""" + + keys.zipWithIndex.map { case (key, i) => + val eval = key.genCode(ctx) s""" ${eval.code} if (${eval.isNull}) { @@ -139,7 +137,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { """ }.mkString("\n") + values.zipWithIndex.map { case (value, i) => - val eval = value.gen(ctx) + val eval = value.genCode(ctx) s""" ${eval.code} if (${eval.isNull}) { @@ -148,7 +146,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { $valueArray[$i] = ${eval.value}; } """ - }.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);" + }.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);") } override def prettyName: String = "map" @@ -181,24 +179,22 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - s""" + ev.copy(code = s""" boolean ${ev.isNull} = false; - final Object[] $values = new Object[${children.size}]; - """ + + final Object[] $values = new Object[${children.size}];""" + children.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) + val eval = e.genCode(ctx) eval.code + s""" if (${eval.isNull}) { $values[$i] = null; } else { $values[$i] = ${eval.value}; - } - """ + }""" }.mkString("\n") + - s"final InternalRow ${ev.value} = new $rowClass($values);" + s"final InternalRow ${ev.value} = new $rowClass($values);") } override def prettyName: String = "struct" @@ -262,24 +258,22 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { InternalRow(valExprs.map(_.eval(input)): _*) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - s""" + ev.copy(code = s""" boolean ${ev.isNull} = false; - final Object[] $values = new Object[${valExprs.size}]; - """ + + final Object[] $values = new Object[${valExprs.size}];""" + valExprs.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) + val eval = e.genCode(ctx) eval.code + s""" if (${eval.isNull}) { $values[$i] = null; } else { $values[$i] = ${eval.value}; - } - """ + }""" }.mkString("\n") + - s"final InternalRow ${ev.value} = new $rowClass($values);" + s"final InternalRow ${ev.value} = new $rowClass($values);") } override def prettyName: String = "named_struct" @@ -314,11 +308,9 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, children) - ev.isNull = eval.isNull - ev.value = eval.value - eval.code + ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) } override def prettyName: String = "struct_unsafe" @@ -354,11 +346,9 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression InternalRow(valExprs.map(_.eval(input)): _*) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ev.isNull = eval.isNull - ev.value = eval.value - eval.code + ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index c06dcc98674fd..3b4468f55ca73 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -122,7 +122,7 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { if (nullable) { s""" @@ -179,7 +179,7 @@ case class GetArrayStructFields( new GenericArrayData(result) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { val n = ctx.freshName("n") @@ -239,7 +239,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("index") s""" @@ -302,7 +302,7 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index ae6a94842f7d0..e97e08947a500 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -55,12 +55,12 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val condEval = predicate.gen(ctx) - val trueEval = trueValue.gen(ctx) - val falseEval = falseValue.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val condEval = predicate.genCode(ctx) + val trueEval = trueValue.genCode(ctx) + val falseEval = falseValue.genCode(ctx) - s""" + ev.copy(code = s""" ${condEval.code} boolean ${ev.isNull} = false; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -72,8 +72,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi ${falseEval.code} ${ev.isNull} = ${falseEval.isNull}; ${ev.value} = ${falseEval.value}; - } - """ + }""") } override def toString: String = s"if ($predicate) $trueValue else $falseValue" @@ -82,18 +81,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } /** - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". - * When a = true, returns b; when c = true, returns d; else returns e. + * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen. * * @param branches seq of (branch condition, branch value) * @param elseValue optional value for the else branch */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.") -// scalastyle:on line.size.limit -case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) - extends Expression with CodegenFallback { +abstract class CaseWhenBase( + branches: Seq[(Expression, Expression)], + elseValue: Option[Expression]) + extends Expression with Serializable { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue @@ -143,16 +139,58 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E } } - def shouldCodegen: Boolean = { - branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN + override def toString: String = { + val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString + val elseCase = elseValue.map(" ELSE " + _).getOrElse("") + "CASE" + cases + elseCase + " END" } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - if (!shouldCodegen) { - // Fallback to interpreted mode if there are too many branches, as it may reach the - // 64K limit (limit on bytecode size for a single function). - return super[CodegenFallback].genCode(ctx, ev) - } + override def sql: String = { + val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString + val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") + "CASE" + cases + elseCase + " END" + } +} + + +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * When a = true, returns b; when c = true, returns d; else returns e. + * + * @param branches seq of (branch condition, branch value) + * @param elseValue optional value for the else branch + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.") +// scalastyle:on line.size.limit +case class CaseWhen( + val branches: Seq[(Expression, Expression)], + val elseValue: Option[Expression] = None) + extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable { + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + super[CodegenFallback].doGenCode(ctx, ev) + } + + def toCodegen(): CaseWhenCodegen = { + CaseWhenCodegen(branches, elseValue) + } +} + +/** + * CaseWhen expression used when code generation condition is satisfied. + * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen. + * + * @param branches seq of (branch condition, branch value) + * @param elseValue optional value for the else branch + */ +case class CaseWhenCodegen( + val branches: Seq[(Expression, Expression)], + val elseValue: Option[Expression] = None) + extends CaseWhenBase(branches, elseValue) with Serializable { + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Generate code that looks like: // // condA = ... @@ -172,8 +210,8 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E // } // } val cases = branches.map { case (condExpr, valueExpr) => - val cond = condExpr.gen(ctx) - val res = valueExpr.gen(ctx) + val cond = condExpr.genCode(ctx) + val res = valueExpr.genCode(ctx) s""" ${cond.code} if (!${cond.isNull} && ${cond.value}) { @@ -187,7 +225,7 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E var generatedCode = cases.mkString("", "\nelse {\n", "\nelse {\n") elseValue.foreach { elseExpr => - val res = elseExpr.gen(ctx) + val res = elseExpr.genCode(ctx) generatedCode += s""" ${res.code} @@ -198,38 +236,22 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E generatedCode += "}\n" * cases.size - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $generatedCode - """ - } - - override def toString: String = { - val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString - val elseCase = elseValue.map(" ELSE " + _).getOrElse("") - "CASE" + cases + elseCase + " END" - } - - override def sql: String = { - val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString - val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") - "CASE" + cases + elseCase + " END" + $generatedCode""") } } /** Factory methods for CaseWhen. */ object CaseWhen { - - // The maximum number of switches supported with codegen. - val MAX_NUM_CASES_FOR_CODEGEN = 20 - def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = { CaseWhen(branches, Option(elseValue)) } /** * A factory method to facilitate the creation of this expression when used in parsers. + * * @param branches Expressions at even position are the branch conditions, and expressions at odd * position are branch values. */ @@ -243,7 +265,6 @@ object CaseWhen { } } - /** * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". * When a = b, returns c; when a = d, returns e; else returns f. @@ -297,8 +318,8 @@ case class Least(children: Seq[Expression]) extends Expression { }) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val evalChildren = children.map(_.gen(ctx)) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evalChildren = children.map(_.genCode(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) def updateEval(eval: ExprCode): String = { @@ -311,12 +332,11 @@ case class Least(children: Seq[Expression]) extends Expression { } """ } - s""" + ev.copy(code = s""" ${first.code} boolean ${ev.isNull} = ${first.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; - ${rest.map(updateEval).mkString("\n")} - """ + ${rest.map(updateEval).mkString("\n")}""") } } @@ -358,8 +378,8 @@ case class Greatest(children: Seq[Expression]) extends Expression { }) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val evalChildren = children.map(_.gen(ctx)) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evalChildren = children.map(_.genCode(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) def updateEval(eval: ExprCode): String = { @@ -372,12 +392,11 @@ case class Greatest(children: Seq[Expression]) extends Expression { } """ } - s""" + ev.copy(code = s""" ${first.code} boolean ${ev.isNull} = ${first.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; - ${rest.map(updateEval).mkString("\n")} - """ + ${rest.map(updateEval).mkString("\n")}""") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 9135753041f92..69c32f447e867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -91,7 +91,7 @@ case class DateAdd(startDate: Expression, days: Expression) start.asInstanceOf[Int] + d.asInstanceOf[Int] } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (sd, d) => { s"""${ev.value} = $sd + $d;""" }) @@ -119,7 +119,7 @@ case class DateSub(startDate: Expression, days: Expression) start.asInstanceOf[Int] - d.asInstanceOf[Int] } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (sd, d) => { s"""${ev.value} = $sd - $d;""" }) @@ -141,7 +141,7 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu DateTimeUtils.getHours(timestamp.asInstanceOf[Long]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") } @@ -160,7 +160,7 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") } @@ -179,7 +179,7 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") } @@ -198,7 +198,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas DateTimeUtils.getDayInYear(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") } @@ -217,7 +217,7 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu DateTimeUtils.getYear(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") } @@ -235,7 +235,7 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI DateTimeUtils.getQuarter(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") } @@ -254,7 +254,7 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp DateTimeUtils.getMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") } @@ -273,7 +273,7 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa DateTimeUtils.getDayOfMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") } @@ -300,7 +300,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa c.get(Calendar.WEEK_OF_YEAR) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val c = ctx.freshName("cal") @@ -335,7 +335,7 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val sdf = classOf[SimpleDateFormat].getName defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString((new $sdf($format.toString())) @@ -430,20 +430,19 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { left.dataType match { case StringType if right.foldable => val sdf = classOf[SimpleDateFormat].getName val fString = if (constFormat == null) null else constFormat.toString val formatter = ctx.freshName("formatter") if (fString == null) { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") } else { - val eval1 = left.gen(ctx) - s""" + val eval1 = left.genCode(ctx) + ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -455,8 +454,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { } catch (java.lang.Throwable e) { ${ev.isNull} = true; } - } - """ + }""") } case StringType => val sdf = classOf[SimpleDateFormat].getName @@ -471,26 +469,24 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { """ }) case TimestampType => - val eval1 = left.gen(ctx) - s""" + val eval1 = left.genCode(ctx) + ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = ${eval1.value} / 1000000L; - } - """ + }""") case DateType => val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") - val eval1 = left.gen(ctx) - s""" + val eval1 = left.genCode(ctx) + ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.daysToMillis(${eval1.value}) / 1000L; - } - """ + }""") } } @@ -550,17 +546,16 @@ case class FromUnixTime(sec: Expression, format: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val sdf = classOf[SimpleDateFormat].getName if (format.foldable) { if (constFormat == null) { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") } else { - val t = left.gen(ctx) - s""" + val t = left.genCode(ctx) + ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -571,8 +566,7 @@ case class FromUnixTime(sec: Expression, format: Expression) } catch (java.lang.Throwable e) { ${ev.isNull} = true; } - } - """ + }""") } } else { nullSafeCodeGen(ctx, ev, (seconds, f) => { @@ -605,7 +599,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") } @@ -646,7 +640,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) } } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (sd, dowS) => { val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") val dayOfWeekTerm = ctx.freshName("dayOfWeek") @@ -698,7 +692,7 @@ case class TimeAdd(start: Expression, interval: Expression) start.asInstanceOf[Long], itvl.months, itvl.microseconds) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" @@ -725,21 +719,21 @@ case class FromUTCTimestamp(left: Expression, right: Expression) timezone.asInstanceOf[UTF8String].toString) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { val tz = right.eval() if (tz == null) { - s""" + ev.copy(code = s""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; - """.stripMargin + """.stripMargin) } else { val tzTerm = ctx.freshName("tz") val tzClass = classOf[TimeZone].getName ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") - val eval = left.gen(ctx) - s""" + val eval = left.genCode(ctx) + ev.copy(code = s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -747,7 +741,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) | ${ev.value} = ${eval.value} + | ${tzTerm}.getOffset(${eval.value} / 1000) * 1000L; |} - """.stripMargin + """.stripMargin) } } else { defineCodeGen(ctx, ev, (timestamp, format) => { @@ -777,7 +771,7 @@ case class TimeSub(start: Expression, interval: Expression) start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" @@ -805,7 +799,7 @@ case class AddMonths(startDate: Expression, numMonths: Expression) DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, m) => { s"""$dtu.dateAddMonths($sd, $m)""" @@ -835,7 +829,7 @@ case class MonthsBetween(date1: Expression, date2: Expression) DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (l, r) => { s"""$dtu.monthsBetween($l, $r)""" @@ -864,21 +858,21 @@ case class ToUTCTimestamp(left: Expression, right: Expression) timezone.asInstanceOf[UTF8String].toString) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { val tz = right.eval() if (tz == null) { - s""" + ev.copy(code = s""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; - """.stripMargin + """.stripMargin) } else { val tzTerm = ctx.freshName("tz") val tzClass = classOf[TimeZone].getName ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") - val eval = left.gen(ctx) - s""" + val eval = left.genCode(ctx) + ev.copy(code = s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -886,7 +880,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) | ${ev.value} = ${eval.value} - | ${tzTerm}.getOffset(${eval.value} / 1000) * 1000L; |} - """.stripMargin + """.stripMargin) } } else { defineCodeGen(ctx, ev, (timestamp, format) => { @@ -912,7 +906,7 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn override def eval(input: InternalRow): Any = child.eval(input) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, d => d) } @@ -959,25 +953,23 @@ case class TruncDate(date: Expression, format: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (format.foldable) { if (truncLevel == -1) { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") } else { - val d = date.gen(ctx) - s""" + val d = date.genCode(ctx) + ev.copy(code = s""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.truncDate(${d.value}, $truncLevel); - } - """ + }""") } } else { nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { @@ -1013,7 +1005,7 @@ case class DateDiff(endDate: Expression, startDate: Expression) end.asInstanceOf[Int] - start.asInstanceOf[Int] } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (end, start) => s"$end - $start") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 74e86f40c0364..fa5dea6841149 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -34,7 +34,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[Decimal].toUnscaledLong - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } } @@ -53,7 +53,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un protected override def nullSafeEval(input: Any): Any = Decimal(input.asInstanceOf[Long], precision, scale) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { s""" ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale); @@ -70,8 +70,8 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = child.eval(input) - override def gen(ctx: CodegenContext): ExprCode = child.gen(ctx) - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = "" + override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") override def prettyName: String = "promote_precision" override def sql: String = child.sql } @@ -93,7 +93,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { val tmp = ctx.freshName("tmp") s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e6804d096cd96..e9dda588de8ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -60,7 +60,8 @@ object Literal { * Constructs a [[Literal]] of [[ObjectType]], for example when you need to pass an object * into code generation. */ - def fromObject(obj: AnyRef): Literal = new Literal(obj, ObjectType(obj.getClass)) + def fromObject(obj: Any, objType: DataType): Literal = new Literal(obj, objType) + def fromObject(obj: Any): Literal = new Literal(obj, ObjectType(obj.getClass)) def fromJSON(json: JValue): Literal = { val dataType = DataType.parseDataType(json \ "dataType") @@ -190,50 +191,50 @@ case class Literal protected (value: Any, dataType: DataType) override def eval(input: InternalRow): Any = value - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" - s"final ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};" + ev.copy(s"final ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};") } else { dataType match { case BooleanType => ev.isNull = "false" ev.value = value.toString - "" + ev.copy("") case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { - super[CodegenFallback].genCode(ctx, ev) + super[CodegenFallback].doGenCode(ctx, ev) } else { ev.isNull = "false" ev.value = s"${value}f" - "" + ev.copy("") } case DoubleType => val v = value.asInstanceOf[Double] if (v.isNaN || v.isInfinite) { - super[CodegenFallback].genCode(ctx, ev) + super[CodegenFallback].doGenCode(ctx, ev) } else { ev.isNull = "false" ev.value = s"${value}D" - "" + ev.copy("") } case ByteType | ShortType => ev.isNull = "false" ev.value = s"(${ctx.javaType(dataType)})$value" - "" + ev.copy("") case IntegerType | DateType => ev.isNull = "false" ev.value = value.toString - "" + ev.copy("") case TimestampType | LongType => ev.isNull = "false" ev.value = s"${value}L" - "" + ev.copy("") // eval() version may be faster for non-primitive types case other => - super[CodegenFallback].genCode(ctx, ev) + super[CodegenFallback].doGenCode(ctx, ev) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index c8a28e847745c..5152265152aed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -70,7 +70,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) // name of function in java.lang.Math def funcName: String = name.toLowerCase - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } } @@ -88,7 +88,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) if (d <= yAsymptote) null else f(d) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => s""" if ($c <= $yAsymptote) { @@ -123,7 +123,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") } } @@ -197,7 +197,7 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(precision, scale) => @@ -242,7 +242,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre toBase.asInstanceOf[Int]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val numconv = NumberConverter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (num, from, to) => s""" @@ -284,7 +284,7 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(precision, scale) => @@ -346,7 +346,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { s""" if ($eval > 20 || $eval < 0) { @@ -370,7 +370,7 @@ case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") extended = "> SELECT _FUNC_(2);\n 1.0") case class Log2(child: Expression) extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => s""" if ($c <= $yAsymptote) { @@ -458,7 +458,7 @@ case class Bin(child: Expression) protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long])) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c) => s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } @@ -556,7 +556,7 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s"${ev.value} = " + (child.dataType match { @@ -584,7 +584,7 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp protected override def nullSafeEval(num: Any): Any = Hex.unhex(num.asInstanceOf[UTF8String].getBytes) - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s""" @@ -613,7 +613,7 @@ case class Atan2(left: Expression, right: Expression) math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } } @@ -623,7 +623,7 @@ case class Atan2(left: Expression, right: Expression) extended = "> SELECT _FUNC_(2, 3);\n 8.0") case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") } } @@ -653,7 +653,7 @@ case class ShiftLeft(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left << $right") } } @@ -683,7 +683,7 @@ case class ShiftRight(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left >> $right") } } @@ -713,7 +713,7 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) } } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right") } } @@ -753,7 +753,7 @@ case class Logarithm(left: Expression, right: Expression) if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (left.isInstanceOf[EulerNumber]) { nullSafeCodeGen(ctx, ev, (c1, c2) => s""" @@ -779,7 +779,6 @@ case class Logarithm(left: Expression, right: Expression) /** * Round the `child`'s result to `scale` decimal place when `scale` >= 0 * or round at integral part when `scale` < 0. - * For example, round(31.415, 2) = 31.42 and round(31.415, -1) = 30. * * Child of IntegralType would round to itself when `scale` >= 0. * Child of FractionalType whose value is NaN or Infinite would always round to itself. @@ -789,16 +788,12 @@ case class Logarithm(left: Expression, right: Expression) * * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime + * @param mode rounding mode (e.g. HALF_UP, HALF_UP) + * @param modeStr rounding mode string name (e.g. "ROUND_HALF_UP", "ROUND_HALF_EVEN") */ -@ExpressionDescription( - usage = "_FUNC_(x, d) - Round x to d decimal places.", - extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3") -case class Round(child: Expression, scale: Expression) - extends BinaryExpression with ImplicitCastInputTypes { - - import BigDecimal.RoundingMode.HALF_UP - - def this(child: Expression) = this(child, Literal(0)) +abstract class RoundBase(child: Expression, scale: Expression, + mode: BigDecimal.RoundingMode.Value, modeStr: String) + extends BinaryExpression with Serializable with ImplicitCastInputTypes { override def left: Expression = child override def right: Expression = scale @@ -853,39 +848,40 @@ case class Round(child: Expression, scale: Expression) child.dataType match { case _: DecimalType => val decimal = input1.asInstanceOf[Decimal] - if (decimal.changePrecision(decimal.precision, _scale)) decimal else null + if (decimal.changePrecision(decimal.precision, _scale, mode)) decimal else null case ByteType => - BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, HALF_UP).toByte + BigDecimal(input1.asInstanceOf[Byte]).setScale(_scale, mode).toByte case ShortType => - BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, HALF_UP).toShort + BigDecimal(input1.asInstanceOf[Short]).setScale(_scale, mode).toShort case IntegerType => - BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, HALF_UP).toInt + BigDecimal(input1.asInstanceOf[Int]).setScale(_scale, mode).toInt case LongType => - BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, HALF_UP).toLong + BigDecimal(input1.asInstanceOf[Long]).setScale(_scale, mode).toLong case FloatType => val f = input1.asInstanceOf[Float] if (f.isNaN || f.isInfinite) { f } else { - BigDecimal(f.toDouble).setScale(_scale, HALF_UP).toFloat + BigDecimal(f.toDouble).setScale(_scale, mode).toFloat } case DoubleType => val d = input1.asInstanceOf[Double] if (d.isNaN || d.isInfinite) { d } else { - BigDecimal(d).setScale(_scale, HALF_UP).toDouble + BigDecimal(d).setScale(_scale, mode).toDouble } } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val ce = child.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val ce = child.genCode(ctx) val evaluationCode = child.dataType match { case _: DecimalType => s""" - if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale})) { + if (${ce.value}.changePrecision(${ce.value}.precision(), ${_scale}, + java.math.BigDecimal.${modeStr})) { ${ev.value} = ${ce.value}; } else { ${ev.isNull} = true; @@ -894,7 +890,7 @@ case class Round(child: Expression, scale: Expression) if (_scale < 0) { s""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).byteValue();""" + setScale(${_scale}, java.math.BigDecimal.${modeStr}).byteValue();""" } else { s"${ev.value} = ${ce.value};" } @@ -902,7 +898,7 @@ case class Round(child: Expression, scale: Expression) if (_scale < 0) { s""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).shortValue();""" + setScale(${_scale}, java.math.BigDecimal.${modeStr}).shortValue();""" } else { s"${ev.value} = ${ce.value};" } @@ -910,7 +906,7 @@ case class Round(child: Expression, scale: Expression) if (_scale < 0) { s""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).intValue();""" + setScale(${_scale}, java.math.BigDecimal.${modeStr}).intValue();""" } else { s"${ev.value} = ${ce.value};" } @@ -918,7 +914,7 @@ case class Round(child: Expression, scale: Expression) if (_scale < 0) { s""" ${ev.value} = new java.math.BigDecimal(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).longValue();""" + setScale(${_scale}, java.math.BigDecimal.${modeStr}).longValue();""" } else { s"${ev.value} = ${ce.value};" } @@ -928,7 +924,7 @@ case class Round(child: Expression, scale: Expression) ${ev.value} = ${ce.value}; } else { ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).floatValue(); + setScale(${_scale}, java.math.BigDecimal.${modeStr}).floatValue(); }""" case DoubleType => // if child eval to NaN or Infinity, just return it. s""" @@ -936,24 +932,49 @@ case class Round(child: Expression, scale: Expression) ${ev.value} = ${ce.value}; } else { ${ev.value} = java.math.BigDecimal.valueOf(${ce.value}). - setScale(${_scale}, java.math.BigDecimal.ROUND_HALF_UP).doubleValue(); + setScale(${_scale}, java.math.BigDecimal.${modeStr}).doubleValue(); }""" } if (scaleV == null) { // if scale is null, no need to eval its child at all - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") } else { - s""" + ev.copy(code = s""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { $evaluationCode - } - """ + }""") } } } + +/** + * Round an expression to d decimal places using HALF_UP rounding mode. + * round(2.5) == 3.0, round(3.5) == 4.0. + */ +@ExpressionDescription( + usage = "_FUNC_(x, d) - Round x to d decimal places using HALF_UP rounding mode.", + extended = "> SELECT _FUNC_(2.5, 0);\n 3.0") +case class Round(child: Expression, scale: Expression) + extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_UP, "ROUND_HALF_UP") + with Serializable with ImplicitCastInputTypes { + def this(child: Expression) = this(child, Literal(0)) +} + +/** + * Round an expression to d decimal places using HALF_EVEN rounding mode, + * also known as Gaussian rounding or bankers' rounding. + * round(2.5) = 2.0, round(3.5) = 4.0. + */ +@ExpressionDescription( + usage = "_FUNC_(x, d) - Round x to d decimal places using HALF_EVEN rounding mode.", + extended = "> SELECT _FUNC_(2.5, 0);\n 2.0") +case class BRound(child: Expression, scale: Expression) + extends RoundBase(child, scale, BigDecimal.RoundingMode.HALF_EVEN, "ROUND_HALF_EVEN") + with Serializable with ImplicitCastInputTypes { + def this(child: Expression) = this(child, Literal(0)) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 4bd918ed01ae2..1c0787bf9227f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -49,7 +49,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]])) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") } @@ -102,7 +102,7 @@ case class Sha2(left: Expression, right: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val digestUtils = "org.apache.commons.codec.digest.DigestUtils" nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" @@ -147,7 +147,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(DigestUtils.sha1Hex(input.asInstanceOf[Array[Byte]])) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.sha1Hex($c))" ) @@ -173,7 +173,7 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp checksum.getValue } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val CRC32 = "java.util.zip.CRC32" nullSafeCodeGen(ctx, ev, value => { s""" @@ -244,19 +244,18 @@ abstract class HashExpression[E] extends Expression { protected def computeHash(value: Any, dataType: DataType, seed: E): E - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" val childrenHash = children.map { child => - val childGen = child.gen(ctx) + val childGen = child.genCode(ctx) childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) } }.mkString("\n") - s""" + ev.copy(code = s""" ${ctx.javaType(dataType)} ${ev.value} = $seed; - $childrenHash - """ + $childrenHash""") } private def nullSafeElementHash( @@ -477,7 +476,7 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = input - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => s""" | System.err.println("Result of ${child.simpleString} is " + $c); @@ -486,6 +485,41 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { } } +/** + * A function throws an exception if 'condition' is not true. + */ +@ExpressionDescription( + usage = "_FUNC_(condition) - Throw an exception if 'condition' is not true.") +case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[DataType] = Seq(BooleanType) + + override def dataType: DataType = NullType + + override def prettyName: String = "assert_true" + + override def eval(input: InternalRow) : Any = { + val v = child.eval(input) + if (v == null || java.lang.Boolean.FALSE.equals(v)) { + throw new RuntimeException(s"'${child.simpleString}' is not true!") + } else { + null + } + } + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = child.genCode(ctx) + ExprCode(code = s"""${eval.code} + |if (${eval.isNull} || !${eval.value}) { + | throw new RuntimeException("'${child.simpleString}' is not true."); + |}""".stripMargin, isNull = "true", value = "null") + } + + override def sql: String = s"assert_true(${child.sql})" +} + /** * A xxHash64 64-bit hash expression. */ @@ -512,3 +546,15 @@ object XxHash64Function extends InterpretedHashFunction { XXH64.hashUnsafeBytes(base, offset, len, seed) } } + +/** + * Returns the current database of the SessionCatalog. + */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current database.", + extended = "> SELECT _FUNC_()") +private[sql] case class CurrentDatabase() extends LeafExpression with Unevaluable { + override def dataType: DataType = StringType + override def foldable: Boolean = true + override def nullable: Boolean = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 78310fb2f1539..c083f12724dbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -142,8 +142,8 @@ case class Alias(child: Expression, name: String)( override def eval(input: InternalRow): Any = child.eval(input) /** Just a simple passthrough for code generation. */ - override def gen(ctx: CodegenContext): ExprCode = child.gen(ctx) - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = "" + override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index 6a452499430c8..421200e147b7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -64,17 +64,16 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val first = children(0) val rest = children.drop(1) - val firstEval = first.gen(ctx) - s""" + val firstEval = first.genCode(ctx) + ev.copy(code = s""" ${firstEval.code} boolean ${ev.isNull} = ${firstEval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value}; - """ + + ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value};""" + rest.map { e => - val eval = e.gen(ctx) + val eval = e.genCode(ctx) s""" if (${ev.isNull}) { ${eval.code} @@ -84,7 +83,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } """ - }.mkString("\n") + }.mkString("\n")) } } @@ -113,16 +112,15 @@ case class IsNaN(child: Expression) extends UnaryExpression } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval = child.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = child.genCode(ctx) child.dataType match { case DoubleType | FloatType => - s""" + ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = false; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value}); - """ + ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""") } } } @@ -155,12 +153,12 @@ case class NaNvl(left: Expression, right: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val leftGen = left.gen(ctx) - val rightGen = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val leftGen = left.genCode(ctx) + val rightGen = right.genCode(ctx) left.dataType match { case DoubleType | FloatType => - s""" + ev.copy(code = s""" ${leftGen.code} boolean ${ev.isNull} = false; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -177,8 +175,7 @@ case class NaNvl(left: Expression, right: Expression) ${ev.value} = ${rightGen.value}; } } - } - """ + }""") } } } @@ -196,11 +193,9 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { child.eval(input) == null } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval = child.gen(ctx) - ev.isNull = "false" - ev.value = eval.isNull - eval.code + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = child.genCode(ctx) + ExprCode(code = eval.code, isNull = "false", value = eval.isNull) } override def sql: String = s"(${child.sql} IS NULL)" @@ -219,11 +214,9 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { child.eval(input) != null } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval = child.gen(ctx) - ev.isNull = "false" - ev.value = s"(!(${eval.isNull}))" - eval.code + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval = child.genCode(ctx) + ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))") } override def sql: String = s"(${child.sql} IS NOT NULL)" @@ -259,10 +252,10 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate numNonNulls >= n } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val nonnull = ctx.freshName("nonnull") val code = children.map { e => - val eval = e.gen(ctx) + val eval = e.genCode(ctx) e.dataType match { case DoubleType | FloatType => s""" @@ -284,11 +277,10 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate """ } }.mkString("\n") - s""" + ev.copy(code = s""" int $nonnull = 0; $code boolean ${ev.isNull} = false; - boolean ${ev.value} = $nonnull >= $n; - """ + boolean ${ev.value} = $nonnull >= $n;""") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 26b1ff39b3e9f..1e418540a2624 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -59,9 +59,9 @@ case class StaticInvoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argGen = arguments.map(_.gen(ctx)) + val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") if (propagateNull) { @@ -72,7 +72,7 @@ case class StaticInvoke( } val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" - s""" + ev.copy(code = s""" ${argGen.map(_.code).mkString("\n")} boolean ${ev.isNull} = !$argsNonNull; @@ -82,14 +82,14 @@ case class StaticInvoke( ${ev.value} = $objectName.$functionName($argString); $objNullCheck } - """ + """) } else { - s""" + ev.copy(code = s""" ${argGen.map(_.code).mkString("\n")} $javaType ${ev.value} = $objectName.$functionName($argString); final boolean ${ev.isNull} = ${ev.value} == null; - """ + """) } } } @@ -148,10 +148,10 @@ case class Invoke( case _ => identity[String] _ } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val obj = targetObject.gen(ctx) - val argGen = arguments.map(_.gen(ctx)) + val obj = targetObject.genCode(ctx) + val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") // If the function can return null, we do an extra check to make sure our null bit is still set @@ -178,12 +178,12 @@ case class Invoke( """ } - s""" + ev.copy(code = s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} $evaluate $objNullCheck - """ + """) } override def toString: String = s"$targetObject.$functionName" @@ -239,12 +239,12 @@ case class NewInstance( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argGen = arguments.map(_.gen(ctx)) + val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") - val outer = outerPointer.map(func => Literal.fromObject(func()).gen(ctx)) + val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) val setup = s""" @@ -261,7 +261,7 @@ case class NewInstance( if (propagateNull && argGen.nonEmpty) { val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" - s""" + ev.copy(code = s""" $setup boolean ${ev.isNull} = true; @@ -270,14 +270,14 @@ case class NewInstance( ${ev.value} = $constructorCall; ${ev.isNull} = false; } - """ + """) } else { - s""" + ev.copy(code = s""" $setup final $javaType ${ev.value} = $constructorCall; final boolean ${ev.isNull} = false; - """ + """) } } @@ -302,23 +302,24 @@ case class UnwrapOption( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val inputObject = child.gen(ctx) + val inputObject = child.genCode(ctx) - s""" + ev.copy(code = s""" ${inputObject.code} boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get(); - """ + """) } } /** * Converts the result of evaluating `child` into an option, checking both the isNull bit and * (in the case of reference types) equality with null. + * * @param child The expression to evaluate and wrap. * @param optType The type of this option. */ @@ -334,17 +335,17 @@ case class WrapOption(child: Expression, optType: DataType) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val inputObject = child.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val inputObject = child.genCode(ctx) - s""" + ev.copy(code = s""" ${inputObject.code} boolean ${ev.isNull} = false; scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); - """ + """) } } @@ -357,7 +358,7 @@ case class LambdaVariable(value: String, isNull: String, dataType: DataType) ext override def nullable: Boolean = true - override def gen(ctx: CodegenContext): ExprCode = { + override def genCode(ctx: CodegenContext): ExprCode = { ExprCode(code = "", value = value, isNull = isNull) } } @@ -443,13 +444,13 @@ case class MapObjects private( override def dataType: DataType = ArrayType(lambdaFunction.dataType) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val elementJavaType = ctx.javaType(loopVar.dataType) ctx.addMutableState("boolean", loopVar.isNull, "") ctx.addMutableState(elementJavaType, loopVar.value, "") - val genInputData = inputData.gen(ctx) - val genFunction = lambdaFunction.gen(ctx) + val genInputData = inputData.genCode(ctx) + val genFunction = lambdaFunction.genCode(ctx) val dataLength = ctx.freshName("dataLength") val convertedArray = ctx.freshName("convertedArray") val loopIndex = ctx.freshName("loopIndex") @@ -473,7 +474,7 @@ case class MapObjects private( s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" } - s""" + ev.copy(code = s""" ${genInputData.code} boolean ${ev.isNull} = ${genInputData.value} == null; @@ -503,7 +504,7 @@ case class MapObjects private( ${ev.isNull} = false; ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); } - """ + """) } } @@ -523,13 +524,13 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, "") val childrenCodes = children.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) + val eval = e.genCode(ctx) eval.code + s""" if (${eval.isNull}) { $values[$i] = null; @@ -540,17 +541,18 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) } val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) val schemaField = ctx.addReferenceObj("schema", schema) - s""" + ev.copy(code = s""" boolean ${ev.isNull} = false; $values = new Object[${children.size}]; $childrenCode final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); - """ + """) } } /** * Serializes an input object using a generic serializer (Kryo or Java). + * * @param kryo if true, use Kryo. Otherwise, use Java. */ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) @@ -559,7 +561,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { @@ -576,15 +578,15 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") // Code to serialize. - val input = child.gen(ctx) - s""" + val input = child.genCode(ctx) + ev.copy(code = s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $serializer.serialize(${input.value}, null).array(); } - """ + """) } override def dataType: DataType = BinaryType @@ -593,12 +595,13 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) /** * Serializes an input object using a generic serializer (Kryo or Java). Note that the ClassTag * is not an implicit parameter because TreeNode cannot copy implicit parameters. + * * @param kryo if true, use Kryo. Otherwise, use Java. */ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) extends UnaryExpression with NonSQLExpression { - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { @@ -615,8 +618,8 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B s"$serializer = ($serializerInstanceClass) new $serializerClass($sparkConf).newInstance();") // Code to serialize. - val input = child.gen(ctx) - s""" + val input = child.genCode(ctx) + ev.copy(code = s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -624,7 +627,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B ${ev.value} = (${ctx.javaType(dataType)}) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); } - """ + """) } override def dataType: DataType = ObjectType(tag.runtimeClass) @@ -643,12 +646,12 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val instanceGen = beanInstance.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val instanceGen = beanInstance.genCode(ctx) val initialize = setters.map { case (setterMethod, fieldValue) => - val fieldGen = fieldValue.gen(ctx) + val fieldGen = fieldValue.genCode(ctx) s""" ${fieldGen.code} ${instanceGen.value}.$setterMethod(${fieldGen.value}); @@ -658,12 +661,12 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp ev.isNull = instanceGen.isNull ev.value = instanceGen.value - s""" + ev.copy(code = s""" ${instanceGen.code} if (!${instanceGen.isNull}) { ${initialize.mkString("\n")} } - """ + """) } } @@ -685,8 +688,8 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val childGen = child.gen(ctx) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val childGen = child.genCode(ctx) val errMsg = "Null value appeared in non-nullable field:" + walkedTypePath.mkString("\n", "\n", "\n") + @@ -695,16 +698,11 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) "(e.g. java.lang.Integer instead of int/scala.Int)." val idx = ctx.references.length ctx.references += errMsg - - ev.isNull = "false" - ev.value = childGen.value - - s""" + ExprCode(code = s""" ${childGen.code} if (${childGen.isNull}) { throw new RuntimeException((String) references[$idx]); - } - """ + }""", isNull = "false", value = childGen.value) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 38f1210a4edb5..057c6545ef7a4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -99,7 +99,7 @@ case class Not(child: Expression) protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean] - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"!($c)") } @@ -157,9 +157,9 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val valueGen = value.gen(ctx) - val listGen = list.map(_.gen(ctx)) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val valueGen = value.genCode(ctx) + val listGen = list.map(_.genCode(ctx)) val listCode = listGen.map(x => s""" if (!${ev.value}) { @@ -172,14 +172,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } } """).mkString("\n") - s""" + ev.copy(code = s""" ${valueGen.code} boolean ${ev.value} = false; boolean ${ev.isNull} = ${valueGen.isNull}; if (!${ev.isNull}) { $listCode } - """ + """) } override def sql: String = { @@ -216,17 +216,17 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with def getHSet(): Set[Any] = hset - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val setName = classOf[Set[Any]].getName val InSetName = classOf[InSet].getName - val childGen = child.gen(ctx) + val childGen = child.genCode(ctx) ctx.references += this val hsetTerm = ctx.freshName("hset") val hasNullTerm = ctx.freshName("hasNull") ctx.addMutableState(setName, hsetTerm, s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();") ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") - s""" + ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; boolean ${ev.value} = false; @@ -236,7 +236,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ${ev.isNull} = true; } } - """ + """) } override def sql: String = { @@ -274,24 +274,22 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.isNull = "false" - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = false; if (${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - } - """ + }""", isNull = "false") } else { - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = false; @@ -306,7 +304,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with ${ev.isNull} = true; } } - """ + """) } } } @@ -339,24 +337,23 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { ev.isNull = "false" - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = true; if (!${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - } - """ + }""", isNull = "false") } else { - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = true; @@ -371,7 +368,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P ${ev.isNull} = true; } } - """ + """) } } } @@ -379,7 +376,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (ctx.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType @@ -428,7 +425,7 @@ case class EqualTo(left: Expression, right: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } } @@ -464,15 +461,13 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val eval1 = left.genCode(ctx) + val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) - ev.isNull = "false" - eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && $equalCode); - """ + (!${eval1.isNull} && $equalCode);""", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 1ec092a5be965..ca200768b2286 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -67,15 +67,13 @@ case class Rand(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") - ev.isNull = "false" - s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble(); - """ + ev.copy(code = s""" + final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") } } @@ -92,14 +90,12 @@ case class Randn(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") }) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") - ev.isNull = "false" - s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian(); - """ + ev.copy(code = s""" + final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 85a54292639d0..541b8601a344b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -78,7 +78,7 @@ case class Like(left: Expression, right: Expression) override def toString: String = s"$left LIKE $right" - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" val pattern = ctx.freshName("pattern") @@ -92,20 +92,20 @@ case class Like(left: Expression, right: Expression) s"""$pattern = ${patternClass}.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) - s""" + val eval = left.genCode(ctx) + ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } - """ + """) } else { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + """) } } else { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { @@ -128,7 +128,7 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val pattern = ctx.freshName("pattern") @@ -141,20 +141,20 @@ case class RLike(left: Expression, right: Expression) s"""$pattern = ${patternClass}.compile("$regexStr");""") // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. - val eval = left.gen(ctx) - s""" + val eval = left.genCode(ctx) + ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); } - """ + """) } else { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + """) } } else { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { @@ -188,7 +188,7 @@ case class StringSplit(str: Expression, pattern: Expression) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. @@ -247,7 +247,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def children: Seq[Expression] = subject :: regexp :: rep :: Nil override def prettyName: String = "regexp_replace" - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") @@ -330,7 +330,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def children: Seq[Expression] = subject :: regexp :: idx :: Nil override def prettyName: String = "regexp_extract" - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") val classNamePattern = classOf[Pattern].getCanonicalName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index a17482697d906..78e846d3f580e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -51,18 +51,18 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas UTF8String.concat(inputs : _*) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val evals = children.map(_.gen(ctx)) + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val evals = children.map(_.genCode(ctx)) val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.value}" }.mkString(", ") - evals.map(_.code).mkString("\n") + s""" + ev.copy(evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = UTF8String.concat($inputs); if (${ev.value} == null) { ${ev.isNull} = true; } - """ + """) } } @@ -106,25 +106,25 @@ case class ConcatWs(children: Seq[Expression]) UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (children.forall(_.dataType == StringType)) { // All children are strings. In that case we can construct a fixed size array. - val evals = children.map(_.gen(ctx)) + val evals = children.map(_.genCode(ctx)) val inputs = evals.map { eval => s"${eval.isNull} ? (UTF8String) null : ${eval.value}" }.mkString(", ") - evals.map(_.code).mkString("\n") + s""" + ev.copy(evals.map(_.code).mkString("\n") + s""" UTF8String ${ev.value} = UTF8String.concatWs($inputs); boolean ${ev.isNull} = ${ev.value} == null; - """ + """) } else { val array = ctx.freshName("array") val varargNum = ctx.freshName("varargNum") val idxInVararg = ctx.freshName("idxInVararg") - val evals = children.map(_.gen(ctx)) + val evals = children.map(_.genCode(ctx)) val (varargCount, varargBuild) = children.tail.zip(evals.tail).map { case (child, eval) => child.dataType match { case StringType => @@ -148,7 +148,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - evals.map(_.code).mkString("\n") + + ev.copy(evals.map(_.code).mkString("\n") + s""" int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxInVararg = 0; @@ -157,7 +157,7 @@ case class ConcatWs(children: Seq[Expression]) ${varargBuild.mkString("\n")} UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array); boolean ${ev.isNull} = ${ev.value} == null; - """ + """) } } } @@ -185,7 +185,7 @@ case class Upper(child: Expression) override def convert(v: UTF8String): UTF8String = v.toUpperCase - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } } @@ -200,7 +200,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx override def convert(v: UTF8String): UTF8String = v.toLowerCase - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } } @@ -225,7 +225,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } } @@ -236,7 +236,7 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } } @@ -247,7 +247,7 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } } @@ -298,7 +298,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac srcEval.asInstanceOf[UTF8String].translate(dict) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val termLastMatching = ctx.freshName("lastMatching") val termLastReplace = ctx.freshName("lastReplace") val termDict = ctx.freshName("dict") @@ -351,7 +351,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override protected def nullSafeEval(word: Any, set: Any): Any = set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);" ) @@ -375,7 +375,7 @@ case class StringTrim(child: Expression) override def prettyName: String = "trim" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).trim()") } } @@ -393,7 +393,7 @@ case class StringTrimLeft(child: Expression) override def prettyName: String = "ltrim" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).trimLeft()") } } @@ -411,7 +411,7 @@ case class StringTrimRight(child: Expression) override def prettyName: String = "rtrim" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).trimRight()") } } @@ -440,7 +440,7 @@ case class StringInstr(str: Expression, substr: Expression) override def prettyName: String = "instr" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } @@ -475,7 +475,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: count.asInstanceOf[Int]) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") } } @@ -524,11 +524,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val substrGen = substr.gen(ctx) - val strGen = str.gen(ctx) - val startGen = start.gen(ctx) - s""" + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val substrGen = substr.genCode(ctx) + val strGen = str.genCode(ctx) + val startGen = start.genCode(ctx) + ev.copy(code = s""" int ${ev.value} = 0; boolean ${ev.isNull} = false; ${startGen.code} @@ -546,7 +546,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) ${ev.isNull} = true; } } - """ + """) } override def prettyName: String = "locate" @@ -571,7 +571,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") } @@ -597,7 +597,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") } @@ -638,10 +638,10 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - val pattern = children.head.gen(ctx) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val pattern = children.head.genCode(ctx) - val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx))) + val argListGen = children.tail.map(x => (x.dataType, x.genCode(ctx))) val argListCode = argListGen.map(_._2.code + "\n") val argListString = argListGen.foldLeft("")((s, v) => { @@ -660,7 +660,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - s""" + ev.copy(code = s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -670,8 +670,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); $form.format(${pattern.value}.toString() $argListString); ${ev.value} = UTF8String.fromString($sb.toString()); - } - """ + }""") } override def prettyName: String = "format_string" @@ -694,7 +693,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI override def nullSafeEval(string: Any): Any = { string.asInstanceOf[UTF8String].toLowerCase.toTitleCase } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") } } @@ -719,7 +718,7 @@ case class StringRepeat(str: Expression, times: Expression) override def prettyName: String = "repeat" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") } } @@ -735,7 +734,7 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 override def prettyName: String = "reverse" - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).reverse()") } } @@ -757,7 +756,7 @@ case class StringSpace(child: Expression) UTF8String.blankString(if (length < 0) 0 else length) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (length) => s"""${ev.value} = UTF8String.blankString(($length < 0) ? 0 : $length);""") } @@ -799,7 +798,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (string, pos, len) => { str.dataType match { @@ -825,7 +824,7 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy case BinaryType => value.asInstanceOf[Array[Byte]].length } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") @@ -848,7 +847,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (left, right) => s"${ev.value} = $left.levenshteinDistance($right);") } @@ -868,7 +867,7 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"$c.soundex()") } } @@ -894,7 +893,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (child) => { val bytes = ctx.freshName("bytes") s""" @@ -924,7 +923,7 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn bytes.asInstanceOf[Array[Byte]])) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (child) => { s"""${ev.value} = UTF8String.fromBytes( org.apache.commons.codec.binary.Base64.encodeBase64($child)); @@ -945,7 +944,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast protected override def nullSafeEval(string: Any): Any = org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (child) => { s""" ${ev.value} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); @@ -973,7 +972,7 @@ case class Decode(bin: Expression, charset: Expression) UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset)) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (bytes, charset) => s""" try { @@ -1005,7 +1004,7 @@ case class Encode(value: Expression, charset: Expression) input1.asInstanceOf[UTF8String].toString.getBytes(toCharset) } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (string, charset) => s""" try { @@ -1088,7 +1087,7 @@ case class FormatNumber(x: Expression, d: Expression) } } - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (num, d) => { def typeHelper(p: String): String = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala index 968bbdb1a5f03..cbee0e61f7a7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala @@ -20,12 +20,12 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.plans.QueryPlan import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ /** * An interface for subquery that is used in expressions. */ -abstract class SubqueryExpression extends LeafExpression { +abstract class SubqueryExpression extends Expression { /** * The logical plan of the query. @@ -61,6 +61,8 @@ case class ScalarSubquery( override def dataType: DataType = query.schema.fields.head.dataType + override def children: Seq[Expression] = Nil + override def checkInputDataTypes(): TypeCheckResult = { if (query.schema.length != 1) { TypeCheckResult.TypeCheckFailure("Scalar subquery must return only one column, but got " + @@ -77,3 +79,81 @@ case class ScalarSubquery( override def toString: String = s"subquery#${exprId.id}" } + +/** + * A predicate subquery checks the existence of a value in a sub-query. We currently only allow + * [[PredicateSubquery]] expressions within a Filter plan (i.e. WHERE or a HAVING clause). This will + * be rewritten into a left semi/anti join during analysis. + */ +abstract class PredicateSubquery extends SubqueryExpression with Unevaluable with Predicate { + override def nullable: Boolean = false + override def plan: LogicalPlan = SubqueryAlias(prettyName, query) +} + +object PredicateSubquery { + def hasPredicateSubquery(e: Expression): Boolean = { + e.find(_.isInstanceOf[PredicateSubquery]).isDefined + } +} + +/** + * The [[InSubQuery]] predicate checks the existence of a value in a sub-query. For example (SQL): + * {{{ + * SELECT * + * FROM a + * WHERE a.id IN (SELECT id + * FROM b) + * }}} + */ +case class InSubQuery(value: Expression, query: LogicalPlan) extends PredicateSubquery { + override def children: Seq[Expression] = value :: Nil + override lazy val resolved: Boolean = value.resolved && query.resolved + override def withNewPlan(plan: LogicalPlan): InSubQuery = InSubQuery(value, plan) + + /** + * The unwrapped value side expressions. + */ + lazy val expressions: Seq[Expression] = value match { + case CreateStruct(cols) => cols + case col => Seq(col) + } + + /** + * Check if the number of columns and the data types on both sides match. + */ + override def checkInputDataTypes(): TypeCheckResult = { + // Check the number of arguments. + if (expressions.length != query.output.length) { + TypeCheckResult.TypeCheckFailure( + s"The number of fields in the value (${expressions.length}) does not match with " + + s"the number of columns in the subquery (${query.output.length})") + } + + // Check the argument types. + expressions.zip(query.output).zipWithIndex.foreach { + case ((e, a), i) if e.dataType != a.dataType => + TypeCheckResult.TypeCheckFailure( + s"The data type of value[$i](${e.dataType}) does not match " + + s"subquery column '${a.name}' (${a.dataType}).") + case _ => + } + + TypeCheckResult.TypeCheckSuccess + } +} + +/** + * The [[Exists]] expression checks if a row exists in a subquery given some correlated condition. + * For example (SQL): + * {{{ + * SELECT * + * FROM a + * WHERE EXISTS (SELECT * + * FROM b + * WHERE b.id = a.id) + * }}} + */ +case class Exists(query: LogicalPlan) extends PredicateSubquery { + override def children: Seq[Expression] = Nil + override def withNewPlan(plan: LogicalPlan): Exists = Exists(plan) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f5172b213a74b..e6d554565d442 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -19,9 +19,12 @@ package org.apache.spark.sql.catalyst.optimizer import scala.annotation.tailrec import scala.collection.immutable.HashSet +import scala.collection.mutable -import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases} -import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} +import org.apache.spark.sql.catalyst.analysis.{CleanupAliases, DistinctAggregationRewriter, EliminateSubqueryAliases, EmptyFunctionRegistry} +import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{InSubQuery, _} import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.planning.{ExtractFiltersAndInnerJoins, Unions} @@ -34,15 +37,21 @@ import org.apache.spark.sql.types._ * Abstract class all optimizers should inherit of, contains the standard batches (extending * Optimizers can override this. */ -abstract class Optimizer extends RuleExecutor[LogicalPlan] { +abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) + extends RuleExecutor[LogicalPlan] { + + protected val fixedPoint = FixedPoint(conf.optimizerMaxIterations) + def batches: Seq[Batch] = { // Technically some of the rules in Finish Analysis are not optimizer rules and belong more // in the analyzer, because they are needed for correctness (e.g. ComputeCurrentTime). // However, because we also use the analyzer to canonicalized queries (for view definition), // we do not eliminate subqueries or compute current time in the analyzer. Batch("Finish Analysis", Once, + RewritePredicateSubquery, EliminateSubqueryAliases, ComputeCurrentTime, + GetCurrentDatabase(sessionCatalog), DistinctAggregationRewriter) :: ////////////////////////////////////////////////////////////////////////////////////////// // Optimizer rules start here @@ -54,12 +63,12 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { // since the other rules might make two separate Unions operators adjacent. Batch("Union", Once, CombineUnions) :: - Batch("Replace Operators", FixedPoint(100), + Batch("Replace Operators", fixedPoint, ReplaceIntersectWithSemiJoin, ReplaceDistinctWithAggregate) :: - Batch("Aggregate", FixedPoint(100), + Batch("Aggregate", fixedPoint, RemoveLiteralFromGroupExpressions) :: - Batch("Operator Optimizations", FixedPoint(100), + Batch("Operator Optimizations", fixedPoint, // Operator push down SetOperationPushDown, SamplePushDown, @@ -90,14 +99,16 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { SimplifyCasts, SimplifyCaseConversionExpressions, EliminateSerialization) :: - Batch("Decimal Optimizations", FixedPoint(100), + Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: - Batch("Typed Filter Optimization", FixedPoint(100), + Batch("Typed Filter Optimization", fixedPoint, EmbedSerializerInFilter) :: - Batch("LocalRelation", FixedPoint(100), + Batch("LocalRelation", fixedPoint, ConvertToLocalRelation) :: Batch("Subquery", Once, - OptimizeSubqueries) :: Nil + OptimizeSubqueries) :: + Batch("OptimizeCodegen", Once, + OptimizeCodegen(conf)) :: Nil } /** @@ -112,12 +123,19 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { } /** - * Non-abstract representation of the standard Spark optimizing strategies + * An optimizer used in test code. * * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while * specific rules go to the subclasses */ -object DefaultOptimizer extends Optimizer +object SimpleTestOptimizer extends SimpleTestOptimizer + +class SimpleTestOptimizer extends Optimizer( + new SessionCatalog( + new InMemoryCatalog, + EmptyFunctionRegistry, + new SimpleCatalystConf(caseSensitiveAnalysis = true)), + new SimpleCatalystConf(caseSensitiveAnalysis = true)) /** * Pushes operations down into a Sample. @@ -137,29 +155,16 @@ object SamplePushDown extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateSerialization extends Rule[LogicalPlan] { - // TODO: find a more general way to do this optimization. def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) - if !deserializer.isInstanceOf[Attribute] && - deserializer.dataType == child.outputObject.dataType => - val childWithoutSerialization = child.withObjectOutput - m.copy( - deserializer = childWithoutSerialization.output.head, - child = childWithoutSerialization) - - case m @ MapElements(_, deserializer, _, child: ObjectOperator) - if !deserializer.isInstanceOf[Attribute] && - deserializer.dataType == child.outputObject.dataType => - val childWithoutSerialization = child.withObjectOutput - m.copy( - deserializer = childWithoutSerialization.output.head, - child = childWithoutSerialization) - - case d @ DeserializeToObject(_, s: SerializeFromObject) + case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) Project(objAttr :: Nil, s.child) + + case a @ AppendColumns(_, _, _, s: SerializeFromObject) + if a.deserializer.dataType == s.inputObjectType => + AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) } } @@ -300,11 +305,9 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { assert(children.nonEmpty) val (deterministic, nondeterministic) = partitionByDeterministic(condition) val newFirstChild = Filter(deterministic, children.head) - val newOtherChildren = children.tail.map { - child => { - val rewrites = buildRewrites(children.head, child) - Filter(pushToRight(deterministic, rewrites), child) - } + val newOtherChildren = children.tail.map { child => + val rewrites = buildRewrites(children.head, child) + Filter(pushToRight(deterministic, rewrites), child) } Filter(nondeterministic, Union(newFirstChild +: newOtherChildren)) @@ -346,15 +349,15 @@ object ColumnPruning extends Rule[LogicalPlan] { case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty => val newOutput = e.output.filter(a.references.contains(_)) val newProjects = e.projections.map { proj => - proj.zip(e.output).filter { case (e, a) => + proj.zip(e.output).filter { case (_, a) => newOutput.contains(a) }.unzip._1 } a.copy(child = Expand(newProjects, newOutput, grandChild)) - // Prunes the unused columns from child of MapPartitions - case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => - mp.copy(child = prunedChild(child, mp.references)) + // Prunes the unused columns from child of `DeserializeToObject` + case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty => + d.copy(child = prunedChild(child, d.references)) // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => @@ -851,6 +854,16 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Optimizes expressions by replacing according to CodeGen configuration. + */ +case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen => + e.toCodegen() + } +} + /** * Combines all adjacent [[Union]] operators into a single [[Union]]. */ @@ -1007,8 +1020,6 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { case filter @ Filter(_, f: Filter) => filter // should not push predicates through sample, or will generate different results. case filter @ Filter(_, s: Sample) => filter - // TODO: push predicates through expand - case filter @ Filter(_, e: Expand) => filter case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => pushDownPredicate(filter, u.child) { predicate => @@ -1399,6 +1410,16 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } } +/** Replaces the expression of CurrentDatabase with the current database name. */ +case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = { + plan transformAllExpressions { + case CurrentDatabase() => + Literal.create(sessionCatalog.getCurrentDatabase, StringType) + } + } +} + /** * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed @@ -1419,9 +1440,120 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] { s } else { val newCondition = condition transform { - case a: Attribute if a == d.output.head => d.deserializer.child + case a: Attribute if a == d.output.head => d.deserializer } Filter(newCondition, d.child) } } } + +/** + * This rule rewrites predicate sub-queries into left semi/anti joins. The following predicates + * are supported: + * a. EXISTS/NOT EXISTS will be rewritten as semi/anti join, unresolved conditions in Filter + * will be pulled out as the join conditions. + * b. IN/NOT IN will be rewritten as semi/anti join, unresolved conditions in the Filter will + * be pulled out as join conditions, value = selected column will also be used as join + * condition. + */ +object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper { + /** + * Pull out all correlated predicates from a given sub-query. This method removes the correlated + * predicates from sub-query [[Filter]]s and adds the references of these predicates to + * all intermediate [[Project]] clauses (if they are missing) in order to be able to evaluate the + * predicates in the join condition. + * + * This method returns the rewritten sub-query and the combined (AND) extracted predicate. + */ + private def pullOutCorrelatedPredicates( + subquery: LogicalPlan, + query: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + val references = query.outputSet + val predicateMap = mutable.Map.empty[LogicalPlan, Seq[Expression]] + val transformed = subquery transformUp { + case f @ Filter(cond, child) => + // Find all correlated predicates. + val (correlated, local) = splitConjunctivePredicates(cond).partition { e => + e.references.intersect(references).nonEmpty + } + // Rewrite the filter without the correlated predicates if any. + correlated match { + case Nil => f + case xs if local.nonEmpty => + val newFilter = Filter(local.reduce(And), child) + predicateMap += newFilter -> correlated + newFilter + case xs => + predicateMap += child -> correlated + child + } + case p @ Project(expressions, child) => + // Find all pulled out predicates defined in the Project's subtree. + val localPredicates = p.collect(predicateMap).flatten + + // Determine which correlated predicate references are missing from this project. + val localPredicateReferences = localPredicates + .map(_.references) + .reduceOption(_ ++ _) + .getOrElse(AttributeSet.empty) + val missingReferences = localPredicateReferences -- p.references -- query.outputSet + + // Create a new project if we need to add missing references. + if (missingReferences.nonEmpty) { + Project(expressions ++ missingReferences, child) + } else { + p + } + } + (transformed, predicateMap.values.flatten.toSeq) + } + + /** + * Prepare an [[InSubQuery]] by rewriting it (in case of correlated predicates) and by + * constructing the required join condition. Both the rewritten subquery and the constructed + * join condition are returned. + */ + private def pullOutCorrelatedPredicates( + in: InSubQuery, + query: LogicalPlan): (LogicalPlan, Seq[Expression]) = { + val (resolved, joinCondition) = pullOutCorrelatedPredicates(in.query, query) + val conditions = joinCondition ++ in.expressions.zip(resolved.output).map(EqualTo.tupled) + (resolved, conditions) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case f @ Filter(condition, child) => + val (withSubquery, withoutSubquery) = + splitConjunctivePredicates(condition).partition(PredicateSubquery.hasPredicateSubquery) + + // Construct the pruned filter condition. + val newFilter: LogicalPlan = withoutSubquery match { + case Nil => child + case conditions => Filter(conditions.reduce(And), child) + } + + // Filter the plan by applying left semi and left anti joins. + withSubquery.foldLeft(newFilter) { + case (p, Exists(sub)) => + val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) + Join(p, resolved, LeftSemi, conditions.reduceOption(And)) + case (p, Not(Exists(sub))) => + val (resolved, conditions) = pullOutCorrelatedPredicates(sub, p) + Join(p, resolved, LeftAnti, conditions.reduceOption(And)) + case (p, in: InSubQuery) => + val (resolved, conditions) = pullOutCorrelatedPredicates(in, p) + Join(p, resolved, LeftSemi, conditions.reduceOption(And)) + case (p, Not(in: InSubQuery)) => + val (resolved, conditions) = pullOutCorrelatedPredicates(in, p) + // This is a NULL-aware (left) anti join (NAAJ). + // Construct the condition. A NULL in one of the conditions is regarded as a positive + // result; such a row will be filtered out by the Anti-Join operator. + val anyNull = conditions.map(IsNull).reduceLeft(Or) + val condition = conditions.reduceLeft(And) + + // Note that will almost certainly be planned as a Broadcast Nested Loop join. Use EXISTS + // if performance matters to you. + Join(p, resolved, LeftAnti, Option(Or(anyNull, condition))) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index aa59f3fb2a4a4..1c067621df524 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -391,9 +391,13 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { // Having val withHaving = withProject.optional(having) { - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(expression(having), BooleanType), withProject) + // Note that we add a cast to non-predicate expressions. If the expression itself is + // already boolean, the optimizer will get rid of the unnecessary cast. + val predicate = expression(having) match { + case p: Predicate => p + case e => Cast(e, BooleanType) + } + Filter(predicate, withProject) } // Distinct @@ -866,10 +870,10 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a filtering correlated sub-query. This is not supported yet. + * Create a filtering correlated sub-query (EXISTS). */ override def visitExists(ctx: ExistsContext): Expression = { - throw new ParseException("EXISTS clauses are not supported.", ctx) + Exists(plan(ctx.query)) } /** @@ -944,7 +948,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { GreaterThanOrEqual(e, expression(ctx.lower)), LessThanOrEqual(e, expression(ctx.upper)))) case SqlBaseParser.IN if ctx.query != null => - throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + invertIfNotDefined(InSubQuery(e, plan(ctx.query))) case SqlBaseParser.IN => invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index aceeb8aadcf68..45ac126a72f5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -42,6 +42,9 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def analyzed: Boolean = _analyzed + /** Returns true if this subtree contains any streaming data sources. */ + def isStreaming: Boolean = children.exists(_.isStreaming == true) + /** * Returns a copy of this node where `rule` has been recursively applied first to all of its * children and then itself (post-order). When `rule` does not apply to a given node, it is left diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d4fc9e4da944a..a445ce694750a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -516,7 +516,10 @@ private[sql] object Expand { // groupingId is the last output, here we use the bit mask as the concrete value for it. } :+ Literal.create(bitmask, IntegerType) } - val output = child.output ++ groupByAttrs :+ gid + + // the `groupByAttrs` has different meaning in `Expand.output`, it could be the original + // grouping expression or null, so here we create new instance of it. + val output = child.output ++ groupByAttrs.map(_.newInstance) :+ gid Expand(projections, output, Project(child.output ++ groupByAliases, child)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala index 47b34d1fa2e49..fcffdbaaf07b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/commands.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.types.StringType trait Command /** - * Returned for the "DESCRIBE [EXTENDED] FUNCTION functionName" command. + * Returned for the "DESCRIBE FUNCTION [EXTENDED] functionName" command. * @param functionName The function to be described. - * @param isExtended True if "DESCRIBE EXTENDED" is used. Otherwise, false. + * @param isExtended True if "DESCRIBE FUNCTION EXTENDED" is used. Otherwise, false. */ private[sql] case class DescribeFunction( functionName: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 6df46189b627c..4a1bdb0b8ac2e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -21,126 +21,111 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, ObjectType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} object CatalystSerde { def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) - DeserializeToObject(Alias(deserializer, "obj")(), child) + DeserializeToObject(deserializer, generateObjAttr[T], child) } def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { SerializeFromObject(encoderFor[T].namedExpressions, child) } + + def generateObjAttr[T : Encoder]: Attribute = { + AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() + } } /** - * Takes the input row from child and turns it into object using the given deserializer expression. - * The output of this operator is a single-field safe row containing the deserialized object. + * A trait for logical operators that produces domain objects as output. + * The output of this operator is a single-field safe row containing the produced object. */ -case class DeserializeToObject( - deserializer: Alias, - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = deserializer.toAttribute :: Nil +trait ObjectProducer extends LogicalPlan { + // The attribute that reference to the single object field this operator outputs. + protected def outputObjAttr: Attribute + + override def output: Seq[Attribute] = outputObjAttr :: Nil + + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) - def outputObjectType: DataType = deserializer.dataType + def outputObjectType: DataType = outputObjAttr.dataType } /** - * Takes the input object from child and turns in into unsafe row using the given serializer - * expression. The output of its child must be a single-field row containing the input object. + * A trait for logical operators that consumes domain objects as input. + * The output of its child must be a single-field row containing the input object. */ -case class SerializeFromObject( - serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) +trait ObjectConsumer extends UnaryNode { + assert(child.output.length == 1) + + // This operator always need all columns of its child, even it doesn't reference to. + override def references: AttributeSet = child.outputSet def inputObjectType: DataType = child.output.head.dataType } /** - * A trait for logical operators that apply user defined functions to domain objects. + * Takes the input row from child and turns it into object using the given deserializer expression. */ -trait ObjectOperator extends LogicalPlan { +case class DeserializeToObject( + deserializer: Expression, + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectProducer - /** The serializer that is used to produce the output of this operator. */ - def serializer: Seq[NamedExpression] +/** + * Takes the input object from child and turns it into unsafe row using the given serializer + * expression. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectConsumer { override def output: Seq[Attribute] = serializer.map(_.toAttribute) - - /** - * The object type that is produced by the user defined function. Note that the return type here - * is the same whether or not the operator is output serialized data. - */ - def outputObject: NamedExpression = - Alias(serializer.head.collect { case b: BoundReference => b }.head, "obj")() - - /** - * Returns a copy of this operator that will produce an object instead of an encoded row. - * Used in the optimizer when transforming plans to remove unneeded serialization. - */ - def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) { - this - } else { - withNewSerializer(outputObject :: Nil) - } - - /** Returns a copy of this operator with a different serializer. */ - def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy { - productIterator.map { - case c if c == serializer => newSerializer - case other: AnyRef => other - }.toArray - } } object MapPartitions { def apply[T : Encoder, U : Encoder]( func: Iterator[T] => Iterator[U], - child: LogicalPlan): MapPartitions = { - MapPartitions( + child: LogicalPlan): LogicalPlan = { + val deserialized = CatalystSerde.deserialize[T](child) + val mapped = MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - UnresolvedDeserializer(encoderFor[T].deserializer), - encoderFor[U].namedExpressions, - child) + CatalystSerde.generateObjAttr[U], + deserialized) + CatalystSerde.serialize[U](mapped) } } /** * A relation produced by applying `func` to each partition of the `child`. - * - * @param deserializer used to extract the input to `func` from an input row. - * @param serializer use to serialize the output of `func`. */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - deserializer: Expression, - serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer object MapElements { def apply[T : Encoder, U : Encoder]( func: AnyRef, - child: LogicalPlan): MapElements = { - MapElements( + child: LogicalPlan): LogicalPlan = { + val deserialized = CatalystSerde.deserialize[T](child) + val mapped = MapElements( func, - UnresolvedDeserializer(encoderFor[T].deserializer), - encoderFor[U].namedExpressions, - child) + CatalystSerde.generateObjAttr[U], + deserialized) + CatalystSerde.serialize[U](mapped) } } /** * A relation produced by applying `func` to each element of the `child`. - * - * @param deserializer used to extract the input to `func` from an input row. - * @param serializer use to serialize the output of `func`. */ case class MapElements( func: AnyRef, - deserializer: Expression, - serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { @@ -156,7 +141,7 @@ object AppendColumns { } /** - * A relation produced by applying `func` to each partition of the `child`, concatenating the + * A relation produced by applying `func` to each element of the `child`, concatenating the * resulting columns at the end of the input row. * * @param deserializer used to extract the input to `func` from an input row. @@ -166,28 +151,41 @@ case class AppendColumns( func: Any => Any, deserializer: Expression, serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator { + child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output ++ newColumns def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) } +/** + * An optimized version of [[AppendColumns]], that can be executed on deserialized object directly. + */ +case class AppendColumnsWithObject( + func: Any => Any, + childSerializer: Seq[NamedExpression], + newColumnsSerializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectConsumer { + + override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute) +} + /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( func: (K, Iterator[T]) => TraversableOnce[U], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], - child: LogicalPlan): MapGroups = { - new MapGroups( + child: LogicalPlan): LogicalPlan = { + val mapped = new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes), - encoderFor[U].namedExpressions, groupingAttributes, dataAttributes, + CatalystSerde.generateObjAttr[U], child) + CatalystSerde.serialize[U](mapped) } } @@ -198,43 +196,43 @@ object MapGroups { * * @param keyDeserializer used to extract the key object for each group. * @param valueDeserializer used to extract the items in the iterator from an input row. - * @param serializer use to serialize the output of `func`. */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], keyDeserializer: Expression, valueDeserializer: Expression, - serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], - child: LogicalPlan) extends UnaryNode with ObjectOperator + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectProducer /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { - def apply[Key : Encoder, Left : Encoder, Right : Encoder, Result : Encoder]( - func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder]( + func: (K, Iterator[L], Iterator[R]) => TraversableOnce[OUT], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan): CoGroup = { + right: LogicalPlan): LogicalPlan = { require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) - CoGroup( + val cogrouped = CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to // resolve the `keyDeserializer` based on either of them, here we pick the left one. - UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup), - UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr), - UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr), - encoderFor[Result].namedExpressions, + UnresolvedDeserializer(encoderFor[K].deserializer, leftGroup), + UnresolvedDeserializer(encoderFor[L].deserializer, leftAttr), + UnresolvedDeserializer(encoderFor[R].deserializer, rightAttr), leftGroup, rightGroup, leftAttr, rightAttr, + CatalystSerde.generateObjAttr[OUT], left, right) + CatalystSerde.serialize[OUT](cogrouped) } } @@ -247,10 +245,10 @@ case class CoGroup( keyDeserializer: Expression, leftDeserializer: Expression, rightDeserializer: Expression, - serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], + outputObjAttr: Attribute, left: LogicalPlan, - right: LogicalPlan) extends BinaryNode with ObjectOperator + right: LogicalPlan) extends BinaryNode with ObjectProducer diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala index a30a3926bb86e..6f4ec6b701919 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/Decimal.scala @@ -201,6 +201,11 @@ final class Decimal extends Ordered[Decimal] with Serializable { changePrecision(precision, scale, ROUND_HALF_UP) } + def changePrecision(precision: Int, scale: Int, mode: Int): Boolean = mode match { + case java.math.BigDecimal.ROUND_HALF_UP => changePrecision(precision, scale, ROUND_HALF_UP) + case java.math.BigDecimal.ROUND_HALF_EVEN => changePrecision(precision, scale, ROUND_HALF_EVEN) + } + /** * Update precision and scale while keeping our value the same, and return true if successful. * @@ -337,6 +342,7 @@ final class Decimal extends Ordered[Decimal] with Serializable { object Decimal { val ROUND_HALF_UP = BigDecimal.RoundingMode.HALF_UP + val ROUND_HALF_EVEN = BigDecimal.RoundingMode.HALF_EVEN val ROUND_CEILING = BigDecimal.RoundingMode.CEILING val ROUND_FLOOR = BigDecimal.RoundingMode.FLOOR diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala index 06ee0fbfe9642..b7b1acc58242e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ObjectType.scala @@ -41,4 +41,6 @@ private[sql] case class ObjectType(cls: Class[_]) extends DataType { throw new UnsupportedOperationException("No size estimation available for objects.") def asNullable: DataType = this + + override def simpleString: String = cls.getName } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 5ca5a72512a29..0672551b2972d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -23,7 +23,7 @@ import java.sql.{Date, Timestamp} import scala.reflect.runtime.universe.typeOf import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.expressions.BoundReference +import org.apache.spark.sql.catalyst.expressions.{BoundReference, SpecificMutableRow} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -81,9 +81,44 @@ case class MultipleConstructorsData(a: Int, b: String, c: Double) { def this(b: String, a: Int) = this(a, b, c = 1.0) } +object TestingUDT { + @SQLUserDefinedType(udt = classOf[NestedStructUDT]) + class NestedStruct(val a: Integer, val b: Long, val c: Double) + + class NestedStructUDT extends UserDefinedType[NestedStruct] { + override def sqlType: DataType = new StructType() + .add("a", IntegerType, nullable = true) + .add("b", LongType, nullable = false) + .add("c", DoubleType, nullable = false) + + override def serialize(n: NestedStruct): Any = { + val row = new SpecificMutableRow(sqlType.asInstanceOf[StructType].map(_.dataType)) + row.setInt(0, n.a) + row.setLong(1, n.b) + row.setDouble(2, n.c) + } + + override def userClass: Class[NestedStruct] = classOf[NestedStruct] + + override def deserialize(datum: Any): NestedStruct = datum match { + case row: InternalRow => + new NestedStruct(row.getInt(0), row.getLong(1), row.getDouble(2)) + } + } +} + + class ScalaReflectionSuite extends SparkFunSuite { import org.apache.spark.sql.catalyst.ScalaReflection._ + test("SQLUserDefinedType annotation on Scala structure") { + val schema = schemaFor[TestingUDT.NestedStruct] + assert(schema === Schema( + new TestingUDT.NestedStructUDT, + nullable = true + )) + } + test("primitive data") { val schema = schemaFor[PrimitiveData] assert(schema === Schema( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ad101d1c406b8..a90636d278673 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -24,8 +24,8 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count} +import org.apache.spark.sql.catalyst.plans.{Inner, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} import org.apache.spark.sql.types._ @@ -444,4 +444,60 @@ class AnalysisErrorSuite extends AnalysisTest { assertAnalysisError(plan2, "map type expression `a` cannot be used in join conditions" :: Nil) } + + test("PredicateSubQuery is used outside of a filter") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val plan = Project( + Seq(a, Alias(InSubQuery(a, LocalRelation(b)), "c")()), + LocalRelation(a)) + assertAnalysisError(plan, "Predicate sub-queries can only be used in a Filter" :: Nil) + } + + test("PredicateSubQuery is used is a nested condition") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val c = AttributeReference("c", BooleanType)() + val plan1 = Filter(Cast(InSubQuery(a, LocalRelation(b)), BooleanType), LocalRelation(a)) + assertAnalysisError(plan1, "Predicate sub-queries cannot be used in nested conditions" :: Nil) + + val plan2 = Filter(Or(InSubQuery(a, LocalRelation(b)), c), LocalRelation(a, c)) + assertAnalysisError(plan2, "Predicate sub-queries cannot be used in nested conditions" :: Nil) + } + + test("PredicateSubQuery correlated predicate is nested in an illegal plan") { + val a = AttributeReference("a", IntegerType)() + val b = AttributeReference("b", IntegerType)() + val c = AttributeReference("c", IntegerType)() + + val plan1 = Filter( + Exists( + Join( + LocalRelation(b), + Filter(EqualTo(a, c), LocalRelation(c)), + LeftOuter, + Option(EqualTo(b, c)))), + LocalRelation(a)) + assertAnalysisError(plan1, "Accessing outer query column is not allowed in" :: Nil) + + val plan2 = Filter( + Exists( + Join( + Filter(EqualTo(a, c), LocalRelation(c)), + LocalRelation(b), + RightOuter, + Option(EqualTo(b, c)))), + LocalRelation(a)) + assertAnalysisError(plan2, "Accessing outer query column is not allowed in" :: Nil) + + val plan3 = Filter( + Exists(Aggregate(Seq.empty, Seq.empty, Filter(EqualTo(a, c), LocalRelation(c)))), + LocalRelation(a)) + assertAnalysisError(plan3, "Accessing outer query column is not allowed in" :: Nil) + + val plan4 = Filter( + Exists(Union(LocalRelation(b), Filter(EqualTo(a, c), LocalRelation(c)))), + LocalRelation(a)) + assertAnalysisError(plan4, "Accessing outer query column is not allowed in" :: Nil) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ace6e10c6ec30..660dc86c3e284 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -192,7 +192,7 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { "values of function map should all be the same type") } - test("check types for ROUND") { + test("check types for ROUND/BROUND") { assertSuccess(Round(Literal(null), Literal(null))) assertSuccess(Round('intField, Literal(1))) @@ -200,6 +200,14 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(Round('intField, 'booleanField), "requires int type") assertError(Round('intField, 'mapField), "requires int type") assertError(Round('booleanField, 'intField), "requires numeric type") + + assertSuccess(BRound(Literal(null), Literal(null))) + assertSuccess(BRound('intField, Literal(1))) + + assertError(BRound('intField, 'intField), "Only foldable Expression is allowed") + assertError(BRound('intField, 'booleanField), "requires int type") + assertError(BRound('intField, 'mapField), "requires int type") + assertError(BRound('booleanField, 'intField), "requires numeric type") } test("check types for Greatest/Least") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 883ef48984d79..18de8b152b070 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -348,15 +348,22 @@ class HiveTypeCoercionSuite extends PlanTest { test("type coercion for If") { val rule = HiveTypeCoercion.IfCoercion + ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), - If(Literal(true), Cast(Literal(1), LongType), Literal(1L)) - ) + If(Literal(true), Cast(Literal(1), LongType), Literal(1L))) ruleTest(rule, If(Literal.create(null, NullType), Literal(1), Literal(1)), - If(Literal.create(null, BooleanType), Literal(1), Literal(1)) - ) + If(Literal.create(null, BooleanType), Literal(1), Literal(1))) + + ruleTest(rule, + If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)), + If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)), + If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2))) } test("type coercion for CaseKeyWhen") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala new file mode 100644 index 0000000000000..ce00a03e764fd --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/UnsupportedOperationsSuite.scala @@ -0,0 +1,384 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType + +class UnsupportedOperationsSuite extends SparkFunSuite { + + val attribute = AttributeReference("a", IntegerType, nullable = true)() + val batchRelation = LocalRelation(attribute) + val streamRelation = new TestStreamingRelation(attribute) + + /* + ======================================================================================= + BATCH QUERIES + ======================================================================================= + */ + + assertSupportedInBatchPlan("local relation", batchRelation) + + assertNotSupportedInBatchPlan( + "streaming source", + streamRelation, + Seq("with streaming source", "startStream")) + + assertNotSupportedInBatchPlan( + "select on streaming source", + streamRelation.select($"count(*)"), + Seq("with streaming source", "startStream")) + + + /* + ======================================================================================= + STREAMING QUERIES + ======================================================================================= + */ + + // Batch plan in streaming query + testError( + "streaming plan - no streaming source", + Seq("without streaming source", "startStream")) { + UnsupportedOperationChecker.checkForStreaming(batchRelation.select($"count(*)"), Append) + } + + // Commands + assertNotSupportedInStreamingPlan( + "commmands", + DescribeFunction("func", true), + outputMode = Append, + expectedMsgs = "commands" :: Nil) + + // Aggregates: Not supported on streams in Append mode + assertSupportedInStreamingPlan( + "aggregate - batch with update output mode", + batchRelation.groupBy("a")("count(*)"), + outputMode = Update) + + assertSupportedInStreamingPlan( + "aggregate - batch with append output mode", + batchRelation.groupBy("a")("count(*)"), + outputMode = Append) + + assertSupportedInStreamingPlan( + "aggregate - stream with update output mode", + streamRelation.groupBy("a")("count(*)"), + outputMode = Update) + + assertNotSupportedInStreamingPlan( + "aggregate - stream with append output mode", + streamRelation.groupBy("a")("count(*)"), + outputMode = Append, + Seq("aggregation", "append output mode")) + + // Inner joins: Stream-stream not supported + testBinaryOperationInStreamingPlan( + "inner join", + _.join(_, joinType = Inner), + streamStreamSupported = false) + + // Full outer joins: only batch-batch is allowed + testBinaryOperationInStreamingPlan( + "full outer join", + _.join(_, joinType = FullOuter), + streamStreamSupported = false, + batchStreamSupported = false, + streamBatchSupported = false) + + // Left outer joins: *-stream not allowed + testBinaryOperationInStreamingPlan( + "left outer join", + _.join(_, joinType = LeftOuter), + streamStreamSupported = false, + batchStreamSupported = false, + expectedMsg = "left outer/semi/anti joins") + + // Left semi joins: stream-* not allowed + testBinaryOperationInStreamingPlan( + "left semi join", + _.join(_, joinType = LeftSemi), + streamStreamSupported = false, + batchStreamSupported = false, + expectedMsg = "left outer/semi/anti joins") + + // Left anti joins: stream-* not allowed + testBinaryOperationInStreamingPlan( + "left anti join", + _.join(_, joinType = LeftAnti), + streamStreamSupported = false, + batchStreamSupported = false, + expectedMsg = "left outer/semi/anti joins") + + // Right outer joins: stream-* not allowed + testBinaryOperationInStreamingPlan( + "right outer join", + _.join(_, joinType = RightOuter), + streamStreamSupported = false, + streamBatchSupported = false) + + // Cogroup: only batch-batch is allowed + testBinaryOperationInStreamingPlan( + "cogroup", + genCogroup, + streamStreamSupported = false, + batchStreamSupported = false, + streamBatchSupported = false) + + def genCogroup(left: LogicalPlan, right: LogicalPlan): LogicalPlan = { + def func(k: Int, left: Iterator[Int], right: Iterator[Int]): Iterator[Int] = { + Iterator.empty + } + implicit val intEncoder = ExpressionEncoder[Int] + + left.cogroup[Int, Int, Int, Int]( + right, + func, + AppendColumns[Int, Int]((x: Int) => x, left).newColumns, + AppendColumns[Int, Int]((x: Int) => x, right).newColumns, + left.output, + right.output) + } + + // Union: Mixing between stream and batch not supported + testBinaryOperationInStreamingPlan( + "union", + _.union(_), + streamBatchSupported = false, + batchStreamSupported = false) + + // Except: *-stream not supported + testBinaryOperationInStreamingPlan( + "except", + _.except(_), + streamStreamSupported = false, + batchStreamSupported = false) + + // Intersect: stream-stream not supported + testBinaryOperationInStreamingPlan( + "intersect", + _.intersect(_), + streamStreamSupported = false) + + + // Unary operations + testUnaryOperatorInStreamingPlan("sort", Sort(Nil, true, _)) + testUnaryOperatorInStreamingPlan("sort partitions", SortPartitions(Nil, _), expectedMsg = "sort") + testUnaryOperatorInStreamingPlan( + "sample", Sample(0.1, 1, true, 1L, _)(), expectedMsg = "sampling") + testUnaryOperatorInStreamingPlan( + "window", Window(Nil, Nil, Nil, _), expectedMsg = "non-time-based windows") + + + /* + ======================================================================================= + TESTING FUNCTIONS + ======================================================================================= + */ + + /** + * Test that an unary operator correctly fails support check when it has a streaming child plan, + * but not when it has batch child plan. There can be batch sub-plans inside a streaming plan, + * so it is valid for the operator to have a batch child plan. + * + * This test wraps the logical plan in a fake operator that makes the whole plan look like + * a streaming plan even if the child plan is a batch plan. This is to test that the operator + * supports having a batch child plan, forming a batch subplan inside a streaming plan. + */ + def testUnaryOperatorInStreamingPlan( + operationName: String, + logicalPlanGenerator: LogicalPlan => LogicalPlan, + outputMode: OutputMode = Append, + expectedMsg: String = ""): Unit = { + + val expectedMsgs = if (expectedMsg.isEmpty) Seq(operationName) else Seq(expectedMsg) + + assertNotSupportedInStreamingPlan( + s"$operationName with stream relation", + wrapInStreaming(logicalPlanGenerator(streamRelation)), + outputMode, + expectedMsgs) + + assertSupportedInStreamingPlan( + s"$operationName with batch relation", + wrapInStreaming(logicalPlanGenerator(batchRelation)), + outputMode) + } + + + /** + * Test that a binary operator correctly fails support check when it has combinations of + * streaming and batch child plans. There can be batch sub-plans inside a streaming plan, + * so it is valid for the operator to have a batch child plan. + */ + def testBinaryOperationInStreamingPlan( + operationName: String, + planGenerator: (LogicalPlan, LogicalPlan) => LogicalPlan, + outputMode: OutputMode = Append, + streamStreamSupported: Boolean = true, + streamBatchSupported: Boolean = true, + batchStreamSupported: Boolean = true, + expectedMsg: String = ""): Unit = { + + val expectedMsgs = if (expectedMsg.isEmpty) Seq(operationName) else Seq(expectedMsg) + + if (streamStreamSupported) { + assertSupportedInStreamingPlan( + s"$operationName with stream-stream relations", + planGenerator(streamRelation, streamRelation), + outputMode) + } else { + assertNotSupportedInStreamingPlan( + s"$operationName with stream-stream relations", + planGenerator(streamRelation, streamRelation), + outputMode, + expectedMsgs) + } + + if (streamBatchSupported) { + assertSupportedInStreamingPlan( + s"$operationName with stream-batch relations", + planGenerator(streamRelation, batchRelation), + outputMode) + } else { + assertNotSupportedInStreamingPlan( + s"$operationName with stream-batch relations", + planGenerator(streamRelation, batchRelation), + outputMode, + expectedMsgs) + } + + if (batchStreamSupported) { + assertSupportedInStreamingPlan( + s"$operationName with batch-stream relations", + planGenerator(batchRelation, streamRelation), + outputMode) + } else { + assertNotSupportedInStreamingPlan( + s"$operationName with batch-stream relations", + planGenerator(batchRelation, streamRelation), + outputMode, + expectedMsgs) + } + + assertSupportedInStreamingPlan( + s"$operationName with batch-batch relations", + planGenerator(batchRelation, batchRelation), + outputMode) + } + + /** + * Assert that the logical plan is supported as subplan insider a streaming plan. + * + * To test this correctly, the given logical plan is wrapped in a fake operator that makes the + * whole plan look like a streaming plan. Otherwise, a batch plan may throw not supported + * exception simply for not being a streaming plan, even though that plan could exists as batch + * subplan inside some streaming plan. + */ + def assertSupportedInStreamingPlan( + name: String, + plan: LogicalPlan, + outputMode: OutputMode): Unit = { + test(s"streaming plan - $name: supported") { + UnsupportedOperationChecker.checkForStreaming(wrapInStreaming(plan), outputMode) + } + } + + /** + * Assert that the logical plan is not supported inside a streaming plan. + * + * To test this correctly, the given logical plan is wrapped in a fake operator that makes the + * whole plan look like a streaming plan. Otherwise, a batch plan may throw not supported + * exception simply for not being a streaming plan, even though that plan could exists as batch + * subplan inside some streaming plan. + */ + def assertNotSupportedInStreamingPlan( + name: String, + plan: LogicalPlan, + outputMode: OutputMode, + expectedMsgs: Seq[String]): Unit = { + testError( + s"streaming plan - $name: not supported", + expectedMsgs :+ "streaming" :+ "DataFrame" :+ "Dataset" :+ "not supported") { + UnsupportedOperationChecker.checkForStreaming(wrapInStreaming(plan), outputMode) + } + } + + /** Assert that the logical plan is supported as a batch plan */ + def assertSupportedInBatchPlan(name: String, plan: LogicalPlan): Unit = { + test(s"batch plan - $name: supported") { + UnsupportedOperationChecker.checkForBatch(plan) + } + } + + /** Assert that the logical plan is not supported as a batch plan */ + def assertNotSupportedInBatchPlan( + name: String, + plan: LogicalPlan, + expectedMsgs: Seq[String]): Unit = { + testError(s"batch plan - $name: not supported", expectedMsgs) { + UnsupportedOperationChecker.checkForBatch(plan) + } + } + + /** + * Test whether the body of code will fail. If it does fail, then check if it has expected + * messages. + */ + def testError(testName: String, expectedMsgs: Seq[String])(testBody: => Unit): Unit = { + + test(testName) { + val e = intercept[AnalysisException] { + testBody + } + + if (!expectedMsgs.map(_.toLowerCase).forall(e.getMessage.toLowerCase.contains)) { + fail( + s"""Exception message should contain the following substrings: + | + | ${expectedMsgs.mkString("\n ")} + | + |Actual exception message: + | + | ${e.getMessage} + """.stripMargin) + } + } + } + + def wrapInStreaming(plan: LogicalPlan): LogicalPlan = { + new StreamingPlanWrapper(plan) + } + + case class StreamingPlanWrapper(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output + override def isStreaming: Boolean = true + } + + case class TestStreamingRelation(output: Seq[Attribute]) extends LeafNode { + def this(attribute: Attribute) = this(Seq(attribute)) + override def isStreaming: Boolean = true + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 260dfb3f42244..b682e7d2b1d8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.util.ThreadUtils /** * Additional tests for code generation. @@ -43,13 +44,13 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { } } - futures.foreach(Await.result(_, 10.seconds)) + futures.foreach(ThreadUtils.awaitResult(_, 10.seconds)) } test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) - val plan = GenerateMutableProjection.generate(expressions)() + val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) @@ -72,7 +73,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expression = CaseWhen((1 to cases).map(generateCase(_))) - val plan = GenerateMutableProjection.generate(Seq(expression))() + val plan = GenerateMutableProjection.generate(Seq(expression)) val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) val actual = plan(input).toSeq(Seq(expression.dataType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index cf26d4843d84f..8a9617cfbf5df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -24,7 +24,7 @@ import org.scalatest.prop.GeneratorDrivenPropertyChecks import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types.DataType import org.apache.spark.util.Utils @@ -110,7 +110,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { inputRow: InternalRow = EmptyRow): Unit = { val plan = generateProject( - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) val actual = plan(inputRow).get(0, expression.dataType) @@ -153,7 +153,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { expected: Any, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = SimpleTestOptimizer.execute(plan) checkEvaluationWithoutCodegen(optimizedPlan.expressions.head, expected, inputRow) } @@ -166,7 +166,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { checkEvaluationWithOptimization(expression, expected) var plan = generateProject( - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) @@ -259,7 +259,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } val plan = generateProject( - GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil)(), + GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil), expr) val codegen = plan(inputRow).get(0, expr.dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 27195d3458b8e..f88c9e8df16d0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -25,7 +25,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.catalyst.optimizer.DefaultOptimizer +import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project} import org.apache.spark.sql.types._ @@ -138,7 +138,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { inputRow: InternalRow = EmptyRow): Unit = { val plan = generateProject( - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) val actual = plan(inputRow).get(0, expression.dataType) @@ -151,7 +151,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { expression: Expression, inputRow: InternalRow = EmptyRow): Unit = { val plan = Project(Alias(expression, s"Optimized($expression)")() :: Nil, OneRowRelation) - val optimizedPlan = DefaultOptimizer.execute(plan) + val optimizedPlan = SimpleTestOptimizer.execute(plan) checkNaNWithoutCodegen(optimizedPlan.expressions.head, inputRow) } @@ -508,7 +508,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Logarithm, DoubleType, DoubleType) } - test("round") { + test("round/bround") { val scales = -6 to 6 val doublePi: Double = math.Pi val shortPi: Short = 31415 @@ -529,11 +529,18 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { 31415926535900000L, 31415926535898000L, 31415926535897900L, 31415926535897930L) ++ Seq.fill(7)(31415926535897932L) + val intResultsB: Seq[Int] = Seq(314000000, 314200000, 314160000, 314159000, 314159300, + 314159260) ++ Seq.fill(7)(314159265) + scales.zipWithIndex.foreach { case (scale, i) => checkEvaluation(Round(doublePi, scale), doubleResults(i), EmptyRow) checkEvaluation(Round(shortPi, scale), shortResults(i), EmptyRow) checkEvaluation(Round(intPi, scale), intResults(i), EmptyRow) checkEvaluation(Round(longPi, scale), longResults(i), EmptyRow) + checkEvaluation(BRound(doublePi, scale), doubleResults(i), EmptyRow) + checkEvaluation(BRound(shortPi, scale), shortResults(i), EmptyRow) + checkEvaluation(BRound(intPi, scale), intResultsB(i), EmptyRow) + checkEvaluation(BRound(longPi, scale), longResults(i), EmptyRow) } val bdResults: Seq[BigDecimal] = Seq(BigDecimal(3.0), BigDecimal(3.1), BigDecimal(3.14), @@ -543,19 +550,33 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { // and not allowed by o.a.s.s.types.Decimal.changePrecision, therefore null (0 to 7).foreach { i => checkEvaluation(Round(bdPi, i), bdResults(i), EmptyRow) + checkEvaluation(BRound(bdPi, i), bdResults(i), EmptyRow) } (8 to 10).foreach { scale => checkEvaluation(Round(bdPi, scale), null, EmptyRow) + checkEvaluation(BRound(bdPi, scale), null, EmptyRow) } DataTypeTestUtils.numericTypes.foreach { dataType => checkEvaluation(Round(Literal.create(null, dataType), Literal(2)), null) checkEvaluation(Round(Literal.create(null, dataType), Literal.create(null, IntegerType)), null) + checkEvaluation(BRound(Literal.create(null, dataType), Literal(2)), null) + checkEvaluation(BRound(Literal.create(null, dataType), + Literal.create(null, IntegerType)), null) } + checkEvaluation(Round(2.5, 0), 3.0) + checkEvaluation(Round(3.5, 0), 4.0) + checkEvaluation(Round(-2.5, 0), -3.0) checkEvaluation(Round(-3.5, 0), -4.0) checkEvaluation(Round(-0.35, 1), -0.4) checkEvaluation(Round(-35, -1), -40) + checkEvaluation(BRound(2.5, 0), 2.0) + checkEvaluation(BRound(3.5, 0), 4.0) + checkEvaluation(BRound(-2.5, 0), -2.0) + checkEvaluation(BRound(-3.5, 0), -4.0) + checkEvaluation(BRound(-0.35, 1), -0.4) + checkEvaluation(BRound(-35, -1), -40) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index f5bafcc6a783e..33916c0891866 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -69,6 +69,23 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } + test("assert_true") { + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Literal.create(false, BooleanType)), null) + } + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null) + } + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Literal.create(null, NullType)), null) + } + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null) + } + checkEvaluation(AssertTrue(Literal.create(true, BooleanType)), null) + checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null) + } + private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index ff34b1e37be93..3a24b4d7d52c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -35,8 +35,8 @@ case class NonFoldableLiteral(value: Any, dataType: DataType) extends LeafExpres override def eval(input: InternalRow): Any = value - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - Literal.create(value, dataType).genCode(ctx, ev) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + Literal.create(value, dataType).doGenCode(ctx, ev) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index c9616cdb26c20..06dc3bd33b90e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -36,7 +36,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) - val instance = GenerateMutableProjection.generate(Seq(expr))() + val instance = GenerateMutableProjection.generate(Seq(expr)) assert(instance.apply(null).getBoolean(0) === false) } @@ -60,12 +60,12 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { test("GenerateMutableProjection should not share expression instances") { val expr1 = MutableExpression() - val instance1 = GenerateMutableProjection.generate(Seq(expr1))() + val instance1 = GenerateMutableProjection.generate(Seq(expr1)) assert(instance1.apply(null).getBoolean(0) === false) val expr2 = MutableExpression() expr2.mutableState = true - val instance2 = GenerateMutableProjection.generate(Seq(expr2))() + val instance2 = GenerateMutableProjection.generate(Seq(expr2)) assert(instance1.apply(null).getBoolean(0) === false) assert(instance2.apply(null).getBoolean(0) === true) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index e2a8eb8ee1d34..b69b74b4240bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -76,7 +76,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => BoundReference(i, f.dataType, true) } - val mutableProj = GenerateMutableProjection.generate(exprs)() + val mutableProj = GenerateMutableProjection.generate(exprs) val row1 = mutableProj(result) assert(result === row1) val row2 = mutableProj(result) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala index 91777375608fd..3c033ddc374cf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -22,8 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.NewInstance -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, MapPartitions} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -37,40 +36,45 @@ class EliminateSerializationSuite extends PlanTest { } implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() - private val func = identity[Iterator[(Int, Int)]] _ - private val func2 = identity[Iterator[OtherTuple]] _ + implicit private def intEncoder = ExpressionEncoder[Int]() - def assertObjectCreations(count: Int, plan: LogicalPlan): Unit = { - val newInstances = plan.flatMap(_.expressions.collect { - case n: NewInstance => n - }) + test("back to back serialization") { + val input = LocalRelation('obj.obj(classOf[(Int, Int)])) + val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze + val optimized = Optimize.execute(plan) + val expected = input.select('obj.as("obj")).analyze + comparePlans(optimized, expected) + } - if (newInstances.size != count) { - fail( - s""" - |Wrong number of object creations in plan: ${newInstances.size} != $count - |$plan - """.stripMargin) - } + test("back to back serialization with object change") { + val input = LocalRelation('obj.obj(classOf[OtherTuple])) + val plan = input.serialize[OtherTuple].deserialize[(Int, Int)].analyze + val optimized = Optimize.execute(plan) + comparePlans(optimized, plan) } - test("back to back MapPartitions") { - val input = LocalRelation('_1.int, '_2.int) - val plan = - MapPartitions(func, - MapPartitions(func, input)) + test("back to back serialization in AppendColumns") { + val input = LocalRelation('obj.obj(classOf[(Int, Int)])) + val func = (item: (Int, Int)) => item._1 + val plan = AppendColumns(func, input.serialize[(Int, Int)]).analyze + + val optimized = Optimize.execute(plan) + + val expected = AppendColumnsWithObject( + func.asInstanceOf[Any => Any], + productEncoder[(Int, Int)].namedExpressions, + intEncoder.namedExpressions, + input).analyze - val optimized = Optimize.execute(plan.analyze) - assertObjectCreations(1, optimized) + comparePlans(optimized, expected) } - test("back to back with object change") { - val input = LocalRelation('_1.int, '_2.int) - val plan = - MapPartitions(func, - MapPartitions(func2, input)) + test("back to back serialization in AppendColumns with object change") { + val input = LocalRelation('obj.obj(classOf[OtherTuple])) + val func = (item: (Int, Int)) => item._1 + val plan = AppendColumns(func, input.serialize[OtherTuple]).analyze - val optimized = Optimize.execute(plan.analyze) - assertObjectCreations(2, optimized) + val optimized = Optimize.execute(plan) + comparePlans(optimized, plan) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index df7529d83f7c8..9174b4e649a6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -743,4 +743,19 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("expand") { + val agg = testRelation + .groupBy(Cube(Seq('a, 'b)))('a, 'b, sum('c)) + .analyze + .asInstanceOf[Aggregate] + + val a = agg.output(0) + val b = agg.output(1) + + val query = agg.where(a > 1 && b > 2) + val optimized = Optimize.execute(query) + val correctedAnswer = agg.copy(child = agg.child.where(a > 1 && b > 2)).analyze + comparePlans(optimized, correctedAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala new file mode 100644 index 0000000000000..4385b0e019f25 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +class OptimizeCodegenSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil + } + + protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze) + comparePlans(actual, correctAnswer) + } + + test("Codegen only when the number of branches is small.") { + assertEquivalent( + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen()) + + assertEquivalent( + CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)), + CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2))) + } + + test("Nested CaseWhen Codegen.") { + assertEquivalent( + CaseWhen( + Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))), + CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))), + CaseWhen( + Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))), + CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen()) + } + + test("Multiple CaseWhen in one operator.") { + val plan = OneRowRelation + .select( + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), + CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)), + CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze + val correctAnswer = OneRowRelation + .select( + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), + CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)), + CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze + val optimized = Optimize.execute(plan) + comparePlans(optimized, correctAnswer) + } + + test("Multiple CaseWhen in different operators") { + val plan = OneRowRelation + .select( + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), + CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) + .where( + LessThan( + CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) + ).analyze + val correctAnswer = OneRowRelation + .select( + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), + CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) + .where( + LessThan( + CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) + ).analyze + val optimized = Optimize.execute(plan) + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala index 6e5672ddc36bd..7112c033eabce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala @@ -15,10 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.catalyst +package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -38,7 +37,7 @@ class OptimizerExtendableSuite extends SparkFunSuite { * This class represents a dummy extended optimizer that takes the batches of the * Optimizer and adds custom ones. */ - class ExtendedOptimizer extends Optimizer { + class ExtendedOptimizer extends SimpleTestOptimizer { // rules set to DummyRule, would not be executed anyways val myBatches: Seq[Batch] = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala index db96bfb652120..6da3eaea3d850 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -60,8 +60,8 @@ class ErrorParserSuite extends SparkFunSuite { intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", "^^^") - intercept("select * from r where a in (select * from t)", 1, 24, - "IN with a Sub-query is currently not supported", - "------------------------^^^") + intercept("select * from r except all select * from t", 1, 0, + "EXCEPT ALL is not supported", + "^^^") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index 6f40ec67ec6e0..d1dc8d621fb4d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -113,7 +113,9 @@ class ExpressionParserSuite extends PlanTest { } test("exists expression") { - intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") + assertEqual( + "exists (select 1 from b where b.x = a.x)", + Exists(table("b").where(Symbol("b.x") === Symbol("a.x")).select(1))) } test("comparison expressions") { @@ -139,7 +141,9 @@ class ExpressionParserSuite extends PlanTest { } test("in sub-query") { - intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") + assertEqual( + "a in (select b from c)", + InSubQuery('a, table("c").select('b))) } test("like expressions") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 411e2372f2e07..a1ca55c262fba 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -107,7 +107,7 @@ class PlanParserSuite extends PlanTest { assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) assertEqual( "select a, b from db.c having x < 1", - table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) + table("db", "c").select('a, 'b).where('x < 1)) assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) } @@ -405,7 +405,7 @@ class PlanParserSuite extends PlanTest { "select g from t group by g having a > (select b from s)", table("t") .groupBy('g)('g) - .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) + .where('a > ScalarSubquery(table("s").select('b)))) } test("table reference") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala index faef9ed274593..cc86f1f6e2f48 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/LogicalPlanSuite.scala @@ -18,7 +18,9 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType /** * This suite is used to test [[LogicalPlan]]'s `resolveOperators` and make sure it can correctly @@ -68,4 +70,23 @@ class LogicalPlanSuite extends SparkFunSuite { assert(invocationCount === 1) } + + test("isStreaming") { + val relation = LocalRelation(AttributeReference("a", IntegerType, nullable = true)()) + val incrementalRelation = new LocalRelation( + Seq(AttributeReference("a", IntegerType, nullable = true)())) { + override def isStreaming(): Boolean = true + } + + case class TestBinaryRelation(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { + override def output: Seq[Attribute] = left.output ++ right.output + } + + require(relation.isStreaming === false) + require(incrementalRelation.isStreaming === true) + assert(TestBinaryRelation(relation, relation).isStreaming === false) + assert(TestBinaryRelation(incrementalRelation, relation).isStreaming === true) + assert(TestBinaryRelation(relation, incrementalRelation).isStreaming === true) + assert(TestBinaryRelation(incrementalRelation, incrementalRelation).isStreaming) + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 5c257bc260873..b224a868454a5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -178,7 +178,7 @@ protected void initialize(String path, List columns) throws IOException config.set("spark.sql.parquet.writeLegacyFormat", "false"); this.file = new Path(path); - long length = FileSystem.get(config).getFileStatus(this.file).getLen(); + long length = this.file.getFileSystem(config).getFileStatus(this.file).getLen(); ParquetMetadata footer = readFooter(config, file, range(0, length)); List blocks = footer.getBlocks(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index d64736e11110b..bd96941da798d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -59,14 +59,14 @@ class TypedColumn[-T, U]( * on a decoded object. */ private[sql] def withInputType( - inputEncoder: ExpressionEncoder[_], - schema: Seq[Attribute]): TypedColumn[T, U] = { - val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] - new TypedColumn[T, U]( - expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => - ta.copy(aEncoder = Some(boundEncoder), children = schema) - }, - encoder) + inputDeserializer: Expression, + inputAttributes: Seq[Attribute]): TypedColumn[T, U] = { + val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes) + val newExpr = expr transform { + case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => + ta.copy(inputDeserializer = Some(unresolvedDeserializer)) + } + new TypedColumn[T, U](newExpr, encoder) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala index d9973b092dc11..953169b63604f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -56,7 +56,7 @@ trait ContinuousQuery { * Returns current status of all the sources. * @since 2.0.0 */ - def sourceStatuses: Array[SourceStatus] + def sourceStatuses: Array[SourceStatus] /** Returns current status of the sink. */ def sinkStatus: SinkStatus @@ -77,7 +77,7 @@ trait ContinuousQuery { /** * Waits for the termination of `this` query, either by `query.stop()` or by an exception. - * If the query has terminated with an exception, then the exception will be throw. + * If the query has terminated with an exception, then the exception will be thrown. * Otherwise, it returns whether the query has terminated or not within the `timeoutMs` * milliseconds. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala index 1343e81569cbd..39d04ed8c24f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQueryManager.scala @@ -20,8 +20,10 @@ package org.apache.spark.sql import scala.collection.mutable import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.{Append, OutputMode, UnsupportedOperationChecker} import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.util.ContinuousQueryListener /** @@ -172,14 +174,23 @@ class ContinuousQueryManager(sqlContext: SQLContext) { checkpointLocation: String, df: DataFrame, sink: Sink, - trigger: Trigger = ProcessingTime(0)): ContinuousQuery = { + trigger: Trigger = ProcessingTime(0), + outputMode: OutputMode = Append): ContinuousQuery = { activeQueriesLock.synchronized { if (activeQueries.contains(name)) { throw new IllegalArgumentException( s"Cannot start query with name $name as a query with that name is already active") } + val analyzedPlan = df.queryExecution.analyzed + df.queryExecution.assertAnalyzed() + + if (sqlContext.conf.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) { + UnsupportedOperationChecker.checkForStreaming(analyzedPlan, outputMode) + } + var nextSourceId = 0L - val logicalPlan = df.logicalPlan.transform { + + val logicalPlan = analyzedPlan.transform { case StreamingRelation(dataSource, _, output) => // Materialize source to avoid creating it in every batch val metadataPath = s"$checkpointLocation/sources/$nextSourceId" @@ -195,6 +206,7 @@ class ContinuousQueryManager(sqlContext: SQLContext) { checkpointLocation, logicalPlan, sink, + outputMode, trigger) query.start() activeQueries.put(name, query) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 54d250867fbb3..0745ef47ffdac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -27,11 +27,10 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, Project} -import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource} +import org.apache.spark.sql.execution.datasources.{BucketSpec, CreateTableUsingAsSelect, DataSource, HadoopFsRelation} import org.apache.spark.sql.execution.datasources.jdbc.JdbcUtils import org.apache.spark.sql.execution.streaming.{MemoryPlan, MemorySink, StreamExecution} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.util.Utils /** @@ -86,18 +85,18 @@ final class DataFrameWriter private[sql](df: DataFrame) { * * Scala Example: * {{{ - * def.writer.trigger(ProcessingTime("10 seconds")) + * df.write.trigger(ProcessingTime("10 seconds")) * * import scala.concurrent.duration._ - * def.writer.trigger(ProcessingTime(10.seconds)) + * df.write.trigger(ProcessingTime(10.seconds)) * }}} * * Java Example: * {{{ - * def.writer.trigger(ProcessingTime.create("10 seconds")) + * df.write.trigger(ProcessingTime.create("10 seconds")) * * import java.util.concurrent.TimeUnit - * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} * * @since 2.0.0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index e216945fbe5e8..3c708cbf29851 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -461,9 +461,7 @@ class Dataset[T] private[sql]( * @since 2.0.0 */ @Experimental - def isStreaming: Boolean = logicalPlan.find { n => - n.isInstanceOf[StreamingRelation] || n.isInstanceOf[StreamingExecutionRelation] - }.isDefined + def isStreaming: Boolean = logicalPlan.isStreaming /** * Displays the [[Dataset]] in a tabular form. Strings more than 20 characters will be truncated, @@ -992,7 +990,7 @@ class Dataset[T] private[sql]( sqlContext, Project( c1.withInputType( - boundTEncoder, + unresolvedTEncoder.deserializer, logicalPlan.output).named :: Nil, logicalPlan), implicitly[Encoder[U1]]) @@ -1006,7 +1004,7 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(resolvedTEncoder, logicalPlan.output).named) + columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named) val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) @@ -1502,7 +1500,9 @@ class Dataset[T] private[sql]( // constituent partitions each time a split is materialized which could result in // overlapping splits. To prevent this, we explicitly sort each input partition to make the // ordering deterministic. - val sorted = Sort(logicalPlan.output.map(SortOrder(_, Ascending)), global = false, logicalPlan) + // MapType cannot be sorted. + val sorted = Sort(logicalPlan.output.filterNot(_.dataType.isInstanceOf[MapType]) + .map(SortOrder(_, Ascending)), global = false, logicalPlan) val sum = weights.sum val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _) normalizedCumWeights.sliding(2).map { x => @@ -2251,16 +2251,16 @@ class Dataset[T] private[sql]( def unpersist(): this.type = unpersist(blocking = false) /** - * Represents the content of the [[Dataset]] as an [[RDD]] of [[Row]]s. Note that the RDD is - * memoized. Once called, it won't change even if you change any query planning related Spark SQL - * configurations (e.g. `spark.sql.shuffle.partitions`). + * Represents the content of the [[Dataset]] as an [[RDD]] of [[T]]. * * @group rdd * @since 1.6.0 */ lazy val rdd: RDD[T] = { - queryExecution.toRdd.mapPartitions { rows => - rows.map(boundTEncoder.fromRow) + val objectType = unresolvedTEncoder.deserializer.dataType + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + sqlContext.executePlan(deserialized).toRdd.mapPartitions { rows => + rows.map(_.get(0, objectType).asInstanceOf[T]) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index f19ad6e707526..05e13e66d137c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -209,8 +209,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map( - _.withInputType(resolvedVEncoder, dataAttributes).named) + columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named) val keyColumn = if (resolvedKEncoder.flat) { assert(groupingAttributes.length == 1) groupingAttributes.head diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7dbf2e6c7c798..0ffb136c248e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -208,7 +208,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr)) + toDF((expr +: exprs).map { + case typed: TypedColumn[_, _] => + typed.withInputType(df.unresolvedTEncoder.deserializer, df.logicalPlan.output).expr + case c => c.expr + }) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 9259ff40625c9..f3f84144ad93e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.command.ShowTablesCommand import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} -import org.apache.spark.sql.internal.{SessionState, SQLConf} +import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager @@ -63,17 +63,18 @@ import org.apache.spark.util.Utils * @since 1.0.0 */ class SQLContext private[sql]( - @transient val sparkContext: SparkContext, - @transient protected[sql] val cacheManager: CacheManager, - @transient private[sql] val listener: SQLListener, - val isRootContext: Boolean, - @transient private[sql] val externalCatalog: ExternalCatalog) + @transient private val sparkSession: SparkSession, + val isRootContext: Boolean) extends Logging with Serializable { self => + private[sql] def this(sparkSession: SparkSession) = { + this(sparkSession, true) + } + def this(sc: SparkContext) = { - this(sc, new CacheManager, SQLContext.createListenerAndUI(sc), true, new InMemoryCatalog) + this(new SparkSession(sc)) } def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) @@ -100,28 +101,26 @@ class SQLContext private[sql]( } } + protected[sql] def sessionState: SessionState = sparkSession.sessionState + protected[sql] def sharedState: SharedState = sparkSession.sharedState + protected[sql] def conf: SQLConf = sessionState.conf + protected[sql] def cacheManager: CacheManager = sharedState.cacheManager + protected[sql] def listener: SQLListener = sharedState.listener + protected[sql] def externalCatalog: ExternalCatalog = sharedState.externalCatalog + + def sparkContext: SparkContext = sharedState.sparkContext + /** - * Returns a SQLContext as new session, with separated SQL configurations, temporary tables, - * registered functions, but sharing the same SparkContext, CacheManager, SQLListener and SQLTab. + * Returns a [[SQLContext]] as new session, with separated SQL configurations, temporary + * tables, registered functions, but sharing the same [[SparkContext]], cached data and + * other things. * * @since 1.6.0 */ def newSession(): SQLContext = { - new SQLContext( - sparkContext = sparkContext, - cacheManager = cacheManager, - listener = listener, - isRootContext = false, - externalCatalog = externalCatalog) + new SQLContext(sparkSession.newSession(), isRootContext = false) } - /** - * Per-session state, e.g. configuration, functions, temporary tables etc. - */ - @transient - protected[sql] lazy val sessionState: SessionState = new SessionState(self) - protected[spark] def conf: SQLConf = sessionState.conf - /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. @@ -135,10 +134,14 @@ class SQLContext private[sql]( * @group config * @since 1.0.0 */ - def setConf(props: Properties): Unit = conf.setConf(props) + def setConf(props: Properties): Unit = sessionState.setConf(props) - /** Set the given Spark SQL configuration property. */ - private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = conf.setConf(entry, value) + /** + * Set the given Spark SQL configuration property. + */ + private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = { + sessionState.setConf(entry, value) + } /** * Set the given Spark SQL configuration property. @@ -146,7 +149,7 @@ class SQLContext private[sql]( * @group config * @since 1.0.0 */ - def setConf(key: String, value: String): Unit = conf.setConfString(key, value) + def setConf(key: String, value: String): Unit = sessionState.setConf(key, value) /** * Return the value of Spark SQL configuration property for the given key. @@ -189,23 +192,19 @@ class SQLContext private[sql]( */ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs - // Extract `spark.sql.*` entries and put it in our SQLConf. - // Subclasses may additionally set these entries in other confs. - SQLContext.getSQLProperties(sparkContext.getConf).asScala.foreach { case (k, v) => - setConf(k, v) - } - protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql) protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) - protected[sql] def executePlan(plan: LogicalPlan) = new QueryExecution(this, plan) + protected[sql] def executePlan(plan: LogicalPlan): QueryExecution = { + sessionState.executePlan(plan) + } /** * Add a jar to SQLContext */ protected[sql] def addJar(path: String): Unit = { - sparkContext.addJar(path) + sessionState.addJar(path) } /** A [[FunctionResourceLoader]] that can be used in SessionCatalog. */ @@ -771,7 +770,7 @@ class SQLContext private[sql]( * as Spark can parse all supported Hive DDLs itself. */ private[sql] def runNativeSql(sqlText: String): Seq[Row] = { - throw new UnsupportedOperationException + sessionState.runNativeSql(sqlText).map { r => Row(r) } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index ad69e23540a91..f423e7d6b5765 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -72,6 +72,29 @@ abstract class SQLImplicits { /** @since 1.6.0 */ implicit def newStringEncoder: Encoder[String] = Encoders.STRING + // Boxed primitives + + /** @since 2.0.0 */ + implicit def newBoxedIntEncoder: Encoder[java.lang.Integer] = Encoders.INT + + /** @since 2.0.0 */ + implicit def newBoxedLongEncoder: Encoder[java.lang.Long] = Encoders.LONG + + /** @since 2.0.0 */ + implicit def newBoxedDoubleEncoder: Encoder[java.lang.Double] = Encoders.DOUBLE + + /** @since 2.0.0 */ + implicit def newBoxedFloatEncoder: Encoder[java.lang.Float] = Encoders.FLOAT + + /** @since 2.0.0 */ + implicit def newBoxedByteEncoder: Encoder[java.lang.Byte] = Encoders.BYTE + + /** @since 2.0.0 */ + implicit def newBoxedShortEncoder: Encoder[java.lang.Short] = Encoders.SHORT + + /** @since 2.0.0 */ + implicit def newBoxedBooleanEncoder: Encoder[java.lang.Boolean] = Encoders.BOOLEAN + // Seqs /** @since 1.6.1 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala new file mode 100644 index 0000000000000..17ba2998250f6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.internal.{SessionState, SharedState} +import org.apache.spark.util.Utils + + +/** + * The entry point to Spark execution. + */ +class SparkSession private( + sparkContext: SparkContext, + existingSharedState: Option[SharedState]) { self => + + def this(sc: SparkContext) { + this(sc, None) + } + + /** + * Start a new session where configurations, temp tables, temp functions etc. are isolated. + */ + def newSession(): SparkSession = { + // Note: materialize the shared state here to ensure the parent and child sessions are + // initialized with the same shared state. + new SparkSession(sparkContext, Some(sharedState)) + } + + @transient + protected[sql] lazy val sharedState: SharedState = { + existingSharedState.getOrElse( + SparkSession.reflect[SharedState, SparkContext]( + SparkSession.sharedStateClassName(sparkContext.conf), + sparkContext)) + } + + @transient + protected[sql] lazy val sessionState: SessionState = { + SparkSession.reflect[SessionState, SQLContext]( + SparkSession.sessionStateClassName(sparkContext.conf), + new SQLContext(self, isRootContext = false)) + } + +} + + +private object SparkSession { + + private def sharedStateClassName(conf: SparkConf): String = { + conf.get(CATALOG_IMPLEMENTATION) match { + case "hive" => "org.apache.spark.sql.hive.HiveSharedState" + case "in-memory" => classOf[SharedState].getCanonicalName + } + } + + private def sessionStateClassName(conf: SparkConf): String = { + conf.get(CATALOG_IMPLEMENTATION) match { + case "hive" => "org.apache.spark.sql.hive.HiveSessionState" + case "in-memory" => classOf[SessionState].getCanonicalName + } + } + + /** + * Helper method to create an instance of [[T]] using a single-arg constructor that + * accepts an [[Arg]]. + */ + private def reflect[T, Arg <: AnyRef]( + className: String, + ctorArg: Arg)(implicit ctorArgTag: ClassTag[Arg]): T = { + try { + val clazz = Utils.classForName(className) + val ctor = clazz.getDeclaredConstructor(ctorArgTag.runtimeClass) + ctor.newInstance(ctorArg).asInstanceOf[T] + } catch { + case NonFatal(e) => + throw new IllegalArgumentException(s"Error while instantiating '$className':", e) + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala index c4e54b3f90ac5..256e8a47a4665 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Trigger.scala @@ -35,23 +35,23 @@ sealed trait Trigger {} /** * :: Experimental :: - * A trigger that runs a query periodically based on the processing time. If `intervalMs` is 0, + * A trigger that runs a query periodically based on the processing time. If `interval` is 0, * the query will run as fast as possible. * * Scala Example: * {{{ - * def.writer.trigger(ProcessingTime("10 seconds")) + * df.write.trigger(ProcessingTime("10 seconds")) * * import scala.concurrent.duration._ - * def.writer.trigger(ProcessingTime(10.seconds)) + * df.write.trigger(ProcessingTime(10.seconds)) * }}} * * Java Example: * {{{ - * def.writer.trigger(ProcessingTime.create("10 seconds")) + * df.write.trigger(ProcessingTime.create("10 seconds")) * * import java.util.concurrent.TimeUnit - * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} */ @Experimental @@ -67,11 +67,11 @@ case class ProcessingTime(intervalMs: Long) extends Trigger { object ProcessingTime { /** - * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. * * Example: * {{{ - * def.writer.trigger(ProcessingTime("10 seconds")) + * df.write.trigger(ProcessingTime("10 seconds")) * }}} */ def apply(interval: String): ProcessingTime = { @@ -94,12 +94,12 @@ object ProcessingTime { } /** - * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. * * Example: * {{{ * import scala.concurrent.duration._ - * def.writer.trigger(ProcessingTime(10.seconds)) + * df.write.trigger(ProcessingTime(10.seconds)) * }}} */ def apply(interval: Duration): ProcessingTime = { @@ -107,11 +107,11 @@ object ProcessingTime { } /** - * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. * * Example: * {{{ - * def.writer.trigger(ProcessingTime.create("10 seconds")) + * df.write.trigger(ProcessingTime.create("10 seconds")) * }}} */ def create(interval: String): ProcessingTime = { @@ -119,12 +119,12 @@ object ProcessingTime { } /** - * Create a [[ProcessingTime]]. If `intervalMs` is 0, the query will run as fast as possible. + * Create a [[ProcessingTime]]. If `interval` is 0, the query will run as fast as possible. * * Example: * {{{ * import java.util.concurrent.TimeUnit - * def.writer.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) + * df.write.trigger(ProcessingTime.create(10, TimeUnit.SECONDS)) * }}} */ def create(interval: Long, unit: TimeUnit): ProcessingTime = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 392c48fb7b93b..12d03a7df8c53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -26,11 +26,12 @@ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCo import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.toCommentSafeString +import org.apache.spark.sql.execution.datasources.HadoopFsRelation import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} -import org.apache.spark.sql.types.{AtomicType, DataType} +import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.DataType object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -177,7 +178,7 @@ private[sql] case class RowDataSourceScan( s"Scan $nodeName${output.mkString("[", ",", "]")}${metadataEntries.mkString(" ", ", ", "")}" } - override def upstreams(): Seq[RDD[InternalRow]] = { + override def inputRDDs(): Seq[RDD[InternalRow]] = { rdd :: Nil } @@ -192,7 +193,7 @@ private[sql] case class RowDataSourceScan( val row = ctx.freshName("row") ctx.INPUT_ROW = row ctx.currentVars = null - val columnsRowInput = exprRows.map(_.gen(ctx)) + val columnsRowInput = exprRows.map(_.genCode(ctx)) val inputRow = if (outputUnsafeRows) row else null s""" |while ($input.hasNext()) { @@ -228,7 +229,7 @@ private[sql] case class BatchedDataSourceScan( s"BatchedScan $nodeName${output.mkString("[", ",", "]")}$metadataStr" } - override def upstreams(): Seq[RDD[InternalRow]] = { + override def inputRDDs(): Seq[RDD[InternalRow]] = { rdd :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala index bd23b7e3ad683..3966af542e397 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala @@ -85,8 +85,8 @@ case class Expand( } } - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } protected override def doProduce(ctx: CodegenContext): String = { @@ -149,7 +149,7 @@ case class Expand( val firstExpr = projections.head(col) if (sameOutput(col)) { // This column is the same across all output rows. Just generate code for it here. - BindReferences.bindReference(firstExpr, child.output).gen(ctx) + BindReferences.bindReference(firstExpr, child.output).genCode(ctx) } else { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") @@ -166,7 +166,7 @@ case class Expand( var updateCode = "" for (col <- exprs.indices) { if (!sameOutput(col)) { - val ev = BindReferences.bindReference(exprs(col), child.output).gen(ctx) + val ev = BindReferences.bindReference(exprs(col), child.output).genCode(ctx) updateCode += s""" |${ev.code} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index f5e1e77263b5b..35228643a5f49 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnsupportedOperationChecker import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} +import org.apache.spark.sql.internal.SQLConf /** * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -43,10 +45,20 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { throw ae } - lazy val analyzed: LogicalPlan = sqlContext.sessionState.analyzer.execute(logical) + def assertSupported(): Unit = { + if (sqlContext.conf.getConf(SQLConf.UNSUPPORTED_OPERATION_CHECK_ENABLED)) { + UnsupportedOperationChecker.checkForBatch(analyzed) + } + } + + lazy val analyzed: LogicalPlan = { + SQLContext.setActive(sqlContext) + sqlContext.sessionState.analyzer.execute(logical) + } lazy val withCachedData: LogicalPlan = { assertAnalyzed() + assertSupported() sqlContext.cacheManager.useCachedData(analyzed) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index efd8760cd2474..80255fafbec0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -100,8 +100,8 @@ case class Sort( override def usedInputs: AttributeSet = AttributeSet(Seq.empty) - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } // Name of sorter variable used in codegen. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index cbde777d98415..08b2d7fcd4882 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -18,9 +18,16 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods +import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer +import org.apache.spark.sql.internal.SQLConf + +class SparkOptimizer( + catalog: SessionCatalog, + conf: SQLConf, + experimentalMethods: ExperimentalMethods) + extends Optimizer(catalog, conf) { -class SparkOptimizer(experimentalMethods: ExperimentalMethods) extends Optimizer { override def batches: Seq[Batch] = super.batches :+ Batch( - "User Provided Optimizers", FixedPoint(100), experimentalMethods.extraOptimizations: _*) + "User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 4091f65aecb50..b64352a9e0dc2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.ArrayBuffer import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration._ +import scala.util.control.NonFatal import org.apache.spark.{broadcast, SparkEnv} import org.apache.spark.internal.Logging @@ -167,7 +168,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ protected def waitForSubqueries(): Unit = { // fill in the result of subqueries subqueryResults.foreach { case (e, futureResult) => - val rows = Await.result(futureResult, Duration.Inf) + val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf) if (rows.length > 1) { sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}") } @@ -351,12 +352,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } - private[this] def isTesting: Boolean = sys.props.contains("spark.testing") - protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute], - useSubexprElimination: Boolean = false): () => MutableProjection = { + useSubexprElimination: Boolean = false): MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8ed6ed21d0170..b140a608c8dfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -108,6 +108,13 @@ class SparkSqlAstBuilder extends AstBuilder { Option(ctx.key).map(visitTablePropertyKey)) } + /** + * Create a [[ShowCreateTableCommand]] logical plan + */ + override def visitShowCreateTable(ctx: ShowCreateTableContext): LogicalPlan = withOrigin(ctx) { + ShowCreateTableCommand(visitTableIdentifier(ctx.tableIdentifier)) + } + /** * Create a [[RefreshTable]] logical plan. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index c15aaed3654ff..a4b0fa59dbb24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -346,21 +346,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical intersect operator should have been replaced by semi-join in the optimizer") - case logical.DeserializeToObject(deserializer, child) => - execution.DeserializeToObject(deserializer, planLater(child)) :: Nil + case logical.DeserializeToObject(deserializer, objAttr, child) => + execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil case logical.SerializeFromObject(serializer, child) => execution.SerializeFromObject(serializer, planLater(child)) :: Nil - case logical.MapPartitions(f, in, out, child) => - execution.MapPartitions(f, in, out, planLater(child)) :: Nil - case logical.MapElements(f, in, out, child) => - execution.MapElements(f, in, out, planLater(child)) :: Nil + case logical.MapPartitions(f, objAttr, child) => + execution.MapPartitions(f, objAttr, planLater(child)) :: Nil + case logical.MapElements(f, objAttr, child) => + execution.MapElements(f, objAttr, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => execution.AppendColumns(f, in, out, planLater(child)) :: Nil - case logical.MapGroups(f, key, in, out, grouping, data, child) => - execution.MapGroups(f, key, in, out, grouping, data, planLater(child)) :: Nil - case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, left, right) => + case logical.AppendColumnsWithObject(f, childSer, newSer, child) => + execution.AppendColumnsWithObject(f, childSer, newSer, planLater(child)) :: Nil + case logical.MapGroups(f, key, value, grouping, data, objAttr, child) => + execution.MapGroups(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil + case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) => execution.CoGroup( - f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, + f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 447dbe701815b..23b2eabd0c809 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -74,10 +74,10 @@ trait CodegenSupport extends SparkPlan { * * Note: right now we support up to two RDDs. */ - def upstreams(): Seq[RDD[InternalRow]] + def inputRDDs(): Seq[RDD[InternalRow]] /** - * Returns Java source code to process the rows from upstream. + * Returns Java source code to process the rows from input RDD. */ final def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent @@ -118,7 +118,7 @@ trait CodegenSupport extends SparkPlan { ctx.currentVars = null ctx.INPUT_ROW = row output.zipWithIndex.map { case (attr, i) => - BoundReference(i, attr.dataType, attr.nullable).gen(ctx) + BoundReference(i, attr.dataType, attr.nullable).genCode(ctx) } } else { assert(outputVars != null) @@ -126,6 +126,7 @@ trait CodegenSupport extends SparkPlan { // outputVars will be used to generate the code for UnsafeRow, so we should copy them outputVars.map(_.copy()) } + val rowVar = if (row != null) { ExprCode("", "false", row) } else { @@ -233,13 +234,13 @@ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport child.doExecuteBroadcast() } - override def upstreams(): Seq[RDD[InternalRow]] = { + override def inputRDDs(): Seq[RDD[InternalRow]] = { child.execute() :: Nil } override def doProduce(ctx: CodegenContext): String = { val input = ctx.freshName("input") - // Right now, InputAdapter is only used when there is one upstream. + // Right now, InputAdapter is only used when there is one input RDD. ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val row = ctx.freshName("row") s""" @@ -271,7 +272,7 @@ object WholeStageCodegen { * * -> execute() * | - * doExecute() ---------> upstreams() -------> upstreams() ------> execute() + * doExecute() ---------> inputRDDs() -------> inputRDDs() ------> execute() * | * +-----------------> produce() * | @@ -349,8 +350,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup val durationMs = longMetric("pipelineTime") - val rdds = child.asInstanceOf[CodegenSupport].upstreams() - assert(rdds.size <= 2, "Up to two upstream RDDs can be supported") + val rdds = child.asInstanceOf[CodegenSupport].inputRDDs() + assert(rdds.size <= 2, "Up to two input RDDs can be supported") if (rdds.length == 1) { rdds.head.mapPartitionsWithIndex { (index, iter) => val clazz = CodeGenerator.compile(cleanedSource) @@ -366,7 +367,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup } } } else { - // Right now, we support up to two upstreams. + // Right now, we support up to two input RDDs. rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => val partitionIndex = TaskContext.getPartitionId() val clazz = CodeGenerator.compile(cleanedSource) @@ -384,7 +385,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup } } - override def upstreams(): Seq[RDD[InternalRow]] = { + override def inputRDDs(): Seq[RDD[InternalRow]] = { throw new UnsupportedOperationException } @@ -428,7 +429,6 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true - case e: CaseWhen => e.shouldCodegen // CodegenFallback requires the input to be an InternalRow case e: CodegenFallback => false case _ => true @@ -473,6 +473,10 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { * Inserts a WholeStageCodegen on top of those that support codegen. */ private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match { + // For operators that will output domain object, do not insert WholeStageCodegen for it as + // domain object can not be written into unsafe row. + case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] => + plan.withNewChildren(plan.children.map(insertWholeStageCodegen)) case plan: CodegenSupport if supportCodegen(plan) => WholeStageCodegen(insertInputAdapter(plan)) case other => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 8e9214fa258b2..85ce388de0aa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -120,14 +120,14 @@ case class Window( val (exprs, current, bound) = if (offset == 0) { // Use the entire order expression when the offset is 0. val exprs = orderSpec.map(_.child) - val projection = newMutableProjection(exprs, child.output) - (orderSpec, projection(), projection()) + val buildProjection = () => newMutableProjection(exprs, child.output) + (orderSpec, buildProjection(), buildProjection()) } else if (orderSpec.size == 1) { // Use only the first order expression when the offset is non-null. val sortExpr = orderSpec.head val expr = sortExpr.child // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output)() + val current = newMutableProjection(expr :: Nil, child.output) // Flip the sign of the offset when processing the order is descending val boundOffset = sortExpr.direction match { case Descending => -offset @@ -135,7 +135,7 @@ case class Window( } // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) - val bound = newMutableProjection(boundExpr :: Nil, child.output)() + val bound = newMutableProjection(boundExpr :: Nil, child.output) (sortExpr :: Nil, current, bound) } else { sys.error("Non-Zero range offsets are not supported for windows " + @@ -564,7 +564,7 @@ private[execution] final class OffsetWindowFunctionFrame( ordinal: Int, expressions: Array[Expression], inputSchema: Seq[Attribute], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection, + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, offset: Int) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ @@ -604,7 +604,7 @@ private[execution] final class OffsetWindowFunctionFrame( } // Create the projection. - newMutableProjection(boundExpressions, Nil)().target(target) + newMutableProjection(boundExpressions, Nil).target(target) } override def prepare(rows: RowBuffer): Unit = { @@ -886,7 +886,7 @@ private[execution] object AggregateProcessor { functions: Array[Expression], ordinal: Int, inputAttributes: Seq[Attribute], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection): AggregateProcessor = { val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] val initialValues = mutable.Buffer.empty[Expression] @@ -938,13 +938,13 @@ private[execution] object AggregateProcessor { // Create the projections. val initialProjection = newMutableProjection( initialValues, - partitionSize.toSeq)() + partitionSize.toSeq) val updateProjection = newMutableProjection( updateExpressions, - aggBufferAttributes ++ inputAttributes)() + aggBufferAttributes ++ inputAttributes) val evaluateProjection = newMutableProjection( evaluateExpressions, - aggBufferAttributes)() + aggBufferAttributes) // Create the processor new AggregateProcessor( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 042c7319018be..81aacb437ba54 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -39,7 +39,7 @@ abstract class AggregationIterator( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection)) + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection) extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// @@ -139,7 +139,7 @@ abstract class AggregationIterator( // no-op expressions which are ignored during projection code-generation. case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) } - newMutableProjection(initExpressions, Nil)() + newMutableProjection(initExpressions, Nil) } // All imperative AggregateFunctions. @@ -175,7 +175,7 @@ abstract class AggregationIterator( // This projection is used to merge buffer values for all expression-based aggregates. val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) val updateProjection = - newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)() + newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes) (currentBuffer: MutableRow, row: InternalRow) => { // Process all expression-based aggregate functions. @@ -211,7 +211,7 @@ abstract class AggregationIterator( case agg: AggregateFunction => NoOp } val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes) expressionAggEvalProjection.target(aggregateResult) val resultProjection = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index de1491d357405..c35d781d3ebf5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -34,7 +34,7 @@ class SortBasedAggregationIterator( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, numOutputRows: LongSQLMetric) extends AggregationIterator( groupingExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 253592028c7f9..d4cef8f310dac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -70,12 +70,14 @@ case class TungstenAggregate( } } - // This is for testing. We force TungstenAggregationIterator to fall back to sort-based - // aggregation once it has processed a given number of input rows. - private val testFallbackStartsAt: Option[Int] = { + // This is for testing. We force TungstenAggregationIterator to fall back to the unsafe row hash + // map and/or the sort-based aggregation once it has processed a given number of input rows. + private val testFallbackStartsAt: Option[(Int, Int)] = { sqlContext.getConf("spark.sql.TungstenAggregate.testFallbackStartsAt", null) match { case null | "" => None - case fallbackStartsAt => Some(fallbackStartsAt.toInt) + case fallbackStartsAt => + val splits = fallbackStartsAt.split(",").map(_.trim) + Some((splits.head.toInt, splits.last.toInt)) } } @@ -127,8 +129,8 @@ case class TungstenAggregate( !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } protected override def doProduce(ctx: CodegenContext): String = { @@ -163,7 +165,7 @@ case class TungstenAggregate( ctx.addMutableState("boolean", isNull, "") ctx.addMutableState(ctx.javaType(e.dataType), value, "") // The initial expression should not access any column - val ev = e.gen(ctx) + val ev = e.genCode(ctx) val initVars = s""" | $isNull = ${ev.isNull}; | $value = ${ev.value}; @@ -177,13 +179,13 @@ case class TungstenAggregate( // evaluate aggregate results ctx.currentVars = bufVars val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) + BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) } val evaluateAggResults = evaluateVariables(aggResults) // evaluate result expressions ctx.currentVars = aggResults val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, aggregateAttributes).gen(ctx) + BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) } (resultVars, s""" |$evaluateAggResults @@ -194,7 +196,7 @@ case class TungstenAggregate( (bufVars, "") } else { // no aggregate function, the result should be literals - val resultVars = resultExpressions.map(_.gen(ctx)) + val resultVars = resultExpressions.map(_.genCode(ctx)) (resultVars, evaluateVariables(resultVars)) } @@ -238,7 +240,7 @@ case class TungstenAggregate( } ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination - val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) + val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).genCode(ctx)) // aggregate buffer should be updated atomic val updates = aggVals.zipWithIndex.map { case (ev, i) => s""" @@ -261,7 +263,15 @@ case class TungstenAggregate( .map(_.asInstanceOf[DeclarativeAggregate]) private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) - // The name for HashMap + // The name for Vectorized HashMap + private var vectorizedHashMapTerm: String = _ + + // We currently only enable vectorized hashmap for long key/value types and partial aggregates + private val isVectorizedHashMapEnabled: Boolean = sqlContext.conf.columnarAggregateMapEnabled && + (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) && + modes.forall(mode => mode == Partial || mode == PartialMerge) + + // The name for UnsafeRow HashMap private var hashMapTerm: String = _ private var sorterTerm: String = _ @@ -325,7 +335,7 @@ case class TungstenAggregate( val mergeProjection = newMutableProjection( mergeExpr, aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), - subexpressionEliminationEnabled)() + subexpressionEliminationEnabled) val joinedRow = new JoinedRow() var currentKey: UnsafeRow = null @@ -384,25 +394,25 @@ case class TungstenAggregate( ctx.currentVars = null ctx.INPUT_ROW = keyTerm val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => - BoundReference(i, e.dataType, e.nullable).gen(ctx) + BoundReference(i, e.dataType, e.nullable).genCode(ctx) } val evaluateKeyVars = evaluateVariables(keyVars) ctx.INPUT_ROW = bufferTerm val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => - BoundReference(i, e.dataType, e.nullable).gen(ctx) + BoundReference(i, e.dataType, e.nullable).genCode(ctx) } val evaluateBufferVars = evaluateVariables(bufferVars) // evaluate the aggregation result ctx.currentVars = bufferVars val aggResults = declFunctions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).gen(ctx) + BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) } val evaluateAggResults = evaluateVariables(aggResults) // generate the final result ctx.currentVars = keyVars ++ aggResults val inputAttrs = groupingAttributes ++ aggregateAttributes val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, inputAttrs).gen(ctx) + BindReferences.bindReference(e, inputAttrs).genCode(ctx) } s""" $evaluateKeyVars @@ -427,7 +437,7 @@ case class TungstenAggregate( ctx.INPUT_ROW = keyTerm ctx.currentVars = null val eval = resultExpressions.map{ e => - BindReferences.bindReference(e, groupingAttributes).gen(ctx) + BindReferences.bindReference(e, groupingAttributes).genCode(ctx) } consume(ctx, eval) } @@ -437,17 +447,18 @@ case class TungstenAggregate( val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") - // create AggregateHashMap - val isAggregateHashMapEnabled: Boolean = false - val isAggregateHashMapSupported: Boolean = - (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) - val aggregateHashMapTerm = ctx.freshName("aggregateHashMap") - val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap") - val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName, + vectorizedHashMapTerm = ctx.freshName("vectorizedHashMap") + val vectorizedHashMapClassName = ctx.freshName("VectorizedHashMap") + val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, vectorizedHashMapClassName, groupingKeySchema, bufferSchema) - if (isAggregateHashMapEnabled && isAggregateHashMapSupported) { - ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm, - s"$aggregateHashMapTerm = new $aggregateHashMapClassName();") + // Create a name for iterator from vectorized HashMap + val iterTermForVectorizedHashMap = ctx.freshName("vectorizedHashMapIter") + if (isVectorizedHashMapEnabled) { + ctx.addMutableState(vectorizedHashMapClassName, vectorizedHashMapTerm, + s"$vectorizedHashMapTerm = new $vectorizedHashMapClassName();") + ctx.addMutableState( + "java.util.Iterator", + iterTermForVectorizedHashMap, "") } // create hashMap @@ -465,11 +476,14 @@ case class TungstenAggregate( val doAgg = ctx.freshName("doAggregateWithKeys") ctx.addNewFunction(doAgg, s""" - ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} + ${if (isVectorizedHashMapEnabled) vectorizedHashMapGenerator.generate() else ""} private void $doAgg() throws java.io.IOException { $hashMapTerm = $thisPlan.createHashMap(); ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + ${if (isVectorizedHashMapEnabled) { + s"$iterTermForVectorizedHashMap = $vectorizedHashMapTerm.rowIterator();"} else ""} + $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); } """) @@ -484,6 +498,34 @@ case class TungstenAggregate( // so `copyResult` should be reset to `false`. ctx.copyResult = false + // Iterate over the aggregate rows and convert them from ColumnarBatch.Row to UnsafeRow + def outputFromGeneratedMap: Option[String] = { + if (isVectorizedHashMapEnabled) { + val row = ctx.freshName("vectorizedHashMapRow") + ctx.currentVars = null + ctx.INPUT_ROW = row + var schema: StructType = groupingKeySchema + bufferSchema.foreach(i => schema = schema.add(i)) + val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex + .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }) + Option( + s""" + | while ($iterTermForVectorizedHashMap.hasNext()) { + | $numOutput.add(1); + | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = + | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) + | $iterTermForVectorizedHashMap.next(); + | ${generateRow.code} + | ${consume(ctx, Seq.empty, {generateRow.value})} + | + | if (shouldStop()) return; + | } + | + | $vectorizedHashMapTerm.close(); + """.stripMargin) + } else None + } + s""" if (!$initAgg) { $initAgg = true; @@ -491,6 +533,8 @@ case class TungstenAggregate( } // output the result + ${outputFromGeneratedMap.getOrElse("")} + while ($iterTerm.next()) { $numOutput.add(1); UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); @@ -511,10 +555,13 @@ case class TungstenAggregate( // create grouping key ctx.currentVars = input - val keyCode = GenerateUnsafeProjection.createCode( + val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) - val key = keyCode.value - val buffer = ctx.freshName("aggBuffer") + val vectorizedRowKeys = ctx.generateExpressions( + groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val unsafeRowKeys = unsafeRowKeyCode.value + val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer") + val vectorizedRowBuffer = ctx.freshName("vectorizedAggBuffer") // only have DeclarativeAggregate val updateExpr = aggregateExpressions.flatMap { e => @@ -529,60 +576,129 @@ case class TungstenAggregate( // generate hash code for key val hashExpr = Murmur3Hash(groupingExpressions, 42) ctx.currentVars = input - val hashEval = BindReferences.bindReference(hashExpr, child.output).gen(ctx) + val hashEval = BindReferences.bindReference(hashExpr, child.output).genCode(ctx) val inputAttr = aggregateBufferAttributes ++ child.output ctx.currentVars = new Array[ExprCode](aggregateBufferAttributes.length) ++ input - ctx.INPUT_ROW = buffer - // TODO: support subexpression elimination - val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) - val updates = evals.zipWithIndex.map { case (ev, i) => - val dt = updateExpr(i).dataType - ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) - } - val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) { + val (checkFallbackForGeneratedHashMap, checkFallbackForBytesToBytesMap, resetCounter, + incCounter) = if (testFallbackStartsAt.isDefined) { val countTerm = ctx.freshName("fallbackCounter") ctx.addMutableState("int", countTerm, s"$countTerm = 0;") - (s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;") + (s"$countTerm < ${testFallbackStartsAt.get._1}", + s"$countTerm < ${testFallbackStartsAt.get._2}", s"$countTerm = 0;", s"$countTerm += 1;") } else { - ("true", "", "") + ("true", "true", "", "") } + // We first generate code to probe and update the vectorized hash map. If the probe is + // successful the corresponding vectorized row buffer will hold the mutable row + val findOrInsertInVectorizedHashMap: Option[String] = { + if (isVectorizedHashMapEnabled) { + Option( + s""" + |if ($checkFallbackForGeneratedHashMap) { + | ${vectorizedRowKeys.map(_.code).mkString("\n")} + | if (${vectorizedRowKeys.map("!" + _.isNull).mkString(" && ")}) { + | $vectorizedRowBuffer = $vectorizedHashMapTerm.findOrInsert( + | ${vectorizedRowKeys.map(_.value).mkString(", ")}); + | } + |} + """.stripMargin) + } else { + None + } + } + + val updateRowInVectorizedHashMap: Option[String] = { + if (isVectorizedHashMapEnabled) { + ctx.INPUT_ROW = vectorizedRowBuffer + val vectorizedRowEvals = + updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) + val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable) + } + Option( + s""" + |// evaluate aggregate function + |${evaluateVariables(vectorizedRowEvals)} + |// update vectorized row + |${updateVectorizedRow.mkString("\n").trim} + """.stripMargin) + } else None + } + + // Next, we generate code to probe and update the unsafe row hash map. + val findOrInsertInUnsafeRowMap: String = { + s""" + | if ($vectorizedRowBuffer == null) { + | // generate grouping key + | ${unsafeRowKeyCode.code.trim} + | ${hashEval.code.trim} + | if ($checkFallbackForBytesToBytesMap) { + | // try to get the buffer from hash map + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); + | } + | if ($unsafeRowBuffer == null) { + | if ($sorterTerm == null) { + | $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); + | } else { + | $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); + | } + | $resetCounter + | // the hash map had be spilled, it should have enough memory now, + | // try to allocate buffer again. + | $unsafeRowBuffer = + | $hashMapTerm.getAggregationBufferFromUnsafeRow($unsafeRowKeys, ${hashEval.value}); + | if ($unsafeRowBuffer == null) { + | // failed to allocate the first page + | throw new OutOfMemoryError("No enough memory for aggregation"); + | } + | } + | } + """.stripMargin + } + + val updateRowInUnsafeRowMap: String = { + ctx.INPUT_ROW = unsafeRowBuffer + val unsafeRowBufferEvals = + updateExpr.map(BindReferences.bindReference(_, inputAttr).genCode(ctx)) + val updateUnsafeRowBuffer = unsafeRowBufferEvals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(unsafeRowBuffer, dt, i, ev, updateExpr(i).nullable) + } + s""" + |// evaluate aggregate function + |${evaluateVariables(unsafeRowBufferEvals)} + |// update unsafe row buffer + |${updateUnsafeRowBuffer.mkString("\n").trim} + """.stripMargin + } + + // We try to do hash map based in-memory aggregation first. If there is not enough memory (the // hash map will return null for new key), we spill the hash map to disk to free memory, then // continue to do in-memory aggregation and spilling until all the rows had been processed. // Finally, sort the spilled aggregate buffers by key, and merge them together for same key. s""" - // generate grouping key - ${keyCode.code.trim} - ${hashEval.code.trim} - UnsafeRow $buffer = null; - if ($checkFallback) { - // try to get the buffer from hash map - $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value}); - } - if ($buffer == null) { - if ($sorterTerm == null) { - $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter(); - } else { - $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter()); - } - $resetCoulter - // the hash map had be spilled, it should have enough memory now, - // try to allocate buffer again. - $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key, ${hashEval.value}); - if ($buffer == null) { - // failed to allocate the first page - throw new OutOfMemoryError("No enough memory for aggregation"); - } - } + UnsafeRow $unsafeRowBuffer = null; + org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $vectorizedRowBuffer = null; + + ${findOrInsertInVectorizedHashMap.getOrElse("")} + + $findOrInsertInUnsafeRowMap + $incCounter - // evaluate aggregate function - ${evaluateVariables(evals)} - // update aggregate buffer - ${updates.mkString("\n").trim} + if ($vectorizedRowBuffer != null) { + // update vectorized row + ${updateRowInVectorizedHashMap.getOrElse("")} + } else { + // update unsafe row + $updateRowInUnsafeRowMap + } """ } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index ce504e20e6dd3..c3687266109f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -82,10 +82,10 @@ class TungstenAggregationIterator( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, originalInputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], - testFallbackStartsAt: Option[Int], + testFallbackStartsAt: Option[(Int, Int)], numOutputRows: LongSQLMetric, dataSize: LongSQLMetric, spillSize: LongSQLMetric) @@ -171,7 +171,7 @@ class TungstenAggregationIterator( // hashMap. If there is not enough memory, it will multiple hash-maps, spilling // after each becomes full then using sort to merge these spills, finally do sort // based aggregation. - private def processInputs(fallbackStartsAt: Int): Unit = { + private def processInputs(fallbackStartsAt: (Int, Int)): Unit = { if (groupingExpressions.isEmpty) { // If there is no grouping expressions, we can just reuse the same buffer over and over again. // Note that it would be better to eliminate the hash map entirely in the future. @@ -187,7 +187,7 @@ class TungstenAggregationIterator( val newInput = inputIter.next() val groupingKey = groupingProjection.apply(newInput) var buffer: UnsafeRow = null - if (i < fallbackStartsAt) { + if (i < fallbackStartsAt._2) { buffer = hashMap.getAggregationBufferFromUnsafeRow(groupingKey) } if (buffer == null) { @@ -352,7 +352,7 @@ class TungstenAggregationIterator( /** * Start processing input rows. */ - processInputs(testFallbackStartsAt.getOrElse(Int.MaxValue)) + processInputs(testFallbackStartsAt.getOrElse((Int.MaxValue, Int.MaxValue))) // If we did not switch to sort-based aggregation in processInputs, // we pre-load the first key-value pair from the map (to make hasNext idempotent). diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 9abae5357973f..535e64cb34442 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -19,133 +19,153 @@ package org.apache.spark.sql.execution.aggregate import scala.language.existentials -import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, OuterScopes} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate +import org.apache.spark.sql.catalyst.expressions.aggregate.DeclarativeAggregate import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.types._ object TypedAggregateExpression { - def apply[A, B : Encoder, C : Encoder]( - aggregator: Aggregator[A, B, C]): TypedAggregateExpression = { + def apply[BUF : Encoder, OUT : Encoder]( + aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { + val bufferEncoder = encoderFor[BUF] + // We will insert the deserializer and function call expression at the bottom of each serializer + // expression while executing `TypedAggregateExpression`, which means multiply serializer + // expressions will all evaluate the same sub-expression at bottom. To avoid the re-evaluating, + // here we always use one single serializer expression to serialize the buffer object into a + // single-field row, no matter whether the encoder is flat or not. We also need to update the + // deserializer to read in all fields from that single-field row. + // TODO: remove this trick after we have better integration of subexpression elimination and + // whole stage codegen. + val bufferSerializer = if (bufferEncoder.flat) { + bufferEncoder.namedExpressions.head + } else { + Alias(CreateStruct(bufferEncoder.serializer), "buffer")() + } + + val bufferDeserializer = if (bufferEncoder.flat) { + bufferEncoder.deserializer transformUp { + case b: BoundReference => bufferSerializer.toAttribute + } + } else { + bufferEncoder.deserializer transformUp { + case UnresolvedAttribute(nameParts) => + assert(nameParts.length == 1) + UnresolvedExtractValue(bufferSerializer.toAttribute, Literal(nameParts.head)) + case BoundReference(ordinal, dt, _) => GetStructField(bufferSerializer.toAttribute, ordinal) + } + } + + val outputEncoder = encoderFor[OUT] + val outputType = if (outputEncoder.flat) { + outputEncoder.schema.head.dataType + } else { + outputEncoder.schema + } + new TypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], None, - encoderFor[B].asInstanceOf[ExpressionEncoder[Any]], - encoderFor[C].asInstanceOf[ExpressionEncoder[Any]], - Nil, - 0, - 0) + bufferSerializer, + bufferDeserializer, + outputEncoder.serializer, + outputEncoder.deserializer.dataType, + outputType) } } /** - * This class is a rough sketch of how to hook `Aggregator` into the Aggregation system. It has - * the following limitations: - * - It assumes the aggregator has a zero, `0`. + * A helper class to hook [[Aggregator]] into the aggregation system. */ case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], - aEncoder: Option[ExpressionEncoder[Any]], // Should be bound. - unresolvedBEncoder: ExpressionEncoder[Any], - cEncoder: ExpressionEncoder[Any], - children: Seq[Attribute], - mutableAggBufferOffset: Int, - inputAggBufferOffset: Int) - extends ImperativeAggregate with Logging { - - override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = - copy(mutableAggBufferOffset = newMutableAggBufferOffset) - - override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = - copy(inputAggBufferOffset = newInputAggBufferOffset) + inputDeserializer: Option[Expression], + bufferSerializer: NamedExpression, + bufferDeserializer: Expression, + outputSerializer: Seq[Expression], + outputExternalType: DataType, + dataType: DataType) extends DeclarativeAggregate with NonSQLExpression { override def nullable: Boolean = true - override def dataType: DataType = if (cEncoder.flat) { - cEncoder.schema.head.dataType - } else { - cEncoder.schema - } - override def deterministic: Boolean = true - override lazy val resolved: Boolean = aEncoder.isDefined - - override lazy val inputTypes: Seq[DataType] = Nil + override def children: Seq[Expression] = inputDeserializer.toSeq :+ bufferDeserializer - override val aggBufferSchema: StructType = unresolvedBEncoder.schema + override lazy val resolved: Boolean = inputDeserializer.isDefined && childrenResolved - override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + override def references: AttributeSet = AttributeSet(inputDeserializer.toSeq) - val bEncoder = unresolvedBEncoder - .resolve(aggBufferAttributes, OuterScopes.outerScopes) - .bind(aggBufferAttributes) + override def inputTypes: Seq[AbstractDataType] = Nil - // Note: although this simply copies aggBufferAttributes, this common code can not be placed - // in the superclass because that will lead to initialization ordering issues. - override val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) + private def aggregatorLiteral = + Literal.create(aggregator, ObjectType(classOf[Aggregator[Any, Any, Any]])) - // We let the dataset do the binding for us. - lazy val boundA = aEncoder.get + private def bufferExternalType = bufferDeserializer.dataType - private def updateBuffer(buffer: MutableRow, value: InternalRow): Unit = { - var i = 0 - while (i < aggBufferAttributes.length) { - val offset = mutableAggBufferOffset + i - aggBufferSchema(i).dataType match { - case BooleanType => buffer.setBoolean(offset, value.getBoolean(i)) - case ByteType => buffer.setByte(offset, value.getByte(i)) - case ShortType => buffer.setShort(offset, value.getShort(i)) - case IntegerType => buffer.setInt(offset, value.getInt(i)) - case LongType => buffer.setLong(offset, value.getLong(i)) - case FloatType => buffer.setFloat(offset, value.getFloat(i)) - case DoubleType => buffer.setDouble(offset, value.getDouble(i)) - case other => buffer.update(offset, value.get(i, other)) - } - i += 1 - } - } + override lazy val aggBufferAttributes: Seq[AttributeReference] = + bufferSerializer.toAttribute.asInstanceOf[AttributeReference] :: Nil - override def initialize(buffer: MutableRow): Unit = { - val zero = bEncoder.toRow(aggregator.zero) - updateBuffer(buffer, zero) + override lazy val initialValues: Seq[Expression] = { + val zero = Literal.fromObject(aggregator.zero, bufferExternalType) + ReferenceToExpressions(bufferSerializer, zero :: Nil) :: Nil } - override def update(buffer: MutableRow, input: InternalRow): Unit = { - val inputA = boundA.fromRow(input) - val currentB = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) - val merged = aggregator.reduce(currentB, inputA) - val returned = bEncoder.toRow(merged) + override lazy val updateExpressions: Seq[Expression] = { + val reduced = Invoke( + aggregatorLiteral, + "reduce", + bufferExternalType, + bufferDeserializer :: inputDeserializer.get :: Nil) - updateBuffer(buffer, returned) + ReferenceToExpressions(bufferSerializer, reduced :: Nil) :: Nil } - override def merge(buffer1: MutableRow, buffer2: InternalRow): Unit = { - val b1 = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer1) - val b2 = bEncoder.shift(inputAggBufferOffset).fromRow(buffer2) - val merged = aggregator.merge(b1, b2) - val returned = bEncoder.toRow(merged) + override lazy val mergeExpressions: Seq[Expression] = { + val leftBuffer = bufferDeserializer transform { + case a: AttributeReference => a.left + } + val rightBuffer = bufferDeserializer transform { + case a: AttributeReference => a.right + } + val merged = Invoke( + aggregatorLiteral, + "merge", + bufferExternalType, + leftBuffer :: rightBuffer :: Nil) - updateBuffer(buffer1, returned) + ReferenceToExpressions(bufferSerializer, merged :: Nil) :: Nil } - override def eval(buffer: InternalRow): Any = { - val b = bEncoder.shift(mutableAggBufferOffset).fromRow(buffer) - val result = cEncoder.toRow(aggregator.finish(b)) + override lazy val evaluateExpression: Expression = { + val resultObj = Invoke( + aggregatorLiteral, + "finish", + outputExternalType, + bufferDeserializer :: Nil) + dataType match { - case _: StructType => result - case _ => result.get(0, dataType) + case s: StructType => + ReferenceToExpressions(CreateStruct(outputSerializer), resultObj :: Nil) + case _ => + assert(outputSerializer.length == 1) + outputSerializer.head transform { + case b: BoundReference => resultObj + } } } override def toString: String = { - s"""${aggregator.getClass.getSimpleName}(${children.mkString(",")})""" + val input = inputDeserializer match { + case Some(UnresolvedDeserializer(deserializer, _)) => deserializer.dataType.simpleString + case Some(deserializer) => deserializer.dataType.simpleString + case _ => "unknown" + } + + s"$nodeName($input)" } - override def nodeName: String = aggregator.getClass.getSimpleName + override def nodeName: String = aggregator.getClass.getSimpleName.stripSuffix("$") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala similarity index 67% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index e415dd8e6ac9f..dd9b2f097e121 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -21,19 +21,24 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.types.StructType /** - * This is a helper object to generate an append-only single-key/single value aggregate hash - * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates - * (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in - * TungstenAggregate to speed up aggregates w/ key. + * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' + * for extremely fast key-value lookups while evaluating aggregates (and fall back to the + * `BytesToBytesMap` if a given key isn't found). This is 'codegened' in TungstenAggregate to speed + * up aggregates w/ key. * * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the * key-value pairs. The index lookups in the array rely on linear probing (with a small number of * maximum tries) and use an inexpensive hash function which makes it really efficient for a * majority of lookups. However, using linear probing and an inexpensive hash function also makes it * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even - * for certain distribution of keys) and requires us to fall back on the latter for correctness. + * for certain distribution of keys) and requires us to fall back on the latter for correctness. We + * also use a secondary columnar batch that logically projects over the original columnar batch and + * is equivalent to the `BytesToBytesMap` aggregate buffer. + * + * NOTE: This vectorized hash map currently doesn't support nullable keys and falls back to the + * `BytesToBytesMap` to store them. */ -class ColumnarAggMapCodeGenerator( +class VectorizedHashMapGenerator( ctx: CodegenContext, generatedClassName: String, groupingKeySchema: StructType, @@ -52,6 +57,10 @@ class ColumnarAggMapCodeGenerator( |${generateEquals()} | |${generateHashFunction()} + | + |${generateRowIterator()} + | + |${generateClose()} |} """.stripMargin } @@ -65,27 +74,40 @@ class ColumnarAggMapCodeGenerator( .mkString("\n")}; """.stripMargin + val generatedAggBufferSchema: String = + s""" + |new org.apache.spark.sql.types.StructType() + |${bufferSchema.map(key => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") + .mkString("\n")}; + """.stripMargin + s""" | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; + | private org.apache.spark.sql.execution.vectorized.ColumnarBatch aggregateBufferBatch; | private int[] buckets; - | private int numBuckets; - | private int maxSteps; + | private int capacity = 1 << 16; + | private double loadFactor = 0.5; + | private int numBuckets = (int) (capacity / loadFactor); + | private int maxSteps = 2; | private int numRows = 0; | private org.apache.spark.sql.types.StructType schema = $generatedSchema + | private org.apache.spark.sql.types.StructType aggregateBufferSchema = + | $generatedAggBufferSchema | - | public $generatedClassName(int capacity, double loadFactor, int maxSteps) { - | assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); - | this.maxSteps = maxSteps; - | numBuckets = (int) (capacity / loadFactor); + | public $generatedClassName() { | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); + | // TODO: Possibly generate this projection in TungstenAggregate directly + | aggregateBufferBatch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate( + | aggregateBufferSchema, org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); + | for (int i = 0 ; i < aggregateBufferBatch.numCols(); i++) { + | aggregateBufferBatch.setColumn(i, batch.column(i+${groupingKeys.length})); + | } + | | buckets = new int[numBuckets]; | java.util.Arrays.fill(buckets, -1); | } - | - | public $generatedClassName() { - | new $generatedClassName(1 << 16, 0.25, 5); - | } """.stripMargin } @@ -101,9 +123,11 @@ class ColumnarAggMapCodeGenerator( */ private def generateHashFunction(): String = { s""" - |// TODO: Improve this hash function |private long hash($groupingKeySignature) { - | return ${groupingKeys.map(_._2).mkString(" ^ ")}; + | long h = 0; + | ${groupingKeys.map(key => s"h = (h ^ (0x9e3779b9)) + ${key._2} + (h << 6) + (h >>> 2);") + .mkString("\n")} + | return h; |} """.stripMargin } @@ -172,15 +196,22 @@ class ColumnarAggMapCodeGenerator( | while (step < maxSteps) { | // Return bucket index if it's either an empty slot or already contains the key | if (buckets[idx] == -1) { - | ${groupingKeys.zipWithIndex.map(k => - s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")} - | ${bufferValues.zipWithIndex.map(k => - s"batch.column(${groupingKeys.length + k._2}).putLong(numRows, 0);") - .mkString("\n")} - | buckets[idx] = numRows++; - | return batch.getRow(buckets[idx]); + | if (numRows < capacity) { + | ${groupingKeys.zipWithIndex.map(k => + s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")} + | ${bufferValues.zipWithIndex.map(k => + s"batch.column(${groupingKeys.length + k._2}).putNull(numRows);") + .mkString("\n")} + | buckets[idx] = numRows++; + | batch.setNumRows(numRows); + | aggregateBufferBatch.setNumRows(numRows); + | return aggregateBufferBatch.getRow(buckets[idx]); + | } else { + | // No more space + | return null; + | } | } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) { - | return batch.getRow(buckets[idx]); + | return aggregateBufferBatch.getRow(buckets[idx]); | } | idx = (idx + 1) & (numBuckets - 1); | step++; @@ -190,4 +221,21 @@ class ColumnarAggMapCodeGenerator( |} """.stripMargin } + + private def generateRowIterator(): String = { + s""" + |public java.util.Iterator + | rowIterator() { + | return batch.rowIterator(); + |} + """.stripMargin + } + + private def generateClose(): String = { + s""" + |public void close() { + | batch.close(); + |} + """.stripMargin + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index f5776e7b8d49a..4ceb710f4b2b4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -361,7 +361,7 @@ private[sql] case class ScalaUDAF( val inputAttributes = childrenSchema.toAttributes log.debug( s"Creating MutableProj: $children, inputSchema: $inputAttributes.") - GenerateMutableProjection.generate(children, inputAttributes)() + GenerateMutableProjection.generate(children, inputAttributes) } private[this] lazy val inputToScalaConverters: Any => Any = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 344aaff348e77..892c57ae7d7c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -31,8 +31,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) override def output: Seq[Attribute] = projectList.map(_.toAttribute) - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } protected override def doProduce(ctx: CodegenContext): String = { @@ -53,7 +53,7 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) val exprs = projectList.map(x => ExpressionCanonicalizer.execute(BindReferences.bindReference(x, child.output))) ctx.currentVars = input - val resultVars = exprs.map(_.gen(ctx)) + val resultVars = exprs.map(_.genCode(ctx)) // Evaluation of non-deterministic expressions can't be deferred. val nonDeterministicAttrs = projectList.filterNot(_.deterministic).map(_.toAttribute) s""" @@ -103,8 +103,8 @@ case class Filter(condition: Expression, child: SparkPlan) private[sql] override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } protected override def doProduce(ctx: CodegenContext): String = { @@ -122,7 +122,7 @@ case class Filter(condition: Expression, child: SparkPlan) val evaluated = evaluateRequiredVariables(child.output, in, c.references) // Generate the code for the predicate. - val ev = ExpressionCanonicalizer.execute(bound).gen(ctx) + val ev = ExpressionCanonicalizer.execute(bound).genCode(ctx) val nullCheck = if (bound.nullable) { s"${ev.isNull} || " } else { @@ -243,8 +243,8 @@ case class Sample( } } - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } protected override def doProduce(ctx: CodegenContext): String = { @@ -315,7 +315,7 @@ case class Range( // output attributes should not affect the results override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) - override def upstreams(): Seq[RDD[InternalRow]] = { + override def inputRDDs(): Seq[RDD[InternalRow]] = { sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) .map(i => InternalRow(i)) :: Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 5d00c805a6afe..bd1cbbe5fd0cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -422,6 +422,37 @@ case class ShowTablePropertiesCommand( } } +/** + * A command for users to get the DDL of an existing table + * The syntax of using this command in SQL is: + * {{{ + * SHOW CREATE TABLE tableIdentifier + * }}} + */ +case class ShowCreateTableCommand(tableIdentifier: TableIdentifier) + extends RunnableCommand{ + + // The result of SHOW CREATE TABLE is the whole string of DDL command + override val output: Seq[Attribute] = { + val schema = StructType( + StructField("DDL", StringType, nullable = false)::Nil) + schema.toAttributes + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val ddl = catalog.showCreateTable(tableIdentifier) + if (ddl.contains("CLUSTERED BY") || ddl.contains("SKEWED BY") || ddl.contains("STORED BY")) { + Seq(Row("WARN: This DDL is not supported by Spark SQL natively, " + + "because it contains 'CLUSTERED BY', 'SKEWED BY' or 'STORED BY' clause."), + Row(ddl)) + } else { + Seq(Row(ddl)) + } + } + +} + /** * A command for users to list all of the registered functions. * The syntax of using this command in SQL is: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index 10fde152ab2a9..0dfe7dba1e5c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -186,7 +186,8 @@ case class DataSource( userSpecifiedSchema = Some(dataSchema), className = className, options = - new CaseInsensitiveMap(options.filterKeys(_ != "path"))).resolveRelation())) + new CaseInsensitiveMap( + options.filterKeys(_ != "path") + ("basePath" -> path))).resolveRelation())) } new FileStreamSource( @@ -310,7 +311,17 @@ case class DataSource( val fileCatalog: FileCatalog = new HDFSFileCatalog(sqlContext, options, globbedPaths, partitionSchema) - val dataSchema = userSpecifiedSchema.orElse { + + val dataSchema = userSpecifiedSchema.map { schema => + val equality = + if (sqlContext.conf.caseSensitiveAnalysis) { + org.apache.spark.sql.catalyst.analysis.caseSensitiveResolution + } else { + org.apache.spark.sql.catalyst.analysis.caseInsensitiveResolution + } + + StructType(schema.filterNot(f => partitionColumns.exists(equality(_, f.name)))) + }.orElse { format.inferSchema( sqlContext, caseInsensitiveOptions, @@ -318,7 +329,7 @@ case class DataSource( }.getOrElse { throw new AnalysisException( s"Unable to infer schema for $format at ${allPaths.take(2).mkString(",")}. " + - "It must be specified manually") + "It must be specified manually") } val enrichedOptions = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 468e101fedb8b..90694d9af4e01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -17,15 +17,17 @@ package org.apache.spark.sql.execution.datasources -import org.apache.spark.{Partition, TaskContext} +import org.apache.spark.{Partition => RDDPartition, TaskContext} +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileNameHolder, RDD} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.vectorized.ColumnarBatch /** * A single file that should be read, along with partition column values that * need to be prepended to each row. The reading should start at the first - * valid record found after `offset`. + * valid record found after `start`. */ case class PartitionedFile( partitionValues: InternalRow, @@ -43,7 +45,7 @@ case class PartitionedFile( * * TODO: This currently does not take locality information about the files into account. */ -case class FilePartition(index: Int, files: Seq[PartitionedFile]) extends Partition +case class FilePartition(index: Int, files: Seq[PartitionedFile]) extends RDDPartition class FileScanRDD( @transient val sqlContext: SQLContext, @@ -51,38 +53,82 @@ class FileScanRDD( @transient val filePartitions: Seq[FilePartition]) extends RDD[InternalRow](sqlContext.sparkContext, Nil) { - override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { + override def compute(split: RDDPartition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { + private val inputMetrics = context.taskMetrics().inputMetrics + private val existingBytesRead = inputMetrics.bytesRead + + // Find a function that will return the FileSystem bytes read by this thread. Do this before + // apply readFunction, because it might read some bytes. + private val getBytesReadCallback: Option[() => Long] = + SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() + + // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. + // If we do a coalesce, however, we are likely to compute multiple partitions in the same + // task and in the same thread, in which case we need to avoid override values written by + // previous partitions (SPARK-13071). + private def updateBytesRead(): Unit = { + getBytesReadCallback.foreach { getBytesRead => + inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) + } + } + + // If we can't get the bytes read from the FS stats, fall back to the file size, + // which may be inaccurate. + private def updateBytesReadWithFileSize(): Unit = { + if (getBytesReadCallback.isEmpty && currentFile != null) { + inputMetrics.incBytesRead(currentFile.length) + } + } + private[this] val files = split.asInstanceOf[FilePartition].files.toIterator + private[this] var currentFile: PartitionedFile = null private[this] var currentIterator: Iterator[Object] = null def hasNext = (currentIterator != null && currentIterator.hasNext) || nextIterator() - def next() = currentIterator.next() + def next() = { + val nextElement = currentIterator.next() + // TODO: we should have a better separation of row based and batch based scan, so that we + // don't need to run this `if` for every record. + if (nextElement.isInstanceOf[ColumnarBatch]) { + inputMetrics.incRecordsRead(nextElement.asInstanceOf[ColumnarBatch].numRows()) + } else { + inputMetrics.incRecordsRead(1) + } + if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { + updateBytesRead() + } + nextElement + } /** Advances to the next file. Returns true if a new non-empty iterator is available. */ private def nextIterator(): Boolean = { + updateBytesReadWithFileSize() if (files.hasNext) { - val nextFile = files.next() - logInfo(s"Reading File $nextFile") - InputFileNameHolder.setInputFileName(nextFile.filePath) - currentIterator = readFunction(nextFile) + currentFile = files.next() + logInfo(s"Reading File $currentFile") + InputFileNameHolder.setInputFileName(currentFile.filePath) + currentIterator = readFunction(currentFile) hasNext } else { + currentFile = null InputFileNameHolder.unsetInputFileName() false } } override def close() = { + updateBytesRead() + updateBytesReadWithFileSize() InputFileNameHolder.unsetInputFileName() } } // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => iterator.close()) + context.addTaskCompletionListener(_ => iterator.close()) iterator.asInstanceOf[Iterator[InternalRow]] // This is an erasure hack. } - override protected def getPartitions: Array[Partition] = filePartitions.toArray + override protected def getPartitions: Array[RDDPartition] = filePartitions.toArray } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala index 815d1d01ef343..b9527db6d0092 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala @@ -33,7 +33,6 @@ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.UnsafeKVExternalSorter import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{OutputWriter, OutputWriterFactory} import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala index 54fb03b6d3bf7..ed40cd0c812ab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVRelation.scala @@ -30,8 +30,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.execution.datasources.PartitionedFile -import org.apache.spark.sql.sources._ +import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory, PartitionedFile} import org.apache.spark.sql.types._ object CSVRelation extends Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index 06a371b88bc02..34db10f822554 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -25,17 +25,15 @@ import org.apache.hadoop.io.{LongWritable, Text} import org.apache.hadoop.mapred.TextInputFormat import org.apache.hadoop.mapreduce._ -import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection} +import org.apache.spark.sql.catalyst.expressions.JoinedRow import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.BitSet /** * Provides access to CSV data from pure SQL statements. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala new file mode 100644 index 0000000000000..d37a939b544aa --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/fileSourceInterfaces.scala @@ -0,0 +1,530 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import scala.collection.mutable +import scala.util.Try + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.mapred.{FileInputFormat, JobConf} +import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental +import org.apache.spark.internal.Logging +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.FileRelation +import org.apache.spark.sql.sources.{BaseRelation, Filter} +import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.SerializableConfiguration + +/** + * ::Experimental:: + * A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver + * side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized + * to executor side to create actual [[OutputWriter]]s on the fly. + * + * @since 1.4.0 + */ +@Experimental +abstract class OutputWriterFactory extends Serializable { + /** + * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side + * to instantiate new [[OutputWriter]]s. + * + * @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that + * this may not point to the final output file. For example, `FileOutputFormat` writes to + * temporary directories and then merge written files back to the final destination. In + * this case, `path` points to a temporary output file under the temporary directory. + * @param dataSchema Schema of the rows to be written. Partition columns are not included in the + * schema if the relation being written is partitioned. + * @param context The Hadoop MapReduce task context. + * @since 1.4.0 + */ + private[sql] def newInstance( + path: String, + bucketId: Option[Int], // TODO: This doesn't belong here... + dataSchema: StructType, + context: TaskAttemptContext): OutputWriter +} + +/** + * ::Experimental:: + * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the + * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. + * An [[OutputWriter]] instance is created and initialized when a new output file is opened on + * executor side. This instance is used to persist rows to this single output file. + * + * @since 1.4.0 + */ +@Experimental +abstract class OutputWriter { + /** + * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned + * tables, dynamic partition columns are not included in rows to be written. + * + * @since 1.4.0 + */ + def write(row: Row): Unit + + /** + * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before + * the task output is committed. + * + * @since 1.4.0 + */ + def close(): Unit + + private var converter: InternalRow => Row = _ + + protected[sql] def initConverter(dataSchema: StructType) = { + converter = + CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] + } + + protected[sql] def writeInternal(row: InternalRow): Unit = { + write(converter(row)) + } +} + +/** + * Acts as a container for all of the metadata required to read from a datasource. All discovery, + * resolution and merging logic for schemas and partitions has been removed. + * + * @param location A [[FileCatalog]] that can enumerate the locations of all the files that comprise + * this relation. + * @param partitionSchema The schema of the columns (if any) that are used to partition the relation + * @param dataSchema The schema of any remaining columns. Note that if any partition columns are + * present in the actual data files as well, they are preserved. + * @param bucketSpec Describes the bucketing (hash-partitioning of the files by some column values). + * @param fileFormat A file format that can be used to read and write the data in files. + * @param options Configuration used when reading / writing data. + */ +case class HadoopFsRelation( + sqlContext: SQLContext, + location: FileCatalog, + partitionSchema: StructType, + dataSchema: StructType, + bucketSpec: Option[BucketSpec], + fileFormat: FileFormat, + options: Map[String, String]) extends BaseRelation with FileRelation { + + val schema: StructType = { + val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet + StructType(dataSchema ++ partitionSchema.filterNot { column => + dataSchemaColumnNames.contains(column.name.toLowerCase) + }) + } + + def partitionSchemaOption: Option[StructType] = + if (partitionSchema.isEmpty) None else Some(partitionSchema) + def partitionSpec: PartitionSpec = location.partitionSpec() + + def refresh(): Unit = location.refresh() + + override def toString: String = + s"HadoopFiles" + + /** Returns the list of files that will be read when scanning this relation. */ + override def inputFiles: Array[String] = + location.allFiles().map(_.getPath.toUri.toString).toArray + + override def sizeInBytes: Long = location.allFiles().map(_.getLen).sum +} + +/** + * Used to read and write data stored in files to/from the [[InternalRow]] format. + */ +trait FileFormat { + /** + * When possible, this method should return the schema of the given `files`. When the format + * does not support inference, or no valid files are given should return None. In these cases + * Spark will require that user specify the schema manually. + */ + def inferSchema( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Option[StructType] + + /** + * Prepares a read job and returns a potentially updated data source option [[Map]]. This method + * can be useful for collecting necessary global information for scanning input data. + */ + def prepareRead( + sqlContext: SQLContext, + options: Map[String, String], + files: Seq[FileStatus]): Map[String, String] = options + + /** + * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can + * be put here. For example, user defined output committer can be configured here + * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. + */ + def prepareWrite( + sqlContext: SQLContext, + job: Job, + options: Map[String, String], + dataSchema: StructType): OutputWriterFactory + + /** + * Returns whether this format support returning columnar batch or not. + * + * TODO: we should just have different traits for the different formats. + */ + def supportBatch(sqlContext: SQLContext, dataSchema: StructType): Boolean = { + false + } + + /** + * Returns a function that can be used to read a single file in as an Iterator of InternalRow. + * + * @param dataSchema The global data schema. It can be either specified by the user, or + * reconciled/merged from all underlying data files. If any partition columns + * are contained in the files, they are preserved in this schema. + * @param partitionSchema The schema of the partition column row that will be present in each + * PartitionedFile. These columns should be appended to the rows that + * are produced by the iterator. + * @param requiredSchema The schema of the data that should be output for each row. This may be a + * subset of the columns that are present in the file if column pruning has + * occurred. + * @param filters A set of filters than can optionally be used to reduce the number of rows output + * @param options A set of string -> string configuration options. + * @return + */ + def buildReader( + sqlContext: SQLContext, + dataSchema: StructType, + partitionSchema: StructType, + requiredSchema: StructType, + filters: Seq[Filter], + options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { + // TODO: Remove this default implementation when the other formats have been ported + // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. + throw new UnsupportedOperationException(s"buildReader is not supported for $this") + } +} + +/** + * A collection of data files from a partitioned relation, along with the partition values in the + * form of an [[InternalRow]]. + */ +case class Partition(values: InternalRow, files: Seq[FileStatus]) + +/** + * An interface for objects capable of enumerating the files that comprise a relation as well + * as the partitioning characteristics of those files. + */ +trait FileCatalog { + def paths: Seq[Path] + + def partitionSpec(): PartitionSpec + + /** + * Returns all valid files grouped into partitions when the data is partitioned. If the data is + * unpartitioned, this will return a single partition with not partition values. + * + * @param filters the filters used to prune which partitions are returned. These filters must + * only refer to partition columns and this method will only return files + * where these predicates are guaranteed to evaluate to `true`. Thus, these + * filters will not need to be evaluated again on the returned data. + */ + def listFiles(filters: Seq[Expression]): Seq[Partition] + + def allFiles(): Seq[FileStatus] + + def getStatus(path: Path): Array[FileStatus] + + def refresh(): Unit +} + +/** + * A file catalog that caches metadata gathered by scanning all the files present in `paths` + * recursively. + * + * @param parameters as set of options to control discovery + * @param paths a list of paths to scan + * @param partitionSchema an optional partition schema that will be use to provide types for the + * discovered partitions + */ +class HDFSFileCatalog( + val sqlContext: SQLContext, + val parameters: Map[String, String], + val paths: Seq[Path], + val partitionSchema: Option[StructType]) + extends FileCatalog with Logging { + + private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) + + var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] + var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] + var cachedPartitionSpec: PartitionSpec = _ + + def partitionSpec(): PartitionSpec = { + if (cachedPartitionSpec == null) { + cachedPartitionSpec = inferPartitioning(partitionSchema) + } + + cachedPartitionSpec + } + + refresh() + + override def listFiles(filters: Seq[Expression]): Seq[Partition] = { + if (partitionSpec().partitionColumns.isEmpty) { + Partition(InternalRow.empty, allFiles().filterNot(_.getPath.getName startsWith "_")) :: Nil + } else { + prunePartitions(filters, partitionSpec()).map { + case PartitionDirectory(values, path) => + Partition( + values, + getStatus(path).filterNot(_.getPath.getName startsWith "_")) + } + } + } + + protected def prunePartitions( + predicates: Seq[Expression], + partitionSpec: PartitionSpec): Seq[PartitionDirectory] = { + val PartitionSpec(partitionColumns, partitions) = partitionSpec + val partitionColumnNames = partitionColumns.map(_.name).toSet + val partitionPruningPredicates = predicates.filter { + _.references.map(_.name).toSet.subsetOf(partitionColumnNames) + } + + if (partitionPruningPredicates.nonEmpty) { + val predicate = partitionPruningPredicates.reduce(expressions.And) + + val boundPredicate = InterpretedPredicate.create(predicate.transform { + case a: AttributeReference => + val index = partitionColumns.indexWhere(a.name == _.name) + BoundReference(index, partitionColumns(index).dataType, nullable = true) + }) + + val selected = partitions.filter { + case PartitionDirectory(values, _) => boundPredicate(values) + } + logInfo { + val total = partitions.length + val selectedSize = selected.length + val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 + s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." + } + + selected + } else { + partitions + } + } + + def allFiles(): Seq[FileStatus] = leafFiles.values.toSeq + + def getStatus(path: Path): Array[FileStatus] = leafDirToChildrenFiles(path) + + private def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { + if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) { + HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext) + } else { + val statuses = paths.flatMap { path => + val fs = path.getFileSystem(hadoopConf) + logInfo(s"Listing $path on driver") + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(hadoopConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + if (pathFilter != null) { + Try(fs.listStatus(path, pathFilter)).getOrElse(Array.empty) + } else { + Try(fs.listStatus(path)).getOrElse(Array.empty) + } + }.filterNot { status => + val name = status.getPath.getName + HadoopFsRelation.shouldFilterOut(name) + } + + val (dirs, files) = statuses.partition(_.isDirectory) + + // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) + if (dirs.isEmpty) { + mutable.LinkedHashSet(files: _*) + } else { + mutable.LinkedHashSet(files: _*) ++ listLeafFiles(dirs.map(_.getPath)) + } + } + } + + def inferPartitioning(schema: Option[StructType]): PartitionSpec = { + // We use leaf dirs containing data files to discover the schema. + val leafDirs = leafDirToChildrenFiles.keys.toSeq + schema match { + case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => + val spec = PartitioningUtils.parsePartitions( + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = false, + basePaths = basePaths) + + // Without auto inference, all of value in the `row` should be null or in StringType, + // we need to cast into the data type that user specified. + def castPartitionValuesToUserSchema(row: InternalRow) = { + InternalRow((0 until row.numFields).map { i => + Cast( + Literal.create(row.getUTF8String(i), StringType), + userProvidedSchema.fields(i).dataType).eval() + }: _*) + } + + PartitionSpec(userProvidedSchema, spec.partitions.map { part => + part.copy(values = castPartitionValuesToUserSchema(part.values)) + }) + case _ => + PartitioningUtils.parsePartitions( + leafDirs, + PartitioningUtils.DEFAULT_PARTITION_NAME, + typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled(), + basePaths = basePaths) + } + } + + /** + * Contains a set of paths that are considered as the base dirs of the input datasets. + * The partitioning discovery logic will make sure it will stop when it reaches any + * base path. By default, the paths of the dataset provided by users will be base paths. + * For example, if a user uses `sqlContext.read.parquet("/path/something=true/")`, the base path + * will be `/path/something=true/`, and the returned DataFrame will not contain a column of + * `something`. If users want to override the basePath. They can set `basePath` in the options + * to pass the new base path to the data source. + * For the above example, if the user-provided base path is `/path/`, the returned + * DataFrame will have the column of `something`. + */ + private def basePaths: Set[Path] = { + val userDefinedBasePath = parameters.get("basePath").map(basePath => Set(new Path(basePath))) + userDefinedBasePath.getOrElse { + // If the user does not provide basePath, we will just use paths. + paths.toSet + }.map { hdfsPath => + // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). + val fs = hdfsPath.getFileSystem(hadoopConf) + hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + } + } + + def refresh(): Unit = { + val files = listLeafFiles(paths) + + leafFiles.clear() + leafDirToChildrenFiles.clear() + + leafFiles ++= files.map(f => f.getPath -> f) + leafDirToChildrenFiles ++= files.toArray.groupBy(_.getPath.getParent) + + cachedPartitionSpec = null + } + + override def equals(other: Any): Boolean = other match { + case hdfs: HDFSFileCatalog => paths.toSet == hdfs.paths.toSet + case _ => false + } + + override def hashCode(): Int = paths.toSet.hashCode() +} + +/** + * Helper methods for gathering metadata from HDFS. + */ +private[sql] object HadoopFsRelation extends Logging { + + /** Checks if we should filter out this path name. */ + def shouldFilterOut(pathName: String): Boolean = { + // TODO: We should try to filter out all files/dirs starting with "." or "_". + // The only reason that we are not doing it now is that Parquet needs to find those + // metadata files from leaf files returned by this methods. We should refactor + // this logic to not mix metadata files with data files. + pathName == "_SUCCESS" || pathName == "_temporary" || pathName.startsWith(".") + } + + // We don't filter files/directories whose name start with "_" except "_temporary" here, as + // specific data sources may take advantages over them (e.g. Parquet _metadata and + // _common_metadata files). "_temporary" directories are explicitly ignored since failed + // tasks/jobs may leave partial/corrupted data files there. Files and directories whose name + // start with "." are also ignored. + def listLeafFiles(fs: FileSystem, status: FileStatus): Array[FileStatus] = { + logInfo(s"Listing ${status.getPath}") + val name = status.getPath.getName.toLowerCase + if (shouldFilterOut(name)) { + Array.empty + } else { + // Dummy jobconf to get to the pathFilter defined in configuration + val jobConf = new JobConf(fs.getConf, this.getClass()) + val pathFilter = FileInputFormat.getInputPathFilter(jobConf) + val statuses = + if (pathFilter != null) { + val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDirectory) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } else { + val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) + files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) + } + statuses.filterNot(status => shouldFilterOut(status.getPath.getName)) + } + } + + // `FileStatus` is Writable but not serializable. What make it worse, somehow it doesn't play + // well with `SerializableWritable`. So there seems to be no way to serialize a `FileStatus`. + // Here we use `FakeFileStatus` to extract key components of a `FileStatus` to serialize it from + // executor side and reconstruct it on driver side. + case class FakeFileStatus( + path: String, + length: Long, + isDir: Boolean, + blockReplication: Short, + blockSize: Long, + modificationTime: Long, + accessTime: Long) + + def listLeafFilesInParallel( + paths: Seq[Path], + hadoopConf: Configuration, + sparkContext: SparkContext): mutable.LinkedHashSet[FileStatus] = { + logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") + + val serializableConfiguration = new SerializableConfiguration(hadoopConf) + val serializedPaths = paths.map(_.toString) + + val fakeStatuses = sparkContext.parallelize(serializedPaths).map(new Path(_)).flatMap { path => + val fs = path.getFileSystem(serializableConfiguration.value) + Try(listLeafFiles(fs, fs.getFileStatus(path))).getOrElse(Array.empty) + }.map { status => + FakeFileStatus( + status.getPath.toString, + status.getLen, + status.isDirectory, + status.getReplication, + status.getBlockSize, + status.getModificationTime, + status.getAccessTime) + }.collect() + + val hadoopFakeStatuses = fakeStatuses.map { f => + new FileStatus( + f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)) + } + mutable.LinkedHashSet(hadoopFakeStatuses: _*) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index b91e892f8f211..bfe7aefe4100c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -784,7 +784,7 @@ private[sql] object ParquetRelation extends Logging { // scalastyle:on classforname redirect(JLogger.getLogger("parquet")) } catch { case _: Throwable => - // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly jar + // SPARK-9974: com.twitter:parquet-hadoop-bundle:1.6.0 is not packaged into the assembly // when Spark is built with SBT. So `parquet.Log` may not be found. This try/catch block // should be removed after this issue is fixed. } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala index 28ac4583e9b25..5b8dc4a3ee723 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/rules.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation, InsertableRelation} +import org.apache.spark.sql.sources.{BaseRelation, InsertableRelation} /** * Try to replaces [[UnresolvedRelation]]s with [[ResolvedDataSource]]. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala index 94ecb7a28663c..fa0df61ca5f2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/text/DefaultSource.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{BufferHolder, UnsafeRowWriter} -import org.apache.spark.sql.execution.datasources.{CompressionCodecs, HadoopFileLinesReader, PartitionedFile} +import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 17eae88b49dec..e6079ecaadc7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -164,8 +164,8 @@ package object debug { } } - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } override def doProduce(ctx: CodegenContext): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala index 102a9356df311..a4f42133425ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.exchange -import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.{ExecutionContext, Future} import scala.concurrent.duration._ import org.apache.spark.broadcast @@ -81,8 +81,7 @@ case class BroadcastExchange( } override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = { - val result = Await.result(relationFuture, timeout) - result.asInstanceOf[broadcast.Broadcast[T]] + ThreadUtils.awaitResult(relationFuture, timeout).asInstanceOf[broadcast.Broadcast[T]] } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala index 7e35db7dd8a79..d7deac93374c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala @@ -22,7 +22,6 @@ import java.util.Random import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.shuffle.sort.SortShuffleManager import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ @@ -179,9 +178,6 @@ object ShuffleExchange { // copy. true } - } else if (shuffleManager.isInstanceOf[HashShuffleManager]) { - // We're using hash-based shuffle, so we don't need to copy. - false } else { // Catch-all case to safely handle any future ShuffleManager implementations. true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index a8f854136c1f9..89487c6b87150 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -71,8 +71,8 @@ case class BroadcastHashJoin( } } - override def upstreams(): Seq[RDD[InternalRow]] = { - streamedPlan.asInstanceOf[CodegenSupport].upstreams() + override def inputRDDs(): Seq[RDD[InternalRow]] = { + streamedPlan.asInstanceOf[CodegenSupport].inputRDDs() } override def doProduce(ctx: CodegenContext): String = { @@ -118,7 +118,7 @@ case class BroadcastHashJoin( ctx.currentVars = input if (streamedKeys.length == 1 && streamedKeys.head.dataType == LongType) { // generate the join key as Long - val ev = streamedKeys.head.gen(ctx) + val ev = streamedKeys.head.genCode(ctx) (ev, ev.isNull) } else { // generate the join key as UnsafeRow @@ -134,7 +134,7 @@ case class BroadcastHashJoin( ctx.currentVars = null ctx.INPUT_ROW = matched buildPlan.output.zipWithIndex.map { case (a, i) => - val ev = BoundReference(i, a.dataType, a.nullable).gen(ctx) + val ev = BoundReference(i, a.dataType, a.nullable).genCode(ctx) if (joinType == Inner) { ev } else { @@ -170,7 +170,8 @@ case class BroadcastHashJoin( val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) // filter the output via condition ctx.currentVars = input ++ buildVars - val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) s""" |$eval |${ev.code} @@ -244,7 +245,8 @@ case class BroadcastHashJoin( // evaluate the variables from build side that used by condition val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) ctx.currentVars = input ++ buildVars - val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) + val ev = + BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) s""" |boolean $conditionPassed = true; |${eval.trim} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 0c3e3c3fc18a1..f021f3758c52c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -60,9 +60,7 @@ case class ShuffledHashJoin( val context = TaskContext.get() val relation = HashedRelation(iter, buildKeys, taskMemoryManager = context.taskMemoryManager()) // This relation is usually used until the end of task. - context.addTaskCompletionListener((t: TaskContext) => - relation.close() - ) + context.addTaskCompletionListener(_ => relation.close()) relation } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 0e7b2f2f3187f..4e45fd656007f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -216,7 +216,7 @@ case class SortMergeJoin( joinType == Inner } - override def upstreams(): Seq[RDD[InternalRow]] = { + override def inputRDDs(): Seq[RDD[InternalRow]] = { left.execute() :: right.execute() :: Nil } @@ -226,7 +226,7 @@ case class SortMergeJoin( keys: Seq[Expression], input: Seq[Attribute]): Seq[ExprCode] = { ctx.INPUT_ROW = row - keys.map(BindReferences.bindReference(_, input).gen(ctx)) + keys.map(BindReferences.bindReference(_, input).genCode(ctx)) } private def copyKeys(ctx: CodegenContext, vars: Seq[ExprCode]): Seq[ExprCode] = { @@ -376,7 +376,7 @@ case class SortMergeJoin( private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { ctx.INPUT_ROW = rightRow right.output.zipWithIndex.map { case (a, i) => - BoundReference(i, a.dataType, a.nullable).gen(ctx) + BoundReference(i, a.dataType, a.nullable).genCode(ctx) } } @@ -427,7 +427,7 @@ case class SortMergeJoin( val (rightBefore, rightAfter) = splitVarsByCondition(right.output, rightVars) // Generate code for condition ctx.currentVars = leftVars ++ rightVars - val cond = BindReferences.bindReference(condition.get, output).gen(ctx) + val cond = BindReferences.bindReference(condition.get, output).genCode(ctx) // evaluate the columns those used by condition before loop val before = s""" |boolean $loaded = false; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index 9643b52f96544..c9a14593fb400 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -57,8 +57,8 @@ trait BaseLimit extends UnaryNode with CodegenSupport { iter.take(limit) } - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } protected override def doProduce(ctx: CodegenContext): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index d2ab18ef0e189..7c8bc7fed8313 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -25,19 +25,22 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.types.ObjectType +import org.apache.spark.sql.types.{DataType, ObjectType} /** * Takes the input row from child and turns it into object using the given deserializer expression. * The output of this operator is a single-field safe row containing the deserialized object. */ case class DeserializeToObject( - deserializer: Alias, + deserializer: Expression, + outputObjAttr: Attribute, child: SparkPlan) extends UnaryNode with CodegenSupport { - override def output: Seq[Attribute] = deserializer.toAttribute :: Nil - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } protected override def doProduce(ctx: CodegenContext): String = { @@ -48,7 +51,7 @@ case class DeserializeToObject( val bound = ExpressionCanonicalizer.execute( BindReferences.bindReference(deserializer, child.output)) ctx.currentVars = input - val resultVars = bound.gen(ctx) :: Nil + val resultVars = bound.genCode(ctx) :: Nil consume(ctx, resultVars) } @@ -67,10 +70,11 @@ case class DeserializeToObject( case class SerializeFromObject( serializer: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with CodegenSupport { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } protected override def doProduce(ctx: CodegenContext): String = { @@ -82,7 +86,7 @@ case class SerializeFromObject( ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output)) } ctx.currentVars = input - val resultVars = bound.map(_.gen(ctx)) + val resultVars = bound.map(_.genCode(ctx)) consume(ctx, resultVars) } @@ -98,63 +102,74 @@ case class SerializeFromObject( * Helper functions for physical operators that work with user defined objects. */ trait ObjectOperator extends SparkPlan { - def generateToObject(objExpr: Expression, inputSchema: Seq[Attribute]): InternalRow => Any = { - val objectProjection = GenerateSafeProjection.generate(objExpr :: Nil, inputSchema) - (i: InternalRow) => objectProjection(i).get(0, objExpr.dataType) + def deserializeRowToObject( + deserializer: Expression, + inputSchema: Seq[Attribute]): InternalRow => Any = { + val proj = GenerateSafeProjection.generate(deserializer :: Nil, inputSchema) + (i: InternalRow) => proj(i).get(0, deserializer.dataType) } - def generateToRow(serializer: Seq[Expression]): Any => InternalRow = { - val outputProjection = if (serializer.head.dataType.isInstanceOf[ObjectType]) { - GenerateSafeProjection.generate(serializer) - } else { - GenerateUnsafeProjection.generate(serializer) + def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = { + val proj = GenerateUnsafeProjection.generate(serializer) + val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head + val objRow = new SpecificMutableRow(objType :: Nil) + (o: Any) => { + objRow(0) = o + proj(objRow) } - val inputType = serializer.head.collect { case b: BoundReference => b.dataType }.head - val outputRow = new SpecificMutableRow(inputType :: Nil) + } + + def wrapObjectToRow(objType: DataType): Any => InternalRow = { + val outputRow = new SpecificMutableRow(objType :: Nil) (o: Any) => { outputRow(0) = o - outputProjection(outputRow) + outputRow } } + + def unwrapObjectFromRow(objType: DataType): InternalRow => Any = { + (i: InternalRow) => i.get(0, objType) + } } /** - * Applies the given function to each input row and encodes the result. + * Applies the given function to input object iterator. + * The output of its child must be a single-field row containing the input object. */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - deserializer: Expression, - serializer: Seq[NamedExpression], + outputObjAttr: Attribute, child: SparkPlan) extends UnaryNode with ObjectOperator { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(deserializer, child.output) - val outputObject = generateToRow(serializer) + val getObject = unwrapObjectFromRow(child.output.head.dataType) + val outputObject = wrapObjectToRow(outputObjAttr.dataType) func(iter.map(getObject)).map(outputObject) } } } /** - * Applies the given function to each input row and encodes the result. + * Applies the given function to each input object. + * The output of its child must be a single-field row containing the input object. * - * Note that, each serializer expression needs the result object which is returned by the given - * function, as input. This operator uses some tricks to make sure we only calculate the result - * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with - * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of - * a project while explain. + * This operator is kind of a safe version of [[Project]], as it's output is custom object, we need + * to use safe row to contain it. */ case class MapElements( func: AnyRef, - deserializer: Expression, - serializer: Seq[NamedExpression], + outputObjAttr: Attribute, child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) - override def upstreams(): Seq[RDD[InternalRow]] = { - child.asInstanceOf[CodegenSupport].upstreams() + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() } protected override def doProduce(ctx: CodegenContext): String = { @@ -167,23 +182,14 @@ case class MapElements( case _ => classOf[Any => Any] -> "apply" } val funcObj = Literal.create(func, ObjectType(funcClass)) - val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType - val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer)) + val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output) val bound = ExpressionCanonicalizer.execute( BindReferences.bindReference(callFunc, child.output)) ctx.currentVars = input - val evaluated = bound.gen(ctx) - - val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType) - val outputFields = serializer.map(_ transform { - case _: BoundReference => resultObj - }) - val resultVars = outputFields.map(_.gen(ctx)) - s""" - ${evaluated.code} - ${consume(ctx, resultVars)} - """ + val resultVars = bound.genCode(ctx) :: Nil + + consume(ctx, resultVars) } override protected def doExecute(): RDD[InternalRow] = { @@ -191,9 +197,10 @@ case class MapElements( case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i) case _ => func.asInstanceOf[Any => Any] } + child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(deserializer, child.output) - val outputObject = generateToRow(serializer) + val getObject = unwrapObjectFromRow(child.output.head.dataType) + val outputObject = wrapObjectToRow(outputObjAttr.dataType) iter.map(row => outputObject(callFunc(getObject(row)))) } } @@ -216,15 +223,43 @@ case class AppendColumns( override protected def doExecute(): RDD[InternalRow] = { child.execute().mapPartitionsInternal { iter => - val getObject = generateToObject(deserializer, child.output) + val getObject = deserializeRowToObject(deserializer, child.output) val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema) - val outputObject = generateToRow(serializer) + val outputObject = serializeObjectToRow(serializer) iter.map { row => val newColumns = outputObject(func(getObject(row))) + combiner.join(row.asInstanceOf[UnsafeRow], newColumns): InternalRow + } + } + } +} + +/** + * An optimized version of [[AppendColumns]], that can be executed on deserialized object directly. + */ +case class AppendColumnsWithObject( + func: Any => Any, + inputSerializer: Seq[NamedExpression], + newColumnsSerializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator { + + override def output: Seq[Attribute] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute) - // This operates on the assumption that we always serialize the result... - combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow + private def inputSchema = inputSerializer.map(_.toAttribute).toStructType + private def newColumnSchema = newColumnsSerializer.map(_.toAttribute).toStructType + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val getChildObject = unwrapObjectFromRow(child.output.head.dataType) + val outputChildObject = serializeObjectToRow(inputSerializer) + val outputNewColumnOjb = serializeObjectToRow(newColumnsSerializer) + val combiner = GenerateUnsafeRowJoiner.create(inputSchema, newColumnSchema) + + iter.map { row => + val childObj = getChildObject(row) + val newColumns = outputNewColumnOjb(func(childObj)) + combiner.join(outputChildObject(childObj), newColumns): InternalRow } } } @@ -232,19 +267,19 @@ case class AppendColumns( /** * Groups the input rows together and calls the function with each group and an iterator containing - * all elements in the group. The result of this function is encoded and flattened before - * being output. + * all elements in the group. The result of this function is flattened before being output. */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], keyDeserializer: Expression, valueDeserializer: Expression, - serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], + outputObjAttr: Attribute, child: SparkPlan) extends UnaryNode with ObjectOperator { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(groupingAttributes) :: Nil @@ -256,9 +291,9 @@ case class MapGroups( child.execute().mapPartitionsInternal { iter => val grouped = GroupedIterator(iter, groupingAttributes, child.output) - val getKey = generateToObject(keyDeserializer, groupingAttributes) - val getValue = generateToObject(valueDeserializer, dataAttributes) - val outputObject = generateToRow(serializer) + val getKey = deserializeRowToObject(keyDeserializer, groupingAttributes) + val getValue = deserializeRowToObject(valueDeserializer, dataAttributes) + val outputObject = wrapObjectToRow(outputObjAttr.dataType) grouped.flatMap { case (key, rowIter) => val result = func( @@ -273,22 +308,23 @@ case class MapGroups( /** * Co-groups the data from left and right children, and calls the function with each group and 2 * iterators containing all elements in the group from left and right side. - * The result of this function is encoded and flattened before being output. + * The result of this function is flattened before being output. */ case class CoGroup( func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any], keyDeserializer: Expression, leftDeserializer: Expression, rightDeserializer: Expression, - serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], + outputObjAttr: Attribute, left: SparkPlan, right: SparkPlan) extends BinaryNode with ObjectOperator { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) + override def output: Seq[Attribute] = outputObjAttr :: Nil + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil @@ -301,10 +337,10 @@ case class CoGroup( val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) - val getKey = generateToObject(keyDeserializer, leftGroup) - val getLeft = generateToObject(leftDeserializer, leftAttr) - val getRight = generateToObject(rightDeserializer, rightAttr) - val outputObject = generateToRow(serializer) + val getKey = deserializeRowToObject(keyDeserializer, leftGroup) + val getLeft = deserializeRowToObject(leftDeserializer, leftAttr) + val getRight = deserializeRowToObject(rightDeserializer, rightAttr) + val outputObject = wrapObjectToRow(outputObjAttr.dataType) new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { case (key, leftResult, rightResult) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index c9ab40a0a9abf..c49f173ad6dff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -86,7 +86,7 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c } }.toArray }.toArray - val projection = newMutableProjection(allInputs, child.output)() + val projection = newMutableProjection(allInputs, child.output) val schema = StructType(dataTypes.map(dt => StructField("", dt))) val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 6921ae584dd84..4f722a514ba69 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -23,7 +23,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, SQLContext} -import org.apache.spark.sql.sources.FileFormat +import org.apache.spark.sql.execution.datasources.FileFormat object FileStreamSink { // The name of the subdirectory that is used to store metadata about which files are valid. @@ -44,28 +44,33 @@ class FileStreamSink( private val basePath = new Path(path) private val logPath = new Path(basePath, FileStreamSink.metadataDir) - private val fileLog = new HDFSMetadataLog[Seq[String]](sqlContext, logPath.toUri.toString) + private val fileLog = new FileStreamSinkLog(sqlContext, logPath.toUri.toString) + private val fs = basePath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) override def addBatch(batchId: Long, data: DataFrame): Unit = { - if (fileLog.get(batchId).isDefined) { + if (batchId <= fileLog.getLatest().map(_._1).getOrElse(-1L)) { logInfo(s"Skipping already committed batch $batchId") } else { - val files = writeFiles(data) + val files = fs.listStatus(writeFiles(data)).map { f => + SinkFileStatus( + path = f.getPath.toUri.toString, + size = f.getLen, + isDir = f.isDirectory, + modificationTime = f.getModificationTime, + blockReplication = f.getReplication, + blockSize = f.getBlockSize, + action = FileStreamSinkLog.ADD_ACTION) + } if (fileLog.add(batchId, files)) { logInfo(s"Committed batch $batchId") } else { - logWarning(s"Race while writing batch $batchId") + throw new IllegalStateException(s"Race while writing batch $batchId") } } } /** Writes the [[DataFrame]] to a UUID-named dir, returning the list of files paths. */ - private def writeFiles(data: DataFrame): Seq[String] = { - val ctx = sqlContext - val outputDir = path - val format = fileFormat - val schema = data.schema - + private def writeFiles(data: DataFrame): Array[Path] = { val file = new Path(basePath, UUID.randomUUID().toString).toUri.toString data.write.parquet(file) sqlContext.read @@ -74,7 +79,6 @@ class FileStreamSink( .inputFiles .map(new Path(_)) .filterNot(_.getName.startsWith("_")) - .map(_.toUri.toString) } override def toString: String = s"FileSink[$path]" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala new file mode 100644 index 0000000000000..6c5449a928293 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLog.scala @@ -0,0 +1,278 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.io.IOException +import java.nio.charset.StandardCharsets.UTF_8 + +import org.apache.hadoop.fs.{FileStatus, Path, PathFilter} +import org.json4s.NoTypeHints +import org.json4s.jackson.Serialization +import org.json4s.jackson.Serialization.{read, write} + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.internal.SQLConf + +/** + * The status of a file outputted by [[FileStreamSink]]. A file is visible only if it appears in + * the sink log and its action is not "delete". + * + * @param path the file path. + * @param size the file size. + * @param isDir whether this file is a directory. + * @param modificationTime the file last modification time. + * @param blockReplication the block replication. + * @param blockSize the block size. + * @param action the file action. Must be either "add" or "delete". + */ +case class SinkFileStatus( + path: String, + size: Long, + isDir: Boolean, + modificationTime: Long, + blockReplication: Int, + blockSize: Long, + action: String) { + + def toFileStatus: FileStatus = { + new FileStatus(size, isDir, blockReplication, blockSize, modificationTime, new Path(path)) + } +} + +/** + * A special log for [[FileStreamSink]]. It will write one log file for each batch. The first line + * of the log file is the version number, and there are multiple JSON lines following. Each JSON + * line is a JSON format of [[SinkFileStatus]]. + * + * As reading from many small files is usually pretty slow, [[FileStreamSinkLog]] will compact log + * files every "spark.sql.sink.file.log.compactLen" batches into a big file. When doing a + * compaction, it will read all old log files and merge them with the new batch. During the + * compaction, it will also delete the files that are deleted (marked by [[SinkFileStatus.action]]). + * When the reader uses `allFiles` to list all files, this method only returns the visible files + * (drops the deleted files). + */ +class FileStreamSinkLog(sqlContext: SQLContext, path: String) + extends HDFSMetadataLog[Seq[SinkFileStatus]](sqlContext, path) { + + import FileStreamSinkLog._ + + private implicit val formats = Serialization.formats(NoTypeHints) + + /** + * If we delete the old files after compaction at once, there is a race condition in S3: other + * processes may see the old files are deleted but still cannot see the compaction file using + * "list". The `allFiles` handles this by looking for the next compaction file directly, however, + * a live lock may happen if the compaction happens too frequently: one processing keeps deleting + * old files while another one keeps retrying. Setting a reasonable cleanup delay could avoid it. + */ + private val fileCleanupDelayMs = sqlContext.getConf(SQLConf.FILE_SINK_LOG_CLEANUP_DELAY) + + private val isDeletingExpiredLog = sqlContext.getConf(SQLConf.FILE_SINK_LOG_DELETION) + + private val compactInterval = sqlContext.getConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL) + require(compactInterval > 0, + s"Please set ${SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key} (was $compactInterval) " + + "to a positive value.") + + override def batchIdToPath(batchId: Long): Path = { + if (isCompactionBatch(batchId, compactInterval)) { + new Path(metadataPath, s"$batchId$COMPACT_FILE_SUFFIX") + } else { + new Path(metadataPath, batchId.toString) + } + } + + override def pathToBatchId(path: Path): Long = { + getBatchIdFromFileName(path.getName) + } + + override def isBatchFile(path: Path): Boolean = { + try { + getBatchIdFromFileName(path.getName) + true + } catch { + case _: NumberFormatException => false + } + } + + override def serialize(logData: Seq[SinkFileStatus]): Array[Byte] = { + (VERSION +: logData.map(write(_))).mkString("\n").getBytes(UTF_8) + } + + override def deserialize(bytes: Array[Byte]): Seq[SinkFileStatus] = { + val lines = new String(bytes, UTF_8).split("\n") + if (lines.length == 0) { + throw new IllegalStateException("Incomplete log file") + } + val version = lines(0) + if (version != VERSION) { + throw new IllegalStateException(s"Unknown log version: ${version}") + } + lines.toSeq.slice(1, lines.length).map(read[SinkFileStatus](_)) + } + + override def add(batchId: Long, logs: Seq[SinkFileStatus]): Boolean = { + if (isCompactionBatch(batchId, compactInterval)) { + compact(batchId, logs) + } else { + super.add(batchId, logs) + } + } + + /** + * Returns all files except the deleted ones. + */ + def allFiles(): Array[SinkFileStatus] = { + var latestId = getLatest().map(_._1).getOrElse(-1L) + // There is a race condition when `FileStreamSink` is deleting old files and `StreamFileCatalog` + // is calling this method. This loop will retry the reading to deal with the + // race condition. + while (true) { + if (latestId >= 0) { + val startId = getAllValidBatches(latestId, compactInterval)(0) + try { + val logs = get(Some(startId), Some(latestId)).flatMap(_._2) + return compactLogs(logs).toArray + } catch { + case e: IOException => + // Another process using `FileStreamSink` may delete the batch files when + // `StreamFileCatalog` are reading. However, it only happens when a compaction is + // deleting old files. If so, let's try the next compaction batch and we should find it. + // Otherwise, this is a real IO issue and we should throw it. + latestId = nextCompactionBatchId(latestId, compactInterval) + get(latestId).getOrElse { + throw e + } + } + } else { + return Array.empty + } + } + Array.empty + } + + /** + * Compacts all logs before `batchId` plus the provided `logs`, and writes them into the + * corresponding `batchId` file. It will delete expired files as well if enabled. + */ + private def compact(batchId: Long, logs: Seq[SinkFileStatus]): Boolean = { + val validBatches = getValidBatchesBeforeCompactionBatch(batchId, compactInterval) + val allLogs = validBatches.flatMap(batchId => get(batchId)).flatten ++ logs + if (super.add(batchId, compactLogs(allLogs))) { + if (isDeletingExpiredLog) { + deleteExpiredLog(batchId) + } + true + } else { + // Return false as there is another writer. + false + } + } + + /** + * Since all logs before `compactionBatchId` are compacted and written into the + * `compactionBatchId` log file, they can be removed. However, due to the eventual consistency of + * S3, the compaction file may not be seen by other processes at once. So we only delete files + * created `fileCleanupDelayMs` milliseconds ago. + */ + private def deleteExpiredLog(compactionBatchId: Long): Unit = { + val expiredTime = System.currentTimeMillis() - fileCleanupDelayMs + fileManager.list(metadataPath, new PathFilter { + override def accept(path: Path): Boolean = { + try { + val batchId = getBatchIdFromFileName(path.getName) + batchId < compactionBatchId + } catch { + case _: NumberFormatException => + false + } + } + }).foreach { f => + if (f.getModificationTime <= expiredTime) { + fileManager.delete(f.getPath) + } + } + } +} + +object FileStreamSinkLog { + val VERSION = "v1" + val COMPACT_FILE_SUFFIX = ".compact" + val DELETE_ACTION = "delete" + val ADD_ACTION = "add" + + def getBatchIdFromFileName(fileName: String): Long = { + fileName.stripSuffix(COMPACT_FILE_SUFFIX).toLong + } + + /** + * Returns if this is a compaction batch. FileStreamSinkLog will compact old logs every + * `compactInterval` commits. + * + * E.g., if `compactInterval` is 3, then 2, 5, 8, ... are all compaction batches. + */ + def isCompactionBatch(batchId: Long, compactInterval: Int): Boolean = { + (batchId + 1) % compactInterval == 0 + } + + /** + * Returns all valid batches before the specified `compactionBatchId`. They contain all logs we + * need to do a new compaction. + * + * E.g., if `compactInterval` is 3 and `compactionBatchId` is 5, this method should returns + * `Seq(2, 3, 4)` (Note: it includes the previous compaction batch 2). + */ + def getValidBatchesBeforeCompactionBatch( + compactionBatchId: Long, + compactInterval: Int): Seq[Long] = { + assert(isCompactionBatch(compactionBatchId, compactInterval), + s"$compactionBatchId is not a compaction batch") + (math.max(0, compactionBatchId - compactInterval)) until compactionBatchId + } + + /** + * Returns all necessary logs before `batchId` (inclusive). If `batchId` is a compaction, just + * return itself. Otherwise, it will find the previous compaction batch and return all batches + * between it and `batchId`. + */ + def getAllValidBatches(batchId: Long, compactInterval: Long): Seq[Long] = { + assert(batchId >= 0) + val start = math.max(0, (batchId + 1) / compactInterval * compactInterval - 1) + start to batchId + } + + /** + * Removes all deleted files from logs. It assumes once one file is deleted, it won't be added to + * the log in future. + */ + def compactLogs(logs: Seq[SinkFileStatus]): Seq[SinkFileStatus] = { + val deletedFiles = logs.filter(_.action == DELETE_ACTION).map(_.path).toSet + if (deletedFiles.isEmpty) { + logs + } else { + logs.filter(f => !deletedFiles.contains(f.path)) + } + } + + /** + * Returns the next compaction batch id after `batchId`. + */ + def nextCompactionBatchId(batchId: Long, compactInterval: Long): Long = { + (batchId + compactInterval + 1) / compactInterval * compactInterval - 1 + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 1b70055f346b3..6448cb6e902fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -39,7 +39,7 @@ class FileStreamSource( providerName: String, dataFrameBuilder: Array[String] => DataFrame) extends Source with Logging { - private val fs = FileSystem.get(sqlContext.sparkContext.hadoopConfiguration) + private val fs = new Path(path).getFileSystem(sqlContext.sparkContext.hadoopConfiguration) private val metadataLog = new HDFSMetadataLog[Seq[String]](sqlContext, metadataPath) private var maxBatchId = metadataLog.getLatest().map(_._1).getOrElse(-1L) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala index 9663fee18d364..b52f7a28b408a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/HDFSMetadataLog.scala @@ -51,8 +51,8 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) import HDFSMetadataLog._ - private val metadataPath = new Path(path) - private val fileManager = createFileManager() + val metadataPath = new Path(path) + protected val fileManager = createFileManager() if (!fileManager.exists(metadataPath)) { fileManager.mkdirs(metadataPath) @@ -62,7 +62,21 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) * A `PathFilter` to filter only batch files */ private val batchFilesFilter = new PathFilter { - override def accept(path: Path): Boolean = try { + override def accept(path: Path): Boolean = isBatchFile(path) + } + + private val serializer = new JavaSerializer(sqlContext.sparkContext.conf).newInstance() + + protected def batchIdToPath(batchId: Long): Path = { + new Path(metadataPath, batchId.toString) + } + + protected def pathToBatchId(path: Path) = { + path.getName.toLong + } + + protected def isBatchFile(path: Path) = { + try { path.getName.toLong true } catch { @@ -70,18 +84,19 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) } } - private val serializer = new JavaSerializer(sqlContext.sparkContext.conf).newInstance() + protected def serialize(metadata: T): Array[Byte] = { + JavaUtils.bufferToArray(serializer.serialize(metadata)) + } - private def batchFile(batchId: Long): Path = { - new Path(metadataPath, batchId.toString) + protected def deserialize(bytes: Array[Byte]): T = { + serializer.deserialize[T](ByteBuffer.wrap(bytes)) } override def add(batchId: Long, metadata: T): Boolean = { get(batchId).map(_ => false).getOrElse { // Only write metadata when the batch has not yet been written. - val buffer = serializer.serialize(metadata) try { - writeBatch(batchId, JavaUtils.bufferToArray(buffer)) + writeBatch(batchId, serialize(metadata)) true } catch { case e: IOException if "java.lang.InterruptedException" == e.getMessage => @@ -113,8 +128,8 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) try { // Try to commit the batch // It will fail if there is an existing file (someone has committed the batch) - logDebug(s"Attempting to write log #${batchFile(batchId)}") - fileManager.rename(tempPath, batchFile(batchId)) + logDebug(s"Attempting to write log #${batchIdToPath(batchId)}") + fileManager.rename(tempPath, batchIdToPath(batchId)) return } catch { case e: IOException if isFileAlreadyExistsException(e) => @@ -158,11 +173,11 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) } override def get(batchId: Long): Option[T] = { - val batchMetadataFile = batchFile(batchId) + val batchMetadataFile = batchIdToPath(batchId) if (fileManager.exists(batchMetadataFile)) { val input = fileManager.open(batchMetadataFile) val bytes = IOUtils.toByteArray(input) - Some(serializer.deserialize[T](ByteBuffer.wrap(bytes))) + Some(deserialize(bytes)) } else { logDebug(s"Unable to find batch $batchMetadataFile") None @@ -172,7 +187,7 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) override def get(startId: Option[Long], endId: Option[Long]): Array[(Long, T)] = { val files = fileManager.list(metadataPath, batchFilesFilter) val batchIds = files - .map(_.getPath.getName.toLong) + .map(f => pathToBatchId(f.getPath)) .filter { batchId => (endId.isEmpty || batchId <= endId.get) && (startId.isEmpty || batchId >= startId.get) } @@ -184,7 +199,7 @@ class HDFSMetadataLog[T: ClassTag](sqlContext: SQLContext, path: String) override def getLatest(): Option[(Long, T)] = { val batchIds = fileManager.list(metadataPath, batchFilesFilter) - .map(_.getPath.getName.toLong) + .map(f => pathToBatchId(f.getPath)) .sorted .reverse for (batchId <- batchIds) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala index aaced49dd16ce..81244ed874498 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -18,9 +18,11 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.analysis.{OutputMode, UnsupportedOperationChecker} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryNode} +import org.apache.spark.sql.internal.SQLConf /** * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] @@ -29,6 +31,7 @@ import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, class IncrementalExecution( ctx: SQLContext, logicalPlan: LogicalPlan, + outputMode: OutputMode, checkpointLocation: String, currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) { @@ -69,4 +72,7 @@ class IncrementalExecution( } override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations + + /** No need assert supported, as this check has already been done */ + override def assertSupported(): Unit = { } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 87dd27a2b1aed..2a1fa1ba627c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.analysis.OutputMode import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} @@ -48,6 +49,7 @@ class StreamExecution( checkpointRoot: String, private[sql] val logicalPlan: LogicalPlan, val sink: Sink, + val outputMode: OutputMode, val trigger: Trigger) extends ContinuousQuery with Logging { /** An monitor used to wait/notify when batches complete. */ @@ -314,8 +316,13 @@ class StreamExecution( } val optimizerStart = System.nanoTime() - lastExecution = - new IncrementalExecution(sqlContext, newPlan, checkpointFile("state"), currentBatchId) + lastExecution = new IncrementalExecution( + sqlContext, + newPlan, + outputMode, + checkpointFile("state"), + currentBatchId) + lastExecution.executedPlan val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 logDebug(s"Optimized batch in ${optimizerTime}ms") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala index b8d69b18450cf..95b5129351ff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamFileCatalog.scala @@ -23,14 +23,13 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.execution.datasources.PartitionSpec -import org.apache.spark.sql.sources.{FileCatalog, Partition} +import org.apache.spark.sql.execution.datasources.{FileCatalog, Partition, PartitionSpec} import org.apache.spark.sql.types.StructType class StreamFileCatalog(sqlContext: SQLContext, path: Path) extends FileCatalog with Logging { val metadataDirectory = new Path(path, FileStreamSink.metadataDir) logInfo(s"Reading streaming file log from $metadataDirectory") - val metadataLog = new HDFSMetadataLog[Seq[String]](sqlContext, metadataDirectory.toUri.toString) + val metadataLog = new FileStreamSinkLog(sqlContext, metadataDirectory.toUri.toString) val fs = path.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) override def paths: Seq[Path] = path :: Nil @@ -54,6 +53,6 @@ class StreamFileCatalog(sqlContext: SQLContext, path: Path) extends FileCatalog override def refresh(): Unit = {} override def allFiles(): Seq[FileStatus] = { - fs.listStatus(metadataLog.get(None, None).flatMap(_._2).map(new Path(_))) + metadataLog.allFiles().map(_.toFileStatus) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index d2872e49ce28a..c29291eb584a0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -37,6 +37,7 @@ object StreamingRelation { */ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: Seq[Attribute]) extends LeafNode { + override def isStreaming: Boolean = true override def toString: String = sourceName } @@ -45,6 +46,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ case class StreamingExecutionRelation(source: Source, output: Seq[Attribute]) extends LeafNode { + override def isStreaming: Boolean = true override def toString: String = source.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 4b3091ba22c60..71b6a97852966 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -18,8 +18,9 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.catalyst.{expressions, InternalRow} -import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule @@ -42,6 +43,7 @@ case class ScalarSubquery( override def plan: SparkPlan = Subquery(simpleString, executedPlan) override def dataType: DataType = executedPlan.schema.fields.head.dataType + override def children: Seq[Expression] = Nil override def nullable: Boolean = true override def toString: String = s"subquery#${exprId.id}" @@ -54,8 +56,8 @@ case class ScalarSubquery( override def eval(input: InternalRow): Any = result - override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - Literal.create(result, dataType).genCode(ctx, ev) + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + Literal.create(result, dataType).doGenCode(ctx, ev) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 7da8379c9aa9a..baae9dd2d5e3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.expressions -import org.apache.spark.sql.{DataFrame, Dataset, Encoder, TypedColumn} +import org.apache.spark.sql.{Dataset, Encoder, TypedColumn} import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression /** - * A base class for user-defined aggregations, which can be used in [[DataFrame]] and [[Dataset]] - * operations to take all of the elements of a group and reduce them to a single value. + * A base class for user-defined aggregations, which can be used in [[Dataset]] operations to take + * all of the elements of a group and reduce them to a single value. * * For example, the following aggregator extracts an `int` from a specific class and adds them up: * {{{ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 223122300dbb3..8e2e94669b8c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1776,6 +1776,23 @@ object functions { */ def round(e: Column, scale: Int): Column = withExpr { Round(e.expr, Literal(scale)) } + /** + * Returns the value of the column `e` rounded to 0 decimal places with HALF_EVEN round mode. + * + * @group math_funcs + * @since 2.0.0 + */ + def bround(e: Column): Column = bround(e, 0) + + /** + * Round the value of `e` to `scale` decimal places with HALF_EVEN round mode + * if `scale` >= 0 or at integral part when `scale` < 0. + * + * @group math_funcs + * @since 2.0.0 + */ + def bround(e: Column, scale: Int): Column = withExpr { BRound(e.expr, Literal(scale)) } + /** * Shift the given value numBits left. If the given value is a long value, this function * will return a long value else it will return an integer value. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2f9d63c2e8134..80e2c1986d758 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.internal import java.util.{NoSuchElementException, Properties} +import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ import scala.collection.immutable @@ -51,6 +52,12 @@ object SQLConf { } + val OPTIMIZER_MAX_ITERATIONS = SQLConfigBuilder("spark.sql.optimizer.maxIterations") + .internal() + .doc("The max number of iterations the optimizer and analyzer runs") + .intConf + .createWithDefault(100) + val ALLOW_MULTIPLE_CONTEXTS = SQLConfigBuilder("spark.sql.allowMultipleContexts") .doc("When set to true, creating multiple SQLContexts/HiveContexts is allowed. " + "When set to false, only one SQLContext/HiveContext is allowed to be created " + @@ -396,6 +403,12 @@ object SQLConf { .intConf .createWithDefault(200) + val MAX_CASES_BRANCHES = SQLConfigBuilder("spark.sql.codegen.maxCaseBranches") + .internal() + .doc("The maximum number of switches supported with codegen.") + .intConf + .createWithDefault(20) + val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") .longConf @@ -436,6 +449,42 @@ object SQLConf { .stringConf .createOptional + val UNSUPPORTED_OPERATION_CHECK_ENABLED = + SQLConfigBuilder("spark.sql.streaming.unsupportedOperationCheck") + .internal() + .doc("When true, the logical plan for continuous query will be checked for unsupported" + + " operations.") + .booleanConf + .createWithDefault(true) + + // TODO: This is still WIP and shouldn't be turned on without extensive test coverage + val COLUMNAR_AGGREGATE_MAP_ENABLED = SQLConfigBuilder("spark.sql.codegen.aggregate.map.enabled") + .internal() + .doc("When true, aggregate with keys use an in-memory columnar map to speed up execution.") + .booleanConf + .createWithDefault(false) + + val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion") + .internal() + .doc("Whether to delete the expired log files in file stream sink.") + .booleanConf + .createWithDefault(true) + + val FILE_SINK_LOG_COMPACT_INTERVAL = + SQLConfigBuilder("spark.sql.streaming.fileSink.log.compactInterval") + .internal() + .doc("Number of log files after which all the previous files " + + "are compacted into the next log file.") + .intConf + .createWithDefault(10) + + val FILE_SINK_LOG_CLEANUP_DELAY = + SQLConfigBuilder("spark.sql.streaming.fileSink.log.cleanupDelay") + .internal() + .doc("How long in milliseconds a file is guaranteed to be visible for all readers.") + .timeConf(TimeUnit.MILLISECONDS) + .createWithDefault(60 * 1000L) // 10 minutes + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" val EXTERNAL_SORT = "spark.sql.planner.externalSort" @@ -466,6 +515,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { /** ************************ Spark SQL Params/Hints ******************* */ + def optimizerMaxIterations: Int = getConf(OPTIMIZER_MAX_ITERATIONS) + def checkpointLocation: String = getConf(CHECKPOINT_LOCATION) def filesMaxPartitionBytes: Long = getConf(FILES_MAX_PARTITION_BYTES) @@ -506,6 +557,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) + def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) @@ -560,6 +613,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) + def columnarAggregateMapEnabled: Boolean = getConf(COLUMNAR_AGGREGATE_MAP_ENABLED) + override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 69e3358d4eb9e..42915d5887f44 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -17,17 +17,22 @@ package org.apache.spark.sql.internal +import java.util.Properties + +import scala.collection.JavaConverters._ + +import org.apache.spark.internal.config.ConfigEntry import org.apache.spark.sql.{ContinuousQueryManager, ExperimentalMethods, SQLContext, UDFRegistration} import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, PreInsertCastAndRename, ResolveDataSource} -import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} import org.apache.spark.sql.util.ExecutionListenerManager + /** * A class that holds all session-specific state in a given [[SQLContext]]. */ @@ -39,7 +44,10 @@ private[sql] class SessionState(ctx: SQLContext) { /** * SQL-specific key-value configurations. */ - lazy val conf = new SQLConf + lazy val conf: SQLConf = new SQLConf + + // Automatically extract `spark.sql.*` entries and put it in our SQLConf + setConf(SQLContext.getSQLProperties(ctx.sparkContext.getConf)) lazy val experimentalMethods = new ExperimentalMethods @@ -80,7 +88,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Logical query plan optimizer. */ - lazy val optimizer: Optimizer = new SparkOptimizer(experimentalMethods) + lazy val optimizer: Optimizer = new SparkOptimizer(catalog, conf, experimentalMethods) /** * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. @@ -103,5 +111,45 @@ private[sql] class SessionState(ctx: SQLContext) { * Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s. */ lazy val continuousQueryManager: ContinuousQueryManager = new ContinuousQueryManager(ctx) -} + + // ------------------------------------------------------ + // Helper methods, partially leftover from pre-2.0 days + // ------------------------------------------------------ + + def executePlan(plan: LogicalPlan): QueryExecution = new QueryExecution(ctx, plan) + + def refreshTable(tableName: String): Unit = { + catalog.refreshTable(sqlParser.parseTableIdentifier(tableName)) + } + + def invalidateTable(tableName: String): Unit = { + catalog.invalidateTable(sqlParser.parseTableIdentifier(tableName)) + } + + final def setConf(properties: Properties): Unit = { + properties.asScala.foreach { case (k, v) => setConf(k, v) } + } + + final def setConf[T](entry: ConfigEntry[T], value: T): Unit = { + conf.setConf(entry, value) + setConf(entry.key, entry.stringConverter(value)) + } + + def setConf(key: String, value: String): Unit = { + conf.setConfString(key, value) + } + + def addJar(path: String): Unit = { + ctx.sparkContext.addJar(path) + } + + def analyze(tableName: String): Unit = { + throw new UnsupportedOperationException + } + + def runNativeSql(sql: String): Seq[String] = { + throw new UnsupportedOperationException + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala new file mode 100644 index 0000000000000..9a30c7de1f8f2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.internal + +import org.apache.spark.SparkContext +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog} +import org.apache.spark.sql.execution.CacheManager +import org.apache.spark.sql.execution.ui.SQLListener + + +/** + * A class that holds all state shared across sessions in a given [[SQLContext]]. + */ +private[sql] class SharedState(val sparkContext: SparkContext) { + + /** + * Class for caching query results reused in future executions. + */ + val cacheManager: CacheManager = new CacheManager + + /** + * A listener for SQL-specific [[org.apache.spark.scheduler.SparkListenerEvent]]s. + */ + val listener: SQLListener = SQLContext.createListenerAndUI(sparkContext) + + /** + * A catalog that interacts with external systems. + */ + lazy val externalCatalog: ExternalCatalog = new InMemoryCatalog + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 4b9bf8daae37c..26285bde31ad0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -17,28 +17,13 @@ package org.apache.spark.sql.sources -import scala.collection.mutable -import scala.util.Try - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} -import org.apache.hadoop.mapred.{FileInputFormat, JobConf} -import org.apache.hadoop.mapreduce.{Job, TaskAttemptContext} - -import org.apache.spark.SparkContext import org.apache.spark.annotation.{DeveloperApi, Experimental} -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.{expressions, CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.execution.FileRelation -import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.streaming.{Sink, Source} -import org.apache.spark.sql.types.{StringType, StructType} -import org.apache.spark.util.SerializableConfiguration -import org.apache.spark.util.collection.BitSet +import org.apache.spark.sql.types.StructType /** * ::DeveloperApi:: @@ -318,496 +303,3 @@ trait InsertableRelation { trait CatalystScan { def buildScan(requiredColumns: Seq[Attribute], filters: Seq[Expression]): RDD[Row] } - -/** - * ::Experimental:: - * A factory that produces [[OutputWriter]]s. A new [[OutputWriterFactory]] is created on driver - * side for each write job issued when writing to a [[HadoopFsRelation]], and then gets serialized - * to executor side to create actual [[OutputWriter]]s on the fly. - * - * @since 1.4.0 - */ -@Experimental -abstract class OutputWriterFactory extends Serializable { - /** - * When writing to a [[HadoopFsRelation]], this method gets called by each task on executor side - * to instantiate new [[OutputWriter]]s. - * - * @param path Path of the file to which this [[OutputWriter]] is supposed to write. Note that - * this may not point to the final output file. For example, `FileOutputFormat` writes to - * temporary directories and then merge written files back to the final destination. In - * this case, `path` points to a temporary output file under the temporary directory. - * @param dataSchema Schema of the rows to be written. Partition columns are not included in the - * schema if the relation being written is partitioned. - * @param context The Hadoop MapReduce task context. - * @since 1.4.0 - */ - private[sql] def newInstance( - path: String, - bucketId: Option[Int], // TODO: This doesn't belong here... - dataSchema: StructType, - context: TaskAttemptContext): OutputWriter -} - -/** - * ::Experimental:: - * [[OutputWriter]] is used together with [[HadoopFsRelation]] for persisting rows to the - * underlying file system. Subclasses of [[OutputWriter]] must provide a zero-argument constructor. - * An [[OutputWriter]] instance is created and initialized when a new output file is opened on - * executor side. This instance is used to persist rows to this single output file. - * - * @since 1.4.0 - */ -@Experimental -abstract class OutputWriter { - /** - * Persists a single row. Invoked on the executor side. When writing to dynamically partitioned - * tables, dynamic partition columns are not included in rows to be written. - * - * @since 1.4.0 - */ - def write(row: Row): Unit - - /** - * Closes the [[OutputWriter]]. Invoked on the executor side after all rows are persisted, before - * the task output is committed. - * - * @since 1.4.0 - */ - def close(): Unit - - private var converter: InternalRow => Row = _ - - protected[sql] def initConverter(dataSchema: StructType) = { - converter = - CatalystTypeConverters.createToScalaConverter(dataSchema).asInstanceOf[InternalRow => Row] - } - - protected[sql] def writeInternal(row: InternalRow): Unit = { - write(converter(row)) - } -} - -/** - * Acts as a container for all of the metadata required to read from a datasource. All discovery, - * resolution and merging logic for schemas and partitions has been removed. - * - * @param location A [[FileCatalog]] that can enumerate the locations of all the files that comprise - * this relation. - * @param partitionSchema The schema of the columns (if any) that are used to partition the relation - * @param dataSchema The schema of any remaining columns. Note that if any partition columns are - * present in the actual data files as well, they are preserved. - * @param bucketSpec Describes the bucketing (hash-partitioning of the files by some column values). - * @param fileFormat A file format that can be used to read and write the data in files. - * @param options Configuration used when reading / writing data. - */ -case class HadoopFsRelation( - sqlContext: SQLContext, - location: FileCatalog, - partitionSchema: StructType, - dataSchema: StructType, - bucketSpec: Option[BucketSpec], - fileFormat: FileFormat, - options: Map[String, String]) extends BaseRelation with FileRelation { - - val schema: StructType = { - val dataSchemaColumnNames = dataSchema.map(_.name.toLowerCase).toSet - StructType(dataSchema ++ partitionSchema.filterNot { column => - dataSchemaColumnNames.contains(column.name.toLowerCase) - }) - } - - def partitionSchemaOption: Option[StructType] = - if (partitionSchema.isEmpty) None else Some(partitionSchema) - def partitionSpec: PartitionSpec = location.partitionSpec() - - def refresh(): Unit = location.refresh() - - override def toString: String = - s"HadoopFiles" - - /** Returns the list of files that will be read when scanning this relation. */ - override def inputFiles: Array[String] = - location.allFiles().map(_.getPath.toUri.toString).toArray - - override def sizeInBytes: Long = location.allFiles().map(_.getLen).sum -} - -/** - * Used to read and write data stored in files to/from the [[InternalRow]] format. - */ -trait FileFormat { - /** - * When possible, this method should return the schema of the given `files`. When the format - * does not support inference, or no valid files are given should return None. In these cases - * Spark will require that user specify the schema manually. - */ - def inferSchema( - sqlContext: SQLContext, - options: Map[String, String], - files: Seq[FileStatus]): Option[StructType] - - /** - * Prepares a read job and returns a potentially updated data source option [[Map]]. This method - * can be useful for collecting necessary global information for scanning input data. - */ - def prepareRead( - sqlContext: SQLContext, - options: Map[String, String], - files: Seq[FileStatus]): Map[String, String] = options - - /** - * Prepares a write job and returns an [[OutputWriterFactory]]. Client side job preparation can - * be put here. For example, user defined output committer can be configured here - * by setting the output committer class in the conf of spark.sql.sources.outputCommitterClass. - */ - def prepareWrite( - sqlContext: SQLContext, - job: Job, - options: Map[String, String], - dataSchema: StructType): OutputWriterFactory - - /** - * Returns whether this format support returning columnar batch or not. - * - * TODO: we should just have different traits for the different formats. - */ - def supportBatch(sqlContext: SQLContext, dataSchema: StructType): Boolean = { - false - } - - /** - * Returns a function that can be used to read a single file in as an Iterator of InternalRow. - * - * @param dataSchema The global data schema. It can be either specified by the user, or - * reconciled/merged from all underlying data files. If any partition columns - * are contained in the files, they are preserved in this schema. - * @param partitionSchema The schema of the partition column row that will be present in each - * PartitionedFile. These columns should be appended to the rows that - * are produced by the iterator. - * @param requiredSchema The schema of the data that should be output for each row. This may be a - * subset of the columns that are present in the file if column pruning has - * occurred. - * @param filters A set of filters than can optionally be used to reduce the number of rows output - * @param options A set of string -> string configuration options. - * @return - */ - def buildReader( - sqlContext: SQLContext, - dataSchema: StructType, - partitionSchema: StructType, - requiredSchema: StructType, - filters: Seq[Filter], - options: Map[String, String]): PartitionedFile => Iterator[InternalRow] = { - // TODO: Remove this default implementation when the other formats have been ported - // Until then we guard in [[FileSourceStrategy]] to only call this method on supported formats. - throw new UnsupportedOperationException(s"buildReader is not supported for $this") - } -} - -/** - * A collection of data files from a partitioned relation, along with the partition values in the - * form of an [[InternalRow]]. - */ -case class Partition(values: InternalRow, files: Seq[FileStatus]) - -/** - * An interface for objects capable of enumerating the files that comprise a relation as well - * as the partitioning characteristics of those files. - */ -trait FileCatalog { - def paths: Seq[Path] - - def partitionSpec(): PartitionSpec - - /** - * Returns all valid files grouped into partitions when the data is partitioned. If the data is - * unpartitioned, this will return a single partition with not partition values. - * - * @param filters the filters used to prune which partitions are returned. These filters must - * only refer to partition columns and this method will only return files - * where these predicates are guaranteed to evaluate to `true`. Thus, these - * filters will not need to be evaluated again on the returned data. - */ - def listFiles(filters: Seq[Expression]): Seq[Partition] - - def allFiles(): Seq[FileStatus] - - def getStatus(path: Path): Array[FileStatus] - - def refresh(): Unit -} - -/** - * A file catalog that caches metadata gathered by scanning all the files present in `paths` - * recursively. - * - * @param parameters as set of options to control discovery - * @param paths a list of paths to scan - * @param partitionSchema an optional partition schema that will be use to provide types for the - * discovered partitions - */ -class HDFSFileCatalog( - val sqlContext: SQLContext, - val parameters: Map[String, String], - val paths: Seq[Path], - val partitionSchema: Option[StructType]) - extends FileCatalog with Logging { - - private val hadoopConf = new Configuration(sqlContext.sparkContext.hadoopConfiguration) - - var leafFiles = mutable.LinkedHashMap.empty[Path, FileStatus] - var leafDirToChildrenFiles = mutable.Map.empty[Path, Array[FileStatus]] - var cachedPartitionSpec: PartitionSpec = _ - - def partitionSpec(): PartitionSpec = { - if (cachedPartitionSpec == null) { - cachedPartitionSpec = inferPartitioning(partitionSchema) - } - - cachedPartitionSpec - } - - refresh() - - override def listFiles(filters: Seq[Expression]): Seq[Partition] = { - if (partitionSpec().partitionColumns.isEmpty) { - Partition(InternalRow.empty, allFiles().filterNot(_.getPath.getName startsWith "_")) :: Nil - } else { - prunePartitions(filters, partitionSpec()).map { - case PartitionDirectory(values, path) => - Partition( - values, - getStatus(path).filterNot(_.getPath.getName startsWith "_")) - } - } - } - - protected def prunePartitions( - predicates: Seq[Expression], - partitionSpec: PartitionSpec): Seq[PartitionDirectory] = { - val PartitionSpec(partitionColumns, partitions) = partitionSpec - val partitionColumnNames = partitionColumns.map(_.name).toSet - val partitionPruningPredicates = predicates.filter { - _.references.map(_.name).toSet.subsetOf(partitionColumnNames) - } - - if (partitionPruningPredicates.nonEmpty) { - val predicate = partitionPruningPredicates.reduce(expressions.And) - - val boundPredicate = InterpretedPredicate.create(predicate.transform { - case a: AttributeReference => - val index = partitionColumns.indexWhere(a.name == _.name) - BoundReference(index, partitionColumns(index).dataType, nullable = true) - }) - - val selected = partitions.filter { - case PartitionDirectory(values, _) => boundPredicate(values) - } - logInfo { - val total = partitions.length - val selectedSize = selected.length - val percentPruned = (1 - selectedSize.toDouble / total.toDouble) * 100 - s"Selected $selectedSize partitions out of $total, pruned $percentPruned% partitions." - } - - selected - } else { - partitions - } - } - - def allFiles(): Seq[FileStatus] = leafFiles.values.toSeq - - def getStatus(path: Path): Array[FileStatus] = leafDirToChildrenFiles(path) - - private def listLeafFiles(paths: Seq[Path]): mutable.LinkedHashSet[FileStatus] = { - if (paths.length >= sqlContext.conf.parallelPartitionDiscoveryThreshold) { - HadoopFsRelation.listLeafFilesInParallel(paths, hadoopConf, sqlContext.sparkContext) - } else { - val statuses = paths.flatMap { path => - val fs = path.getFileSystem(hadoopConf) - logInfo(s"Listing $path on driver") - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(hadoopConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - if (pathFilter != null) { - Try(fs.listStatus(path, pathFilter)).getOrElse(Array.empty) - } else { - Try(fs.listStatus(path)).getOrElse(Array.empty) - } - }.filterNot { status => - val name = status.getPath.getName - HadoopFsRelation.shouldFilterOut(name) - } - - val (dirs, files) = statuses.partition(_.isDirectory) - - // It uses [[LinkedHashSet]] since the order of files can affect the results. (SPARK-11500) - if (dirs.isEmpty) { - mutable.LinkedHashSet(files: _*) - } else { - mutable.LinkedHashSet(files: _*) ++ listLeafFiles(dirs.map(_.getPath)) - } - } - } - - def inferPartitioning(schema: Option[StructType]): PartitionSpec = { - // We use leaf dirs containing data files to discover the schema. - val leafDirs = leafDirToChildrenFiles.keys.toSeq - schema match { - case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => - val spec = PartitioningUtils.parsePartitions( - leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference = false, - basePaths = basePaths) - - // Without auto inference, all of value in the `row` should be null or in StringType, - // we need to cast into the data type that user specified. - def castPartitionValuesToUserSchema(row: InternalRow) = { - InternalRow((0 until row.numFields).map { i => - Cast( - Literal.create(row.getUTF8String(i), StringType), - userProvidedSchema.fields(i).dataType).eval() - }: _*) - } - - PartitionSpec(userProvidedSchema, spec.partitions.map { part => - part.copy(values = castPartitionValuesToUserSchema(part.values)) - }) - case _ => - PartitioningUtils.parsePartitions( - leafDirs, - PartitioningUtils.DEFAULT_PARTITION_NAME, - typeInference = sqlContext.conf.partitionColumnTypeInferenceEnabled(), - basePaths = basePaths) - } - } - - /** - * Contains a set of paths that are considered as the base dirs of the input datasets. - * The partitioning discovery logic will make sure it will stop when it reaches any - * base path. By default, the paths of the dataset provided by users will be base paths. - * For example, if a user uses `sqlContext.read.parquet("/path/something=true/")`, the base path - * will be `/path/something=true/`, and the returned DataFrame will not contain a column of - * `something`. If users want to override the basePath. They can set `basePath` in the options - * to pass the new base path to the data source. - * For the above example, if the user-provided base path is `/path/`, the returned - * DataFrame will have the column of `something`. - */ - private def basePaths: Set[Path] = { - val userDefinedBasePath = parameters.get("basePath").map(basePath => Set(new Path(basePath))) - userDefinedBasePath.getOrElse { - // If the user does not provide basePath, we will just use paths. - paths.toSet - }.map { hdfsPath => - // Make the path qualified (consistent with listLeafFiles and listLeafFilesInParallel). - val fs = hdfsPath.getFileSystem(hadoopConf) - hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) - } - } - - def refresh(): Unit = { - val files = listLeafFiles(paths) - - leafFiles.clear() - leafDirToChildrenFiles.clear() - - leafFiles ++= files.map(f => f.getPath -> f) - leafDirToChildrenFiles ++= files.toArray.groupBy(_.getPath.getParent) - - cachedPartitionSpec = null - } - - override def equals(other: Any): Boolean = other match { - case hdfs: HDFSFileCatalog => paths.toSet == hdfs.paths.toSet - case _ => false - } - - override def hashCode(): Int = paths.toSet.hashCode() -} - -/** - * Helper methods for gathering metadata from HDFS. - */ -private[sql] object HadoopFsRelation extends Logging { - - /** Checks if we should filter out this path name. */ - def shouldFilterOut(pathName: String): Boolean = { - // TODO: We should try to filter out all files/dirs starting with "." or "_". - // The only reason that we are not doing it now is that Parquet needs to find those - // metadata files from leaf files returned by this methods. We should refactor - // this logic to not mix metadata files with data files. - pathName == "_SUCCESS" || pathName == "_temporary" || pathName.startsWith(".") - } - - // We don't filter files/directories whose name start with "_" except "_temporary" here, as - // specific data sources may take advantages over them (e.g. Parquet _metadata and - // _common_metadata files). "_temporary" directories are explicitly ignored since failed - // tasks/jobs may leave partial/corrupted data files there. Files and directories whose name - // start with "." are also ignored. - def listLeafFiles(fs: FileSystem, status: FileStatus): Array[FileStatus] = { - logInfo(s"Listing ${status.getPath}") - val name = status.getPath.getName.toLowerCase - if (shouldFilterOut(name)) { - Array.empty - } else { - // Dummy jobconf to get to the pathFilter defined in configuration - val jobConf = new JobConf(fs.getConf, this.getClass()) - val pathFilter = FileInputFormat.getInputPathFilter(jobConf) - val statuses = - if (pathFilter != null) { - val (dirs, files) = fs.listStatus(status.getPath, pathFilter).partition(_.isDirectory) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } else { - val (dirs, files) = fs.listStatus(status.getPath).partition(_.isDirectory) - files ++ dirs.flatMap(dir => listLeafFiles(fs, dir)) - } - statuses.filterNot(status => shouldFilterOut(status.getPath.getName)) - } - } - - // `FileStatus` is Writable but not serializable. What make it worse, somehow it doesn't play - // well with `SerializableWritable`. So there seems to be no way to serialize a `FileStatus`. - // Here we use `FakeFileStatus` to extract key components of a `FileStatus` to serialize it from - // executor side and reconstruct it on driver side. - case class FakeFileStatus( - path: String, - length: Long, - isDir: Boolean, - blockReplication: Short, - blockSize: Long, - modificationTime: Long, - accessTime: Long) - - def listLeafFilesInParallel( - paths: Seq[Path], - hadoopConf: Configuration, - sparkContext: SparkContext): mutable.LinkedHashSet[FileStatus] = { - logInfo(s"Listing leaf files and directories in parallel under: ${paths.mkString(", ")}") - - val serializableConfiguration = new SerializableConfiguration(hadoopConf) - val serializedPaths = paths.map(_.toString) - - val fakeStatuses = sparkContext.parallelize(serializedPaths).map(new Path(_)).flatMap { path => - val fs = path.getFileSystem(serializableConfiguration.value) - Try(listLeafFiles(fs, fs.getFileStatus(path))).getOrElse(Array.empty) - }.map { status => - FakeFileStatus( - status.getPath.toString, - status.getLen, - status.isDirectory, - status.getReplication, - status.getBlockSize, - status.getModificationTime, - status.getAccessTime) - }.collect() - - val hadoopFakeStatuses = fakeStatuses.map { f => - new FileStatus( - f.length, f.isDir, f.blockReplication, f.blockSize, f.modificationTime, new Path(f.path)) - } - mutable.LinkedHashSet(hadoopFakeStatuses: _*) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 3a7215ee39728..6eae3ed7ad6c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -73,6 +73,16 @@ object NameAgg extends Aggregator[AggData, String, String] { } +object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] { + def zero: Seq[Int] = Nil + def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b + def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2 + def finish(r: Seq[Int]): Seq[Int] = r + override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder() + override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder() +} + + class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) extends Aggregator[IN, OUT, OUT] { @@ -85,6 +95,15 @@ class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] } +object RowAgg extends Aggregator[Row, Int, Int] { + def zero: Int = 0 + def reduce(b: Int, a: Row): Int = a.getInt(0) + b + def merge(b1: Int, b2: Int): Int = b1 + b2 + def finish(r: Int): Int = r + override def bufferEncoder: Encoder[Int] = Encoders.scalaInt + override def outputEncoder: Encoder[Int] = Encoders.scalaInt +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { @@ -200,4 +219,17 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { (1279869254, "Some String")) } + test("aggregator in DataFrame/Dataset[Row]") { + val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil) + } + + test("SPARK-14675: ClassFormatError when use Seq as Aggregator buffer type") { + val ds = Seq(AggData(1, "a"), AggData(2, "a")).toDS() + + checkDataset( + ds.groupByKey(_.b).agg(SeqAgg.toColumn), + "a" -> Seq(1, 2) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 5f3dd906feb0f..ae9fb80c68f42 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -18,6 +18,9 @@ package org.apache.spark.sql import org.apache.spark.SparkContext +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.sql.expressions.scala.typed +import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.StringType import org.apache.spark.util.Benchmark @@ -33,16 +36,17 @@ object DatasetBenchmark { val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back map", numRows) - val func = (d: Data) => Data(d.l + 1, d.s) - benchmark.addCase("Dataset") { iter => - var res = df.as[Data] + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd var i = 0 while (i < numChains) { - res = res.map(func) + res = rdd.map(func) i += 1 } - res.queryExecution.toRdd.foreach(_ => Unit) + res.foreach(_ => Unit) } benchmark.addCase("DataFrame") { iter => @@ -55,15 +59,14 @@ object DatasetBenchmark { res.queryExecution.toRdd.foreach(_ => Unit) } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) - benchmark.addCase("RDD") { iter => - var res = rdd + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] var i = 0 while (i < numChains) { - res = rdd.map(func) + res = res.map(func) i += 1 } - res.foreach(_ => Unit) + res.queryExecution.toRdd.foreach(_ => Unit) } benchmark @@ -74,19 +77,20 @@ object DatasetBenchmark { val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) val benchmark = new Benchmark("back-to-back filter", numRows) - val func = (d: Data, i: Int) => d.l % (100L + i) == 0L val funcs = 0.until(numChains).map { i => (d: Data) => func(d, i) } - benchmark.addCase("Dataset") { iter => - var res = df.as[Data] + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd var i = 0 while (i < numChains) { - res = res.filter(funcs(i)) + res = rdd.filter(funcs(i)) i += 1 } - res.queryExecution.toRdd.foreach(_ => Unit) + res.foreach(_ => Unit) } benchmark.addCase("DataFrame") { iter => @@ -99,15 +103,54 @@ object DatasetBenchmark { res.queryExecution.toRdd.foreach(_ => Unit) } - val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) - benchmark.addCase("RDD") { iter => - var res = rdd + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] var i = 0 while (i < numChains) { - res = rdd.filter(funcs(i)) + res = res.filter(funcs(i)) i += 1 } - res.foreach(_ => Unit) + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark + } + + object ComplexAggregator extends Aggregator[Data, Data, Long] { + override def zero: Data = Data(0, "") + + override def reduce(b: Data, a: Data): Data = Data(b.l + a.l, "") + + override def finish(reduction: Data): Long = reduction.l + + override def merge(b1: Data, b2: Data): Data = Data(b1.l + b2.l, "") + + override def bufferEncoder: Encoder[Data] = Encoders.product[Data] + + override def outputEncoder: Encoder[Long] = Encoders.scalaLong + } + + def aggregate(sqlContext: SQLContext, numRows: Long): Benchmark = { + import sqlContext.implicits._ + + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val benchmark = new Benchmark("aggregate", numRows) + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD sum") { iter => + rdd.aggregate(0L)(_ + _.l, _ + _) + } + + benchmark.addCase("DataFrame sum") { iter => + df.select(sum($"l")).queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset sum using Aggregator") { iter => + df.as[Data].select(typed.sumLong((d: Data) => d.l)).queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("Dataset complex Aggregator") { iter => + df.as[Data].select(ComplexAggregator.toColumn).queryExecution.toRdd.foreach(_ => Unit) } benchmark @@ -117,30 +160,45 @@ object DatasetBenchmark { val sparkContext = new SparkContext("local[*]", "Dataset benchmark") val sqlContext = new SQLContext(sparkContext) - val numRows = 10000000 + val numRows = 100000000 val numChains = 10 val benchmark = backToBackMap(sqlContext, numRows, numChains) val benchmark2 = backToBackFilter(sqlContext, numRows, numChains) + val benchmark3 = aggregate(sqlContext, numRows) /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Dataset 902 / 995 11.1 90.2 1.0X - DataFrame 132 / 167 75.5 13.2 6.8X - RDD 216 / 237 46.3 21.6 4.2X + RDD 1935 / 2105 51.7 19.3 1.0X + DataFrame 756 / 799 132.3 7.6 2.6X + Dataset 7359 / 7506 13.6 73.6 0.3X */ benchmark.run() /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------- - Dataset 585 / 628 17.1 58.5 1.0X - DataFrame 62 / 80 160.7 6.2 9.4X - RDD 205 / 220 48.7 20.5 2.8X + RDD 1974 / 2036 50.6 19.7 1.0X + DataFrame 103 / 127 967.4 1.0 19.1X + Dataset 4343 / 4477 23.0 43.4 0.5X */ benchmark2.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + aggregate: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + RDD sum 2130 / 2166 46.9 21.3 1.0X + DataFrame sum 92 / 128 1085.3 0.9 23.1X + Dataset sum using Aggregator 4111 / 4282 24.3 41.1 0.5X + Dataset complex Aggregator 8782 / 9036 11.4 87.8 0.2X + */ + benchmark3.run() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d074535bf6265..2a1867f67c178 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -471,6 +471,10 @@ class DatasetSuite extends QueryTest with SharedSQLContext { (JavaData(2), JavaData(2)))) } + test("SPARK-14696: implicit encoders for boxed types") { + assert(sqlContext.range(1).map { i => i : java.lang.Long }.head == 0L) + } + test("SPARK-11894: Incorrect results are returned when using null") { val nullInt = null.asInstanceOf[java.lang.Integer] val ds1 = Seq((nullInt, "1"), (new java.lang.Integer(22), "2")).toDS() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index f5a67fd782d63..0de7f2321f398 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -207,12 +207,16 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { testOneToOneMathFunction(rint, math.rint) } - test("round") { + test("round/bround") { val df = Seq(5, 55, 555).map(Tuple1(_)).toDF("a") checkAnswer( df.select(round('a), round('a, -1), round('a, -2)), Seq(Row(5, 10, 0), Row(55, 60, 100), Row(555, 560, 600)) ) + checkAnswer( + df.select(bround('a), bround('a, -1), bround('a, -2)), + Seq(Row(5, 0, 0), Row(55, 60, 100), Row(555, 560, 600)) + ) val pi = "3.1415" checkAnswer( @@ -221,6 +225,12 @@ class MathExpressionsSuite extends QueryTest with SharedSQLContext { Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) ) + checkAnswer( + sql(s"SELECT bround($pi, -3), bround($pi, -2), bround($pi, -1), " + + s"bround($pi, 0), bround($pi, 1), bround($pi, 2), bround($pi, 3)"), + Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3), + BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142"))) + ) } test("exp") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 826862835a709..cbacb5e1033f0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.util.{Locale, TimeZone} +import java.util.{ArrayDeque, Locale, TimeZone} import scala.collection.JavaConverters._ import scala.util.control.NonFatal @@ -29,11 +29,14 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.LogicalRDD +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.MemoryPlan import org.apache.spark.sql.types.ObjectType + + abstract class QueryTest extends PlanTest { protected def sqlContext: SQLContext @@ -46,6 +49,7 @@ abstract class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer contains all of the keywords, or the * none of keywords are listed in the answer + * * @param df the [[DataFrame]] to be executed * @param exists true for make sure the keywords are listed in the output, otherwise * to make sure none of the keyword are not listed in the output @@ -118,6 +122,7 @@ abstract class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer matches the expected result. + * * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ @@ -157,6 +162,7 @@ abstract class QueryTest extends PlanTest { /** * Runs the plan and makes sure the answer is within absTol of the expected result. + * * @param dataFrame the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. * @param absTol the absolute tolerance between actual and expected answers. @@ -197,14 +203,20 @@ abstract class QueryTest extends PlanTest { } private def checkJsonFormat(df: DataFrame): Unit = { + // Get the analyzed plan and rewrite the PredicateSubqueries in order to make sure that + // RDD and Data resolution does not break. val logicalPlan = df.queryExecution.analyzed + // bypass some cases that we can't handle currently. logicalPlan.transform { - case _: ObjectOperator => return + case _: ObjectConsumer => return + case _: ObjectProducer => return + case _: AppendColumns => return case _: LogicalRelation => return case _: MemoryPlan => return }.transformAllExpressions { case a: ImperativeAggregate => return + case _: TypedAggregateExpression => return case Literal(_, _: ObjectType) => return } @@ -232,9 +244,27 @@ abstract class QueryTest extends PlanTest { // RDDs/data are not serializable to JSON, so we need to collect LogicalPlans that contains // these non-serializable stuff, and use these original ones to replace the null-placeholders // in the logical plans parsed from JSON. - var logicalRDDs = logicalPlan.collect { case l: LogicalRDD => l } - var localRelations = logicalPlan.collect { case l: LocalRelation => l } - var inMemoryRelations = logicalPlan.collect { case i: InMemoryRelation => i } + val logicalRDDs = new ArrayDeque[LogicalRDD]() + val localRelations = new ArrayDeque[LocalRelation]() + val inMemoryRelations = new ArrayDeque[InMemoryRelation]() + def collectData: (LogicalPlan => Unit) = { + case l: LogicalRDD => + logicalRDDs.offer(l) + case l: LocalRelation => + localRelations.offer(l) + case i: InMemoryRelation => + inMemoryRelations.offer(i) + case p => + p.expressions.foreach { + _.foreach { + case s: SubqueryExpression => + s.query.foreach(collectData) + case _ => + } + } + } + logicalPlan.foreach(collectData) + val jsonBackPlan = try { TreeNode.fromJSON[LogicalPlan](jsonString, sqlContext.sparkContext) @@ -249,18 +279,15 @@ abstract class QueryTest extends PlanTest { """.stripMargin, e) } - val normalized2 = jsonBackPlan transformDown { + def renormalize: PartialFunction[LogicalPlan, LogicalPlan] = { case l: LogicalRDD => - val origin = logicalRDDs.head - logicalRDDs = logicalRDDs.drop(1) + val origin = logicalRDDs.pop() LogicalRDD(l.output, origin.rdd)(sqlContext) case l: LocalRelation => - val origin = localRelations.head - localRelations = localRelations.drop(1) + val origin = localRelations.pop() l.copy(data = origin.data) case l: InMemoryRelation => - val origin = inMemoryRelations.head - inMemoryRelations = inMemoryRelations.drop(1) + val origin = inMemoryRelations.pop() InMemoryRelation( l.output, l.useCompression, @@ -271,7 +298,13 @@ abstract class QueryTest extends PlanTest { origin.cachedColumnBuffers, l._statistics, origin._batchStats) + case p => + p.transformExpressions { + case s: SubqueryExpression => + s.withNewPlan(s.query.transformDown(renormalize)) + } } + val normalized2 = jsonBackPlan.transformDown(renormalize) assert(logicalRDDs.isEmpty) assert(localRelations.isEmpty) @@ -305,6 +338,7 @@ object QueryTest { * If there was exception during the execution or the contents of the DataFrame does not * match the expected result, an error message will be returned. Otherwise, a [[None]] will * be returned. + * * @param df the [[DataFrame]] to be executed * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. */ @@ -379,6 +413,7 @@ object QueryTest { /** * Runs the plan and makes sure the answer is within absTol of the expected result. + * * @param actualAnswer the actual result in a [[Row]]. * @param expectedAnswer the expected result in a[[Row]]. * @param absTol the absolute tolerance between actual and expected answers. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index 6ccc99fe179d7..242ea9cb27361 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -33,6 +33,7 @@ import org.scalatest.exceptions.TestFailedDueToTimeoutException import org.scalatest.time.Span import org.scalatest.time.SpanSugar._ +import org.apache.spark.sql.catalyst.analysis.{Append, OutputMode} import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder, RowEncoder} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ @@ -75,6 +76,8 @@ trait StreamTest extends QueryTest with Timeouts { /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds + val outputMode: OutputMode = Append + /** A trait for actions that can be performed while testing a streaming DataFrame. */ trait StreamAction @@ -228,6 +231,7 @@ trait StreamTest extends QueryTest with Timeouts { |$testActions | |== Stream == + |Output Mode: $outputMode |Stream state: $currentOffsets |Thread state: $threadState |${if (streamDeathCause != null) stackTraceToString(streamDeathCause) else ""} @@ -235,6 +239,7 @@ trait StreamTest extends QueryTest with Timeouts { |== Sink == |${sink.toDebugString} | + | |== Plan == |${if (currentStream != null) currentStream.lastExecution else ""} """.stripMargin @@ -293,7 +298,8 @@ trait StreamTest extends QueryTest with Timeouts { StreamExecution.nextName, metadataRoot, stream, - sink) + sink, + outputMode = outputMode) .asInstanceOf[StreamExecution] currentStream.microBatchThread.setUncaughtExceptionHandler( new UncaughtExceptionHandler { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala index 21b19fe7df8b2..d69ef087357d8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala @@ -22,33 +22,70 @@ import org.apache.spark.sql.test.SharedSQLContext class SubquerySuite extends QueryTest with SharedSQLContext { import testImplicits._ + setupTestData() + + val row = identity[(java.lang.Integer, java.lang.Double)](_) + + lazy val l = Seq( + row(1, 2.0), + row(1, 2.0), + row(2, 1.0), + row(2, 1.0), + row(3, 3.0), + row(null, null), + row(null, 5.0), + row(6, null)).toDF("a", "b") + + lazy val r = Seq( + row(2, 3.0), + row(2, 3.0), + row(3, 2.0), + row(4, 1.0), + row(null, null), + row(null, 5.0), + row(6, null)).toDF("c", "d") + + lazy val t = r.filter($"c".isNotNull && $"d".isNotNull) + + protected override def beforeAll(): Unit = { + super.beforeAll() + l.registerTempTable("l") + r.registerTempTable("r") + t.registerTempTable("t") + } + test("simple uncorrelated scalar subquery") { - assertResult(Array(Row(1))) { - sql("select (select 1 as b) as b").collect() - } + checkAnswer( + sql("select (select 1 as b) as b"), + Array(Row(1)) + ) - assertResult(Array(Row(3))) { - sql("select (select (select 1) + 1) + 1").collect() - } + checkAnswer( + sql("select (select (select 1) + 1) + 1"), + Array(Row(3)) + ) // string type - assertResult(Array(Row("s"))) { - sql("select (select 's' as s) as b").collect() - } + checkAnswer( + sql("select (select 's' as s) as b"), + Array(Row("s")) + ) } test("uncorrelated scalar subquery in CTE") { - assertResult(Array(Row(1))) { + checkAnswer( sql("with t2 as (select 1 as b, 2 as c) " + "select a from (select 1 as a union all select 2 as a) t " + - "where a = (select max(b) from t2) ").collect() - } + "where a = (select max(b) from t2) "), + Array(Row(1)) + ) } test("uncorrelated scalar subquery should return null if there is 0 rows") { - assertResult(Array(Row(null))) { - sql("select (select 's' as s limit 0) as b").collect() - } + checkAnswer( + sql("select (select 's' as s limit 0) as b"), + Array(Row(null)) + ) } test("runtime error when the number of rows is greater than 1") { @@ -56,28 +93,99 @@ class SubquerySuite extends QueryTest with SharedSQLContext { sql("select (select a from (select 1 as a union all select 2 as a) t) as b").collect() } assert(error2.getMessage.contains( - "more than one row returned by a subquery used as an expression")) + "more than one row returned by a subquery used as an expression") + ) } test("uncorrelated scalar subquery on a DataFrame generated query") { val df = Seq((1, "one"), (2, "two"), (3, "three")).toDF("key", "value") df.registerTempTable("subqueryData") - assertResult(Array(Row(4))) { - sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1").collect() - } + checkAnswer( + sql("select (select key from subqueryData where key > 2 order by key limit 1) + 1"), + Array(Row(4)) + ) - assertResult(Array(Row(-3))) { - sql("select -(select max(key) from subqueryData)").collect() - } + checkAnswer( + sql("select -(select max(key) from subqueryData)"), + Array(Row(-3)) + ) - assertResult(Array(Row(null))) { - sql("select (select value from subqueryData limit 0)").collect() - } + checkAnswer( + sql("select (select value from subqueryData limit 0)"), + Array(Row(null)) + ) - assertResult(Array(Row("two"))) { + checkAnswer( sql("select (select min(value) from subqueryData" + - " where key = (select max(key) from subqueryData) - 1)").collect() - } + " where key = (select max(key) from subqueryData) - 1)"), + Array(Row("two")) + ) + } + + test("EXISTS predicate subquery") { + checkAnswer( + sql("select * from l where exists(select * from r where l.a = r.c)"), + Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) + + checkAnswer( + sql("select * from l where exists(select * from r where l.a = r.c) and l.a <= 2"), + Row(2, 1.0) :: Row(2, 1.0) :: Nil) + } + + test("NOT EXISTS predicate subquery") { + checkAnswer( + sql("select * from l where not exists(select * from r where l.a = r.c)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(null, null) :: Row(null, 5.0) :: Nil) + + checkAnswer( + sql("select * from l where not exists(select * from r where l.a = r.c and l.b < r.d)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: + Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) + } + + test("IN predicate subquery") { + checkAnswer( + sql("select * from l where l.a in (select c from r)"), + Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Row(6, null) :: Nil) + + checkAnswer( + sql("select * from l where l.a in (select c from r where l.b < r.d)"), + Row(2, 1.0) :: Row(2, 1.0) :: Nil) + + checkAnswer( + sql("select * from l where l.a in (select c from r) and l.a > 2 and l.b is not null"), + Row(3, 3.0) :: Nil) + } + + test("NOT IN predicate subquery") { + checkAnswer( + sql("select * from l where a not in(select c from r)"), + Nil) + + checkAnswer( + sql("select * from l where a not in(select c from r where c is not null)"), + Row(1, 2.0) :: Row(1, 2.0) :: Nil) + + checkAnswer( + sql("select * from l where a not in(select c from t where b < d)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(3, 3.0) :: Nil) + + // Empty sub-query + checkAnswer( + sql("select * from l where a not in(select c from r where c > 10 and b < d)"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: + Row(3, 3.0) :: Row(null, null) :: Row(null, 5.0) :: Row(6, null) :: Nil) + + } + + test("complex IN predicate subquery") { + checkAnswer( + sql("select * from l where (a, b) not in(select c, d from r)"), + Nil) + + checkAnswer( + sql("select * from l where (a, b) not in(select c, d from t) and (a + b) is not null"), + Row(1, 2.0) :: Row(1, 2.0) :: Row(2, 1.0) :: Row(2, 1.0) :: Row(3, 3.0) :: Nil) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index 8c4afb605b01f..acc9f48d7e08f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -27,51 +27,56 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ -@SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) -private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { - override def equals(other: Any): Boolean = other match { - case v: MyDenseVector => - java.util.Arrays.equals(this.data, v.data) - case _ => false - } -} - @BeanInfo private[sql] case class MyLabeledPoint( - @BeanProperty label: Double, - @BeanProperty features: MyDenseVector) + @BeanProperty label: Double, + @BeanProperty features: UDT.MyDenseVector) + +// Wrapped in an object to check Scala compatibility. See SPARK-13929 +object UDT { + + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) + private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { + override def equals(other: Any): Boolean = other match { + case v: MyDenseVector => + java.util.Arrays.equals(this.data, v.data) + case _ => false + } + } -private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { + private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { - override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) + override def sqlType: DataType = ArrayType(DoubleType, containsNull = false) - override def serialize(features: MyDenseVector): ArrayData = { - new GenericArrayData(features.data.map(_.asInstanceOf[Any])) - } + override def serialize(features: MyDenseVector): ArrayData = { + new GenericArrayData(features.data.map(_.asInstanceOf[Any])) + } - override def deserialize(datum: Any): MyDenseVector = { - datum match { - case data: ArrayData => - new MyDenseVector(data.toDoubleArray()) + override def deserialize(datum: Any): MyDenseVector = { + datum match { + case data: ArrayData => + new MyDenseVector(data.toDoubleArray()) + } } - } - override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] + override def userClass: Class[MyDenseVector] = classOf[MyDenseVector] - private[spark] override def asNullable: MyDenseVectorUDT = this + private[spark] override def asNullable: MyDenseVectorUDT = this - override def equals(other: Any): Boolean = other match { - case _: MyDenseVectorUDT => true - case _ => false + override def equals(other: Any): Boolean = other match { + case _: MyDenseVectorUDT => true + case _ => false + } } + } class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetTest { import testImplicits._ private lazy val pointsRDD = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))).toDF() + MyLabeledPoint(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new UDT.MyDenseVector(Array(0.2, 2.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -80,16 +85,16 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT assert(labelsArrays.contains(1.0)) assert(labelsArrays.contains(0.0)) - val features: RDD[MyDenseVector] = - pointsRDD.select('features).rdd.map { case Row(v: MyDenseVector) => v } - val featuresArrays: Array[MyDenseVector] = features.collect() + val features: RDD[UDT.MyDenseVector] = + pointsRDD.select('features).rdd.map { case Row(v: UDT.MyDenseVector) => v } + val featuresArrays: Array[UDT.MyDenseVector] = features.collect() assert(featuresArrays.size === 2) - assert(featuresArrays.contains(new MyDenseVector(Array(0.1, 1.0)))) - assert(featuresArrays.contains(new MyDenseVector(Array(0.2, 2.0)))) + assert(featuresArrays.contains(new UDT.MyDenseVector(Array(0.1, 1.0)))) + assert(featuresArrays.contains(new UDT.MyDenseVector(Array(0.2, 2.0)))) } test("UDTs and UDFs") { - sqlContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + sqlContext.udf.register("testType", (d: UDT.MyDenseVector) => d.isInstanceOf[UDT.MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( sql("SELECT testType(features) from points"), @@ -103,8 +108,8 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT checkAnswer( sqlContext.read.parquet(path), Seq( - Row(1.0, new MyDenseVector(Array(0.1, 1.0))), - Row(0.0, new MyDenseVector(Array(0.2, 2.0))))) + Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) } } @@ -115,18 +120,19 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT checkAnswer( sqlContext.read.parquet(path), Seq( - Row(1.0, new MyDenseVector(Array(0.1, 1.0))), - Row(0.0, new MyDenseVector(Array(0.2, 2.0))))) + Row(1.0, new UDT.MyDenseVector(Array(0.1, 1.0))), + Row(0.0, new UDT.MyDenseVector(Array(0.2, 2.0))))) } } // Tests to make sure that all operators correctly convert types on the way out. test("Local UDTs") { - val df = Seq((1, new MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") - df.collect()(0).getAs[MyDenseVector](1) - df.take(1)(0).getAs[MyDenseVector](1) - df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) - df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[MyDenseVector](0) + val df = Seq((1, new UDT.MyDenseVector(Array(0.1, 1.0)))).toDF("int", "vec") + df.collect()(0).getAs[UDT.MyDenseVector](1) + df.take(1)(0).getAs[UDT.MyDenseVector](1) + df.limit(1).groupBy('int).agg(first('vec)).collect()(0).getAs[UDT.MyDenseVector](0) + df.orderBy('int).limit(1).groupBy('int).agg(first('vec)).collect()(0) + .getAs[UDT.MyDenseVector](0) } test("UDTs with JSON") { @@ -136,26 +142,47 @@ class UserDefinedTypeSuite extends QueryTest with SharedSQLContext with ParquetT ) val schema = StructType(Seq( StructField("id", IntegerType, false), - StructField("vec", new MyDenseVectorUDT, false) + StructField("vec", new UDT.MyDenseVectorUDT, false) )) val stringRDD = sparkContext.parallelize(data) val jsonRDD = sqlContext.read.schema(schema).json(stringRDD) checkAnswer( jsonRDD, - Row(1, new MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: - Row(2, new MyDenseVector(Array(2.25, 4.5, 8.75))) :: + Row(1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))) :: + Row(2, new UDT.MyDenseVector(Array(2.25, 4.5, 8.75))) :: Nil ) } + test("UDTs with JSON and Dataset") { + val data = Seq( + "{\"id\":1,\"vec\":[1.1,2.2,3.3,4.4]}", + "{\"id\":2,\"vec\":[2.25,4.5,8.75]}" + ) + + val schema = StructType(Seq( + StructField("id", IntegerType, false), + StructField("vec", new UDT.MyDenseVectorUDT, false) + )) + + val stringRDD = sparkContext.parallelize(data) + val jsonDataset = sqlContext.read.schema(schema).json(stringRDD) + .as[(Int, UDT.MyDenseVector)] + checkDataset( + jsonDataset, + (1, new UDT.MyDenseVector(Array(1.1, 2.2, 3.3, 4.4))), + (2, new UDT.MyDenseVector(Array(2.25, 4.5, 8.75))) + ) + } + test("SPARK-10472 UserDefinedType.typeName") { assert(IntegerType.typeName === "integer") - assert(new MyDenseVectorUDT().typeName === "mydensevector") + assert(new UDT.MyDenseVectorUDT().typeName === "mydensevector") } test("Catalyst type converter null handling for UDTs") { - val udt = new MyDenseVectorUDT() + val udt = new UDT.MyDenseVectorUDT() val toScalaConverter = CatalystTypeConverters.createToScalaConverter(udt) assert(toScalaConverter(null) === null) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 352fd07d0e8b0..3fb70f2eb6ae3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -150,19 +150,77 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { */ } - ignore("aggregate with keys") { + ignore("aggregate with linear keys") { val N = 20 << 20 - runBenchmark("Aggregate w keys", N) { - sqlContext.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + val benchmark = new Benchmark("Aggregate w keys", N) + def f(): Unit = sqlContext.range(N).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + + benchmark.addCase(s"codegen = F") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + f() + } + + benchmark.addCase(s"codegen = T hashmap = F") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "false") + f() + } + + benchmark.addCase(s"codegen = T hashmap = T") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "true") + f() } + benchmark.run() + /* - Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative - ------------------------------------------------------------------------------------------- - Aggregate w keys codegen=false 2429 / 2644 8.6 115.8 1.0X - Aggregate w keys codegen=true 1535 / 1571 13.7 73.2 1.6X + Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + codegen = F 2067 / 2166 10.1 98.6 1.0X + codegen = T hashmap = F 1149 / 1321 18.3 54.8 1.8X + codegen = T hashmap = T 388 / 475 54.0 18.5 5.3X + */ + } + + ignore("aggregate with randomized keys") { + val N = 20 << 20 + + val benchmark = new Benchmark("Aggregate w keys", N) + sqlContext.range(N).selectExpr("id", "floor(rand() * 10000) as k").registerTempTable("test") + + def f(): Unit = sqlContext.sql("select k, k, sum(id) from test group by k, k").collect() + + benchmark.addCase(s"codegen = F") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + f() + } + + benchmark.addCase(s"codegen = T hashmap = F") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "false") + f() + } + + benchmark.addCase(s"codegen = T hashmap = T") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.setConf("spark.sql.codegen.aggregate.map.enabled", "true") + f() + } + + benchmark.run() + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_73-b02 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + Aggregate w keys: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + codegen = F 2517 / 2608 8.3 120.0 1.0X + codegen = T hashmap = F 1484 / 1560 14.1 70.8 1.7X + codegen = T hashmap = T 794 / 908 26.4 37.9 3.2X */ } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index 01687877eeed6..53105e0b24959 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -21,6 +21,7 @@ import java.io.{ByteArrayInputStream, ByteArrayOutputStream, File} import java.util.Properties import org.apache.spark._ +import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.rdd.RDD import org.apache.spark.sql.Row @@ -113,8 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl( - 0, 0, 0, 0, taskMemoryManager, new Properties, null, InternalAccumulator.createAll(sc)) + val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 4474cfcf6e41b..d7cf1dc6aadb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Row import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin +import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions.{avg, broadcast, col, max} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StringType, StructType} @@ -78,7 +79,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegen] && - p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined) assert(ds.collect() === 0.until(10).map(_.toString).toArray) } @@ -99,4 +100,17 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined) assert(ds.collect() === Array(0, 6)) } + + test("simple typed UDAF should be included in WholeStageCodegen") { + import testImplicits._ + + val ds = Seq(("a", 10), ("b", 1), ("b", 2), ("c", 1)).toDS() + .groupByKey(_._1).agg(typed.sum(_._2)) + + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[TungstenAggregate]).isDefined) + assert(ds.collect() === Array(("a", 10.0), ("b", 3.0), ("c", 1.0))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index e17340c70b7e6..1a7b62ca0ac77 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1421,7 +1421,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5), DateType, TimestampType, ArrayType(IntegerType), MapType(StringType, LongType), struct, - new MyDenseVectorUDT()) + new UDT.MyDenseVectorUDT()) val fields = dataTypes.zipWithIndex.map { case (dataType, index) => StructField(s"col$index", dataType, nullable = true) } @@ -1445,7 +1445,7 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { Seq(2, 3, 4), Map("a string" -> 2000L), Row(4.75.toFloat, Seq(false, true)), - new MyDenseVector(Array(0.25, 2.25, 4.25))) + new UDT.MyDenseVector(Array(0.25, 2.25, 4.25))) val data = Row.fromSeq(Seq("Spark " + sqlContext.sparkContext.version) ++ constantValues) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index 51183e970d965..65635e3c066d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -27,10 +27,9 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, LogicalRelation} +import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index f875b54cd6649..5bffb307ec80e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -29,9 +29,8 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec} +import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, PartitionDirectory => Partition, PartitioningUtils, PartitionSpec} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.HadoopFsRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala new file mode 100644 index 0000000000000..70c2a82990ba9 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/FileStreamSinkLogSuite.scala @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import java.nio.charset.StandardCharsets.UTF_8 + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class FileStreamSinkLogSuite extends SparkFunSuite with SharedSQLContext { + + import FileStreamSinkLog._ + + test("getBatchIdFromFileName") { + assert(1234L === getBatchIdFromFileName("1234")) + assert(1234L === getBatchIdFromFileName("1234.compact")) + intercept[NumberFormatException] { + FileStreamSinkLog.getBatchIdFromFileName("1234a") + } + } + + test("isCompactionBatch") { + assert(false === isCompactionBatch(0, compactInterval = 3)) + assert(false === isCompactionBatch(1, compactInterval = 3)) + assert(true === isCompactionBatch(2, compactInterval = 3)) + assert(false === isCompactionBatch(3, compactInterval = 3)) + assert(false === isCompactionBatch(4, compactInterval = 3)) + assert(true === isCompactionBatch(5, compactInterval = 3)) + } + + test("nextCompactionBatchId") { + assert(2 === nextCompactionBatchId(0, compactInterval = 3)) + assert(2 === nextCompactionBatchId(1, compactInterval = 3)) + assert(5 === nextCompactionBatchId(2, compactInterval = 3)) + assert(5 === nextCompactionBatchId(3, compactInterval = 3)) + assert(5 === nextCompactionBatchId(4, compactInterval = 3)) + assert(8 === nextCompactionBatchId(5, compactInterval = 3)) + } + + test("getValidBatchesBeforeCompactionBatch") { + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(0, compactInterval = 3) + } + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(1, compactInterval = 3) + } + assert(Seq(0, 1) === getValidBatchesBeforeCompactionBatch(2, compactInterval = 3)) + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(3, compactInterval = 3) + } + intercept[AssertionError] { + getValidBatchesBeforeCompactionBatch(4, compactInterval = 3) + } + assert(Seq(2, 3, 4) === getValidBatchesBeforeCompactionBatch(5, compactInterval = 3)) + } + + test("getAllValidBatches") { + assert(Seq(0) === getAllValidBatches(0, compactInterval = 3)) + assert(Seq(0, 1) === getAllValidBatches(1, compactInterval = 3)) + assert(Seq(2) === getAllValidBatches(2, compactInterval = 3)) + assert(Seq(2, 3) === getAllValidBatches(3, compactInterval = 3)) + assert(Seq(2, 3, 4) === getAllValidBatches(4, compactInterval = 3)) + assert(Seq(5) === getAllValidBatches(5, compactInterval = 3)) + assert(Seq(5, 6) === getAllValidBatches(6, compactInterval = 3)) + assert(Seq(5, 6, 7) === getAllValidBatches(7, compactInterval = 3)) + assert(Seq(8) === getAllValidBatches(8, compactInterval = 3)) + } + + test("compactLogs") { + val logs = Seq( + newFakeSinkFileStatus("/a/b/x", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/y", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/z", FileStreamSinkLog.ADD_ACTION)) + assert(logs === compactLogs(logs)) + + val logs2 = Seq( + newFakeSinkFileStatus("/a/b/m", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/n", FileStreamSinkLog.ADD_ACTION), + newFakeSinkFileStatus("/a/b/z", FileStreamSinkLog.DELETE_ACTION)) + assert(logs.dropRight(1) ++ logs2.dropRight(1) === compactLogs(logs ++ logs2)) + } + + test("serialize") { + withFileStreamSinkLog { sinkLog => + val logs = Seq( + SinkFileStatus( + path = "/a/b/x", + size = 100L, + isDir = false, + modificationTime = 1000L, + blockReplication = 1, + blockSize = 10000L, + action = FileStreamSinkLog.ADD_ACTION), + SinkFileStatus( + path = "/a/b/y", + size = 200L, + isDir = false, + modificationTime = 2000L, + blockReplication = 2, + blockSize = 20000L, + action = FileStreamSinkLog.DELETE_ACTION), + SinkFileStatus( + path = "/a/b/z", + size = 300L, + isDir = false, + modificationTime = 3000L, + blockReplication = 3, + blockSize = 30000L, + action = FileStreamSinkLog.ADD_ACTION)) + + // scalastyle:off + val expected = s"""${FileStreamSinkLog.VERSION} + |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} + |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} + |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin + // scalastyle:on + assert(expected === new String(sinkLog.serialize(logs), UTF_8)) + + assert(FileStreamSinkLog.VERSION === new String(sinkLog.serialize(Nil), UTF_8)) + } + } + + test("deserialize") { + withFileStreamSinkLog { sinkLog => + // scalastyle:off + val logs = s"""${FileStreamSinkLog.VERSION} + |{"path":"/a/b/x","size":100,"isDir":false,"modificationTime":1000,"blockReplication":1,"blockSize":10000,"action":"add"} + |{"path":"/a/b/y","size":200,"isDir":false,"modificationTime":2000,"blockReplication":2,"blockSize":20000,"action":"delete"} + |{"path":"/a/b/z","size":300,"isDir":false,"modificationTime":3000,"blockReplication":3,"blockSize":30000,"action":"add"}""".stripMargin + // scalastyle:on + + val expected = Seq( + SinkFileStatus( + path = "/a/b/x", + size = 100L, + isDir = false, + modificationTime = 1000L, + blockReplication = 1, + blockSize = 10000L, + action = FileStreamSinkLog.ADD_ACTION), + SinkFileStatus( + path = "/a/b/y", + size = 200L, + isDir = false, + modificationTime = 2000L, + blockReplication = 2, + blockSize = 20000L, + action = FileStreamSinkLog.DELETE_ACTION), + SinkFileStatus( + path = "/a/b/z", + size = 300L, + isDir = false, + modificationTime = 3000L, + blockReplication = 3, + blockSize = 30000L, + action = FileStreamSinkLog.ADD_ACTION)) + + assert(expected === sinkLog.deserialize(logs.getBytes(UTF_8))) + + assert(Nil === sinkLog.deserialize(FileStreamSinkLog.VERSION.getBytes(UTF_8))) + } + } + + test("batchIdToPath") { + withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { + withFileStreamSinkLog { sinkLog => + assert("0" === sinkLog.batchIdToPath(0).getName) + assert("1" === sinkLog.batchIdToPath(1).getName) + assert("2.compact" === sinkLog.batchIdToPath(2).getName) + assert("3" === sinkLog.batchIdToPath(3).getName) + assert("4" === sinkLog.batchIdToPath(4).getName) + assert("5.compact" === sinkLog.batchIdToPath(5).getName) + } + } + } + + test("compact") { + withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { + withFileStreamSinkLog { sinkLog => + for (batchId <- 0 to 10) { + sinkLog.add( + batchId, + Seq(newFakeSinkFileStatus("/a/b/" + batchId, FileStreamSinkLog.ADD_ACTION))) + val expectedFiles = (0 to batchId).map { + id => newFakeSinkFileStatus("/a/b/" + id, FileStreamSinkLog.ADD_ACTION) + } + assert(sinkLog.allFiles() === expectedFiles) + if (isCompactionBatch(batchId, 3)) { + // Since batchId is a compaction batch, the batch log file should contain all logs + assert(sinkLog.get(batchId).getOrElse(Nil) === expectedFiles) + } + } + } + } + } + + test("delete expired file") { + // Set FILE_SINK_LOG_CLEANUP_DELAY to 0 so that we can detect the deleting behaviour + // deterministically + withSQLConf( + SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3", + SQLConf.FILE_SINK_LOG_CLEANUP_DELAY.key -> "0") { + withFileStreamSinkLog { sinkLog => + val fs = sinkLog.metadataPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration) + + def listBatchFiles(): Set[String] = { + fs.listStatus(sinkLog.metadataPath).map(_.getPath.getName).filter { fileName => + try { + getBatchIdFromFileName(fileName) + true + } catch { + case _: NumberFormatException => false + } + }.toSet + } + + sinkLog.add(0, Seq(newFakeSinkFileStatus("/a/b/0", FileStreamSinkLog.ADD_ACTION))) + assert(Set("0") === listBatchFiles()) + sinkLog.add(1, Seq(newFakeSinkFileStatus("/a/b/1", FileStreamSinkLog.ADD_ACTION))) + assert(Set("0", "1") === listBatchFiles()) + sinkLog.add(2, Seq(newFakeSinkFileStatus("/a/b/2", FileStreamSinkLog.ADD_ACTION))) + assert(Set("2.compact") === listBatchFiles()) + sinkLog.add(3, Seq(newFakeSinkFileStatus("/a/b/3", FileStreamSinkLog.ADD_ACTION))) + assert(Set("2.compact", "3") === listBatchFiles()) + sinkLog.add(4, Seq(newFakeSinkFileStatus("/a/b/4", FileStreamSinkLog.ADD_ACTION))) + assert(Set("2.compact", "3", "4") === listBatchFiles()) + sinkLog.add(5, Seq(newFakeSinkFileStatus("/a/b/5", FileStreamSinkLog.ADD_ACTION))) + assert(Set("5.compact") === listBatchFiles()) + } + } + } + + /** + * Create a fake SinkFileStatus using path and action. Most of tests don't care about other fields + * in SinkFileStatus. + */ + private def newFakeSinkFileStatus(path: String, action: String): SinkFileStatus = { + SinkFileStatus( + path = path, + size = 100L, + isDir = false, + modificationTime = 100L, + blockReplication = 1, + blockSize = 100L, + action = action) + } + + private def withFileStreamSinkLog(f: FileStreamSinkLog => Unit): Unit = { + withTempDir { file => + val sinkLog = new FileStreamSinkLog(sqlContext, file.getCanonicalPath) + f(sinkLog) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 73d1b1b1d507d..64cddf0deecb0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -281,6 +281,30 @@ class FileStreamSourceSuite extends FileStreamSourceTest with SharedSQLContext { Utils.deleteRecursively(tmp) } + + test("reading from json files inside partitioned directory") { + val src = { + val base = Utils.createTempDir(namePrefix = "streaming.src") + new File(base, "type=X") + } + val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") + src.mkdirs() + + + // Add a file so that we can infer its schema + stringToFile(new File(src, "existing"), "{'c': 'drop1'}\n{'c': 'keep2'}\n{'c': 'keep3'}") + + val textSource = createFileStreamSource("json", src.getCanonicalPath) + + // FileStreamSource should infer the column "c" + val filtered = textSource.toDF().filter($"c" contains "keep") + + testStream(filtered)( + AddTextFileData(textSource, "{'c': 'drop4'}\n{'c': 'keep5'}\n{'c': 'keep6'}", src, tmp), + CheckAnswer("keep2", "keep3", "keep5", "keep6") + ) + } + test("read from parquet files") { val src = Utils.createTempDir(namePrefix = "streaming.src") val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 2bd27c7efdbdc..6f3149dbc5033 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,10 +17,9 @@ package org.apache.spark.sql.streaming -import org.scalatest.concurrent.Eventually._ - -import org.apache.spark.sql.{DataFrame, Row, SQLContext, StreamTest} +import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructField, StructType} @@ -108,6 +107,35 @@ class StreamSuite extends StreamTest with SharedSQLContext { assertDF(df) assertDF(df) } + + test("unsupported queries") { + val streamInput = MemoryStream[Int] + val batchInput = Seq(1, 2, 3).toDS() + + def assertError(expectedMsgs: Seq[String])(body: => Unit): Unit = { + val e = intercept[AnalysisException] { + body + } + expectedMsgs.foreach { s => assert(e.getMessage.contains(s)) } + } + + // Running streaming plan as a batch query + assertError("startStream" :: Nil) { + streamInput.toDS.map { i => i }.count() + } + + // Running non-streaming plan with as a streaming query + assertError("without streaming sources" :: "startStream" :: Nil) { + val ds = batchInput.map { i => i } + testStream(ds)() + } + + // Running streaming plan that cannot be incrementalized + assertError("not supported" :: "streaming" :: Nil) { + val ds = streamInput.toDS.map { i => i }.sort() + testStream(ds)() + } + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 3af7c01e525ad..fa3b122f6d2da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.streaming import org.apache.spark.SparkException import org.apache.spark.sql.StreamTest +import org.apache.spark.sql.catalyst.analysis.Update import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -32,6 +33,8 @@ class StreamingAggregationSuite extends StreamTest with SharedSQLContext { import testImplicits._ + override val outputMode = Update + test("simple count") { val inputData = MemoryStream[Int] diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala index ee0d23a6e57c4..6703cdbac3d17 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2.scala @@ -55,7 +55,7 @@ object HiveThriftServer2 extends Logging { @DeveloperApi def startWithContext(sqlContext: HiveContext): Unit = { val server = new HiveThriftServer2(sqlContext) - server.init(sqlContext.hiveconf) + server.init(sqlContext.sessionState.hiveconf) server.start() listener = new HiveThriftServer2Listener(server, sqlContext.conf) sqlContext.sparkContext.addSparkListener(listener) @@ -83,7 +83,7 @@ object HiveThriftServer2 extends Logging { try { val server = new HiveThriftServer2(SparkSQLEnv.hiveContext) - server.init(SparkSQLEnv.hiveContext.hiveconf) + server.init(SparkSQLEnv.hiveContext.sessionState.hiveconf) server.start() logInfo("HiveThriftServer2 started") listener = new HiveThriftServer2Listener(server, SparkSQLEnv.hiveContext.conf) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 673a293ce2601..d89c3b4ab2d1c 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -195,7 +195,7 @@ private[hive] class SparkExecuteStatementOperation( setState(OperationState.RUNNING) // Always use the latest class loader provided by executionHive's state. val executionHiveClassLoader = - hiveContext.executionHive.state.getConf.getClassLoader + hiveContext.sessionState.executionHive.state.getConf.getClassLoader Thread.currentThread().setContextClassLoader(executionHiveClassLoader) HiveThriftServer2.listener.onStatementStart( diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala index b8bc8ea44dc84..7e8eada5adb4f 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLDriver.scala @@ -28,7 +28,7 @@ import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes} +import org.apache.spark.sql.hive.{HiveContext, HiveMetastoreTypes, HiveQueryExecution} private[hive] class SparkSQLDriver( val context: HiveContext = SparkSQLEnv.hiveContext) @@ -41,7 +41,7 @@ private[hive] class SparkSQLDriver( override def init(): Unit = { } - private def getResultSetSchema(query: context.QueryExecution): Schema = { + private def getResultSetSchema(query: HiveQueryExecution): Schema = { val analyzed = query.analyzed logDebug(s"Result Schema: ${analyzed.output}") if (analyzed.output.isEmpty) { @@ -59,7 +59,8 @@ private[hive] class SparkSQLDriver( // TODO unify the error code try { context.sparkContext.setJobDescription(command) - val execution = context.executePlan(context.sql(command).logicalPlan) + val execution = + context.executePlan(context.sql(command).logicalPlan).asInstanceOf[HiveQueryExecution] hiveResponse = execution.stringResult() tableSchema = getResultSetSchema(execution) new CommandProcessorResponse(0) diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala index 2594c5bfdb3af..2679ac1854bb8 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLEnv.scala @@ -58,16 +58,15 @@ private[hive] object SparkSQLEnv extends Logging { sparkContext.addSparkListener(new StatsReportListener()) hiveContext = new HiveContext(sparkContext) - hiveContext.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) - hiveContext.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) - hiveContext.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) + hiveContext.sessionState.metadataHive.setOut(new PrintStream(System.out, true, "UTF-8")) + hiveContext.sessionState.metadataHive.setInfo(new PrintStream(System.err, true, "UTF-8")) + hiveContext.sessionState.metadataHive.setError(new PrintStream(System.err, true, "UTF-8")) hiveContext.setConf("spark.sql.hive.version", HiveContext.hiveExecutionVersion) if (log.isDebugEnabled) { - hiveContext.hiveconf.getAllProperties.asScala.toSeq.sorted.foreach { case (k, v) => - logDebug(s"HiveConf var: $k=$v") - } + hiveContext.sessionState.hiveconf.getAllProperties.asScala.toSeq.sorted + .foreach { case (k, v) => logDebug(s"HiveConf var: $k=$v") } } } } diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala index de4e9c62b57a4..f492b5656c3c3 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkSQLSessionManager.scala @@ -71,7 +71,7 @@ private[hive] class SparkSQLSessionManager(hiveServer: HiveServer2, hiveContext: val session = super.getSession(sessionHandle) HiveThriftServer2.listener.onSessionCreated( session.getIpAddress, sessionHandle.getSessionId.toString, session.getUsername) - val ctx = if (hiveContext.hiveThriftServerSingleSession) { + val ctx = if (hiveContext.sessionState.hiveThriftServerSingleSession) { hiveContext } else { hiveContext.newSession() diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala index 0c468a408ba98..da410c68c851d 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/server/SparkSQLOperationManager.scala @@ -47,7 +47,7 @@ private[thriftserver] class SparkSQLOperationManager() confOverlay: JMap[String, String], async: Boolean): ExecuteStatementOperation = synchronized { val hiveContext = sessionToContexts(parentSession.getSessionHandle) - val runInBackground = async && hiveContext.hiveThriftServerAsync + val runInBackground = async && hiveContext.sessionState.hiveThriftServerAsync val operation = new SparkExecuteStatementOperation(parentSession, statement, confOverlay, runInBackground)(hiveContext, sessionToActivePool) handleToOperation.put(operation.getHandle, operation) diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index eb49eabcb1ba9..0d0f556d9eae3 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -23,7 +23,7 @@ import java.sql.Timestamp import java.util.Date import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Await, Promise} +import scala.concurrent.Promise import scala.concurrent.duration._ import org.apache.hadoop.hive.conf.HiveConf.ConfVars @@ -32,7 +32,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.internal.Logging import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer -import org.apache.spark.util.Utils +import org.apache.spark.util.{ThreadUtils, Utils} /** * A test suite for the `spark-sql` CLI tool. Note that all test cases share the same temporary @@ -132,7 +132,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { new ProcessOutputCapturer(process.getErrorStream, captureOutput("stderr")).start() try { - Await.result(foundAllExpectedAnswers.future, timeout) + ThreadUtils.awaitResult(foundAllExpectedAnswers.future, timeout) } catch { case cause: Throwable => val message = s""" diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index a1268b8e94f56..ee14b6dc8d01d 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -24,7 +24,7 @@ import java.sql.{Date, DriverManager, SQLException, Statement} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.concurrent.{Await, ExecutionContext, Future, Promise} +import scala.concurrent.{ExecutionContext, Future, Promise} import scala.concurrent.duration._ import scala.io.Source import scala.util.{Random, Try} @@ -40,7 +40,7 @@ import org.apache.thrift.protocol.TBinaryProtocol import org.apache.thrift.transport.TSocket import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.internal.Logging import org.apache.spark.sql.hive.HiveContext import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer @@ -373,9 +373,10 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // slightly more conservatively than may be strictly necessary. Thread.sleep(1000) statement.cancel() - val e = intercept[SQLException] { - Await.result(f, 3.minute) - } + val e = intercept[SparkException] { + ThreadUtils.awaitResult(f, 3.minute) + }.getCause + assert(e.isInstanceOf[SQLException]) assert(e.getMessage.contains("cancelled")) // Cancellation is a no-op if spark.sql.hive.thriftServer.async=false @@ -391,7 +392,7 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { // might race and complete before we issue the cancel. Thread.sleep(1000) statement.cancel() - val rs1 = Await.result(sf, 3.minute) + val rs1 = ThreadUtils.awaitResult(sf, 3.minute) rs1.next() assert(rs1.getInt(1) === math.pow(5, 5)) rs1.close() @@ -814,7 +815,7 @@ abstract class HiveThriftServer2Test extends SparkFunSuite with BeforeAndAfterAl process } - Await.result(serverStarted.future, SERVER_STARTUP_TIMEOUT) + ThreadUtils.awaitResult(serverStarted.future, SERVER_STARTUP_TIMEOUT) } private def stopThriftServer(): Unit = { diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 989e68aebed9b..49fd19873017d 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -39,7 +39,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { private val originalLocale = Locale.getDefault private val originalColumnBatchSize = TestHive.conf.columnBatchSize private val originalInMemoryPartitionPruning = TestHive.conf.inMemoryPartitionPruning - private val originalConvertMetastoreOrc = TestHive.convertMetastoreOrc + private val originalConvertMetastoreOrc = TestHive.sessionState.convertMetastoreOrc def testCases: Seq[(String, File)] = { hiveQueryDir.listFiles.map(f => f.getName.stripSuffix(".q") -> f) @@ -47,7 +47,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { override def beforeAll() { super.beforeAll() - TestHive.cacheTables = true + TestHive.setCacheTables(true) // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -66,7 +66,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { override def afterAll() { try { - TestHive.cacheTables = false + TestHive.setCacheTables(false) TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala index d0b4cbe401eb3..de592f8d937dd 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveWindowFunctionQuerySuite.scala @@ -38,7 +38,8 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte private val testTempDir = Utils.createTempDir() override def beforeAll() { - TestHive.cacheTables = true + super.beforeAll() + TestHive.setCacheTables(true) // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -100,11 +101,14 @@ class HiveWindowFunctionQuerySuite extends HiveComparisonTest with BeforeAndAfte } override def afterAll() { - TestHive.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - TestHive.reset() - super.afterAll() + try { + TestHive.setCacheTables(false) + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.reset() + } finally { + super.afterAll() + } } ///////////////////////////////////////////////////////////////////////////// @@ -773,7 +777,8 @@ class HiveWindowFunctionQueryFileSuite private val testTempDir = Utils.createTempDir() override def beforeAll() { - TestHive.cacheTables = true + super.beforeAll() + TestHive.setCacheTables(true) // Timezone is fixed to America/Los_Angeles for those timezone sensitive tests (timestamp_*) TimeZone.setDefault(TimeZone.getTimeZone("America/Los_Angeles")) // Add Locale setting @@ -790,10 +795,14 @@ class HiveWindowFunctionQueryFileSuite } override def afterAll() { - TestHive.cacheTables = false - TimeZone.setDefault(originalTimeZone) - Locale.setDefault(originalLocale) - TestHive.reset() + try { + TestHive.setCacheTables(false) + TimeZone.setDefault(originalTimeZone) + Locale.setDefault(originalLocale) + TestHive.reset() + } finally { + super.afterAll() + } } override def blackList: Seq[String] = Seq( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 505e5c0bb62f1..b2ce3e0df25b4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -22,57 +22,29 @@ import java.net.{URL, URLClassLoader} import java.nio.charset.StandardCharsets import java.sql.Timestamp import java.util.concurrent.TimeUnit -import java.util.regex.Pattern import scala.collection.JavaConverters._ import scala.collection.mutable.HashMap import scala.language.implicitConversions import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileSystem, Path} -import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.common.`type`.HiveDecimal import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.metadata.Table -import org.apache.hadoop.hive.ql.parse.VariableSubstitution import org.apache.hadoop.hive.serde2.io.{DateWritable, TimestampWritable} import org.apache.hadoop.util.VersionInfo import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.api.java.JavaSparkContext import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.ConfigEntry +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions.{Expression, LeafExpression} -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.{ExecutedCommand, SetCommand} -import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.hive.client._ -import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils -/** - * Returns the current database of metadataHive. - */ -private[hive] case class CurrentDatabase(ctx: HiveContext) - extends LeafExpression with CodegenFallback { - override def dataType: DataType = StringType - override def foldable: Boolean = true - override def nullable: Boolean = false - override def eval(input: InternalRow): Any = { - UTF8String.fromString(ctx.sessionState.catalog.getCurrentDatabase) - } -} - /** * An instance of the Spark SQL execution engine that integrates with data stored in Hive. * Configuration for Hive is read from hive-site.xml on the classpath. @@ -80,336 +52,45 @@ private[hive] case class CurrentDatabase(ctx: HiveContext) * @since 1.0.0 */ class HiveContext private[hive]( - sc: SparkContext, - cacheManager: CacheManager, - listener: SQLListener, - @transient private[hive] val executionHive: HiveClientImpl, - @transient private[hive] val metadataHive: HiveClient, - isRootContext: Boolean, - @transient private[sql] val hiveCatalog: HiveExternalCatalog) - extends SQLContext(sc, cacheManager, listener, isRootContext, hiveCatalog) with Logging { - self => + @transient private val sparkSession: SparkSession, + isRootContext: Boolean) + extends SQLContext(sparkSession, isRootContext) with Logging { - private def this(sc: SparkContext, execHive: HiveClientImpl, metaHive: HiveClient) { - this( - sc, - new CacheManager, - SQLContext.createListenerAndUI(sc), - execHive, - metaHive, - true, - new HiveExternalCatalog(metaHive)) - } + self => def this(sc: SparkContext) = { - this( - sc, - HiveContext.newClientForExecution(sc.conf, sc.hadoopConfiguration), - HiveContext.newClientForMetadata(sc.conf, sc.hadoopConfiguration)) + this(new SparkSession(HiveContext.withHiveExternalCatalog(sc)), true) } def this(sc: JavaSparkContext) = this(sc.sc) - import org.apache.spark.sql.hive.HiveContext._ - - logDebug("create HiveContext") - /** * Returns a new HiveContext as new session, which will have separated SQLConf, UDF/UDAF, * temporary tables and SessionState, but sharing the same CacheManager, IsolatedClientLoader * and Hive client (both of execution and metadata) with existing HiveContext. */ override def newSession(): HiveContext = { - new HiveContext( - sc = sc, - cacheManager = cacheManager, - listener = listener, - executionHive = executionHive.newSession(), - metadataHive = metadataHive.newSession(), - isRootContext = false, - hiveCatalog = hiveCatalog) + new HiveContext(sparkSession.newSession(), isRootContext = false) } - @transient - protected[sql] override lazy val sessionState = new HiveSessionState(self) - - // The Hive UDF current_database() is foldable, will be evaluated by optimizer, - // but the optimizer can't access the SessionState of metadataHive. - sessionState.functionRegistry.registerFunction( - "current_database", (e: Seq[Expression]) => new CurrentDatabase(self)) - - /** - * When true, enables an experimental feature where metastore tables that use the parquet SerDe - * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive - * SerDe. - */ - protected[sql] def convertMetastoreParquet: Boolean = getConf(CONVERT_METASTORE_PARQUET) - - /** - * When true, also tries to merge possibly different but compatible Parquet schemas in different - * Parquet data files. - * - * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. - */ - protected[sql] def convertMetastoreParquetWithSchemaMerging: Boolean = - getConf(CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) - - /** - * When true, enables an experimental feature where metastore tables that use the Orc SerDe - * are automatically converted to use the Spark SQL ORC table scan, instead of the Hive - * SerDe. - */ - protected[sql] def convertMetastoreOrc: Boolean = getConf(CONVERT_METASTORE_ORC) - - /** - * When true, a table created by a Hive CTAS statement (no USING clause) will be - * converted to a data source table, using the data source set by spark.sql.sources.default. - * The table in CTAS statement will be converted when it meets any of the following conditions: - * - The CTAS does not specify any of a SerDe (ROW FORMAT SERDE), a File Format (STORED AS), or - * a Storage Hanlder (STORED BY), and the value of hive.default.fileformat in hive-site.xml - * is either TextFile or SequenceFile. - * - The CTAS statement specifies TextFile (STORED AS TEXTFILE) as the file format and no SerDe - * is specified (no ROW FORMAT SERDE clause). - * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format - * and no SerDe is specified (no ROW FORMAT SERDE clause). - */ - protected[sql] def convertCTAS: Boolean = getConf(CONVERT_CTAS) - - /* - * hive thrift server use background spark sql thread pool to execute sql queries - */ - protected[hive] def hiveThriftServerAsync: Boolean = getConf(HIVE_THRIFT_SERVER_ASYNC) - - protected[hive] def hiveThriftServerSingleSession: Boolean = - sc.conf.get("spark.sql.hive.thriftServer.singleSession", "false").toBoolean - - @transient - protected[sql] lazy val substitutor = new VariableSubstitution() - - /** - * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. - * - allow SQL11 keywords to be used as identifiers - */ - private[sql] def defaultOverrides() = { - setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") - } - - defaultOverrides() - - protected[sql] override def parseSql(sql: String): LogicalPlan = { - executionHive.withHiveState { - super.parseSql(substitutor.substitute(hiveconf, sql)) - } - } - - override protected[sql] def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution(plan) - - /** - * Invalidate and refresh all the cached the metadata of the given table. For performance reasons, - * Spark SQL or the external data source library it uses might cache certain metadata about a - * table, such as the location of blocks. When those change outside of Spark SQL, users should - * call this function to invalidate the cache. - * - * @since 1.3.0 - */ - def refreshTable(tableName: String): Unit = { - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) - sessionState.catalog.refreshTable(tableIdent) + protected[sql] override def sessionState: HiveSessionState = { + sparkSession.sessionState.asInstanceOf[HiveSessionState] } - protected[hive] def invalidateTable(tableName: String): Unit = { - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) - sessionState.catalog.invalidateTable(tableIdent) + protected[sql] override def sharedState: HiveSharedState = { + sparkSession.sharedState.asInstanceOf[HiveSharedState] } - /** - * Analyzes the given table in the current database to generate statistics, which will be - * used in query optimizations. - * - * Right now, it only supports Hive tables and it only updates the size of a Hive table - * in the Hive metastore. - * - * @since 1.2.0 - */ - def analyze(tableName: String) { - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) - val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) - - relation match { - case relation: MetastoreRelation => - // This method is mainly based on - // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) - // in Hive 0.13 (except that we do not use fs.getContentSummary). - // TODO: Generalize statistics collection. - // TODO: Why fs.getContentSummary returns wrong size on Jenkins? - // Can we use fs.getContentSummary in future? - // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use - // countFileSize to count the table size. - val stagingDir = metadataHive.getConf(HiveConf.ConfVars.STAGINGDIR.varname, - HiveConf.ConfVars.STAGINGDIR.defaultStrVal) - - def calculateTableSize(fs: FileSystem, path: Path): Long = { - val fileStatus = fs.getFileStatus(path) - val size = if (fileStatus.isDirectory) { - fs.listStatus(path) - .map { status => - if (!status.getPath().getName().startsWith(stagingDir)) { - calculateTableSize(fs, status.getPath) - } else { - 0L - } - } - .sum - } else { - fileStatus.getLen - } - - size - } - - def getFileSizeForTable(conf: HiveConf, table: Table): Long = { - val path = table.getPath - var size: Long = 0L - try { - val fs = path.getFileSystem(conf) - size = calculateTableSize(fs, path) - } catch { - case e: Exception => - logWarning( - s"Failed to get the size of table ${table.getTableName} in the " + - s"database ${table.getDbName} because of ${e.toString}", e) - size = 0L - } - - size - } - - val tableParameters = relation.hiveQlTable.getParameters - val oldTotalSize = - Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)) - .map(_.toLong) - .getOrElse(0L) - val newTotalSize = getFileSizeForTable(hiveconf, relation.hiveQlTable) - // Update the Hive metastore if the total size of the table is different than the size - // recorded in the Hive metastore. - // This logic is based on org.apache.hadoop.hive.ql.exec.StatsTask.aggregateStats(). - if (newTotalSize > 0 && newTotalSize != oldTotalSize) { - sessionState.catalog.alterTable( - relation.table.copy( - properties = relation.table.properties + - (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString))) - } - case otherRelation => - throw new UnsupportedOperationException( - s"Analyze only works for Hive tables, but $tableName is a ${otherRelation.nodeName}") - } - } - - override def setConf(key: String, value: String): Unit = { - super.setConf(key, value) - executionHive.runSqlHive(s"SET $key=$value") - metadataHive.runSqlHive(s"SET $key=$value") - // If users put any Spark SQL setting in the spark conf (e.g. spark-defaults.conf), - // this setConf will be called in the constructor of the SQLContext. - // Also, calling hiveconf will create a default session containing a HiveConf, which - // will interfer with the creation of executionHive (which is a lazy val). So, - // we put hiveconf.set at the end of this method. - hiveconf.set(key, value) - } - - override private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = { - setConf(entry.key, entry.stringConverter(value)) - } - - /** - * SQLConf and HiveConf contracts: - * - * 1. create a new o.a.h.hive.ql.session.SessionState for each HiveContext - * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the - * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be - * set in the SQLConf *as well as* in the HiveConf. - */ - @transient - protected[hive] lazy val hiveconf: HiveConf = { - val c = executionHive.conf - setConf(c.getAllProperties) - c - } - - private def functionOrMacroDDLPattern(command: String) = Pattern.compile( - ".*(create|drop)\\s+(temporary\\s+)?(function|macro).+", Pattern.DOTALL).matcher(command) - - protected[hive] def runSqlHive(sql: String): Seq[String] = { - val command = sql.trim.toLowerCase - if (functionOrMacroDDLPattern(command).matches()) { - executionHive.runSqlHive(sql) - } else if (command.startsWith("set")) { - metadataHive.runSqlHive(sql) - executionHive.runSqlHive(sql) - } else { - metadataHive.runSqlHive(sql) - } - } - - /** - * Executes a SQL query without parsing it, but instead passing it directly to Hive. - * This is currently only used for DDLs and will be removed as soon as Spark can parse - * all supported Hive DDLs itself. - */ - protected[sql] override def runNativeSql(sqlText: String): Seq[Row] = { - runSqlHive(sqlText).map { s => Row(s) } - } +} - /** Extends QueryExecution with hive specific features. */ - protected[sql] class QueryExecution(logicalPlan: LogicalPlan) - extends org.apache.spark.sql.execution.QueryExecution(this, logicalPlan) { - - /** - * Returns the result as a hive compatible sequence of strings. For native commands, the - * execution is simply passed back to Hive. - */ - def stringResult(): Seq[String] = executedPlan match { - case ExecutedCommand(desc: DescribeHiveTableCommand) => - // If it is a describe command for a Hive table, we want to have the output format - // be similar with Hive. - desc.run(self).map { - case Row(name: String, dataType: String, comment) => - Seq(name, dataType, - Option(comment.asInstanceOf[String]).getOrElse("")) - .map(s => String.format(s"%-20s", s)) - .mkString("\t") - } - case command: ExecutedCommand => - command.executeCollect().map(_.getString(0)) - - case other => - val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq - // We need the types so we can output struct field names - val types = analyzed.output.map(_.dataType) - // Reformat to match hive tab delimited output. - result.map(_.zip(types).map(HiveContext.toHiveString)).map(_.mkString("\t")).toSeq - } - override def simpleString: String = - logical match { - case _: HiveNativeCommand => "" - case _: SetCommand => "" - case _ => super.simpleString - } - } +private[hive] object HiveContext extends Logging { - protected[sql] override def addJar(path: String): Unit = { - // Add jar to Hive and classloader - executionHive.addJar(path) - metadataHive.addJar(path) - Thread.currentThread().setContextClassLoader(executionHive.clientLoader.classLoader) - super.addJar(path) + def withHiveExternalCatalog(sc: SparkContext): SparkContext = { + sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive") + sc } -} - -private[hive] object HiveContext extends Logging { /** The version of hive used internally by Spark SQL. */ val hiveExecutionVersion: String = "1.2.1" @@ -429,7 +110,7 @@ private[hive] object HiveContext extends Logging { | Location of the jars that should be used to instantiate the HiveMetastoreClient. | This property can be one of three options: " | 1. "builtin" - | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly jar when + | Use Hive ${hiveExecutionVersion}, which is bundled with the Spark assembly when | -Phive is enabled. When this option is chosen, | spark.sql.hive.metastore.version must be either | ${hiveExecutionVersion} or not defined. @@ -619,7 +300,9 @@ private[hive] object HiveContext extends Logging { * The version of the Hive client that is used here must match the metastore that is configured * in the hive-site.xml file. */ - private def newClientForMetadata(conf: SparkConf, hadoopConf: Configuration): HiveClient = { + protected[hive] def newClientForMetadata( + conf: SparkConf, + hadoopConf: Configuration): HiveClient = { val hiveConf = new HiveConf(hadoopConf, classOf[HiveConf]) val configurations = hiveClientConfigurations(hiveConf) newClientForMetadata(conf, hiveConf, hadoopConf, configurations) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index f627384253aa9..ff52e6ad74c33 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -197,6 +197,11 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat client.listTables(db, pattern) } + override def showCreateTable(db: String, table: String): String = withClient { + require(this.tableExists(db, table), s"The table $db.$table does not exist.") + client.showCreateTable(db, table) + } + // -------------------------------------------------------------------------- // Partitions // -------------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index ccc8345d7375d..33a926e4d2551 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -41,12 +41,11 @@ import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.execution.FileRelation -import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.{Partition => _, _} import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetDefaultSource, ParquetRelation} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.sql.hive.orc.{DefaultSource => OrcDefaultSource} -import org.apache.spark.sql.sources.{FileFormat, HadoopFsRelation, HDFSFileCatalog} import org.apache.spark.sql.types._ private[hive] case class HiveSerDe( @@ -116,17 +115,16 @@ private[hive] object HiveSerDe { * This is still used for things like creating data source tables, but in the future will be * cleaned up to integrate more nicely with [[HiveExternalCatalog]]. */ -private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext) - extends Logging { - - val conf = hive.conf +private[hive] class HiveMetastoreCatalog(hive: SQLContext) extends Logging { + private val conf = hive.conf + private val sessionState = hive.sessionState.asInstanceOf[HiveSessionState] + private val client = hive.sharedState.asInstanceOf[HiveSharedState].metadataHive + private val hiveconf = sessionState.hiveconf /** A fully qualified identifier for a table (i.e., database.tableName) */ case class QualifiedTableName(database: String, name: String) - private def getCurrentDatabase: String = { - hive.sessionState.catalog.getCurrentDatabase - } + private def getCurrentDatabase: String = hive.sessionState.catalog.getCurrentDatabase def getQualifiedTableName(tableIdent: TableIdentifier): QualifiedTableName = { QualifiedTableName( @@ -299,7 +297,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte CatalogTableType.MANAGED_TABLE } - val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hive.hiveconf) + val maybeSerDe = HiveSerDe.sourceToSerDe(provider, hiveconf) val dataSource = DataSource( hive, @@ -504,11 +502,12 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte } } - private def convertToLogicalRelation(metastoreRelation: MetastoreRelation, - options: Map[String, String], - defaultSource: FileFormat, - fileFormatClass: Class[_ <: FileFormat], - fileType: String): LogicalRelation = { + private def convertToLogicalRelation( + metastoreRelation: MetastoreRelation, + options: Map[String, String], + defaultSource: FileFormat, + fileFormatClass: Class[_ <: FileFormat], + fileType: String): LogicalRelation = { val metastoreSchema = StructType.fromAttributes(metastoreRelation.output) val tableIdentifier = QualifiedTableName(metastoreRelation.databaseName, metastoreRelation.tableName) @@ -600,14 +599,14 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte object ParquetConversions extends Rule[LogicalPlan] { private def shouldConvertMetastoreParquet(relation: MetastoreRelation): Boolean = { relation.tableDesc.getSerdeClassName.toLowerCase.contains("parquet") && - hive.convertMetastoreParquet + sessionState.convertMetastoreParquet } private def convertToParquetRelation(relation: MetastoreRelation): LogicalRelation = { val defaultSource = new ParquetDefaultSource() val fileFormatClass = classOf[ParquetDefaultSource] - val mergeSchema = hive.convertMetastoreParquetWithSchemaMerging + val mergeSchema = sessionState.convertMetastoreParquetWithSchemaMerging val options = Map( ParquetRelation.MERGE_SCHEMA -> mergeSchema.toString, ParquetRelation.METASTORE_TABLE_NAME -> TableIdentifier( @@ -652,7 +651,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte object OrcConversions extends Rule[LogicalPlan] { private def shouldConvertMetastoreOrc(relation: MetastoreRelation): Boolean = { relation.tableDesc.getSerdeClassName.toLowerCase.contains("orc") && - hive.convertMetastoreOrc + sessionState.convertMetastoreOrc } private def convertToOrcRelation(relation: MetastoreRelation): LogicalRelation = { @@ -727,7 +726,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte val desc = table.copy(schema = schema) - if (hive.convertCTAS && table.storage.serde.isEmpty) { + if (sessionState.convertCTAS && table.storage.serde.isEmpty) { // Do the conversion when spark.sql.hive.convertCTAS is true and the query // does not specify any storage format (file format and storage handler). if (table.identifier.database.isDefined) { @@ -815,14 +814,13 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte * the information from the metastore. */ class MetaStoreFileCatalog( - hive: HiveContext, + ctx: SQLContext, paths: Seq[Path], partitionSpecFromHive: PartitionSpec) - extends HDFSFileCatalog(hive, Map.empty, paths, Some(partitionSpecFromHive.partitionColumns)) { - + extends HDFSFileCatalog(ctx, Map.empty, paths, Some(partitionSpecFromHive.partitionColumns)) { override def getStatus(path: Path): Array[FileStatus] = { - val fs = path.getFileSystem(hive.sparkContext.hadoopConfiguration) + val fs = path.getFileSystem(ctx.sparkContext.hadoopConfiguration) fs.listStatus(path) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQueryExecution.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQueryExecution.scala new file mode 100644 index 0000000000000..1c1bfb610c29e --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQueryExecution.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.command.{ExecutedCommand, SetCommand} +import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} + + +/** + * A [[QueryExecution]] with hive specific features. + */ +protected[hive] class HiveQueryExecution(ctx: SQLContext, logicalPlan: LogicalPlan) + extends QueryExecution(ctx, logicalPlan) { + + /** + * Returns the result as a hive compatible sequence of strings. For native commands, the + * execution is simply passed back to Hive. + */ + def stringResult(): Seq[String] = executedPlan match { + case ExecutedCommand(desc: DescribeHiveTableCommand) => + // If it is a describe command for a Hive table, we want to have the output format + // be similar with Hive. + desc.run(ctx).map { + case Row(name: String, dataType: String, comment) => + Seq(name, dataType, + Option(comment.asInstanceOf[String]).getOrElse("")) + .map(s => String.format(s"%-20s", s)) + .mkString("\t") + } + case command: ExecutedCommand => + command.executeCollect().map(_.getString(0)) + + case other => + val result: Seq[Seq[Any]] = other.executeCollectPublic().map(_.toSeq).toSeq + // We need the types so we can output struct field names + val types = analyzed.output.map(_.dataType) + // Reformat to match hive tab delimited output. + result.map(_.zip(types).map(HiveContext.toHiveString)).map(_.mkString("\t")).toSeq + } + + override def simpleString: String = + logical match { + case _: HiveNativeCommand => "" + case _: SetCommand => "" + case _ => super.simpleString + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 0cccc22e5a624..4f9513389c8c2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} +import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder @@ -33,7 +34,6 @@ import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCat import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client.HiveClient @@ -45,10 +45,11 @@ import org.apache.spark.util.Utils private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, client: HiveClient, - context: HiveContext, + context: SQLContext, functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, - conf: SQLConf) + conf: SQLConf, + hiveconf: HiveConf) extends SessionCatalog(externalCatalog, functionResourceLoader, functionRegistry, conf) { override def setCurrentDatabase(db: String): Unit = { @@ -75,7 +76,7 @@ private[sql] class HiveSessionCatalog( // ---------------------------------------------------------------- override def getDefaultDBPath(db: String): String = { - val defaultPath = context.hiveconf.getVar(HiveConf.ConfVars.METASTOREWAREHOUSE) + val defaultPath = hiveconf.getVar(HiveConf.ConfVars.METASTOREWAREHOUSE) new Path(new Path(defaultPath), db + ".db").toString } @@ -83,7 +84,7 @@ private[sql] class HiveSessionCatalog( // essentially a cache for metastore tables. However, it relies on a lot of session-specific // things so it would be a lot of work to split its functionality between HiveSessionCatalog // and HiveCatalog. We should still do it at some point... - private val metastoreCatalog = new HiveMetastoreCatalog(client, context) + private val metastoreCatalog = new HiveMetastoreCatalog(context) val ParquetConversions: Rule[LogicalPlan] = metastoreCatalog.ParquetConversions val OrcConversions: Rule[LogicalPlan] = metastoreCatalog.OrcConversions diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index b992fda18cef7..09297c27dc5bc 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -17,35 +17,80 @@ package org.apache.spark.sql.hive +import java.util.regex.Pattern + +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.parse.VariableSubstitution + import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.Analyzer import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.execution.{python, SparkPlanner} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.hive.execution.HiveSqlParser +import org.apache.spark.sql.hive.client.{HiveClient, HiveClientImpl} +import org.apache.spark.sql.hive.execution.{AnalyzeTable, HiveSqlParser} import org.apache.spark.sql.internal.{SessionState, SQLConf} /** * A class that holds all session-specific state in a given [[HiveContext]]. */ -private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) { +private[hive] class HiveSessionState(ctx: SQLContext) extends SessionState(ctx) { + + self => + + private lazy val sharedState: HiveSharedState = ctx.sharedState.asInstanceOf[HiveSharedState] + + /** + * A Hive client used for execution. + */ + lazy val executionHive: HiveClientImpl = sharedState.executionHive.newSession() + + /** + * A Hive client used for interacting with the metastore. + */ + lazy val metadataHive: HiveClient = sharedState.metadataHive.newSession() + + /** + * A Hive helper class for substituting variables in a SQL statement. + */ + lazy val substitutor = new VariableSubstitution override lazy val conf: SQLConf = new SQLConf { override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } + + /** + * SQLConf and HiveConf contracts: + * + * 1. create a new o.a.h.hive.ql.session.SessionState for each HiveContext + * 2. when the Hive session is first initialized, params in HiveConf will get picked up by the + * SQLConf. Additionally, any properties set by set() or a SET command inside sql() will be + * set in the SQLConf *as well as* in the HiveConf. + */ + lazy val hiveconf: HiveConf = { + val c = executionHive.conf + conf.setConf(c.getAllProperties) + c + } + + setDefaultOverrideConfs() + /** * Internal catalog for managing table and database states. */ override lazy val catalog = { new HiveSessionCatalog( - ctx.hiveCatalog, - ctx.metadataHive, + sharedState.externalCatalog, + metadataHive, ctx, ctx.functionResourceLoader, functionRegistry, - conf) + conf, + hiveconf) } /** @@ -69,7 +114,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) /** * Parser for HiveQl query texts. */ - override lazy val sqlParser: ParserInterface = HiveSqlParser + override lazy val sqlParser: ParserInterface = new HiveSqlParser(substitutor, hiveconf) /** * Planner that takes into account Hive-specific strategies. @@ -77,13 +122,14 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) override def planner: SparkPlanner = { new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) with HiveStrategies { - override val hiveContext = ctx + override val context: SQLContext = ctx + override val hiveconf: HiveConf = self.hiveconf override def strategies: Seq[Strategy] = { experimentalMethods.extraStrategies ++ Seq( FileSourceStrategy, DataSourceStrategy, - HiveCommandStrategy(ctx), + HiveCommandStrategy, HiveDDLStrategy, DDLStrategy, SpecialLimits, @@ -103,4 +149,119 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) } } + + // ------------------------------------------------------ + // Helper methods, partially leftover from pre-2.0 days + // ------------------------------------------------------ + + override def executePlan(plan: LogicalPlan): HiveQueryExecution = { + new HiveQueryExecution(ctx, plan) + } + + /** + * Overrides default Hive configurations to avoid breaking changes to Spark SQL users. + * - allow SQL11 keywords to be used as identifiers + */ + def setDefaultOverrideConfs(): Unit = { + setConf(ConfVars.HIVE_SUPPORT_SQL11_RESERVED_KEYWORDS.varname, "false") + } + + override def setConf(key: String, value: String): Unit = { + super.setConf(key, value) + executionHive.runSqlHive(s"SET $key=$value") + metadataHive.runSqlHive(s"SET $key=$value") + hiveconf.set(key, value) + } + + override def addJar(path: String): Unit = { + super.addJar(path) + executionHive.addJar(path) + metadataHive.addJar(path) + Thread.currentThread().setContextClassLoader(executionHive.clientLoader.classLoader) + } + + /** + * Analyzes the given table in the current database to generate statistics, which will be + * used in query optimizations. + * + * Right now, it only supports Hive tables and it only updates the size of a Hive table + * in the Hive metastore. + */ + override def analyze(tableName: String): Unit = { + AnalyzeTable(tableName).run(ctx) + } + + /** + * Execute a SQL statement by passing the query text directly to Hive. + */ + override def runNativeSql(sql: String): Seq[String] = { + val command = sql.trim.toLowerCase + val functionOrMacroDDLPattern = Pattern.compile( + ".*(create|drop)\\s+(temporary\\s+)?(function|macro).+", Pattern.DOTALL) + if (functionOrMacroDDLPattern.matcher(command).matches()) { + executionHive.runSqlHive(sql) + } else if (command.startsWith("set")) { + metadataHive.runSqlHive(sql) + executionHive.runSqlHive(sql) + } else { + metadataHive.runSqlHive(sql) + } + } + + /** + * When true, enables an experimental feature where metastore tables that use the parquet SerDe + * are automatically converted to use the Spark SQL parquet table scan, instead of the Hive + * SerDe. + */ + def convertMetastoreParquet: Boolean = { + conf.getConf(HiveContext.CONVERT_METASTORE_PARQUET) + } + + /** + * When true, also tries to merge possibly different but compatible Parquet schemas in different + * Parquet data files. + * + * This configuration is only effective when "spark.sql.hive.convertMetastoreParquet" is true. + */ + def convertMetastoreParquetWithSchemaMerging: Boolean = { + conf.getConf(HiveContext.CONVERT_METASTORE_PARQUET_WITH_SCHEMA_MERGING) + } + + /** + * When true, enables an experimental feature where metastore tables that use the Orc SerDe + * are automatically converted to use the Spark SQL ORC table scan, instead of the Hive + * SerDe. + */ + def convertMetastoreOrc: Boolean = { + conf.getConf(HiveContext.CONVERT_METASTORE_ORC) + } + + /** + * When true, a table created by a Hive CTAS statement (no USING clause) will be + * converted to a data source table, using the data source set by spark.sql.sources.default. + * The table in CTAS statement will be converted when it meets any of the following conditions: + * - The CTAS does not specify any of a SerDe (ROW FORMAT SERDE), a File Format (STORED AS), or + * a Storage Hanlder (STORED BY), and the value of hive.default.fileformat in hive-site.xml + * is either TextFile or SequenceFile. + * - The CTAS statement specifies TextFile (STORED AS TEXTFILE) as the file format and no SerDe + * is specified (no ROW FORMAT SERDE clause). + * - The CTAS statement specifies SequenceFile (STORED AS SEQUENCEFILE) as the file format + * and no SerDe is specified (no ROW FORMAT SERDE clause). + */ + def convertCTAS: Boolean = { + conf.getConf(HiveContext.CONVERT_CTAS) + } + + /** + * When true, Hive Thrift server will execute SQL queries asynchronously using a thread pool." + */ + def hiveThriftServerAsync: Boolean = { + conf.getConf(HiveContext.HIVE_THRIFT_SERVER_ASYNC) + } + + def hiveThriftServerSingleSession: Boolean = { + ctx.sparkContext.conf.getBoolean( + "spark.sql.hive.thriftServer.singleSession", defaultValue = false) + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSharedState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSharedState.scala new file mode 100644 index 0000000000000..11097c33df2d5 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSharedState.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.spark.SparkContext +import org.apache.spark.sql.hive.client.{HiveClient, HiveClientImpl} +import org.apache.spark.sql.internal.SharedState + + +/** + * A class that holds all state shared across sessions in a given [[HiveContext]]. + */ +private[hive] class HiveSharedState(override val sparkContext: SparkContext) + extends SharedState(sparkContext) { + + // TODO: just share the IsolatedClientLoader instead of the client instances themselves + + /** + * A Hive client used for execution. + */ + val executionHive: HiveClientImpl = { + HiveContext.newClientForExecution(sparkContext.conf, sparkContext.hadoopConfiguration) + } + + /** + * A Hive client used to interact with the metastore. + */ + // This needs to be a lazy val at here because TestHiveSharedState is overriding it. + lazy val metadataHive: HiveClient = { + HiveContext.newClientForMetadata(sparkContext.conf, sparkContext.hadoopConfiguration) + } + + /** + * A catalog that interacts with the Hive metastore. + */ + override lazy val externalCatalog = new HiveExternalCatalog(metadataHive) + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 010361a32eb34..bbdcc8c6c2fff 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import org.apache.hadoop.hive.conf.HiveConf + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ @@ -31,12 +33,13 @@ private[hive] trait HiveStrategies { // Possibly being too clever with types here... or not clever enough. self: SparkPlanner => - val hiveContext: HiveContext + val context: SQLContext + val hiveconf: HiveConf object Scripts extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.ScriptTransformation(input, script, output, child, schema: HiveScriptIOSchema) => - ScriptTransformation(input, script, output, planLater(child), schema)(hiveContext) :: Nil + ScriptTransformation(input, script, output, planLater(child), schema)(hiveconf) :: Nil case _ => Nil } } @@ -74,7 +77,7 @@ private[hive] trait HiveStrategies { projectList, otherPredicates, identity[Seq[Expression]], - HiveTableScan(_, relation, pruningPredicates)(hiveContext)) :: Nil + HiveTableScan(_, relation, pruningPredicates)(context, hiveconf)) :: Nil case _ => Nil } @@ -103,7 +106,7 @@ private[hive] trait HiveStrategies { } } - case class HiveCommandStrategy(context: HiveContext) extends Strategy { + case object HiveCommandStrategy extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case describe: DescribeCommand => ExecutedCommand( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index e54358e657690..2d44813f0eac5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -288,8 +288,9 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = { assert(a.child == e && e.child == p) - a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && - sameOutput(e.output, p.child.output ++ a.groupingExpressions.map(_.asInstanceOf[Attribute])) + a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput( + e.output.drop(p.child.output.length), + a.groupingExpressions.map(_.asInstanceOf[Attribute])) } private def groupingSetToSQL( @@ -303,25 +304,28 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val numOriginalOutput = project.child.output.length // Assumption: Aggregate's groupingExpressions is composed of - // 1) the attributes of aliased group by expressions + // 1) the grouping attributes // 2) gid, which is always the last one val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute]) // Assumption: Project's projectList is composed of // 1) the original output (Project's child.output), // 2) the aliased group by expressions. + val expandedAttributes = project.output.drop(numOriginalOutput) val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child) val groupingSQL = groupByExprs.map(_.sql).mkString(", ") // a map from group by attributes to the original group by expressions. val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs)) + // a map from expanded attributes to the original group by expressions. + val expandedAttrMap = AttributeMap(expandedAttributes.zip(groupByExprs)) val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project => // Assumption: expand.projections is composed of // 1) the original output (Project's child.output), - // 2) group by attributes(or null literal) + // 2) expanded attributes(or null literal) // 3) gid, which is always the last one in each project in Expand project.drop(numOriginalOutput).dropRight(1).collect { - case attr: Attribute if groupByAttrMap.contains(attr) => groupByAttrMap(attr) + case attr: Attribute if expandedAttrMap.contains(attr) => expandedAttrMap(attr) } } val groupingSetSQL = "GROUPING SETS(" + diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala index 54afe9c2a3550..6a20d7c25b682 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/TableReader.scala @@ -37,6 +37,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, RDD, UnionRDD} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.DateTimeUtils @@ -61,8 +62,8 @@ private[hive] class HadoopTableReader( @transient private val attributes: Seq[Attribute], @transient private val relation: MetastoreRelation, - @transient private val sc: HiveContext, - hiveExtraConf: HiveConf) + @transient private val sc: SQLContext, + hiveconf: HiveConf) extends TableReader with Logging { // Hadoop honors "mapred.map.tasks" as hint, but will ignore when mapred.job.tracker is "local". @@ -72,12 +73,12 @@ class HadoopTableReader( private val _minSplitsPerRDD = if (sc.sparkContext.isLocal) { 0 // will splitted based on block by default. } else { - math.max(sc.hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinPartitions) + math.max(hiveconf.getInt("mapred.map.tasks", 1), sc.sparkContext.defaultMinPartitions) } - SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sc.sparkContext.conf, hiveExtraConf) + SparkHadoopUtil.get.appendS3AndSparkHadoopConfigurations(sc.sparkContext.conf, hiveconf) private val _broadcastedHiveConf = - sc.sparkContext.broadcast(new SerializableConfiguration(hiveExtraConf)) + sc.sparkContext.broadcast(new SerializableConfiguration(hiveconf)) override def makeRDDForTable(hiveTable: HiveTable): RDD[InternalRow] = makeRDDForTable( @@ -162,7 +163,7 @@ class HadoopTableReader( case (partition, partDeserializer) => def updateExistPathSetByPathPattern(pathPatternStr: String) { val pathPattern = new Path(pathPatternStr) - val fs = pathPattern.getFileSystem(sc.hiveconf) + val fs = pathPattern.getFileSystem(hiveconf) val matches = fs.globStatus(pathPattern) matches.foreach(fileStatus => existPathSet += fileStatus.getPath.toString) } @@ -259,7 +260,7 @@ class HadoopTableReader( private def applyFilterIfNeeded(path: Path, filterOpt: Option[PathFilter]): String = { filterOpt match { case Some(filter) => - val fs = path.getFileSystem(sc.hiveconf) + val fs = path.getFileSystem(hiveconf) val filteredFiles = fs.listStatus(path, filter).map(_.getPath.toString) filteredFiles.mkString(",") case None => path.toString diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 6f7e7bf45106f..7280ebc6302f9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -249,4 +249,5 @@ private[hive] trait HiveClient { /** Used for testing only. Removes all metadata from this instance of Hive. */ def reset(): Unit + def showCreateTable(db: String, table: String): String } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 2a1fff92b570a..e54f888bbd148 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -28,9 +28,11 @@ import org.apache.hadoop.hive.cli.CliSessionState import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.{PartitionDropOptions, TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceType, ResourceUri} +import org.apache.hadoop.hive.metastore.api.hive_metastoreConstants import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException} +import org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState @@ -43,7 +45,9 @@ import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.QueryExecutionException +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.{CircularBuffer, Utils} /** @@ -151,6 +155,8 @@ private[hive] class HiveClientImpl( } /** Returns the configuration for the current session. */ + // TODO: We should not use it because HiveSessionState has a hiveconf + // for the current Session. def conf: HiveConf = SessionState.get().getConf override def getConf(key: String, defaultValue: String): String = { @@ -625,11 +631,270 @@ private[hive] class HiveClientImpl( } } + override def showCreateTable(db: String, tableName: String): String = withHiveState { + Option(client.getTable(db, tableName, false)).map { hiveTable => + val tblProperties = hiveTable.getParameters.asScala.toMap + if (tblProperties.get("spark.sql.sources.provider").isDefined) { + generateDataSourceDDL(hiveTable) + } else { + generateHiveDDL(hiveTable) + } + }.get + } + /* -------------------------------------------------------- * | Helper methods for converting to and from Hive classes | * -------------------------------------------------------- */ + private def generateCreateTableHeader( + hiveTable: HiveTable, + processedProps: scala.collection.mutable.ArrayBuffer[String]): String = { + val sb = new StringBuilder("CREATE ") + if(hiveTable.isTemporary) { + sb.append("TEMPORARY ") + } + if (hiveTable.getTableType == HiveTableType.EXTERNAL_TABLE) { + processedProps += "EXTERNAL" + sb.append("EXTERNAL TABLE " + + quoteIdentifier(hiveTable.getDbName) + "." + quoteIdentifier(hiveTable.getTableName)) + } else { + sb.append("TABLE " + + quoteIdentifier(hiveTable.getDbName) + "." + quoteIdentifier(hiveTable.getTableName)) + } + sb.toString() + } + + private def generateColsDataSource( + hiveTable: HiveTable, + processedProps: scala.collection.mutable.ArrayBuffer[String]): String = { + val schemaStringFromParts: Option[String] = { + val props = hiveTable.getParameters.asScala + props.get("spark.sql.sources.schema.numParts").map { numParts => + val parts = (0 until numParts.toInt).map { index => + val part = props.get(s"spark.sql.sources.schema.part.$index").orNull + if (part == null) { + throw new AnalysisException( + "Could not read schema from the metastore because it is corrupted " + + s"(missing part $index of the schema, $numParts parts are expected).") + } + part + } + // Stick all parts back to a single schema string. + parts.mkString + } + } + + if (schemaStringFromParts.isDefined) { + (schemaStringFromParts.map(s => DataType.fromJson(s).asInstanceOf[StructType]). + get map { f => s"${quoteIdentifier(f.name)} ${f.dataType.sql}" }) + .mkString("( ", ", ", " )") + } else { + "" + } + } + + private def generateDataSourceDDL(hiveTable: HiveTable): String = { + + /** + * Currently, when users use dataFrame to create a table, such as df.write.partitionBy("a") + * .saveAsTable("t1"), HiveMetaStoreCatalog will put the partitioning, buckecting and soring + * information in TBLPROPERTIES when creating corresponding hive table. So when the table + * is known to be created with "spark.sql.sources.provider" value, and the following properties + * are the indicator of whether the table is created from dataframe. + * @param tblProperties + * @return + */ + def createdByDataframe(tblProperties: Map[String, String]): Boolean = { + tblProperties.get("spark.sql.sources.schema.numPartCols").isDefined || + tblProperties.get("spark.sql.sources.schema.numBucketCols").isDefined || + tblProperties.get("spark.sql.sources.schema.numSortCols").isDefined + } + + val tblProperties = hiveTable.getParameters.asScala.toMap + if (createdByDataframe(tblProperties)) { + def getColumnNames(colType: String): Seq[String] = { + tblProperties.get(s"spark.sql.sources.schema.num${colType.capitalize}Cols").map { + numCols => (0 until numCols.toInt).map { index => + tblProperties.getOrElse(s"spark.sql.sources.schema.${colType}Col.$index", + throw new AnalysisException( + s"Could not read $colType columns from the metastore because it is corrupted " + + s"(missing part $index of it, $numCols parts are expected).")) + } + }.getOrElse(Nil) + } + + // generate the syntax for creating table from Dataframe + val sb = new StringBuilder(".write") + val partitionCols = getColumnNames("part") + val bucketCols = getColumnNames("bucket") + if (partitionCols.size > 0) { + sb.append(".partitionBy" + partitionCols.map("\"" + _ + "\"").mkString("(", ", ", ")")) + } + if (bucketCols.size > 0) { + tblProperties.get("spark.sql.sources.schema.numBuckets").map { n => + sb.append(s".bucketBy($n, " + bucketCols.map("\"" + _ + "\"").mkString(", ") + ")") + } + val sortCols = getColumnNames("sort") + if (sortCols.size > 0) { + sb.append(".sortBy" + sortCols.map("\"" + _ + "\"").mkString("(", ", ", ")")) + } + } + + // file format + tblProperties.get("spark.sql.sources.provider").map { provider => + sb.append(".format(\"" + provider + "\")") + } + + sb.append(".saveAsTable(\"" + hiveTable.getDbName + "." + hiveTable.getTableName + "\")") + sb.toString + } else { + val processedProperties = scala.collection.mutable.ArrayBuffer.empty[String] + val sb = new StringBuilder(generateCreateTableHeader(hiveTable, processedProperties)) + // It is possible that the column list returned from hive metastore is just a dummy + // one, such as "col array", because the metastore was created as spark sql + // specific metastore (refer to HiveMetaStoreCatalog.createDataSourceTable. + // newSparkSQLSpecificMetastoreTable). In such case, the column schema information + // is located in tblproperties in json format. + sb.append(generateColsDataSource(hiveTable, processedProperties) + "\n") + sb.append("USING " + hiveTable.getProperty("spark.sql.sources.provider") + "\n") + val options = scala.collection.mutable.ArrayBuffer.empty[String] + hiveTable.getSd.getSerdeInfo.getParameters.asScala.foreach { e => + options += "" + escapeHiveCommand(e._1) + " '" + escapeHiveCommand(e._2) + "'" + } + if (options.size > 0) { + sb.append("OPTIONS " + options.mkString("( ", ", \n", " )")) + } + // create table using syntax does not support EXTERNAL keyword + sb.toString.replace("EXTERNAL TABLE", "TABLE") + } + } + + private def generateHiveDDL(hiveTable: HiveTable): String = { + val sb = new StringBuilder + val tblProperties = hiveTable.getParameters.asScala.toMap + val duplicateProps = scala.collection.mutable.ArrayBuffer.empty[String] + + if (hiveTable.getTableType == HiveTableType.VIRTUAL_VIEW) { + sb.append("CREATE VIEW " + quoteIdentifier(hiveTable.getDbName) + "." + + quoteIdentifier(hiveTable.getTableName) + " AS " + hiveTable.getViewOriginalText) + } else { + // create table header + sb.append(generateCreateTableHeader(hiveTable, duplicateProps)) + + // column list + sb.append( + hiveTable.getCols.asScala.map(fromHiveColumn).map { col => + quoteIdentifier(col.name) + " " + col.dataType + ( col.comment.getOrElse("") match { + case cmt: String if cmt.length > 0 => + " COMMENT '" + escapeHiveCommand(cmt) + "'" + case _ => "" + }) + }.mkString("( ", ", ", " ) \n")) + + // partition + val partCols = hiveTable.getPartitionKeys + if (partCols != null && partCols.size() > 0) { + sb.append("PARTITIONED BY ") + sb.append( + partCols.asScala.map(fromHiveColumn).map { col => + quoteIdentifier(col.name) + " " + col.dataType + (col.comment.getOrElse("") match { + case cmt: String if cmt.length > 0 => + " COMMENT '" + escapeHiveCommand(cmt) + "'" + case _ => "" + }) + }.mkString("( ", ", ", " )\n")) + } + + // sort bucket + val bucketCols = hiveTable.getBucketCols + if (bucketCols != null && bucketCols.size() > 0) { + sb.append("CLUSTERED BY ") + sb.append(bucketCols.asScala.map(quoteIdentifier(_)).mkString("( ", ", ", " ) \n")) + // SORTing columns + val sortCols = hiveTable.getSortCols + if (sortCols != null && sortCols.size() > 0) { + sb.append("SORTED BY ") + sb.append( + sortCols.asScala.map { col => + quoteIdentifier(col.getCol) + " " + (col.getOrder match { + case o if o == BaseSemanticAnalyzer.HIVE_COLUMN_ORDER_DESC => "DESC" + case _ => "ASC" + }) + }.mkString("( ", ", ", " ) \n")) + } + if (hiveTable.getNumBuckets > 0) { + sb.append("INTO " + hiveTable.getNumBuckets + " BUCKETS \n") + } + } + + // skew spec + val skewCols = hiveTable.getSkewedColNames + if (skewCols != null && skewCols.size() > 0) { + sb.append("SKEWED BY ") + sb.append(skewCols.asScala.map(quoteIdentifier(_)).mkString("( ", ", ", " ) \n")) + val skewColValues = hiveTable.getSkewedColValues + sb.append("ON ") + sb.append(skewColValues.asScala.map { values => + values.asScala.map("" + _).mkString("(", ", ", ")") + }.mkString("(", ", ", ")\n")) + if (hiveTable.isStoredAsSubDirectories) { + sb.append("STORED AS DIRECTORIES\n") + } + } + + // ROW FORMAT + val storageHandler = hiveTable.getStorageHandler + val serdeProps = hiveTable.getSd.getSerdeInfo.getParameters.asScala + sb.append("ROW FORMAT SERDE '" + escapeHiveCommand(hiveTable.getSerializationLib) + "'\n") + if (storageHandler == null) { + sb.append("WITH SERDEPROPERTIES ") + sb.append(serdeProps.map { serdeProp => + "'" + escapeHiveCommand(serdeProp._1) + "'='" + escapeHiveCommand(serdeProp._2) + "'" + }.mkString("( ", ", ", " )\n")) + + sb.append("STORED AS INPUTFORMAT '" + + escapeHiveCommand(hiveTable.getInputFormatClass.getName) + "' \n") + + sb.append("OUTPUTFORMAT '" + + escapeHiveCommand(hiveTable.getOutputFormatClass.getName) + "' \n") + } else { + // storage handler case + duplicateProps += hive_metastoreConstants.META_TABLE_STORAGE + sb.append("STORED BY '" + escapeHiveCommand( + tblProperties.getOrElse(hive_metastoreConstants.META_TABLE_STORAGE, "")) + "'\n") + sb.append("WITH SERDEPROPERTIES ") + sb.append(serdeProps.map { serdeProp => + "'" + escapeHiveCommand(serdeProp._1) + "'='" + escapeHiveCommand(serdeProp._2) + "'" + }.mkString("( ", ", ", " )\n")) + } + + // table location + sb.append("LOCATION '" + + escapeHiveCommand(shim.getDataLocation(hiveTable).get) + "' \n") + + // table properties + val propertPairs = hiveTable.getParameters.asScala.collect { + case (k, v) if !duplicateProps.contains(k) => + "'" + escapeHiveCommand(k) + "'='" + escapeHiveCommand(v) + "'" + } + if (propertPairs.size>0) { + sb.append("TBLPROPERTIES " + propertPairs.mkString("( ", ", \n", " )") + "\n") + } + } + sb.toString() + } + + private def escapeHiveCommand(str: String): String = { + str.map{c => + if (c == '\'' || c == ';') { + '\\' + } else { + c + } + } + } + private def toInputFormat(name: String) = Utils.classForName(name).asInstanceOf[Class[_ <: org.apache.hadoop.mapred.InputFormat[_, _]]] @@ -753,5 +1018,4 @@ private[hive] class HiveClientImpl( serde = Option(apiPartition.getSd.getSerdeInfo.getSerializationLib), serdeProperties = apiPartition.getSd.getSerdeInfo.getParameters.asScala.toMap)) } - } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 29f7dc2997d26..ceb7f3b890949 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -43,7 +43,6 @@ case class CreateTableAsSelect( override def children: Seq[LogicalPlan] = Seq(query) override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] lazy val metastoreRelation: MetastoreRelation = { import org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe @@ -69,24 +68,24 @@ case class CreateTableAsSelect( withFormat } - hiveContext.sessionState.catalog.createTable(withSchema, ignoreIfExists = false) + sqlContext.sessionState.catalog.createTable(withSchema, ignoreIfExists = false) // Get the Metastore Relation - hiveContext.sessionState.catalog.lookupRelation(tableIdentifier) match { + sqlContext.sessionState.catalog.lookupRelation(tableIdentifier) match { case r: MetastoreRelation => r } } // TODO ideally, we should get the output data ready first and then // add the relation into catalog, just in case of failure occurs while data // processing. - if (hiveContext.sessionState.catalog.tableExists(tableIdentifier)) { + if (sqlContext.sessionState.catalog.tableExists(tableIdentifier)) { if (allowExisting) { // table already exists, will do nothing, to keep consistent with Hive } else { throw new AnalysisException(s"$tableIdentifier already exists.") } } else { - hiveContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd + sqlContext.executePlan(InsertIntoTable(metastoreRelation, Map(), query, true, false)).toRdd } Seq.empty[Row] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 33cd8b44805b8..1e234d8508b40 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -20,12 +20,11 @@ package org.apache.spark.sql.hive.execution import scala.util.control.NonFatal import org.apache.spark.sql.{AnalysisException, Row, SQLContext} -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable} import org.apache.spark.sql.catalyst.expressions.Alias import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.command.RunnableCommand -import org.apache.spark.sql.hive.{ HiveContext, HiveMetastoreTypes, SQLBuilder} +import org.apache.spark.sql.hive.{HiveMetastoreTypes, HiveSessionState, SQLBuilder} /** * Create Hive view on non-hive-compatible tables by specifying schema ourselves instead of @@ -47,16 +46,16 @@ private[hive] case class CreateViewAsSelect( private val tableIdentifier = tableDesc.identifier override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] + val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] - hiveContext.sessionState.catalog.tableExists(tableIdentifier) match { + sessionState.catalog.tableExists(tableIdentifier) match { case true if allowExisting => // Handles `CREATE VIEW IF NOT EXISTS v0 AS SELECT ...`. Does nothing when the target view // already exists. case true if orReplace => // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` - hiveContext.metadataHive.alertView(prepareTable(sqlContext)) + sessionState.metadataHive.alertView(prepareTable(sqlContext)) case true => // Handles `CREATE VIEW v0 AS SELECT ...`. Throws exception when the target view already @@ -66,7 +65,7 @@ private[hive] case class CreateViewAsSelect( "CREATE OR REPLACE VIEW AS") case false => - hiveContext.metadataHive.createView(prepareTable(sqlContext)) + sessionState.metadataHive.createView(prepareTable(sqlContext)) } Seq.empty[Row] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala index 9bb971992d0d1..8c1f4a8dc5139 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveNativeCommand.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hive.execution import org.apache.spark.sql.{Row, SQLContext} import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.execution.command.RunnableCommand -import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.hive.HiveSessionState import org.apache.spark.sql.types.StringType private[hive] @@ -29,6 +29,8 @@ case class HiveNativeCommand(sql: String) extends RunnableCommand { override def output: Seq[AttributeReference] = Seq(AttributeReference("result", StringType, nullable = false)()) - override def run(sqlContext: SQLContext): Seq[Row] = - sqlContext.asInstanceOf[HiveContext].runSqlHive(sql).map(Row(_)) + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.asInstanceOf[HiveSessionState].runNativeSql(sql).map(Row(_)) + } + } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index a97b65e27bc59..4ff02cdbd0b39 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -21,8 +21,7 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.parse.EximUtil -import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.ql.parse.{EximUtil, VariableSubstitution} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe @@ -39,36 +38,28 @@ import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper /** * Concrete parser for HiveQl statements. */ -object HiveSqlParser extends AbstractSqlParser { - val astBuilder = new HiveSqlAstBuilder +class HiveSqlParser( + substitutor: VariableSubstitution, + hiveconf: HiveConf) + extends AbstractSqlParser { - override protected def nativeCommand(sqlText: String): LogicalPlan = { - HiveNativeCommand(sqlText) + val astBuilder = new HiveSqlAstBuilder(hiveconf) + + protected override def parse[T](command: String)(toResult: SqlBaseParser => T): T = { + super.parse(substitutor.substitute(hiveconf, command))(toResult) + } + + protected override def nativeCommand(sqlText: String): LogicalPlan = { + HiveNativeCommand(substitutor.substitute(hiveconf, sqlText)) } } /** * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ -class HiveSqlAstBuilder extends SparkSqlAstBuilder { +class HiveSqlAstBuilder(hiveConf: HiveConf) extends SparkSqlAstBuilder { import ParserUtils._ - /** - * Get the current Hive Configuration. - */ - private[this] def hiveConf: HiveConf = { - var ss = SessionState.get() - // SessionState is lazy initialization, it can be null here - if (ss == null) { - val original = Thread.currentThread().getContextClassLoader - val conf = new HiveConf(classOf[SessionState]) - conf.setClassLoader(original) - ss = new SessionState(conf) - SessionState.start(ss) - } - ss.getConf - } - /** * Pass a command to Hive using a [[HiveNativeCommand]]. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala index 235b80b7c697c..9a834660f953f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils.Object import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution._ @@ -47,7 +48,8 @@ case class HiveTableScan( requestedAttributes: Seq[Attribute], relation: MetastoreRelation, partitionPruningPred: Seq[Expression])( - @transient val context: HiveContext) + @transient val context: SQLContext, + @transient val hiveconf: HiveConf) extends LeafNode { require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned, @@ -75,7 +77,7 @@ case class HiveTableScan( // Create a local copy of hiveconf,so that scan specific modifications should not impact // other queries @transient - private[this] val hiveExtraConf = new HiveConf(context.hiveconf) + private[this] val hiveExtraConf = new HiveConf(hiveconf) // append columns ids and names before broadcast addColumnMetadataToConf(hiveExtraConf) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 430fa4616fc2b..e614daadf3918 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -43,9 +43,10 @@ case class InsertIntoHiveTable( overwrite: Boolean, ifNotExists: Boolean) extends UnaryNode { - @transient val sc: HiveContext = sqlContext.asInstanceOf[HiveContext] - @transient private lazy val hiveContext = new Context(sc.hiveconf) - @transient private lazy val client = sc.metadataHive + @transient private val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] + @transient private val client = sessionState.metadataHive + @transient private val hiveconf = sessionState.hiveconf + @transient private lazy val hiveContext = new Context(hiveconf) def output: Seq[Attribute] = Seq.empty @@ -67,7 +68,7 @@ case class InsertIntoHiveTable( SparkHiveWriterContainer.createPathFromString(fileSinkConf.getDirName, conf.value)) log.debug("Saving as hadoop file of type " + valueClass.getSimpleName) writerContainer.driverSideSetup() - sc.sparkContext.runJob(rdd, writerContainer.writeToFile _) + sqlContext.sparkContext.runJob(rdd, writerContainer.writeToFile _) writerContainer.commitJob() } @@ -86,17 +87,17 @@ case class InsertIntoHiveTable( val tableLocation = table.hiveQlTable.getDataLocation val tmpLocation = hiveContext.getExternalTmpPath(tableLocation) val fileSinkConf = new FileSinkDesc(tmpLocation.toString, tableDesc, false) - val isCompressed = sc.hiveconf.getBoolean( + val isCompressed = hiveconf.getBoolean( ConfVars.COMPRESSRESULT.varname, ConfVars.COMPRESSRESULT.defaultBoolVal) if (isCompressed) { // Please note that isCompressed, "mapred.output.compress", "mapred.output.compression.codec", // and "mapred.output.compression.type" have no impact on ORC because it uses table properties // to store compression information. - sc.hiveconf.set("mapred.output.compress", "true") + hiveconf.set("mapred.output.compress", "true") fileSinkConf.setCompressed(true) - fileSinkConf.setCompressCodec(sc.hiveconf.get("mapred.output.compression.codec")) - fileSinkConf.setCompressType(sc.hiveconf.get("mapred.output.compression.type")) + fileSinkConf.setCompressCodec(hiveconf.get("mapred.output.compression.codec")) + fileSinkConf.setCompressType(hiveconf.get("mapred.output.compression.type")) } val numDynamicPartitions = partition.values.count(_.isEmpty) @@ -113,13 +114,13 @@ case class InsertIntoHiveTable( // Validate partition spec if there exist any dynamic partitions if (numDynamicPartitions > 0) { // Report error if dynamic partitioning is not enabled - if (!sc.hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) { + if (!hiveconf.getBoolVar(HiveConf.ConfVars.DYNAMICPARTITIONING)) { throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_DISABLED.getMsg) } // Report error if dynamic partition strict mode is on but no static partition is found - if (numStaticPartitions == 0 && - sc.hiveconf.getVar(HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) { + if (numStaticPartitions == 0 && hiveconf.getVar( + HiveConf.ConfVars.DYNAMICPARTITIONINGMODE).equalsIgnoreCase("strict")) { throw new SparkException(ErrorMsg.DYNAMIC_PARTITION_STRICT_MODE.getMsg) } @@ -130,7 +131,7 @@ case class InsertIntoHiveTable( } } - val jobConf = new JobConf(sc.hiveconf) + val jobConf = new JobConf(hiveconf) val jobConfSer = new SerializableJobConf(jobConf) // When speculation is on and output committer class name contains "Direct", we should warn diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala index 3566526561b2f..2f7cec354d84f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala @@ -26,6 +26,7 @@ import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.ql.exec.{RecordReader, RecordWriter} import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.AbstractSerDe @@ -39,7 +40,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ScriptInputOutputSchema import org.apache.spark.sql.execution._ -import org.apache.spark.sql.hive.{HiveContext, HiveInspectors} +import org.apache.spark.sql.hive.HiveInspectors import org.apache.spark.sql.hive.HiveShim._ import org.apache.spark.sql.types.DataType import org.apache.spark.util.{CircularBuffer, RedirectThread, SerializableConfiguration, Utils} @@ -57,14 +58,14 @@ case class ScriptTransformation( script: String, output: Seq[Attribute], child: SparkPlan, - ioschema: HiveScriptIOSchema)(@transient private val sc: HiveContext) + ioschema: HiveScriptIOSchema)(@transient private val hiveconf: HiveConf) extends UnaryNode { - override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil + override protected def otherCopyArgs: Seq[HiveConf] = hiveconf :: Nil override def producedAttributes: AttributeSet = outputSet -- inputSet - private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf) + private val serializedHiveConf = new SerializableConfiguration(hiveconf) protected override def doExecute(): RDD[InternalRow] = { def processIterator(inputIterator: Iterator[InternalRow]): Iterator[InternalRow] = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 06badff474f49..b5ee9a62954ce 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -17,7 +17,11 @@ package org.apache.spark.sql.hive.execution +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.hive.common.StatsSetupConst +import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.MetaStoreUtils +import org.apache.hadoop.hive.ql.metadata.Table import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier @@ -25,8 +29,8 @@ import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.command.RunnableCommand -import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSource, LogicalRelation} -import org.apache.spark.sql.hive.HiveContext +import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSource, HadoopFsRelation, LogicalRelation} +import org.apache.spark.sql.hive.{HiveSessionState, MetastoreRelation} import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -41,7 +45,80 @@ private[hive] case class AnalyzeTable(tableName: String) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.asInstanceOf[HiveContext].analyze(tableName) + val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) + val relation = EliminateSubqueryAliases(sessionState.catalog.lookupRelation(tableIdent)) + + relation match { + case relation: MetastoreRelation => + // This method is mainly based on + // org.apache.hadoop.hive.ql.stats.StatsUtils.getFileSizeForTable(HiveConf, Table) + // in Hive 0.13 (except that we do not use fs.getContentSummary). + // TODO: Generalize statistics collection. + // TODO: Why fs.getContentSummary returns wrong size on Jenkins? + // Can we use fs.getContentSummary in future? + // Seems fs.getContentSummary returns wrong table size on Jenkins. So we use + // countFileSize to count the table size. + val stagingDir = sessionState.metadataHive.getConf( + HiveConf.ConfVars.STAGINGDIR.varname, + HiveConf.ConfVars.STAGINGDIR.defaultStrVal) + + def calculateTableSize(fs: FileSystem, path: Path): Long = { + val fileStatus = fs.getFileStatus(path) + val size = if (fileStatus.isDirectory) { + fs.listStatus(path) + .map { status => + if (!status.getPath().getName().startsWith(stagingDir)) { + calculateTableSize(fs, status.getPath) + } else { + 0L + } + } + .sum + } else { + fileStatus.getLen + } + + size + } + + def getFileSizeForTable(conf: HiveConf, table: Table): Long = { + val path = table.getPath + var size: Long = 0L + try { + val fs = path.getFileSystem(conf) + size = calculateTableSize(fs, path) + } catch { + case e: Exception => + logWarning( + s"Failed to get the size of table ${table.getTableName} in the " + + s"database ${table.getDbName} because of ${e.toString}", e) + size = 0L + } + + size + } + + val tableParameters = relation.hiveQlTable.getParameters + val oldTotalSize = + Option(tableParameters.get(StatsSetupConst.TOTAL_SIZE)) + .map(_.toLong) + .getOrElse(0L) + val newTotalSize = + getFileSizeForTable(sessionState.hiveconf, relation.hiveQlTable) + // Update the Hive metastore if the total size of the table is different than the size + // recorded in the Hive metastore. + // This logic is based on org.apache.hadoop.hive.ql.exec.StatsTask.aggregateStats(). + if (newTotalSize > 0 && newTotalSize != oldTotalSize) { + sessionState.catalog.alterTable( + relation.table.copy( + properties = relation.table.properties + + (StatsSetupConst.TOTAL_SIZE -> newTotalSize.toString))) + } + case otherRelation => + throw new UnsupportedOperationException( + s"Analyze only works for Hive tables, but $tableName is a ${otherRelation.nodeName}") + } Seq.empty[Row] } } @@ -66,9 +143,8 @@ private[hive] case class AddFile(path: String) extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] - hiveContext.runSqlHive(s"ADD FILE $path") - hiveContext.sparkContext.addFile(path) + sqlContext.sessionState.runNativeSql(s"ADD FILE $path") + sqlContext.sparkContext.addFile(path) Seq.empty[Row] } } @@ -98,9 +174,9 @@ case class CreateMetastoreDataSource( } val tableName = tableIdent.unquotedString - val hiveContext = sqlContext.asInstanceOf[HiveContext] + val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] - if (hiveContext.sessionState.catalog.tableExists(tableIdent)) { + if (sessionState.catalog.tableExists(tableIdent)) { if (allowExisting) { return Seq.empty[Row] } else { @@ -112,8 +188,7 @@ case class CreateMetastoreDataSource( val optionsWithPath = if (!options.contains("path") && managedIfNoPath) { isExternal = false - options + ("path" -> - hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) + options + ("path" -> sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } @@ -126,7 +201,7 @@ case class CreateMetastoreDataSource( bucketSpec = None, options = optionsWithPath).resolveRelation() - hiveContext.sessionState.catalog.createDataSourceTable( + sessionState.catalog.createDataSourceTable( tableIdent, userSpecifiedSchema, Array.empty[String], @@ -165,14 +240,13 @@ case class CreateMetastoreDataSourceAsSelect( } val tableName = tableIdent.unquotedString - val hiveContext = sqlContext.asInstanceOf[HiveContext] + val sessionState = sqlContext.sessionState.asInstanceOf[HiveSessionState] var createMetastoreTable = false var isExternal = true val optionsWithPath = if (!options.contains("path")) { isExternal = false - options + ("path" -> - hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) + options + ("path" -> sessionState.catalog.hiveDefaultTableFilePath(tableIdent)) } else { options } @@ -203,14 +277,14 @@ case class CreateMetastoreDataSourceAsSelect( // inserting into (i.e. using the same compression). EliminateSubqueryAliases( - sqlContext.sessionState.catalog.lookupRelation(tableIdent)) match { + sessionState.catalog.lookupRelation(tableIdent)) match { case l @ LogicalRelation(_: InsertableRelation | _: HadoopFsRelation, _, _) => existingSchema = Some(l.schema) case o => throw new AnalysisException(s"Saving data in ${o.toString} is not supported.") } case SaveMode.Overwrite => - hiveContext.sql(s"DROP TABLE IF EXISTS $tableName") + sqlContext.sql(s"DROP TABLE IF EXISTS $tableName") // Need to create the table again. createMetastoreTable = true } @@ -219,10 +293,10 @@ case class CreateMetastoreDataSourceAsSelect( createMetastoreTable = true } - val data = Dataset.ofRows(hiveContext, query) + val data = Dataset.ofRows(sqlContext, query) val df = existingSchema match { // If we are inserting into an existing table, just use the existing schema. - case Some(s) => sqlContext.internalCreateDataFrame(data.queryExecution.toRdd, s) + case Some(s) => data.selectExpr(s.fieldNames: _*) case None => data } @@ -240,7 +314,7 @@ case class CreateMetastoreDataSourceAsSelect( // We will use the schema of resolved.relation as the schema of the table (instead of // the schema of df). It is important since the nullability may be changed by the relation // provider (for example, see org.apache.spark.sql.parquet.DefaultSource). - hiveContext.sessionState.catalog.createDataSourceTable( + sessionState.catalog.createDataSourceTable( tableIdent, Some(result.schema), partitionColumns, @@ -251,7 +325,7 @@ case class CreateMetastoreDataSourceAsSelect( } // Refresh the cache of the table in the catalog. - hiveContext.sessionState.catalog.refreshTable(tableIdent) + sessionState.catalog.refreshTable(tableIdent) Seq.empty[Row] } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 784b018353472..5aab4132bc4ce 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -82,7 +82,7 @@ private[hive] case class HiveSimpleUDF( // TODO: Finish input output types. override def eval(input: InternalRow): Any = { - val inputs = wrap(children.map(c => c.eval(input)), arguments, cached, inputDataTypes) + val inputs = wrap(children.map(_.eval(input)), arguments, cached, inputDataTypes) val ret = FunctionRegistry.invoke( method, function, @@ -152,10 +152,8 @@ private[hive] case class HiveGenericUDF( var i = 0 while (i < children.length) { val idx = i - deferredObjects(i).asInstanceOf[DeferredObjectAdapter].set( - () => { - children(idx).eval(input) - }) + deferredObjects(i).asInstanceOf[DeferredObjectAdapter] + .set(() => children(idx).eval(input)) i += 1 } unwrap(function.evaluate(deferredObjects), returnInspector) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 7f6ca21782da4..e629099086899 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -32,16 +32,16 @@ import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.CacheManager import org.apache.spark.sql.execution.command.CacheTableCommand -import org.apache.spark.sql.execution.ui.SQLListener import org.apache.spark.sql.hive._ -import org.apache.spark.sql.hive.client.{HiveClient, HiveClientImpl} +import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.hive.execution.HiveNativeCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.{ShutdownHookManager, Utils} @@ -71,100 +71,95 @@ object TestHive * hive metastore seems to lead to weird non-deterministic failures. Therefore, the execution of * test cases that rely on TestHive must be serialized. */ -class TestHiveContext private[hive]( - sc: SparkContext, - cacheManager: CacheManager, - listener: SQLListener, - executionHive: HiveClientImpl, - metadataHive: HiveClient, - isRootContext: Boolean, - hiveCatalog: HiveExternalCatalog, - val warehousePath: File, - val scratchDirPath: File, - metastoreTemporaryConf: Map[String, String]) - extends HiveContext( - sc, - cacheManager, - listener, - executionHive, - metadataHive, - isRootContext, - hiveCatalog) { self => - - // Unfortunately, due to the complex interactions between the construction parameters - // and the limitations in scala constructors, we need many of these constructors to - // provide a shorthand to create a new TestHiveContext with only a SparkContext. - // This is not a great design pattern but it's necessary here. - - private def this( - sc: SparkContext, - executionHive: HiveClientImpl, - metadataHive: HiveClient, - warehousePath: File, - scratchDirPath: File, - metastoreTemporaryConf: Map[String, String]) { - this( - sc, - new CacheManager, - SQLContext.createListenerAndUI(sc), - executionHive, - metadataHive, - true, - new HiveExternalCatalog(metadataHive), - warehousePath, - scratchDirPath, - metastoreTemporaryConf) +class TestHiveContext(@transient val sparkSession: TestHiveSparkSession, isRootContext: Boolean) + extends HiveContext(sparkSession, isRootContext) { + + def this(sc: SparkContext) { + this(new TestHiveSparkSession(HiveContext.withHiveExternalCatalog(sc)), true) } - private def this( - sc: SparkContext, - warehousePath: File, - scratchDirPath: File, - metastoreTemporaryConf: Map[String, String]) { - this( - sc, - HiveContext.newClientForExecution(sc.conf, sc.hadoopConfiguration), - TestHiveContext.newClientForMetadata( - sc.conf, sc.hadoopConfiguration, warehousePath, scratchDirPath, metastoreTemporaryConf), - warehousePath, - scratchDirPath, - metastoreTemporaryConf) + override def newSession(): TestHiveContext = { + new TestHiveContext(sparkSession.newSession(), false) } + override def sharedState: TestHiveSharedState = sparkSession.sharedState + + override def sessionState: TestHiveSessionState = sparkSession.sessionState + + def setCacheTables(c: Boolean): Unit = { + sparkSession.setCacheTables(c) + } + + def getHiveFile(path: String): File = { + sparkSession.getHiveFile(path) + } + + def loadTestTable(name: String): Unit = { + sparkSession.loadTestTable(name) + } + + def reset(): Unit = { + sparkSession.reset() + } + +} + + +private[hive] class TestHiveSparkSession( + sc: SparkContext, + val warehousePath: File, + scratchDirPath: File, + metastoreTemporaryConf: Map[String, String], + existingSharedState: Option[TestHiveSharedState]) + extends SparkSession(sc) with Logging { self => + def this(sc: SparkContext) { this( sc, Utils.createTempDir(namePrefix = "warehouse"), TestHiveContext.makeScratchDir(), - HiveContext.newTemporaryConfiguration(useInMemoryDerby = false)) + HiveContext.newTemporaryConfiguration(useInMemoryDerby = false), + None) + } + + assume(sc.conf.get(CATALOG_IMPLEMENTATION) == "hive") + + // TODO: Let's remove TestHiveSharedState and TestHiveSessionState. Otherwise, + // we are not really testing the reflection logic based on the setting of + // CATALOG_IMPLEMENTATION. + @transient + override lazy val sharedState: TestHiveSharedState = { + existingSharedState.getOrElse( + new TestHiveSharedState(sc, warehousePath, scratchDirPath, metastoreTemporaryConf)) + } + + @transient + override lazy val sessionState: TestHiveSessionState = new TestHiveSessionState(self) + + override def newSession(): TestHiveSparkSession = { + new TestHiveSparkSession( + sc, warehousePath, scratchDirPath, metastoreTemporaryConf, Some(sharedState)) } - override def newSession(): HiveContext = { - new TestHiveContext( - sc = sc, - cacheManager = cacheManager, - listener = listener, - executionHive = executionHive.newSession(), - metadataHive = metadataHive.newSession(), - isRootContext = false, - hiveCatalog = hiveCatalog, - warehousePath = warehousePath, - scratchDirPath = scratchDirPath, - metastoreTemporaryConf = metastoreTemporaryConf) + private var cacheTables: Boolean = false + + def setCacheTables(c: Boolean): Unit = { + cacheTables = c } // By clearing the port we force Spark to pick a new one. This allows us to rerun tests // without restarting the JVM. System.clearProperty("spark.hostPort") - CommandProcessorFactory.clean(hiveconf) + CommandProcessorFactory.clean(sessionState.hiveconf) - hiveconf.set("hive.plan.serialization.format", "javaXML") + sessionState.hiveconf.set("hive.plan.serialization.format", "javaXML") // A snapshot of the entries in the starting SQLConf // We save this because tests can mutate this singleton object if they want + // This snapshot is saved when we create this TestHiveSparkSession. val initialSQLConf: SQLConf = { val snapshot = new SQLConf - conf.getAllConfs.foreach { case (k, v) => snapshot.setConfString(k, v) } + sessionState.conf.getAllConfs.foreach { case (k, v) => snapshot.setConfString(k, v) } snapshot } @@ -175,42 +170,10 @@ class TestHiveContext private[hive]( /** The location of the compiled hive distribution */ lazy val hiveHome = envVarToFile("HIVE_HOME") + /** The location of the hive source code. */ lazy val hiveDevHome = envVarToFile("HIVE_DEV_HOME") - // Override so we can intercept relative paths and rewrite them to point at hive. - override def runSqlHive(sql: String): Seq[String] = - super.runSqlHive(rewritePaths(substitutor.substitute(this.hiveconf, sql))) - - override def executePlan(plan: LogicalPlan): this.QueryExecution = - new this.QueryExecution(plan) - - @transient - protected[sql] override lazy val sessionState = new HiveSessionState(this) { - override lazy val conf: SQLConf = { - new SQLConf { - clear() - override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) - override def clear(): Unit = { - super.clear() - TestHiveContext.overrideConfs.map { - case (key, value) => setConfString(key, value) - } - } - } - } - - override lazy val functionRegistry = { - // We use TestHiveFunctionRegistry at here to track functions that have been explicitly - // unregistered (through TestHiveFunctionRegistry.unregisterFunction method). - val fr = new TestHiveFunctionRegistry - org.apache.spark.sql.catalyst.analysis.FunctionRegistry.expressions.foreach { - case (name, (info, builder)) => fr.registerFunction(name, info, builder) - } - fr - } - } - /** * Returns the value of specified environmental variable as a [[java.io.File]] after checking * to ensure it exists @@ -223,7 +186,7 @@ class TestHiveContext private[hive]( * Replaces relative paths to the parent directory "../" with hiveDevHome since this is how the * hive test cases assume the system is set up. */ - private def rewritePaths(cmd: String): String = + private[hive] def rewritePaths(cmd: String): String = if (cmd.toUpperCase contains "LOAD DATA") { val testDataLocation = hiveDevHome.map(_.getCanonicalPath).getOrElse(inRepoTests.getCanonicalPath) @@ -254,36 +217,11 @@ class TestHiveContext private[hive]( val describedTable = "DESCRIBE (\\w+)".r - /** - * Override QueryExecution with special debug workflow. - */ - class QueryExecution(logicalPlan: LogicalPlan) - extends super.QueryExecution(logicalPlan) { - def this(sql: String) = this(parseSql(sql)) - override lazy val analyzed = { - val describedTables = logical match { - case HiveNativeCommand(describedTable(tbl)) => tbl :: Nil - case CacheTableCommand(tbl, _, _) => tbl :: Nil - case _ => Nil - } - - // Make sure any test tables referenced are loaded. - val referencedTables = - describedTables ++ - logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } - val referencedTestTables = referencedTables.filter(testTables.contains) - logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") - referencedTestTables.foreach(loadTestTable) - // Proceed with analysis. - sessionState.analyzer.execute(logical) - } - } - case class TestTable(name: String, commands: (() => Unit)*) protected[hive] implicit class SqlCmd(sql: String) { def cmd: () => Unit = { - () => new QueryExecution(sql).stringResult(): Unit + () => new TestHiveQueryExecution(sql).stringResult(): Unit } } @@ -310,19 +248,20 @@ class TestHiveContext private[hive]( "CREATE TABLE src1 (key INT, value STRING)".cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv3.txt")}' INTO TABLE src1".cmd), TestTable("srcpart", () => { - runSqlHive( + sessionState.runNativeSql( "CREATE TABLE srcpart (key INT, value STRING) PARTITIONED BY (ds STRING, hr STRING)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- Seq("11", "12")) { - runSqlHive( + sessionState.runNativeSql( s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart PARTITION (ds='$ds',hr='$hr') """.stripMargin) } }), TestTable("srcpart1", () => { - runSqlHive("CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") + sessionState.runNativeSql( + "CREATE TABLE srcpart1 (key INT, value STRING) PARTITIONED BY (ds STRING, hr INT)") for (ds <- Seq("2008-04-08", "2008-04-09"); hr <- 11 to 12) { - runSqlHive( + sessionState.runNativeSql( s"""LOAD DATA LOCAL INPATH '${getHiveFile("data/files/kv1.txt")}' |OVERWRITE INTO TABLE srcpart1 PARTITION (ds='$ds',hr='$hr') """.stripMargin) @@ -333,7 +272,7 @@ class TestHiveContext private[hive]( import org.apache.hadoop.mapred.{SequenceFileInputFormat, SequenceFileOutputFormat} import org.apache.thrift.protocol.TBinaryProtocol - runSqlHive( + sessionState.runNativeSql( s""" |CREATE TABLE src_thrift(fake INT) |ROW FORMAT SERDE '${classOf[ThriftDeserializer].getName}' @@ -346,7 +285,7 @@ class TestHiveContext private[hive]( |OUTPUTFORMAT '${classOf[SequenceFileOutputFormat[_, _]].getName}' """.stripMargin) - runSqlHive( + sessionState.runNativeSql( s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/complex.seq")}' INTO TABLE src_thrift") }), TestTable("serdeins", @@ -459,7 +398,6 @@ class TestHiveContext private[hive]( private val loadedTables = new collection.mutable.HashSet[String] - var cacheTables: Boolean = false def loadTestTable(name: String) { if (!(loadedTables contains name)) { // Marks the table as loaded first to prevent infinite mutually recursive table loading. @@ -470,7 +408,7 @@ class TestHiveContext private[hive]( createCmds.foreach(_()) if (cacheTables) { - cacheTable(name) + new SQLContext(self).cacheTable(name) } } } @@ -495,34 +433,35 @@ class TestHiveContext private[hive]( } } - cacheManager.clearCache() + sharedState.cacheManager.clearCache() loadedTables.clear() sessionState.catalog.clearTempTables() sessionState.catalog.invalidateCache() - metadataHive.reset() + + sessionState.metadataHive.reset() FunctionRegistry.getFunctionNames.asScala.filterNot(originalUDFs.contains(_)). foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } // Some tests corrupt this value on purpose, which breaks the RESET call below. - hiveconf.set("fs.default.name", new File(".").toURI.toString) + sessionState.hiveconf.set("fs.default.name", new File(".").toURI.toString) // It is important that we RESET first as broken hooks that might have been set could break // other sql exec here. - executionHive.runSqlHive("RESET") - metadataHive.runSqlHive("RESET") + sessionState.executionHive.runSqlHive("RESET") + sessionState.metadataHive.runSqlHive("RESET") // For some reason, RESET does not reset the following variables... // https://issues.apache.org/jira/browse/HIVE-9004 - runSqlHive("set hive.table.parameters.default=") - runSqlHive("set datanucleus.cache.collections=true") - runSqlHive("set datanucleus.cache.collections.lazy=true") + sessionState.runNativeSql("set hive.table.parameters.default=") + sessionState.runNativeSql("set datanucleus.cache.collections=true") + sessionState.runNativeSql("set datanucleus.cache.collections.lazy=true") // Lots of tests fail if we do not change the partition whitelist from the default. - runSqlHive("set hive.metastore.partition.name.whitelist.pattern=.*") + sessionState.runNativeSql("set hive.metastore.partition.name.whitelist.pattern=.*") // In case a test changed any of these values, restore all the original ones here. TestHiveContext.hiveClientConfigurations( - hiveconf, warehousePath, scratchDirPath, metastoreTemporaryConf) - .foreach { case (k, v) => metadataHive.runSqlHive(s"SET $k=$v") } - defaultOverrides() + sessionState.hiveconf, warehousePath, scratchDirPath, metastoreTemporaryConf) + .foreach { case (k, v) => sessionState.metadataHive.runSqlHive(s"SET $k=$v") } + sessionState.setDefaultOverrideConfs() sessionState.catalog.setCurrentDatabase("default") } catch { @@ -533,6 +472,40 @@ class TestHiveContext private[hive]( } + +private[hive] class TestHiveQueryExecution( + sparkSession: TestHiveSparkSession, + logicalPlan: LogicalPlan) + extends HiveQueryExecution(new SQLContext(sparkSession), logicalPlan) with Logging { + + def this(sparkSession: TestHiveSparkSession, sql: String) { + this(sparkSession, sparkSession.sessionState.sqlParser.parsePlan(sql)) + } + + def this(sql: String) { + this(TestHive.sparkSession, sql) + } + + override lazy val analyzed: LogicalPlan = { + val describedTables = logical match { + case HiveNativeCommand(sparkSession.describedTable(tbl)) => tbl :: Nil + case CacheTableCommand(tbl, _, _) => tbl :: Nil + case _ => Nil + } + + // Make sure any test tables referenced are loaded. + val referencedTables = + describedTables ++ + logical.collect { case UnresolvedRelation(tableIdent, _) => tableIdent.table } + val referencedTestTables = referencedTables.filter(sparkSession.testTables.contains) + logDebug(s"Query references test tables: ${referencedTestTables.mkString(", ")}") + referencedTestTables.foreach(sparkSession.loadTestTable) + // Proceed with analysis. + sparkSession.sessionState.analyzer.execute(logical) + } +} + + private[hive] class TestHiveFunctionRegistry extends SimpleFunctionRegistry { private val removedFunctions = @@ -549,6 +522,58 @@ private[hive] class TestHiveFunctionRegistry extends SimpleFunctionRegistry { } } + +private[hive] class TestHiveSharedState( + sc: SparkContext, + warehousePath: File, + scratchDirPath: File, + metastoreTemporaryConf: Map[String, String]) + extends HiveSharedState(sc) { + + override lazy val metadataHive: HiveClient = { + TestHiveContext.newClientForMetadata( + sc.conf, sc.hadoopConfiguration, warehousePath, scratchDirPath, metastoreTemporaryConf) + } +} + + +private[hive] class TestHiveSessionState(sparkSession: TestHiveSparkSession) + extends HiveSessionState(new SQLContext(sparkSession)) { + + override lazy val conf: SQLConf = { + new SQLConf { + clear() + override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) + override def clear(): Unit = { + super.clear() + TestHiveContext.overrideConfs.map { + case (key, value) => setConfString(key, value) + } + } + } + } + + override lazy val functionRegistry: TestHiveFunctionRegistry = { + // We use TestHiveFunctionRegistry at here to track functions that have been explicitly + // unregistered (through TestHiveFunctionRegistry.unregisterFunction method). + val fr = new TestHiveFunctionRegistry + org.apache.spark.sql.catalyst.analysis.FunctionRegistry.expressions.foreach { + case (name, (info, builder)) => fr.registerFunction(name, info, builder) + } + fr + } + + override def executePlan(plan: LogicalPlan): TestHiveQueryExecution = { + new TestHiveQueryExecution(sparkSession, plan) + } + + // Override so we can intercept relative paths and rewrite them to point at hive. + override def runNativeSql(sql: String): Seq[String] = { + super.runNativeSql(sparkSession.rewritePaths(substitutor.substitute(hiveconf, sql))) + } +} + + private[hive] object TestHiveContext { /** @@ -563,7 +588,7 @@ private[hive] object TestHiveContext { /** * Create a [[HiveClient]] used to retrieve metadata from the Hive MetaStore. */ - private def newClientForMetadata( + def newClientForMetadata( conf: SparkConf, hadoopConf: Configuration, warehousePath: File, @@ -580,7 +605,7 @@ private[hive] object TestHiveContext { /** * Configurations needed to create a [[HiveClient]]. */ - private def hiveClientConfigurations( + def hiveClientConfigurations( hiveconf: HiveConf, warehousePath: File, scratchDirPath: File, @@ -592,7 +617,7 @@ private[hive] object TestHiveContext { ConfVars.METASTORE_CLIENT_CONNECT_RETRY_DELAY.varname -> "1") } - private def makeScratchDir(): File = { + def makeScratchDir(): File = { val scratchDir = Utils.createTempDir(namePrefix = "scratch") scratchDir.delete() scratchDir diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index d9664680f4a11..61910b8e6b51d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -23,7 +23,6 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.util.quietly -import org.apache.spark.sql.hive.execution.HiveSqlParser import org.apache.spark.sql.hive.test.TestHiveSingleton class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterEach { @@ -131,7 +130,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { - def ast = HiveSqlParser.parsePlan(query) + def ast = hiveContext.parseSql(query) def parseTree = Try(quietly(ast.treeString)).getOrElse("") test(name) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala index bf85d71c66759..4d75becdb01d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala @@ -92,8 +92,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT abs(15), abs(-15)") checkSqlGeneration("SELECT array(1,2,3)") checkSqlGeneration("SELECT coalesce(null, 1, 2)") - // wait for resolution of JIRA SPARK-12719 SQL Generation for Generators - // checkSqlGeneration("SELECT explode(array(1,2,3))") + checkSqlGeneration("SELECT explode(array(1,2,3))") checkSqlGeneration("SELECT greatest(1,null,3)") checkSqlGeneration("SELECT if(1==2, 'yes', 'no')") checkSqlGeneration("SELECT isnan(15), isnan('invalid')") @@ -200,8 +199,7 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT locate('is', 'This is a test', 3)") checkSqlGeneration("SELECT lpad('SparkSql', 16, 'Learning')") checkSqlGeneration("SELECT ltrim(' SparkSql ')") - // wait for resolution of JIRA SPARK-12719 SQL Generation for Generators - // checkSqlGeneration("SELECT json_tuple('{\"f1\": \"value1\", \"f2\": \"value2\"}','f1')") + checkSqlGeneration("SELECT json_tuple('{\"f1\": \"value1\", \"f2\": \"value2\"}','f1')") checkSqlGeneration("SELECT printf('aa%d%s', 123, 'cc')") checkSqlGeneration("SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 1)") checkSqlGeneration("SELECT regexp_replace('100-200', '(\\d+)', 'num')") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala index b644a50613337..b2c0f7e0e57b4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveContextSuite.scala @@ -28,9 +28,12 @@ class HiveContextSuite extends SparkFunSuite { val sc = TestHive.sparkContext require(sc.conf.get("spark.sql.hive.metastore.barrierPrefixes") == "org.apache.spark.sql.hive.execution.PairSerDe") - assert(TestHive.initialSQLConf.getConfString("spark.sql.hive.metastore.barrierPrefixes") == + assert(TestHive.sparkSession.initialSQLConf.getConfString( + "spark.sql.hive.metastore.barrierPrefixes") == "org.apache.spark.sql.hive.execution.PairSerDe") - assert(TestHive.metadataHive.getConf("spark.sql.hive.metastore.barrierPrefixes", "") == + // This setting should be also set in the hiveconf of the current session. + assert(TestHive.sessionState.hiveconf.get( + "spark.sql.hive.metastore.barrierPrefixes", "") == "org.apache.spark.sql.hive.execution.PairSerDe") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 110c6d19d89ba..484cf528e6db7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -30,10 +30,11 @@ import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike} -import org.apache.spark.sql.hive.execution.{HiveNativeCommand, HiveSqlParser} +import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.sql.hive.test.TestHive class HiveDDLCommandSuite extends PlanTest { - val parser = HiveSqlParser + val parser = TestHive.sessionState.sqlParser private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parser.parsePlan(sql).collect { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index 3334c16f0be87..84285b7f40832 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -18,12 +18,10 @@ package org.apache.spark.sql.hive import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.util.VersionInfo import org.apache.spark.SparkConf import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.hive.client.{HiveClient, IsolatedClientLoader} -import org.apache.spark.util.Utils +import org.apache.spark.sql.hive.client.HiveClient /** * Test suite for the [[HiveExternalCatalog]]. @@ -31,11 +29,9 @@ import org.apache.spark.util.Utils class HiveExternalCatalogSuite extends CatalogTestCases { private val client: HiveClient = { - IsolatedClientLoader.forVersion( - hiveMetastoreVersion = HiveContext.hiveExecutionVersion, - hadoopVersion = VersionInfo.getVersion, - sparkConf = new SparkConf(), - hadoopConf = new Configuration()).createClient() + // We create a metastore at a temp location to avoid any potential + // conflict of having multiple connections to a single derby instance. + HiveContext.newClientForExecution(new SparkConf, new Configuration) } protected override val utils: CatalogTestUtils = new CatalogTestUtils { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 8648834f0d881..2a201c195f167 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -96,7 +96,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + assert(sessionState.runNativeSql("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } @@ -129,7 +129,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq("decimal(10,3)", "string")) checkAnswer(table("t"), testDF) - assert(runSqlHive("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) + assert(sessionState.runNativeSql("SELECT * FROM t") === Seq("1.1\t1", "2.1\t2")) } } } @@ -159,7 +159,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(columns.map(_.dataType) === Seq("int", "string")) checkAnswer(table("t"), Row(1, "val_1")) - assert(runSqlHive("SELECT * FROM t") === Seq("1\tval_1")) + assert(sessionState.runNativeSql("SELECT * FROM t") === Seq("1\tval_1")) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowDDLSuite.scala new file mode 100644 index 0000000000000..069fdd02fa2d5 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveShowDDLSuite.scala @@ -0,0 +1,468 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.hive + +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils + +class HiveShowDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext.implicits._ + + var jsonFilePath: String = _ + override def beforeAll(): Unit = { + super.beforeAll() + jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile + } + // there are 3 types of create table to test + // 1. HIVE Syntax: create table t1 (c1 int) partitionedby (c2 int) row format... tblproperties.. + // 2. Spark sql syntx: crate table t1 (c1 int) using .. options (... ) + // 3. saving table from datasource: df.write.format("parquet").saveAsTable("t1") + + /** + * Hive syntax DDL + */ + test("Hive syntax DDL: no row format") { + withTempDir { tmpDir => + withTable("t1") { + sql( + s""" + |create table t1(c1 int, c2 string) + |stored as parquet + |location '${tmpDir}' + """.stripMargin) + assert(compareCatalog( + TableIdentifier("t1"), + sql("show create table t1").collect()(0).toSeq(0).toString)) + } + } + } + + test("Hive syntax DDL - partitioned table with column and table comments") { + withTempDir { tmpDir => + withTable("t1") { + sql( + s""" + |create table t1(c1 bigint, c2 string) + |PARTITIONED BY (c3 int COMMENT 'partition column', c4 string) + |row format delimited fields terminated by ',' + |stored as parquet + |location '${tmpDir}' + |TBLPROPERTIES ('my.property.one'='true', 'my.property.two'='1', + |'my.property.three'='2', 'my.property.four'='false') + """.stripMargin) + + assert(compareCatalog( + TableIdentifier("t1"), + sql("show create table t1").collect()(0).toSeq(0).toString)) + } + } + } + + test("Hive syntax DDL - external, row format and tblproperties") { + withTempDir { tmpDir => + withTable("t1") { + sql( + s""" + |create external table t1(c1 bigint, c2 string) + |row format delimited fields terminated by ',' + |stored as parquet + |location '${tmpDir}' + |TBLPROPERTIES ('my.property.one'='true', 'my.property.two'='1', + |'my.property.three'='2', 'my.property.four'='false') + """.stripMargin) + + assert(compareCatalog( + TableIdentifier("t1"), + sql("show create table t1").collect()(0).toSeq(0).toString)) + } + } + } + + test("Hive syntax DDL - more row format definition") { + withTempDir { tmpDir => + withTable("t1") { + sql( + s""" + |create external table t1(c1 int COMMENT 'first column', c2 string) + |COMMENT 'some table' + |PARTITIONED BY (c3 int COMMENT 'partition column', c4 string) + |row format delimited fields terminated by ',' + |COLLECTION ITEMS TERMINATED BY '@' + |MAP KEYS TERMINATED BY '#' + |NULL DEFINED AS 'NaN' + |stored as parquet + |location '${tmpDir}' + |TBLPROPERTIES ('my.property.one'='true', 'my.property.two'='1', + |'my.property.three'='2', 'my.property.four'='false') + """.stripMargin) + + assert(compareCatalog( + TableIdentifier("t1"), + sql("show create table t1").collect()(0).toSeq(0).toString)) + } + } + } + + test("Hive syntax DDL - hive view") { + withTempDir { tmpDir => + withTable("t1") { + withView("v1") { + sql( + s""" + |create table t1(c1 int, c2 string) + |row format delimited fields terminated by ',' + |stored as parquet + |location '${tmpDir}' + """.stripMargin) + sql( + """ + |create view v1 as select * from t1 + """.stripMargin) + assert(compareCatalog( + TableIdentifier("v1"), + sql("show create table v1").collect()(0).toSeq(0).toString)) + } + } + } + } + + test("Hive syntax DDL - SERDE") { + withTempDir { tmpDir => + withTable("t1") { + sql( + s""" + |create table t1(c1 int, c2 string) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + |WITH SERDEPROPERTIES ('mapkey.delim'=',', 'field.delim'=',') + |STORED AS + |INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + |OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + |location '${tmpDir}' + """.stripMargin) + + assert(compareCatalog( + TableIdentifier("t1"), + sql("show create table t1").collect()(0).toSeq(0).toString)) + } + } + } + + // hive native DDL that is not supported by Spark SQL DDL + test("Hive syntax DDL -- CLUSTERED") { + withTable("t1") { + HiveNativeCommand( + s""" + |create table t1 (c1 int, c2 string) + |clustered by (c1) sorted by (c2 desc) into 5 buckets + """.stripMargin).run(sqlContext) + val ddl = sql("show create table t1").collect() + assert(ddl(0).toSeq(0).toString.contains("WARN")) + assert(ddl(1).toSeq(0).toString + .contains("CLUSTERED BY ( `c1` ) \nSORTED BY ( `c2` DESC ) \nINTO 5 BUCKETS")) + } + } + + test("Hive syntax DDL -- STORED BY") { + withTable("tmp_showcrt1") { + HiveNativeCommand( + s""" + |CREATE EXTERNAL TABLE tmp_showcrt1 (key string, value boolean) + |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe' + |STORED BY 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler' + |WITH SERDEPROPERTIES ('field.delim'=',', 'serialization.format'=',') + """.stripMargin).run(sqlContext) + val ddl = sql("show create table tmp_showcrt1").collect() + assert(ddl(0).toSeq(0).toString.contains("WARN")) + assert(ddl(1).toSeq(0).toString + .contains("STORED BY 'org.apache.hadoop.hive.ql.metadata.DefaultStorageHandler'")) + } + } + + test("Hive syntax DDL -- SKEWED BY") { + withTable("stored_as_dirs_multiple") { + HiveNativeCommand( + s""" + |CREATE TABLE stored_as_dirs_multiple (col1 STRING, col2 int, col3 STRING) + |SKEWED BY (col1, col2) ON (('s1',1), ('s3',3), ('s13',13), ('s78',78)) + |stored as DIRECTORIES + """.stripMargin).run(sqlContext) + val ddl = sql("show create table stored_as_dirs_multiple").collect() + assert(ddl(0).toSeq(0).toString.contains("WARN")) + assert(ddl(1).toSeq(0).toString.contains("SKEWED BY ( `col1`, `col2` )")) + } + } + + /** + * Datasource table syntax DDL + */ + test("Datasource Table DDL syntax - persistent JSON table with a user specified schema") { + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable ( + |a string, + |b String, + |`c_!@(3)` int, + |`` Struct<`d!`:array, `=`:array>>) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath') + """.stripMargin) + assert(compareCatalog( + TableIdentifier("jsonTable"), + sql("show create table jsonTable").collect()(0).toSeq(0).toString)) + } + } + + test("Datasource Table DDL syntax - persistent JSON with a subset of user-specified fields") { + withTable("jsonTable") { + // This works because JSON objects are self-describing and JSONRelation can get needed + // field values based on field names. + sql( + s""" + |CREATE TABLE jsonTable (`` Struct<`=`:array>>, b String) + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS (path '$jsonFilePath') + """.stripMargin) + assert(compareCatalog( + TableIdentifier("jsonTable"), + sql("show create table jsonTable").collect()(0).toSeq(0).toString)) + } + } + + test("Datasource Table DDL syntax - USING and OPTIONS - no user-specified schema") { + withTable("jsonTable") { + sql( + s""" + |CREATE TABLE jsonTable + |USING org.apache.spark.sql.json.DefaultSource + |OPTIONS ( + | path '$jsonFilePath', + | key.key1 'value1', + | 'key.key2' 'value2' + |) + """.stripMargin) + assert(compareCatalog( + TableIdentifier("jsonTable"), + sql("show create table jsonTable").collect()(0).toSeq(0).toString)) + } + } + + test("Datasource Table DDL syntax - USING and NO options") { + withTable("parquetTable") { + sql( + s""" + |CREATE TABLE parquetTable (c1 int, c2 string, c3 long) + |USING parquet + """.stripMargin) + assert(compareCatalog( + TableIdentifier("parquetTable"), + sql("show create table parquetTable").collect()(0).toSeq(0).toString)) + } + } + + /** + * Datasource saved to table + */ + test("Save datasource to table - dataframe with a select ") { + withTable("t_datasource") { + val df = sql("select 1, 'abc'") + df.write.saveAsTable("t_datasource") + assert(compareCatalog( + TableIdentifier("t_datasource"), + sql("show create table t_datasource").collect()(0).toSeq(0).toString)) + } + } + + test("Save datasource to table - dataframe from json file") { + val jsonFilePath = Utils.getSparkClassLoader.getResource("sample.json").getFile + withTable("t_datasource") { + val df = sqlContext.read.json(jsonFilePath) + df.write.format("json").saveAsTable("t_datasource") + assert(compareCatalog( + TableIdentifier("t_datasource"), + sql("show create table t_datasource").collect()(0).toSeq(0).toString)) + } + } + + test("Save datasource to table -- dataframe with user-specified schema") { + withTable("ttt3") { + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + df.write + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("ttt3") + assert(compareCatalog( + TableIdentifier("ttt3"), + sql("show create table ttt3").collect()(0).toSeq(0).toString)) + } + } + + test("Save datasource to table -- partitioned, bucket and sort") { + // does not have a place to keep the partitioning columns. + val df = (1 to 3).map(i => (i, s"val_$i", i * 2)).toDF("a", "b", "c") + withTable("ttt3", "ttt5") { + df.write + .partitionBy("a", "b") + .bucketBy(5, "c") + .sortBy("c") + .format("parquet") + .mode(SaveMode.Overwrite) + .saveAsTable("ttt3") + val generatedDDL = + sql("show create table ttt3").collect()(0).toSeq(0).toString.replace("", "df") + assert(generatedDDL.contains("partitionBy(\"a\", \"b\")")) + assert(generatedDDL.contains("bucketBy(5, \"c\")")) + assert(generatedDDL.contains("sortBy(\"c\")")) + assert(generatedDDL.contains("format(\"parquet\")")) + } + } + + /** + * In order to verify whether the generated DDL from a table is correct, we can + * compare the CatalogTable generated from the existing table to the CatalogTable + * generated from the table created with the the generated DDL + * @param expectedTable + * @param actualDDL + * @return true or false + */ + private def compareCatalog(expectedTable: TableIdentifier, actualDDL: String): Boolean = { + val actualTable = expectedTable.table + "_actual" + var actual: CatalogTable = null + val expected: CatalogTable = + sqlContext.sessionState.catalog.getTableMetadata(expectedTable) + withTempDir { tmpDir => + if (actualDDL.contains("CREATE VIEW")) { + withView(actualTable) { + sql(actualDDL.replace(expectedTable.table.toLowerCase(), actualTable)) + actual = sqlContext.sessionState.catalog.getTableMetadata(TableIdentifier(actualTable)) + } + } else { + withTable(actualTable) { + var revisedActualDDL: String = null + if (expected.tableType == CatalogTableType.EXTERNAL_TABLE) { + revisedActualDDL = actualDDL.replace(expectedTable.table.toLowerCase(), actualTable) + } else { + revisedActualDDL = actualDDL + .replace(expectedTable.table.toLowerCase(), actualTable) + .replaceAll("path.*,", s"path '${tmpDir}',") + } + sql(revisedActualDDL) + actual = sqlContext.sessionState.catalog.getTableMetadata(TableIdentifier(actualTable)) + } + } + } + + if (expected.properties.get("spark.sql.sources.provider").isDefined) { + // datasource table: The generated DDL will be like: + // CREATE EXTERNAL TABLE () USING + // OPTIONS (