diff --git a/.travis.yml b/.travis.yml index d94872db6437a..d7e9f8c0290e8 100644 --- a/.travis.yml +++ b/.travis.yml @@ -28,7 +28,6 @@ dist: trusty # 2. Choose language and target JDKs for parallel builds. language: java jdk: - - oraclejdk7 - oraclejdk8 # 3. Setup cache directory for SBT and Maven. diff --git a/R/WINDOWS.md b/R/WINDOWS.md index cb2eebb9ffe6e..9ca7e58e20cd2 100644 --- a/R/WINDOWS.md +++ b/R/WINDOWS.md @@ -6,7 +6,7 @@ To build SparkR on Windows, the following steps are required include Rtools and R in `PATH`. 2. Install -[JDK7](http://www.oracle.com/technetwork/java/javase/downloads/jdk7-downloads-1880260.html) and set +[JDK8](http://www.oracle.com/technetwork/java/javase/downloads/jdk8-downloads-2133151.html) and set `JAVA_HOME` in the system environment variables. 3. Download and install [Maven](http://maven.apache.org/download.html). Also include the `bin` diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 81e19364ae7ea..871f8e41a0f23 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -229,6 +229,7 @@ exportMethods("%in%", "floor", "format_number", "format_string", + "from_json", "from_unixtime", "from_utc_timestamp", "getField", @@ -327,6 +328,7 @@ exportMethods("%in%", "toDegrees", "toRadians", "to_date", + "to_json", "to_timestamp", "to_utc_timestamp", "translate", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index cc4cfa3423ced..97e0c9edeab48 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -280,7 +280,7 @@ setMethod("dtypes", #' Column Names of SparkDataFrame #' -#' Return all column names as a list. +#' Return a vector of column names. #' #' @param x a SparkDataFrame. #' @@ -338,7 +338,7 @@ setMethod("colnames", }) #' @param value a character vector. Must have the same length as the number -#' of columns in the SparkDataFrame. +#' of columns to be renamed. #' @rdname columns #' @aliases colnames<-,SparkDataFrame-method #' @name colnames<- @@ -2642,6 +2642,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame #' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. +#' Input SparkDataFrames can have different schemas (names and data types). #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. #' @@ -2685,7 +2686,8 @@ setMethod("unionAll", #' Union two or more SparkDataFrames #' -#' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. +#' Union two or more SparkDataFrames by row. As in R's \code{rbind}, this method +#' requires that the input SparkDataFrames have the same column names. #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. #' @@ -2709,6 +2711,10 @@ setMethod("unionAll", setMethod("rbind", signature(... = "SparkDataFrame"), function(x, ..., deparse.level = 1) { + nm <- lapply(list(x, ...), names) + if (length(unique(nm)) != 1) { + stop("Names of input data frames are different.") + } if (nargs() == 3) { union(x, ...) } else { diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index e771a057e2444..8354f705f6dea 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -332,8 +332,10 @@ setMethod("toDF", signature(x = "RDD"), #' Create a SparkDataFrame from a JSON file. #' -#' Loads a JSON file (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} -#' ), returning the result as a SparkDataFrame +#' Loads a JSON file, returning the result as a SparkDataFrame +#' By default, (\href{http://jsonlines.org/}{JSON Lines text format or newline-delimited JSON} +#' ) is supported. For JSON (one record per file), set a named property \code{wholeFile} to +#' \code{TRUE}. #' It goes through the entire dataset once to determine the schema. #' #' @param path Path of file to read. A vector of multiple paths is allowed. @@ -346,6 +348,7 @@ setMethod("toDF", signature(x = "RDD"), #' sparkR.session() #' path <- "path/to/file.json" #' df <- read.json(path) +#' df <- read.json(path, wholeFile = TRUE) #' df <- jsonFile(path) #' } #' @name read.json @@ -778,6 +781,7 @@ dropTempView <- function(viewName) { #' @return SparkDataFrame #' @rdname read.df #' @name read.df +#' @seealso \link{read.json} #' @export #' @examples #'\dontrun{ @@ -785,7 +789,7 @@ dropTempView <- function(viewName) { #' df1 <- read.df("path/to/file.json", source = "json") #' schema <- structType(structField("name", "string"), #' structField("info", "map")) -#' df2 <- read.df(mapTypeJsonPath, "json", schema) +#' df2 <- read.df(mapTypeJsonPath, "json", schema, wholeFile = TRUE) #' df3 <- loadDF("data/test_table", "parquet", mergeSchema = "true") #' } #' @name read.df diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R index 9e5084481fcde..edf2bcf8fdb3c 100644 --- a/R/pkg/R/functions.R +++ b/R/pkg/R/functions.R @@ -1793,6 +1793,33 @@ setMethod("to_date", column(jc) }) +#' to_json +#' +#' Converts a column containing a \code{structType} into a Column of JSON string. +#' Resolving the Column can fail if an unsupported type is encountered. +#' +#' @param x Column containing the struct +#' @param ... additional named properties to control how it is converted, accepts the same options +#' as the JSON data source. +#' +#' @family normal_funcs +#' @rdname to_json +#' @name to_json +#' @aliases to_json,Column-method +#' @export +#' @examples +#' \dontrun{ +#' to_json(df$t, dateFormat = 'dd/MM/yyyy') +#' select(df, to_json(df$t)) +#'} +#' @note to_json since 2.2.0 +setMethod("to_json", signature(x = "Column"), + function(x, ...) { + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", "to_json", x@jc, options) + column(jc) + }) + #' to_timestamp #' #' Converts the column into a TimestampType. You may optionally specify a format @@ -2403,6 +2430,36 @@ setMethod("date_format", signature(y = "Column", x = "character"), column(jc) }) +#' from_json +#' +#' Parses a column containing a JSON string into a Column of \code{structType} with the specified +#' \code{schema}. If the string is unparseable, the Column will contains the value NA. +#' +#' @param x Column containing the JSON string. +#' @param schema a structType object to use as the schema to use when parsing the JSON string. +#' @param ... additional named properties to control how the json is parsed, accepts the same +#' options as the JSON data source. +#' +#' @family normal_funcs +#' @rdname from_json +#' @name from_json +#' @aliases from_json,Column,structType-method +#' @export +#' @examples +#' \dontrun{ +#' schema <- structType(structField("name", "string"), +#' select(df, from_json(df$value, schema, dateFormat = "dd/MM/yyyy")) +#'} +#' @note from_json since 2.2.0 +setMethod("from_json", signature(x = "Column", schema = "structType"), + function(x, schema, ...) { + options <- varargsToStrEnv(...) + jc <- callJStatic("org.apache.spark.sql.functions", + "from_json", + x@jc, schema$jobj, options) + column(jc) + }) + #' from_utc_timestamp #' #' Given a timestamp, which corresponds to a certain time of day in UTC, returns another timestamp diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 647cbbdd825e3..45bc12746511c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -991,6 +991,10 @@ setGeneric("format_number", function(y, x) { standardGeneric("format_number") }) #' @export setGeneric("format_string", function(format, x, ...) { standardGeneric("format_string") }) +#' @rdname from_json +#' @export +setGeneric("from_json", function(x, schema, ...) { standardGeneric("from_json") }) + #' @rdname from_unixtime #' @export setGeneric("from_unixtime", function(x, ...) { standardGeneric("from_unixtime") }) @@ -1265,6 +1269,10 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @export setGeneric("to_date", function(x, format) { standardGeneric("to_date") }) +#' @rdname to_json +#' @export +setGeneric("to_json", function(x, ...) { standardGeneric("to_json") }) + #' @rdname to_timestamp #' @export setGeneric("to_timestamp", function(x, format) { standardGeneric("to_timestamp") }) diff --git a/R/pkg/R/mllib_classification.R b/R/pkg/R/mllib_classification.R index 05bb95266173a..4db9cc30fb0c1 100644 --- a/R/pkg/R/mllib_classification.R +++ b/R/pkg/R/mllib_classification.R @@ -75,9 +75,9 @@ setClass("NaiveBayesModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' df <- createDataFrame(iris) -#' training <- df[df$Species %in% c("versicolor", "virginica"), ] -#' model <- spark.svmLinear(training, Species ~ ., regParam = 0.5) +#' t <- as.data.frame(Titanic) +#' training <- createDataFrame(t) +#' model <- spark.svmLinear(training, Survived ~ ., regParam = 0.5) #' summary <- summary(model) #' #' # fitted values on training data @@ -220,9 +220,9 @@ function(object, path, overwrite = FALSE) { #' \dontrun{ #' sparkR.session() #' # binary logistic regression -#' df <- createDataFrame(iris) -#' training <- df[df$Species %in% c("versicolor", "virginica"), ] -#' model <- spark.logit(training, Species ~ ., regParam = 0.5) +#' t <- as.data.frame(Titanic) +#' training <- createDataFrame(t) +#' model <- spark.logit(training, Survived ~ ., regParam = 0.5) #' summary <- summary(model) #' #' # fitted values on training data @@ -239,8 +239,7 @@ function(object, path, overwrite = FALSE) { #' #' # multinomial logistic regression #' -#' df <- createDataFrame(iris) -#' model <- spark.logit(df, Species ~ ., regParam = 0.5) +#' model <- spark.logit(training, Class ~ ., regParam = 0.5) #' summary <- summary(model) #' #' } diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R index 8823f90775960..0ebdb5a273088 100644 --- a/R/pkg/R/mllib_clustering.R +++ b/R/pkg/R/mllib_clustering.R @@ -72,8 +72,9 @@ setClass("LDAModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' df <- createDataFrame(iris) -#' model <- spark.bisectingKmeans(df, Sepal_Length ~ Sepal_Width, k = 4) +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.bisectingKmeans(df, Class ~ Survived, k = 4) #' summary(model) #' #' # get fitted result from a bisecting k-means model @@ -82,7 +83,7 @@ setClass("LDAModel", representation(jobj = "jobj")) #' #' # fitted values on training data #' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) +#' head(select(fitted, "Class", "prediction")) #' #' # save fitted model to input path #' path <- "path/to/model" @@ -338,14 +339,14 @@ setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "charact #' @examples #' \dontrun{ #' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- spark.kmeans(df, Sepal_Length ~ Sepal_Width, k = 4, initMode = "random") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.kmeans(df, Class ~ Survived, k = 4, initMode = "random") #' summary(model) #' #' # fitted values on training data #' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) +#' head(select(fitted, "Class", "prediction")) #' #' # save fitted model to input path #' path <- "path/to/model" diff --git a/R/pkg/R/mllib_regression.R b/R/pkg/R/mllib_regression.R index ac0578c4ab259..648d363f1a255 100644 --- a/R/pkg/R/mllib_regression.R +++ b/R/pkg/R/mllib_regression.R @@ -68,14 +68,14 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj")) #' @examples #' \dontrun{ #' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- spark.glm(df, Sepal_Length ~ Sepal_Width, family = "gaussian") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.glm(df, Freq ~ Sex + Age, family = "gaussian") #' summary(model) #' #' # fitted values on training data #' fitted <- predict(model, df) -#' head(select(fitted, "Sepal_Length", "prediction")) +#' head(select(fitted, "Freq", "prediction")) #' #' # save fitted model to input path #' path <- "path/to/model" @@ -137,9 +137,9 @@ setMethod("spark.glm", signature(data = "SparkDataFrame", formula = "formula"), #' @examples #' \dontrun{ #' sparkR.session() -#' data(iris) -#' df <- createDataFrame(iris) -#' model <- glm(Sepal_Length ~ Sepal_Width, df, family = "gaussian") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- glm(Freq ~ Sex + Age, df, family = "gaussian") #' summary(model) #' } #' @note glm since 1.5.0 diff --git a/R/pkg/R/mllib_tree.R b/R/pkg/R/mllib_tree.R index 0d53fad061809..40a806c41bad0 100644 --- a/R/pkg/R/mllib_tree.R +++ b/R/pkg/R/mllib_tree.R @@ -143,14 +143,15 @@ print.summary.treeEnsemble <- function(x) { #' #' # fit a Gradient Boosted Tree Classification Model #' # label must be binary - Only binary classification is supported for GBT. -#' df <- createDataFrame(iris[iris$Species != "virginica", ]) -#' model <- spark.gbt(df, Species ~ Petal_Length + Petal_Width, "classification") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.gbt(df, Survived ~ Age + Freq, "classification") #' #' # numeric label is also supported -#' iris2 <- iris[iris$Species != "virginica", ] -#' iris2$NumericSpecies <- ifelse(iris2$Species == "setosa", 0, 1) -#' df <- createDataFrame(iris2) -#' model <- spark.gbt(df, NumericSpecies ~ ., type = "classification") +#' t2 <- as.data.frame(Titanic) +#' t2$NumericGender <- ifelse(t2$Sex == "Male", 0, 1) +#' df <- createDataFrame(t2) +#' model <- spark.gbt(df, NumericGender ~ ., type = "classification") #' } #' @note spark.gbt since 2.1.0 setMethod("spark.gbt", signature(data = "SparkDataFrame", formula = "formula"), @@ -351,8 +352,9 @@ setMethod("write.ml", signature(object = "GBTClassificationModel", path = "chara #' summary(savedModel) #' #' # fit a Random Forest Classification Model -#' df <- createDataFrame(iris) -#' model <- spark.randomForest(df, Species ~ Petal_Length + Petal_Width, "classification") +#' t <- as.data.frame(Titanic) +#' df <- createDataFrame(t) +#' model <- spark.randomForest(df, Survived ~ Freq + Age, "classification") #' } #' @note spark.randomForest since 2.1.0 setMethod("spark.randomForest", signature(data = "SparkDataFrame", formula = "formula"), diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index ce0f5a198a259..620b633637138 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -88,6 +88,13 @@ mockLinesComplexType <- complexTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") writeLines(mockLinesComplexType, complexTypeJsonPath) +# For test map type and struct type in DataFrame +mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", + "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", + "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") +mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") +writeLines(mockLinesMapType, mapTypeJsonPath) + test_that("calling sparkRSQL.init returns existing SQL context", { sqlContext <- suppressWarnings(sparkRSQL.init(sc)) expect_equal(suppressWarnings(sparkRSQL.init(sc)), sqlContext) @@ -466,13 +473,6 @@ test_that("create DataFrame from a data.frame with complex types", { expect_equal(ldf$an_envir, collected$an_envir) }) -# For test map type and struct type in DataFrame -mockLinesMapType <- c("{\"name\":\"Bob\",\"info\":{\"age\":16,\"height\":176.5}}", - "{\"name\":\"Alice\",\"info\":{\"age\":20,\"height\":164.3}}", - "{\"name\":\"David\",\"info\":{\"age\":60,\"height\":180}}") -mapTypeJsonPath <- tempfile(pattern = "sparkr-test", fileext = ".tmp") -writeLines(mockLinesMapType, mapTypeJsonPath) - test_that("Collect DataFrame with complex types", { # ArrayType df <- read.json(complexTypeJsonPath) @@ -898,6 +898,12 @@ test_that("names() colnames() set the column names", { expect_equal(names(z)[3], "c") names(z)[3] <- "c2" expect_equal(names(z)[3], "c2") + + # Test subset assignment + colnames(df)[1] <- "col5" + expect_equal(colnames(df)[1], "col5") + names(df)[2] <- "col6" + expect_equal(names(df)[2], "col6") }) test_that("head() and first() return the correct data", { @@ -1331,6 +1337,33 @@ test_that("column functions", { df <- createDataFrame(data.frame(x = c(2.5, 3.5))) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][1], 2) expect_equal(collect(select(df, bround(df$x, 0)))[[1]][2], 4) + + # Test to_json(), from_json() + df <- read.json(mapTypeJsonPath) + j <- collect(select(df, alias(to_json(df$info), "json"))) + expect_equal(j[order(j$json), ][1], "{\"age\":16,\"height\":176.5}") + df <- as.DataFrame(j) + schema <- structType(structField("age", "integer"), + structField("height", "double")) + s <- collect(select(df, alias(from_json(df$json, schema), "structcol"))) + expect_equal(ncol(s), 1) + expect_equal(nrow(s), 3) + expect_is(s[[1]][[1]], "struct") + expect_true(any(apply(s, 1, function(x) { x[[1]]$age == 16 } ))) + + # passing option + df <- as.DataFrame(list(list("col" = "{\"date\":\"21/10/2014\"}"))) + schema2 <- structType(structField("date", "date")) + expect_error(tryCatch(collect(select(df, from_json(df$col, schema2))), + error = function(e) { stop(e) }), + paste0(".*(java.lang.NumberFormatException: For input string:).*")) + s <- collect(select(df, from_json(df$col, schema2, dateFormat = "dd/MM/yyyy"))) + expect_is(s[[1]][[1]]$date, "Date") + expect_equal(as.character(s[[1]][[1]]$date), "2014-10-21") + + # check for unparseable + df <- as.DataFrame(list(list("a" = ""))) + expect_equal(collect(select(df, from_json(df$a, schema)))[[1]][[1]], NA) }) test_that("column binary mathfunctions", { @@ -1817,6 +1850,13 @@ test_that("union(), rbind(), except(), and intersect() on a DataFrame", { expect_equal(count(unioned2), 12) expect_equal(first(unioned2)$name, "Michael") + df3 <- df2 + names(df3)[1] <- "newName" + expect_error(rbind(df, df3), + "Names of input data frames are different.") + expect_error(rbind(df, df2, df3), + "Names of input data frames are different.") + excepted <- arrange(except(df, df2), desc(df$age)) expect_is(unioned, "SparkDataFrame") expect_equal(count(excepted), 2) @@ -2861,5 +2901,7 @@ unlink(parquetPath) unlink(orcPath) unlink(jsonPath) unlink(jsonPathNa) +unlink(complexTypeJsonPath) +unlink(mapTypeJsonPath) sparkR.session.stop() diff --git a/R/pkg/vignettes/sparkr-vignettes.Rmd b/R/pkg/vignettes/sparkr-vignettes.Rmd index bc8bc3c26c116..43c255cff3028 100644 --- a/R/pkg/vignettes/sparkr-vignettes.Rmd +++ b/R/pkg/vignettes/sparkr-vignettes.Rmd @@ -565,11 +565,10 @@ We use a simple example to demonstrate `spark.logit` usage. In general, there ar and 3). Obtain the coefficient matrix of the fitted model using `summary` and use the model for prediction with `predict`. Binomial logistic regression -```{r, warning=FALSE} -df <- createDataFrame(iris) -# Create a DataFrame containing two classes -training <- df[df$Species %in% c("versicolor", "virginica"), ] -model <- spark.logit(training, Species ~ ., regParam = 0.00042) +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +model <- spark.logit(training, Survived ~ ., regParam = 0.04741301) summary(model) ``` @@ -579,10 +578,11 @@ fitted <- predict(model, training) ``` Multinomial logistic regression against three classes -```{r, warning=FALSE} -df <- createDataFrame(iris) +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) # Note in this case, Spark infers it is multinomial logistic regression, so family = "multinomial" is optional. -model <- spark.logit(df, Species ~ ., regParam = 0.056) +model <- spark.logit(training, Class ~ ., regParam = 0.07815179) summary(model) ``` @@ -609,11 +609,12 @@ MLPC employs backpropagation for learning the model. We use the logistic loss fu `spark.mlp` requires at least two columns in `data`: one named `"label"` and the other one `"features"`. The `"features"` column should be in libSVM-format. -We use iris data set to show how to use `spark.mlp` in classification. -```{r, warning=FALSE} -df <- createDataFrame(iris) +We use Titanic data set to show how to use `spark.mlp` in classification. +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) # fit a Multilayer Perceptron Classification Model -model <- spark.mlp(df, Species ~ ., blockSize = 128, layers = c(4, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c(0, 0, 0, 0, 0, 5, 5, 5, 5, 5, 9, 9, 9, 9, 9)) +model <- spark.mlp(training, Survived ~ Age + Sex, blockSize = 128, layers = c(2, 3), solver = "l-bfgs", maxIter = 100, tol = 0.5, stepSize = 1, seed = 1, initialWeights = c( 0, 0, 0, 5, 5, 5, 9, 9, 9)) ``` To avoid lengthy display, we only present partial results of the model summary. You can check the full result from your sparkR shell. @@ -630,7 +631,7 @@ options(ops) ``` ```{r} # make predictions use the fitted model -predictions <- predict(model, df) +predictions <- predict(model, training) head(select(predictions, predictions$prediction)) ``` @@ -769,12 +770,13 @@ predictions <- predict(rfModel, df) `spark.bisectingKmeans` is a kind of [hierarchical clustering](https://en.wikipedia.org/wiki/Hierarchical_clustering) using a divisive (or "top-down") approach: all observations start in one cluster, and splits are performed recursively as one moves down the hierarchy. -```{r, warning=FALSE} -df <- createDataFrame(iris) -model <- spark.bisectingKmeans(df, Sepal_Length ~ Sepal_Width, k = 4) +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +model <- spark.bisectingKmeans(training, Class ~ Survived, k = 4) summary(model) -fitted <- predict(model, df) -head(select(fitted, "Sepal_Length", "prediction")) +fitted <- predict(model, training) +head(select(fitted, "Class", "prediction")) ``` #### Gaussian Mixture Model @@ -912,9 +914,10 @@ testSummary ### Model Persistence The following example shows how to save/load an ML model by SparkR. -```{r, warning=FALSE} -irisDF <- createDataFrame(iris) -gaussianGLM <- spark.glm(irisDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") +```{r} +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +gaussianGLM <- spark.glm(training, Freq ~ Sex + Age, family = "gaussian") # Save and then load a fitted MLlib model modelPath <- tempfile(pattern = "ml", fileext = ".tmp") @@ -925,7 +928,7 @@ gaussianGLM2 <- read.ml(modelPath) summary(gaussianGLM2) # Check model prediction -gaussianPredictions <- predict(gaussianGLM2, irisDF) +gaussianPredictions <- predict(gaussianGLM2, training) head(gaussianPredictions) unlink(modelPath) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0fd777ed12829..f0867ecb16ea3 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.executor.TaskMetrics import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util.{AccumulatorV2, TaskCompletionListener, TaskFailureListener} @@ -190,4 +191,10 @@ abstract class TaskContext extends Serializable { */ private[spark] def registerAccumulator(a: AccumulatorV2[_, _]): Unit + /** + * Record that this task has failed due to a fetch failure from a remote host. This allows + * fetch-failure handling to get triggered by the driver, regardless of intervening user-code. + */ + private[spark] def setFetchFailed(fetchFailed: FetchFailedException): Unit + } diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index c904e083911cd..dc0d12878550a 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.MetricsSystem import org.apache.spark.metrics.source.Source +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ private[spark] class TaskContextImpl( @@ -56,6 +57,10 @@ private[spark] class TaskContextImpl( // Whether the task has failed. @volatile private var failed: Boolean = false + // If there was a fetch failure in the task, we store it here, to make sure user-code doesn't + // hide the exception. See SPARK-19276 + @volatile private var _fetchFailedException: Option[FetchFailedException] = None + override def addTaskCompletionListener(listener: TaskCompletionListener): this.type = { onCompleteCallbacks += listener this @@ -126,4 +131,10 @@ private[spark] class TaskContextImpl( taskMetrics.registerAccumulator(a) } + private[spark] override def setFetchFailed(fetchFailed: FetchFailedException): Unit = { + this._fetchFailedException = Option(fetchFailed) + } + + private[spark] def fetchFailed: Option[FetchFailedException] = _fetchFailedException + } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 941e2d13fb28e..f475ce87540aa 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -82,17 +82,20 @@ class SparkHadoopUtil extends Logging { // the behavior of the old implementation of this code, for backwards compatibility. if (conf != null) { // Explicitly check for S3 environment variables - if (System.getenv("AWS_ACCESS_KEY_ID") != null && - System.getenv("AWS_SECRET_ACCESS_KEY") != null) { - val keyId = System.getenv("AWS_ACCESS_KEY_ID") - val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") - + val keyId = System.getenv("AWS_ACCESS_KEY_ID") + val accessKey = System.getenv("AWS_SECRET_ACCESS_KEY") + if (keyId != null && accessKey != null) { hadoopConf.set("fs.s3.awsAccessKeyId", keyId) hadoopConf.set("fs.s3n.awsAccessKeyId", keyId) hadoopConf.set("fs.s3a.access.key", keyId) hadoopConf.set("fs.s3.awsSecretAccessKey", accessKey) hadoopConf.set("fs.s3n.awsSecretAccessKey", accessKey) hadoopConf.set("fs.s3a.secret.key", accessKey) + + val sessionToken = System.getenv("AWS_SESSION_TOKEN") + if (sessionToken != null) { + hadoopConf.set("fs.s3a.session.token", sessionToken) + } } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" conf.getAll.foreach { case (key, value) => diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index 5ffdedd1658ab..1e50eb6635651 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -665,7 +665,8 @@ object SparkSubmit extends CommandLineUtils { if (verbose) { printStream.println(s"Main class:\n$childMainClass") printStream.println(s"Arguments:\n${childArgs.mkString("\n")}") - printStream.println(s"System properties:\n${sysProps.mkString("\n")}") + // sysProps may contain sensitive information, so redact before printing + printStream.println(s"System properties:\n${Utils.redact(sysProps).mkString("\n")}") printStream.println(s"Classpath elements:\n${childClasspath.mkString("\n")}") printStream.println("\n") } diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala index dee77343d806d..0614d80b60e1c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala @@ -84,9 +84,15 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S // scalastyle:off println if (verbose) SparkSubmit.printStream.println(s"Using properties file: $propertiesFile") Option(propertiesFile).foreach { filename => - Utils.getPropertiesFromFile(filename).foreach { case (k, v) => + val properties = Utils.getPropertiesFromFile(filename) + properties.foreach { case (k, v) => defaultProperties(k) = v - if (verbose) SparkSubmit.printStream.println(s"Adding default property: $k=$v") + } + // Property files may contain sensitive information, so redact before printing + if (verbose) { + Utils.redact(properties).foreach { case (k, v) => + SparkSubmit.printStream.println(s"Adding default property: $k=$v") + } } } // scalastyle:on println @@ -318,7 +324,7 @@ private[deploy] class SparkSubmitArguments(args: Seq[String], env: Map[String, S | |Spark properties used, including those specified through | --conf and those from the properties file $propertiesFile: - |${sparkProperties.mkString(" ", "\n ", "\n")} + |${Utils.redact(sparkProperties).mkString(" ", "\n ", "\n")} """.stripMargin } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 7dbe32975435d..e722a24d4a89e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -76,7 +76,7 @@ private[ui] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { val aliveWorkers = state.workers.filter(_.state == WorkerState.ALIVE) val workerTable = UIUtils.listingTable(workerHeaders, workerRow, workers) - val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Node", "Submitted Time", + val appHeaders = Seq("Application ID", "Name", "Cores", "Memory per Executor", "Submitted Time", "User", "State", "Duration") val activeApps = state.activeApps.sortBy(_.startTime).reverse val activeAppsTable = UIUtils.listingTable(appHeaders, appRow, activeApps) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 975a6e4eeb33a..790c1ae942474 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import java.io.{File, NotSerializableException} +import java.lang.Thread.UncaughtExceptionHandler import java.lang.management.ManagementFactory import java.net.{URI, URL} import java.nio.ByteBuffer @@ -52,7 +53,8 @@ private[spark] class Executor( executorHostname: String, env: SparkEnv, userClassPath: Seq[URL] = Nil, - isLocal: Boolean = false) + isLocal: Boolean = false, + uncaughtExceptionHandler: UncaughtExceptionHandler = SparkUncaughtExceptionHandler) extends Logging { logInfo(s"Starting executor ID $executorId on host $executorHostname") @@ -78,7 +80,7 @@ private[spark] class Executor( // Setup an uncaught exception handler for non-local mode. // Make any thread terminations due to uncaught exceptions kill the entire // executor process to avoid surprising stalls. - Thread.setDefaultUncaughtExceptionHandler(SparkUncaughtExceptionHandler) + Thread.setDefaultUncaughtExceptionHandler(uncaughtExceptionHandler) } // Start worker thread pool @@ -342,6 +344,14 @@ private[spark] class Executor( } } } + task.context.fetchFailed.foreach { fetchFailure => + // uh-oh. it appears the user code has caught the fetch-failure without throwing any + // other exceptions. Its *possible* this is what the user meant to do (though highly + // unlikely). So we will log an error and keep going. + logError(s"TID ${taskId} completed successfully though internally it encountered " + + s"unrecoverable fetch failures! Most likely this means user code is incorrectly " + + s"swallowing Spark's internal ${classOf[FetchFailedException]}", fetchFailure) + } val taskFinish = System.currentTimeMillis() val taskFinishCpu = if (threadMXBean.isCurrentThreadCpuTimeSupported) { threadMXBean.getCurrentThreadCpuTime @@ -402,8 +412,17 @@ private[spark] class Executor( execBackend.statusUpdate(taskId, TaskState.FINISHED, serializedResult) } catch { - case ffe: FetchFailedException => - val reason = ffe.toTaskFailedReason + case t: Throwable if hasFetchFailure && !Utils.isFatalError(t) => + val reason = task.context.fetchFailed.get.toTaskFailedReason + if (!t.isInstanceOf[FetchFailedException]) { + // there was a fetch failure in the task, but some user code wrapped that exception + // and threw something else. Regardless, we treat it as a fetch failure. + val fetchFailedCls = classOf[FetchFailedException].getName + logWarning(s"TID ${taskId} encountered a ${fetchFailedCls} and " + + s"failed, but the ${fetchFailedCls} was hidden by another " + + s"exception. Spark is handling this like a fetch failure and ignoring the " + + s"other exception: $t") + } setTaskFinishedAndClearInterruptStatus() execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason)) @@ -455,13 +474,17 @@ private[spark] class Executor( // Don't forcibly exit unless the exception was inherently fatal, to avoid // stopping other tasks unnecessarily. if (Utils.isFatalError(t)) { - SparkUncaughtExceptionHandler.uncaughtException(t) + uncaughtExceptionHandler.uncaughtException(Thread.currentThread(), t) } } finally { runningTasks.remove(taskId) } } + + private def hasFetchFailure: Boolean = { + task != null && task.context != null && task.context.fetchFailed.isDefined + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 08d220b40b6f3..83d87b548a430 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -48,25 +48,29 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) private type StageId = Int private type PartitionId = Int private type TaskAttemptNumber = Int - private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1 + private case class StageState(numPartitions: Int) { + val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER) + val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]() + } /** - * Map from active stages's id => partition id => task attempt with exclusive lock on committing - * output for that partition. + * Map from active stages's id => authorized task attempts for each partition id, which hold an + * exclusive lock on committing task output for that partition, as well as any known failed + * attempts in the stage. * * Entries are added to the top-level map when stages start and are removed they finish * (either successfully or unsuccessfully). * * Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance. */ - private val authorizedCommittersByStage = mutable.Map[StageId, Array[TaskAttemptNumber]]() + private val stageStates = mutable.Map[StageId, StageState]() /** * Returns whether the OutputCommitCoordinator's internal data structures are all empty. */ def isEmpty: Boolean = { - authorizedCommittersByStage.isEmpty + stageStates.isEmpty } /** @@ -105,19 +109,13 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) * @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e. * the maximum possible value of `context.partitionId`). */ - private[scheduler] def stageStart( - stage: StageId, - maxPartitionId: Int): Unit = { - val arr = new Array[TaskAttemptNumber](maxPartitionId + 1) - java.util.Arrays.fill(arr, NO_AUTHORIZED_COMMITTER) - synchronized { - authorizedCommittersByStage(stage) = arr - } + private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized { + stageStates(stage) = new StageState(maxPartitionId + 1) } // Called by DAGScheduler private[scheduler] def stageEnd(stage: StageId): Unit = synchronized { - authorizedCommittersByStage.remove(stage) + stageStates.remove(stage) } // Called by DAGScheduler @@ -126,7 +124,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) partition: PartitionId, attemptNumber: TaskAttemptNumber, reason: TaskEndReason): Unit = synchronized { - val authorizedCommitters = authorizedCommittersByStage.getOrElse(stage, { + val stageState = stageStates.getOrElse(stage, { logDebug(s"Ignoring task completion for completed stage") return }) @@ -137,10 +135,12 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " + s"attempt: $attemptNumber") case otherReason => - if (authorizedCommitters(partition) == attemptNumber) { + // Mark the attempt as failed to blacklist from future commit protocol + stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber + if (stageState.authorizedCommitters(partition) == attemptNumber) { logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " + s"partition=$partition) failed; clearing lock") - authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER + stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER } } } @@ -149,7 +149,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) if (isDriver) { coordinatorRef.foreach(_ send StopCoordinator) coordinatorRef = None - authorizedCommittersByStage.clear() + stageStates.clear() } } @@ -158,13 +158,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) stage: StageId, partition: PartitionId, attemptNumber: TaskAttemptNumber): Boolean = synchronized { - authorizedCommittersByStage.get(stage) match { - case Some(authorizedCommitters) => - authorizedCommitters(partition) match { + stageStates.get(stage) match { + case Some(state) if attemptFailed(state, partition, attemptNumber) => + logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," + + s" partition=$partition as task attempt $attemptNumber has already failed.") + false + case Some(state) => + state.authorizedCommitters(partition) match { case NO_AUTHORIZED_COMMITTER => logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " + s"partition=$partition") - authorizedCommitters(partition) = attemptNumber + state.authorizedCommitters(partition) = attemptNumber true case existingCommitter => // Coordinator should be idempotent when receiving AskPermissionToCommit. @@ -181,11 +185,18 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean) } } case None => - logDebug(s"Stage $stage has completed, so not allowing attempt number $attemptNumber of" + - s"partition $partition to commit") + logDebug(s"Stage $stage has completed, so not allowing" + + s" attempt number $attemptNumber of partition $partition to commit") false } } + + private def attemptFailed( + stageState: StageState, + partition: PartitionId, + attempt: TaskAttemptNumber): Boolean = synchronized { + stageState.failures.get(partition).exists(_.contains(attempt)) + } } private[spark] object OutputCommitCoordinator { diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7b726d5659e91..70213722aae4f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -17,19 +17,14 @@ package org.apache.spark.scheduler -import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer import java.util.Properties -import scala.collection.mutable -import scala.collection.mutable.HashMap - import org.apache.spark._ import org.apache.spark.executor.TaskMetrics import org.apache.spark.internal.config.APP_CALLER_CONTEXT import org.apache.spark.memory.{MemoryMode, TaskMemoryManager} import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.serializer.SerializerInstance import org.apache.spark.util._ /** @@ -137,6 +132,8 @@ private[spark] abstract class Task[T]( memoryManager.synchronized { memoryManager.notifyAll() } } } finally { + // Though we unset the ThreadLocal here, the context member variable itself is still queried + // directly in the TaskRunner to check for FetchFailedExceptions. TaskContext.unset() } } @@ -156,7 +153,7 @@ private[spark] abstract class Task[T]( var epoch: Long = -1 // Task context, to be initialized in run(). - @transient protected var context: TaskContextImpl = _ + @transient var context: TaskContextImpl = _ // The actual Thread on which the task is running, if any. Initialized in run(). @volatile @transient private var taskThread: Thread = _ diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 78aa5c40010cc..c98b87148e404 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.util.Properties import scala.collection.JavaConverters._ @@ -86,7 +87,10 @@ private[spark] object TaskDescription { dataOut.writeInt(taskDescription.properties.size()) taskDescription.properties.asScala.foreach { case (key, value) => dataOut.writeUTF(key) - dataOut.writeUTF(value) + // SPARK-19796 -- writeUTF doesn't work for long strings, which can happen for property values + val bytes = value.getBytes(StandardCharsets.UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) } // Write the task. The task is already serialized, so write it directly to the byte buffer. @@ -124,7 +128,11 @@ private[spark] object TaskDescription { val properties = new Properties() val numProperties = dataIn.readInt() for (i <- 0 until numProperties) { - properties.setProperty(dataIn.readUTF(), dataIn.readUTF()) + val key = dataIn.readUTF() + val valueLength = dataIn.readInt() + val valueBytes = new Array[Byte](valueLength) + dataIn.readFully(valueBytes) + properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8)) } // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 3b25513bea057..19ebaf817e24e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -874,7 +874,8 @@ private[spark] class TaskSetManager( // and we are not using an external shuffle server which could serve the shuffle outputs. // The reason is the next stage wouldn't be able to fetch the data from this dead executor // so we would need to rerun these tasks on other executors. - if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled) { + if (tasks(0).isInstanceOf[ShuffleMapTask] && !env.blockManager.externalShuffleServiceEnabled + && !isZombie) { for ((tid, info) <- taskInfos if info.executorId == execId) { val index = taskInfos(tid).index if (successful(index)) { @@ -906,8 +907,6 @@ private[spark] class TaskSetManager( * Check for tasks to be speculated and return true if there are any. This is called periodically * by the TaskScheduler. * - * TODO: To make this scale to large jobs, we need to maintain a list of running tasks, so that - * we don't scan the whole task set. It might also help to make this sorted by launch time. */ override def checkSpeculatableTasks(minTimeToSpeculation: Int): Boolean = { // Can't speculate if we only have one task, and no need to speculate if the task set is a @@ -927,7 +926,8 @@ private[spark] class TaskSetManager( // TODO: Threshold should also look at standard deviation of task durations and have a lower // bound based on that. logDebug("Task length threshold for speculation: " + threshold) - for ((tid, info) <- taskInfos) { + for (tid <- runningTasksSet) { + val info = taskInfos(tid) val index = info.index if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold && !speculatableTasks.contains(index)) { diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala index 498c12e196ce0..265a8acfa8d61 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala @@ -17,7 +17,7 @@ package org.apache.spark.shuffle -import org.apache.spark.{FetchFailed, TaskFailedReason} +import org.apache.spark.{FetchFailed, TaskContext, TaskFailedReason} import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.Utils @@ -26,6 +26,11 @@ import org.apache.spark.util.Utils * back to DAGScheduler (through TaskEndReason) so we'd resubmit the previous stage. * * Note that bmAddress can be null. + * + * To prevent user code from hiding this fetch failure, in the constructor we call + * [[TaskContext.setFetchFailed()]]. This means that you *must* throw this exception immediately + * after creating it -- you cannot create it, check some condition, and then decide to ignore it + * (or risk triggering any other exceptions). See SPARK-19276. */ private[spark] class FetchFailedException( bmAddress: BlockManagerId, @@ -45,6 +50,12 @@ private[spark] class FetchFailedException( this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause) } + // SPARK-19276. We set the fetch failure in the task context, so that even if there is user-code + // which intercepts this exception (possibly wrapping it), the Executor can still tell there was + // a fetch failure, and send the correct error msg back to the driver. We wrap with an Option + // because the TaskContext is not defined in some test cases. + Option(TaskContext.get()).map(_.setFetchFailed(this)) + def toTaskFailedReason: TaskFailedReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, Utils.exceptionString(this)) } diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 10e5233679562..1af34e3da231f 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -39,6 +39,7 @@ import scala.io.Source import scala.reflect.ClassTag import scala.util.Try import scala.util.control.{ControlThrowable, NonFatal} +import scala.util.matching.Regex import _root_.io.netty.channel.unix.Errors.NativeIoException import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} @@ -2588,13 +2589,31 @@ private[spark] object Utils extends Logging { def redact(conf: SparkConf, kvs: Seq[(String, String)]): Seq[(String, String)] = { val redactionPattern = conf.get(SECRET_REDACTION_PATTERN).r + redact(redactionPattern, kvs) + } + + private def redact(redactionPattern: Regex, kvs: Seq[(String, String)]): Seq[(String, String)] = { kvs.map { kv => redactionPattern.findFirstIn(kv._1) - .map { ignore => (kv._1, REDACTION_REPLACEMENT_TEXT) } + .map { _ => (kv._1, REDACTION_REPLACEMENT_TEXT) } .getOrElse(kv) } } + /** + * Looks up the redaction regex from within the key value pairs and uses it to redact the rest + * of the key value pairs. No care is taken to make sure the redaction property itself is not + * redacted. So theoretically, the property itself could be configured to redact its own value + * when printing. + */ + def redact(kvs: Map[String, String]): Seq[(String, String)] = { + val redactionPattern = kvs.getOrElse( + SECRET_REDACTION_PATTERN.key, + SECRET_REDACTION_PATTERN.defaultValueString + ).r + redact(redactionPattern, kvs.toArray) + } + } private[util] object CallerContext extends Logging { diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala index b743ff5376c49..8150fff2d018d 100644 --- a/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.executor import java.io.{Externalizable, ObjectInput, ObjectOutput} +import java.lang.Thread.UncaughtExceptionHandler import java.nio.ByteBuffer import java.util.Properties import java.util.concurrent.{CountDownLatch, TimeUnit} @@ -27,7 +28,7 @@ import scala.concurrent.duration._ import org.mockito.ArgumentCaptor import org.mockito.Matchers.{any, eq => meq} -import org.mockito.Mockito.{inOrder, when} +import org.mockito.Mockito.{inOrder, verify, when} import org.mockito.invocation.InvocationOnMock import org.mockito.stubbing.Answer import org.scalatest.concurrent.Eventually @@ -37,9 +38,12 @@ import org.apache.spark._ import org.apache.spark.TaskState.TaskState import org.apache.spark.memory.MemoryManager import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rdd.RDD import org.apache.spark.rpc.RpcEnv -import org.apache.spark.scheduler.{FakeTask, TaskDescription} +import org.apache.spark.scheduler.{FakeTask, ResultTask, TaskDescription} import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.storage.BlockManagerId class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSugar with Eventually { @@ -123,6 +127,75 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } } + test("SPARK-19276: Handle FetchFailedExceptions that are hidden by user exceptions") { + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + sc = new SparkContext(conf) + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + + // Submit a job where a fetch failure is thrown, but user code has a try/catch which hides + // the fetch failure. The executor should still tell the driver that the task failed due to a + // fetch failure, not a generic exception from user code. + val inputRDD = new FetchFailureThrowingRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = false) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = createFakeTaskDescription(serTask) + + val failReason = runTaskAndGetFailReason(taskDescription) + assert(failReason.isInstanceOf[FetchFailed]) + } + + test("SPARK-19276: OOMs correctly handled with a FetchFailure") { + // when there is a fatal error like an OOM, we don't do normal fetch failure handling, since it + // may be a false positive. And we should call the uncaught exception handler. + val conf = new SparkConf().setMaster("local").setAppName("executor suite test") + sc = new SparkContext(conf) + val serializer = SparkEnv.get.closureSerializer.newInstance() + val resultFunc = (context: TaskContext, itr: Iterator[Int]) => itr.size + + // Submit a job where a fetch failure is thrown, but then there is an OOM. We should treat + // the fetch failure as a false positive, and just do normal OOM handling. + val inputRDD = new FetchFailureThrowingRDD(sc) + val secondRDD = new FetchFailureHidingRDD(sc, inputRDD, throwOOM = true) + val taskBinary = sc.broadcast(serializer.serialize((secondRDD, resultFunc)).array()) + val serializedTaskMetrics = serializer.serialize(TaskMetrics.registered).array() + val task = new ResultTask( + stageId = 1, + stageAttemptId = 0, + taskBinary = taskBinary, + partition = secondRDD.partitions(0), + locs = Seq(), + outputId = 0, + localProperties = new Properties(), + serializedTaskMetrics = serializedTaskMetrics + ) + + val serTask = serializer.serialize(task) + val taskDescription = createFakeTaskDescription(serTask) + + val (failReason, uncaughtExceptionHandler) = + runTaskGetFailReasonAndExceptionHandler(taskDescription) + // make sure the task failure just looks like a OOM, not a fetch failure + assert(failReason.isInstanceOf[ExceptionFailure]) + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]) + verify(uncaughtExceptionHandler).uncaughtException(any(), exceptionCaptor.capture()) + assert(exceptionCaptor.getAllValues.size === 1) + assert(exceptionCaptor.getAllValues.get(0).isInstanceOf[OutOfMemoryError]) + } + test("Gracefully handle error in task deserialization") { val conf = new SparkConf val serializer = new JavaSerializer(conf) @@ -169,13 +242,20 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug } private def runTaskAndGetFailReason(taskDescription: TaskDescription): TaskFailedReason = { + runTaskGetFailReasonAndExceptionHandler(taskDescription)._1 + } + + private def runTaskGetFailReasonAndExceptionHandler( + taskDescription: TaskDescription): (TaskFailedReason, UncaughtExceptionHandler) = { val mockBackend = mock[ExecutorBackend] + val mockUncaughtExceptionHandler = mock[UncaughtExceptionHandler] var executor: Executor = null try { - executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true) + executor = new Executor("id", "localhost", SparkEnv.get, userClassPath = Nil, isLocal = true, + uncaughtExceptionHandler = mockUncaughtExceptionHandler) // the task will be launched in a dedicated worker thread executor.launchTask(mockBackend, taskDescription) - eventually(timeout(5 seconds), interval(10 milliseconds)) { + eventually(timeout(5.seconds), interval(10.milliseconds)) { assert(executor.numRunningTasks === 0) } } finally { @@ -193,7 +273,56 @@ class ExecutorSuite extends SparkFunSuite with LocalSparkContext with MockitoSug assert(statusCaptor.getAllValues().get(0).remaining() === 0) // second update is more interesting val failureData = statusCaptor.getAllValues.get(1) - SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + val failReason = + SparkEnv.get.closureSerializer.newInstance().deserialize[TaskFailedReason](failureData) + (failReason, mockUncaughtExceptionHandler) + } +} + +class FetchFailureThrowingRDD(sc: SparkContext) extends RDD[Int](sc, Nil) { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + new Iterator[Int] { + override def hasNext: Boolean = true + override def next(): Int = { + throw new FetchFailedException( + bmAddress = BlockManagerId("1", "hostA", 1234), + shuffleId = 0, + mapId = 0, + reduceId = 0, + message = "fake fetch failure" + ) + } + } + } + override protected def getPartitions: Array[Partition] = { + Array(new SimplePartition) + } +} + +class SimplePartition extends Partition { + override def index: Int = 0 +} + +class FetchFailureHidingRDD( + sc: SparkContext, + val input: FetchFailureThrowingRDD, + throwOOM: Boolean) extends RDD[Int](input) { + override def compute(split: Partition, context: TaskContext): Iterator[Int] = { + val inItr = input.compute(split, context) + try { + Iterator(inItr.size) + } catch { + case t: Throwable => + if (throwOOM) { + throw new OutOfMemoryError("OOM while handling another exception") + } else { + throw new RuntimeException("User Exception that hides the original exception", t) + } + } + } + + override protected def getPartitions: Array[Partition] = { + Array(new SimplePartition) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala index 0c362b881d912..83ed12752074d 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala @@ -195,6 +195,17 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter { sc.runJob(rdd, OutputCommitFunctions(tempDir.getAbsolutePath).callCanCommitMultipleTimes _, 0 until rdd.partitions.size) } + + test("SPARK-19631: Do not allow failed attempts to be authorized for committing") { + val stage: Int = 1 + val partition: Int = 1 + val failedAttempt: Int = 0 + outputCommitCoordinator.stageStart(stage, maxPartitionId = 1) + outputCommitCoordinator.taskCompleted(stage, partition, attemptNumber = failedAttempt, + reason = ExecutorLostFailure("0", exitCausedByApp = true, None)) + assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt)) + assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1)) + } } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 9f1fe0515732e..97487ce1d2ca8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{ByteArrayOutputStream, DataOutputStream, UTFDataFormatException} import java.nio.ByteBuffer import java.util.Properties @@ -36,6 +37,21 @@ class TaskDescriptionSuite extends SparkFunSuite { val originalProperties = new Properties() originalProperties.put("property1", "18") originalProperties.put("property2", "test value") + // SPARK-19796 -- large property values (like a large job description for a long sql query) + // can cause problems for DataOutputStream, make sure we handle correctly + val sb = new StringBuilder() + (0 to 10000).foreach(_ => sb.append("1234567890")) + val largeString = sb.toString() + originalProperties.put("property3", largeString) + // make sure we've got a good test case + intercept[UTFDataFormatException] { + val out = new DataOutputStream(new ByteArrayOutputStream()) + try { + out.writeUTF(largeString) + } finally { + out.close() + } + } // Create a dummy byte buffer for the task. val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index d03a0c990a02b..2c2cda9f318eb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.scheduler -import java.util.Random +import java.util.{Properties, Random} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer @@ -28,6 +28,7 @@ import org.mockito.Mockito.{mock, never, spy, verify, when} import org.apache.spark._ import org.apache.spark.internal.config import org.apache.spark.internal.Logging +import org.apache.spark.serializer.SerializerInstance import org.apache.spark.storage.BlockManagerId import org.apache.spark.util.{AccumulatorV2, ManualClock} @@ -664,6 +665,67 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg assert(thrown2.getMessage().contains("bigger than spark.driver.maxResultSize")) } + test("[SPARK-13931] taskSetManager should not send Resubmitted tasks after being a zombie") { + val conf = new SparkConf().set("spark.speculation", "true") + sc = new SparkContext("local", "test", conf) + + val sched = new FakeTaskScheduler(sc, ("execA", "host1"), ("execB", "host2")) + sched.initialize(new FakeSchedulerBackend() { + override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = {} + }) + + // Keep track of the number of tasks that are resubmitted, + // so that the test can check that no tasks were resubmitted. + var resubmittedTasks = 0 + val dagScheduler = new FakeDAGScheduler(sc, sched) { + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + reason match { + case Resubmitted => resubmittedTasks += 1 + case _ => + } + } + } + sched.setDAGScheduler(dagScheduler) + + val singleTask = new ShuffleMapTask(0, 0, null, new Partition { + override def index: Int = 0 + }, Seq(TaskLocation("host1", "execA")), new Properties, null) + val taskSet = new TaskSet(Array(singleTask), 0, 0, 0, null) + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES) + + // Offer host1, which should be accepted as a PROCESS_LOCAL location + // by the one task in the task set + val task1 = manager.resourceOffer("execA", "host1", TaskLocality.PROCESS_LOCAL).get + + // Mark the task as available for speculation, and then offer another resource, + // which should be used to launch a speculative copy of the task. + manager.speculatableTasks += singleTask.partitionId + val task2 = manager.resourceOffer("execB", "host2", TaskLocality.ANY).get + + assert(manager.runningTasks === 2) + assert(manager.isZombie === false) + + val directTaskResult = new DirectTaskResult[String](null, Seq()) { + override def value(resultSer: SerializerInstance): String = "" + } + // Complete one copy of the task, which should result in the task set manager + // being marked as a zombie, because at least one copy of its only task has completed. + manager.handleSuccessfulTask(task1.taskId, directTaskResult) + assert(manager.isZombie === true) + assert(resubmittedTasks === 0) + assert(manager.runningTasks === 1) + + manager.executorLost("execB", "host2", new SlaveLost()) + assert(manager.runningTasks === 0) + assert(resubmittedTasks === 0) + } + test("speculative and noPref task should be scheduled after node-local") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler( diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index ccede34b8cb4d..75dc04038debc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -489,12 +489,12 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav Thread.sleep(200) } - // giving enough time for replication complete and locks released - Thread.sleep(500) - - val newLocations = master.getLocations(blockId).toSet + val newLocations = eventually(timeout(5 seconds), interval(10 millis)) { + val _newLocations = master.getLocations(blockId).toSet + assert(_newLocations.size === replicationFactor) + _newLocations + } logInfo(s"New locations : $newLocations") - assert(newLocations.size === replicationFactor) // there should only be one common block manager between initial and new locations assert(newLocations.intersect(blockLocations.toSet).size === 1) diff --git a/docs/ml-collaborative-filtering.md b/docs/ml-collaborative-filtering.md index cfe835172ab45..58f2d4b531e70 100644 --- a/docs/ml-collaborative-filtering.md +++ b/docs/ml-collaborative-filtering.md @@ -59,6 +59,34 @@ This approach is named "ALS-WR" and discussed in the paper It makes `regParam` less dependent on the scale of the dataset, so we can apply the best parameter learned from a sampled subset to the full dataset and expect similar performance. +### Cold-start strategy + +When making predictions using an `ALSModel`, it is common to encounter users and/or items in the +test dataset that were not present during training the model. This typically occurs in two +scenarios: + +1. In production, for new users or items that have no rating history and on which the model has not +been trained (this is the "cold start problem"). +2. During cross-validation, the data is split between training and evaluation sets. When using +simple random splits as in Spark's `CrossValidator` or `TrainValidationSplit`, it is actually +very common to encounter users and/or items in the evaluation set that are not in the training set + +By default, Spark assigns `NaN` predictions during `ALSModel.transform` when a user and/or item +factor is not present in the model. This can be useful in a production system, since it indicates +a new user or item, and so the system can make a decision on some fallback to use as the prediction. + +However, this is undesirable during cross-validation, since any `NaN` predicted values will result +in `NaN` results for the evaluation metric (for example when using `RegressionEvaluator`). +This makes model selection impossible. + +Spark allows users to set the `coldStartStrategy` parameter +to "drop" in order to drop any rows in the `DataFrame` of predictions that contain `NaN` values. +The evaluation metric will then be computed over the non-`NaN` data and will be valid. +Usage of this parameter is illustrated in the example below. + +**Note:** currently the supported cold start strategies are "nan" (the default behavior mentioned +above) and "drop". Further strategies may be supported in future. + **Examples**
diff --git a/docs/ml-features.md b/docs/ml-features.md index 57605bafbf4c3..dad1c6db18f8b 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -503,6 +503,7 @@ for more details on the API. `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. +The unseen labels will be put at index numLabels if user chooses to keep them. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or `Transformer` make use of this string-indexed label, you must set the input @@ -542,12 +543,13 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. -Additionally, there are two strategies regarding how `StringIndexer` will handle +Additionally, there are three strategies regarding how `StringIndexer` will handle unseen labels when you have fit a `StringIndexer` on one dataset and then use it to transform another: - throw an exception (which is the default) - skip the row containing the unseen label entirely +- put unseen labels in a special additional bucket, at index numLabels **Examples** @@ -561,6 +563,7 @@ Let's go back to our previous example but this time reuse our previously defined 1 | b 2 | c 3 | d + 4 | e ~~~~ If you've not set how `StringIndexer` handles unseen labels or set it to @@ -576,7 +579,22 @@ will be generated: 2 | c | 1.0 ~~~~ -Notice that the row containing "d" does not appear. +Notice that the rows containing "d" or "e" do not appear. + +If you call `setHandleInvalid("keep")`, the following dataset +will be generated: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 + 3 | d | 3.0 + 4 | e | 3.0 +~~~~ + +Notice that the rows containing "d" or "e" are mapped to index "3.0"
diff --git a/docs/ml-pipeline.md b/docs/ml-pipeline.md index 7cbb14654e9d1..aa92c0a37c0f4 100644 --- a/docs/ml-pipeline.md +++ b/docs/ml-pipeline.md @@ -132,7 +132,7 @@ The `Pipeline.fit()` method is called on the original `DataFrame`, which has raw The `Tokenizer.transform()` method splits the raw text documents into words, adding a new column with words to the `DataFrame`. The `HashingTF.transform()` method converts the words column into feature vectors, adding a new column with those vectors to the `DataFrame`. Now, since `LogisticRegression` is an `Estimator`, the `Pipeline` first calls `LogisticRegression.fit()` to produce a `LogisticRegressionModel`. -If the `Pipeline` had more stages, it would call the `LogisticRegressionModel`'s `transform()` +If the `Pipeline` had more `Estimator`s, it would call the `LogisticRegressionModel`'s `transform()` method on the `DataFrame` before passing the `DataFrame` to the next stage. A `Pipeline` is an `Estimator`. diff --git a/docs/quick-start.md b/docs/quick-start.md index aa4319a23325c..b88ae5f6bb313 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -10,12 +10,13 @@ description: Quick start tutorial for Spark SPARK_VERSION_SHORT This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. -See the [programming guide](programming-guide.html) for a more complete reference. To follow along with this guide, first download a packaged release of Spark from the [Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. +Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. + # Interactive Analysis with the Spark Shell ## Basics @@ -29,28 +30,28 @@ or Python. Start it by running the following in the Spark directory: ./bin/spark-shell -Spark's primary abstraction is a distributed collection of items called a Resilient Distributed Dataset (RDD). RDDs can be created from Hadoop InputFormats (such as HDFS files) or by transforming other RDDs. Let's make a new RDD from the text of the README file in the Spark source directory: +Spark's primary abstraction is a distributed collection of items called a Dataset. Datasets can be created from Hadoop InputFormats (such as HDFS files) or by transforming other Datasets. Let's make a new Dataset from the text of the README file in the Spark source directory: {% highlight scala %} -scala> val textFile = sc.textFile("README.md") -textFile: org.apache.spark.rdd.RDD[String] = README.md MapPartitionsRDD[1] at textFile at :25 +scala> val textFile = spark.read.textFile("README.md") +textFile: org.apache.spark.sql.Dataset[String] = [value: string] {% endhighlight %} -RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions: +You can get values from Dataset directly, by calling some actions, or transform the Dataset to get a new one. For more details, please read the _[API doc](api/scala/index.html#org.apache.spark.sql.Dataset)_. {% highlight scala %} -scala> textFile.count() // Number of items in this RDD +scala> textFile.count() // Number of items in this Dataset res0: Long = 126 // May be different from yours as README.md will change over time, similar to other outputs -scala> textFile.first() // First item in this RDD +scala> textFile.first() // First item in this Dataset res1: String = # Apache Spark {% endhighlight %} -Now let's use a transformation. We will use the [`filter`](programming-guide.html#transformations) transformation to return a new RDD with a subset of the items in the file. +Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file. {% highlight scala %} scala> val linesWithSpark = textFile.filter(line => line.contains("Spark")) -linesWithSpark: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[2] at filter at :27 +linesWithSpark: org.apache.spark.sql.Dataset[String] = [value: string] {% endhighlight %} We can chain together transformations and actions: @@ -65,32 +66,32 @@ res3: Long = 15 ./bin/pyspark -Spark's primary abstraction is a distributed collection of items called a Resilient Distributed Dataset (RDD). RDDs can be created from Hadoop InputFormats (such as HDFS files) or by transforming other RDDs. Let's make a new RDD from the text of the README file in the Spark source directory: +Spark's primary abstraction is a distributed collection of items called a Dataset. Datasets can be created from Hadoop InputFormats (such as HDFS files) or by transforming other Datasets. Due to Python's dynamic nature, we don't need the Dataset to be strongly-typed in Python. As a result, all Datasets in Python are Dataset[Row], and we call it `DataFrame` to be consistent with the data frame concept in Pandas and R. Let's make a new DataFrame from the text of the README file in the Spark source directory: {% highlight python %} ->>> textFile = sc.textFile("README.md") +>>> textFile = spark.read.text("README.md") {% endhighlight %} -RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions: +You can get values from DataFrame directly, by calling some actions, or transform the DataFrame to get a new one. For more details, please read the _[API doc](api/python/index.html#pyspark.sql.DataFrame)_. {% highlight python %} ->>> textFile.count() # Number of items in this RDD +>>> textFile.count() # Number of rows in this DataFrame 126 ->>> textFile.first() # First item in this RDD -u'# Apache Spark' +>>> textFile.first() # First row in this DataFrame +Row(value=u'# Apache Spark') {% endhighlight %} -Now let's use a transformation. We will use the [`filter`](programming-guide.html#transformations) transformation to return a new RDD with a subset of the items in the file. +Now let's transform this DataFrame to a new one. We call `filter` to return a new DataFrame with a subset of the lines in the file. {% highlight python %} ->>> linesWithSpark = textFile.filter(lambda line: "Spark" in line) +>>> linesWithSpark = textFile.filter(textFile.value.contains("Spark")) {% endhighlight %} We can chain together transformations and actions: {% highlight python %} ->>> textFile.filter(lambda line: "Spark" in line).count() # How many lines contain "Spark"? +>>> textFile.filter(textFile.value.contains("Spark")).count() # How many lines contain "Spark"? 15 {% endhighlight %} @@ -98,8 +99,8 @@ We can chain together transformations and actions:
-## More on RDD Operations -RDD actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words: +## More on Dataset Operations +Dataset actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words:
@@ -109,7 +110,7 @@ scala> textFile.map(line => line.split(" ").size).reduce((a, b) => if (a > b) a res4: Long = 15 {% endhighlight %} -This first maps a line to an integer value, creating a new RDD. `reduce` is called on that RDD to find the largest line count. The arguments to `map` and `reduce` are Scala function literals (closures), and can use any language feature or Scala/Java library. For example, we can easily call functions declared elsewhere. We'll use `Math.max()` function to make this code easier to understand: +This first maps a line to an integer value, creating a new Dataset. `reduce` is called on that Dataset to find the largest word count. The arguments to `map` and `reduce` are Scala function literals (closures), and can use any language feature or Scala/Java library. For example, we can easily call functions declared elsewhere. We'll use `Math.max()` function to make this code easier to understand: {% highlight scala %} scala> import java.lang.Math @@ -122,11 +123,11 @@ res5: Int = 15 One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can implement MapReduce flows easily: {% highlight scala %} -scala> val wordCounts = textFile.flatMap(line => line.split(" ")).map(word => (word, 1)).reduceByKey((a, b) => a + b) -wordCounts: org.apache.spark.rdd.RDD[(String, Int)] = ShuffledRDD[8] at reduceByKey at :28 +scala> val wordCounts = textFile.flatMap(line => line.split(" ")).groupByKey(identity).count() +wordCounts: org.apache.spark.sql.Dataset[(String, Long)] = [value: string, count(1): bigint] {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we call `flatMap` to transform a Dataset of lines to a Dataset of words, and then combine `groupByKey` and `count` to compute the per-word counts in the file as a Dataset of (String, Long) pairs. To collect the word counts in our shell, we can call `collect`: {% highlight scala %} scala> wordCounts.collect() @@ -137,37 +138,24 @@ res6: Array[(String, Int)] = Array((means,1), (under,2), (this,3), (Because,1),
{% highlight python %} ->>> textFile.map(lambda line: len(line.split())).reduce(lambda a, b: a if (a > b) else b) -15 +>>> from pyspark.sql.functions import * +>>> textFile.select(size(split(textFile.value, "\s+")).name("numWords")).agg(max(col("numWords"))).collect() +[Row(max(numWords)=15)] {% endhighlight %} -This first maps a line to an integer value, creating a new RDD. `reduce` is called on that RDD to find the largest line count. The arguments to `map` and `reduce` are Python [anonymous functions (lambdas)](https://docs.python.org/2/reference/expressions.html#lambda), -but we can also pass any top-level Python function we want. -For example, we'll define a `max` function to make this code easier to understand: - -{% highlight python %} ->>> def max(a, b): -... if a > b: -... return a -... else: -... return b -... - ->>> textFile.map(lambda line: len(line.split())).reduce(max) -15 -{% endhighlight %} +This first maps a line to an integer value and aliases it as "numWords", creating a new DataFrame. `agg` is called on that DataFrame to find the largest word count. The arguments to `select` and `agg` are both _[Column](api/python/index.html#pyspark.sql.Column)_, we can use `df.colName` to get a column from a DataFrame. We can also import pyspark.sql.functions, which provides a lot of convenient functions to build a new Column from an old one. One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can implement MapReduce flows easily: {% highlight python %} ->>> wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(lambda a, b: a+b) +>>> wordCounts = textFile.select(explode(split(textFile.value, "\s+")).as("word")).groupBy("word").count() {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we use the `explode` function in `select`, to transfrom a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: {% highlight python %} >>> wordCounts.collect() -[(u'and', 9), (u'A', 1), (u'webpage', 1), (u'README', 1), (u'Note', 1), (u'"local"', 1), (u'variable', 1), ...] +[Row(word=u'online', count=1), Row(word=u'graphs', count=1), ...] {% endhighlight %}
@@ -181,7 +169,7 @@ Spark also supports pulling data sets into a cluster-wide in-memory cache. This {% highlight scala %} scala> linesWithSpark.cache() -res7: linesWithSpark.type = MapPartitionsRDD[2] at filter at :27 +res7: linesWithSpark.type = [value: string] scala> linesWithSpark.count() res8: Long = 15 @@ -193,7 +181,7 @@ res9: Long = 15 It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is that these same functions can be used on very large data sets, even when they are striped across tens or hundreds of nodes. You can also do this interactively by connecting `bin/spark-shell` to -a cluster, as described in the [programming guide](programming-guide.html#initializing-spark). +a cluster, as described in the [RDD programming guide](rdd-programming-guide.html#using-the-shell).
@@ -211,7 +199,7 @@ a cluster, as described in the [programming guide](programming-guide.html#initia It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is that these same functions can be used on very large data sets, even when they are striped across tens or hundreds of nodes. You can also do this interactively by connecting `bin/pyspark` to -a cluster, as described in the [programming guide](programming-guide.html#initializing-spark). +a cluster, as described in the [RDD programming guide](rdd-programming-guide.html#using-the-shell).
@@ -228,20 +216,17 @@ named `SimpleApp.scala`: {% highlight scala %} /* SimpleApp.scala */ -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ -import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession object SimpleApp { def main(args: Array[String]) { val logFile = "YOUR_SPARK_HOME/README.md" // Should be some file on your system - val conf = new SparkConf().setAppName("Simple Application") - val sc = new SparkContext(conf) - val logData = sc.textFile(logFile, 2).cache() + val spark = SparkSession.builder.appName("Simple Application").getOrCreate() + val logData = spark.read.textFile(logFile).cache() val numAs = logData.filter(line => line.contains("a")).count() val numBs = logData.filter(line => line.contains("b")).count() println(s"Lines with a: $numAs, Lines with b: $numBs") - sc.stop() + spark.stop() } } {% endhighlight %} @@ -251,16 +236,13 @@ Subclasses of `scala.App` may not work correctly. This program just counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is -installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, -we initialize a SparkContext as part of the program. +installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkSession, +we initialize a SparkSession as part of the program. -We pass the SparkContext constructor a -[SparkConf](api/scala/index.html#org.apache.spark.SparkConf) -object which contains information about our -application. +We call `SparkSession.builder` to construct a [[SparkSession]], then set the application name, and finally call `getOrCreate` to get the [[SparkSession]] instance. -Our application depends on the Spark API, so we'll also include an sbt configuration file, -`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that +Our application depends on the Spark API, so we'll also include an sbt configuration file, +`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that Spark depends on: {% highlight scala %} @@ -270,7 +252,7 @@ version := "1.0" scalaVersion := "{{site.SCALA_VERSION}}" -libraryDependencies += "org.apache.spark" %% "spark-core" % "{{site.SPARK_VERSION}}" +libraryDependencies += "org.apache.spark" %% "spark-sql" % "{{site.SPARK_VERSION}}" {% endhighlight %} For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `build.sbt` @@ -309,34 +291,28 @@ We'll create a very simple Spark application, `SimpleApp.java`: {% highlight java %} /* SimpleApp.java */ -import org.apache.spark.api.java.*; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.SparkSession; public class SimpleApp { public static void main(String[] args) { String logFile = "YOUR_SPARK_HOME/README.md"; // Should be some file on your system - SparkConf conf = new SparkConf().setAppName("Simple Application"); - JavaSparkContext sc = new JavaSparkContext(conf); - JavaRDD logData = sc.textFile(logFile).cache(); + SparkSession spark = SparkSession.builder().appName("Simple Application").getOrCreate(); + Dataset logData = spark.read.textFile(logFile).cache(); long numAs = logData.filter(s -> s.contains("a")).count(); long numBs = logData.filter(s -> s.contains("b")).count(); System.out.println("Lines with a: " + numAs + ", lines with b: " + numBs); - - sc.stop(); + + spark.stop(); } } {% endhighlight %} -This program just counts the number of lines containing 'a' and the number containing 'b' in a text -file. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed. -As with the Scala example, we initialize a SparkContext, though we use the special -`JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by -`JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes -that extend `spark.api.java.function.Function`. The -[Spark programming guide](programming-guide.html) describes these differences in more detail. +This program just counts the number of lines containing 'a' and the number containing 'b' in the +Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is +installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkSession, +we initialize a SparkSession as part of the program. To build the program, we also write a Maven `pom.xml` file that lists Spark as a dependency. Note that Spark artifacts are tagged with a Scala version. @@ -352,7 +328,7 @@ Note that Spark artifacts are tagged with a Scala version. org.apache.spark - spark-core_{{site.SCALA_BINARY_VERSION}} + spark-sql_{{site.SCALA_BINARY_VERSION}} {{site.SPARK_VERSION}} @@ -395,27 +371,25 @@ As an example, we'll create a simple Spark application, `SimpleApp.py`: {% highlight python %} """SimpleApp.py""" -from pyspark import SparkContext +from pyspark.sql import SparkSession logFile = "YOUR_SPARK_HOME/README.md" # Should be some file on your system -sc = SparkContext("local", "Simple App") -logData = sc.textFile(logFile).cache() +spark = SparkSession.builder().appName(appName).master(master).getOrCreate() +logData = spark.read.text(logFile).cache() -numAs = logData.filter(lambda s: 'a' in s).count() -numBs = logData.filter(lambda s: 'b' in s).count() +numAs = logData.filter(logData.value.contains('a')).count() +numBs = logData.filter(logData.value.contains('b')).count() print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) -sc.stop() +spark.stop() {% endhighlight %} This program just counts the number of lines containing 'a' and the number containing 'b' in a text file. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed. -As with the Scala and Java examples, we use a SparkContext to create RDDs. -We can pass Python functions to Spark, which are automatically serialized along with any variables -that they reference. +As with the Scala and Java examples, we use a SparkSession to create Datasets. For applications that use custom classes or third-party libraries, we can also add code dependencies to `spark-submit` through its `--py-files` argument by packaging them into a .zip file (see `spark-submit --help` for details). @@ -438,8 +412,7 @@ Lines with a: 46, Lines with b: 23 # Where to Go from Here Congratulations on running your first Spark application! -* For an in-depth overview of the API, start with the [Spark programming guide](programming-guide.html), - or see "Programming Guides" menu for other components. +* For an in-depth overview of the API, start with the [RDD programming guide](rdd-programming-guide.html) and the [SQL programming guide](sql-programming-guide.html), or see "Programming Guides" menu for other components. * For running applications on a cluster, head to the [deployment overview](cluster-overview.html). * Finally, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), diff --git a/docs/programming-guide.md b/docs/rdd-programming-guide.md similarity index 99% rename from docs/programming-guide.md rename to docs/rdd-programming-guide.md index 6740dbe0014b4..cad9ff4e646e5 100644 --- a/docs/programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -24,7 +24,7 @@ along with if you launch Spark's interactive shell -- either `bin/spark-shell` f
-Spark {{site.SPARK_VERSION}} is built and distributed to work with Scala {{site.SCALA_BINARY_VERSION}} +Spark {{site.SPARK_VERSION}} is built and distributed to work with Scala {{site.SCALA_BINARY_VERSION}} by default. (Spark can be built to work with other versions of Scala, too.) To write applications in Scala, you will need to use a compatible Scala version (e.g. {{site.SCALA_BINARY_VERSION}}.X). @@ -76,10 +76,10 @@ In addition, if you wish to access an HDFS cluster, you need to add a dependency Finally, you need to import some Spark classes into your program. Add the following lines: -{% highlight scala %} -import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.SparkConf +{% highlight java %} +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.SparkConf; {% endhighlight %}
@@ -244,13 +244,13 @@ use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running $ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark {% endhighlight %} -To use the Jupyter notebook (previously known as the IPython notebook), +To use the Jupyter notebook (previously known as the IPython notebook), {% highlight bash %} $ PYSPARK_DRIVER_PYTHON=jupyter ./bin/pyspark {% endhighlight %} -You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`. +You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`. After the Jupyter Notebook server is launched, you can create a new "Python 2" notebook from the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of @@ -811,7 +811,7 @@ The variables within the closure sent to each executor are now copies and thus, In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. -To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. +To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. @@ -1230,8 +1230,8 @@ storage levels is: -**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, -so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, +**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, +so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, and `DISK_ONLY_2`.* Spark also automatically persists some intermediate data in shuffle operations (e.g. `reduceByKey`), even without users calling `persist`. This is done to avoid recomputing the entire input if a node fails during the shuffle. We still recommend users call `persist` on the resulting RDD if they plan to reuse it. @@ -1346,7 +1346,7 @@ As a user, you can create named or unnamed accumulators. As seen in the image be Accumulators in the Spark UI

-Tracking accumulators in the UI can be useful for understanding the progress of +Tracking accumulators in the UI can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python).
@@ -1355,7 +1355,7 @@ running stages (NOTE: this is not yet supported in Python). A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using -the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, using its `value` method. The code below shows an accumulator being used to add up the elements of an array: @@ -1409,7 +1409,7 @@ Note that, when programmers define their own type of AccumulatorV2, the resultin A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using -the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, using its `value` method. The code below shows an accumulator being used to add up the elements of an array: diff --git a/docs/security.md b/docs/security.md index a4796767832b9..9eda42888637f 100644 --- a/docs/security.md +++ b/docs/security.md @@ -12,7 +12,7 @@ Spark currently supports authentication via a shared secret. Authentication can ## Web UI The Spark UI can be secured by using [javax servlet filters](http://docs.oracle.com/javaee/6/api/javax/servlet/Filter.html) via the `spark.ui.filters` setting -and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via the `spark.ui.https.enabled` setting. +and by using [https/SSL](http://en.wikipedia.org/wiki/HTTPS) via [SSL settings](security.html#ssl-configuration). ### Authentication diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 2dd1ab6ef3de1..b077575155eb0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -386,8 +386,8 @@ For example: The [built-in DataFrames functions](api/scala/index.html#org.apache.spark.sql.functions$) provide common aggregations such as `count()`, `countDistinct()`, `avg()`, `max()`, `min()`, etc. -While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in -[Scala](api/scala/index.html#org.apache.spark.sql.expressions.scalalang.typed$) and +While those functions are designed for DataFrames, Spark SQL also has type-safe versions for some of them in +[Scala](api/scala/index.html#org.apache.spark.sql.expressions.scalalang.typed$) and [Java](api/java/org/apache/spark/sql/expressions/javalang/typed.html) to work with strongly typed Datasets. Moreover, users are not limited to the predefined aggregate functions and can create their own. @@ -397,7 +397,7 @@ Moreover, users are not limited to the predefined aggregate functions and can cr
-Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction) +Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction) abstract class to implement a custom untyped aggregate function. For example, a user-defined average can look like: @@ -888,8 +888,9 @@ or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set the `wholeFile` option to `true`. {% include_example json_dataset scala/org/apache/spark/examples/sql/SQLDataSourceExample.scala %}
@@ -901,8 +902,9 @@ or a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set the `wholeFile` option to `true`. {% include_example json_dataset java/org/apache/spark/examples/sql/JavaSQLDataSourceExample.java %}
@@ -913,8 +915,9 @@ This conversion can be done using `SparkSession.read.json` on a JSON file. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set the `wholeFile` parameter to `True`. {% include_example json_dataset python/sql/datasource.py %}
@@ -926,8 +929,9 @@ files is a JSON object. Note that the file that is offered as _a json file_ is not a typical JSON file. Each line must contain a separate, self-contained valid JSON object. For more information, please see -[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). As a -consequence, a regular multi-line JSON file will most often fail. +[JSON Lines text format, also called newline-delimited JSON](http://jsonlines.org/). + +For a regular multi-line JSON file, set a named parameter `wholeFile` to `TRUE`. {% include_example json_dataset r/RSparkSQLExample.R %} diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java index 33ba668b32fc2..81970b7c81f40 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaALSExample.java @@ -103,6 +103,8 @@ public static void main(String[] args) { ALSModel model = als.fit(training); // Evaluate the model by computing the RMSE on the test data + // Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics + model.setColdStartStrategy("drop"); Dataset predictions = model.transform(test); RegressionEvaluator evaluator = new RegressionEvaluator() diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java index 4594e3462b2a5..ff917b720c8b6 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBucketedRandomProjectionLSHExample.java @@ -42,7 +42,7 @@ /** * An example demonstrating BucketedRandomProjectionLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.JavaBucketedRandomProjectionLSHExample + * bin/run-example ml.JavaBucketedRandomProjectionLSHExample */ public class JavaBucketedRandomProjectionLSHExample { public static void main(String[] args) { diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java index 0aace46939257..e164598e3ef87 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaMinHashLSHExample.java @@ -42,7 +42,7 @@ /** * An example demonstrating MinHashLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.JavaMinHashLSHExample + * bin/run-example ml.JavaMinHashLSHExample */ public class JavaMinHashLSHExample { public static void main(String[] args) { diff --git a/examples/src/main/python/ml/als_example.py b/examples/src/main/python/ml/als_example.py index 1a979ff5b5be2..2e7214ed56f98 100644 --- a/examples/src/main/python/ml/als_example.py +++ b/examples/src/main/python/ml/als_example.py @@ -44,7 +44,9 @@ (training, test) = ratings.randomSplit([0.8, 0.2]) # Build the recommendation model using ALS on the training data - als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating") + # Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics + als = ALS(maxIter=5, regParam=0.01, userCol="userId", itemCol="movieId", ratingCol="rating", + coldStartStrategy="drop") model = als.fit(training) # Evaluate the model by computing the RMSE on the test data diff --git a/examples/src/main/r/ml/bisectingKmeans.R b/examples/src/main/r/ml/bisectingKmeans.R index 5fb5bfb0fa5a3..b3eaa6dd86d7d 100644 --- a/examples/src/main/r/ml/bisectingKmeans.R +++ b/examples/src/main/r/ml/bisectingKmeans.R @@ -25,20 +25,21 @@ library(SparkR) sparkR.session(appName = "SparkR-ML-bisectingKmeans-example") # $example on$ -irisDF <- createDataFrame(iris) +t <- as.data.frame(Titanic) +training <- createDataFrame(t) # Fit bisecting k-means model with four centers -model <- spark.bisectingKmeans(df, Sepal_Length ~ Sepal_Width, k = 4) +model <- spark.bisectingKmeans(training, Class ~ Survived, k = 4) # get fitted result from a bisecting k-means model fitted.model <- fitted(model, "centers") # Model summary -summary(fitted.model) +head(summary(fitted.model)) # fitted values on training data -fitted <- predict(model, df) -head(select(fitted, "Sepal_Length", "prediction")) +fitted <- predict(model, training) +head(select(fitted, "Class", "prediction")) # $example off$ sparkR.session.stop() diff --git a/examples/src/main/r/ml/glm.R b/examples/src/main/r/ml/glm.R index e41af97751d3f..ee13910382c58 100644 --- a/examples/src/main/r/ml/glm.R +++ b/examples/src/main/r/ml/glm.R @@ -25,11 +25,12 @@ library(SparkR) sparkR.session(appName = "SparkR-ML-glm-example") # $example on$ -irisDF <- suppressWarnings(createDataFrame(iris)) +training <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") # Fit a generalized linear model of family "gaussian" with spark.glm -gaussianDF <- irisDF -gaussianTestDF <- irisDF -gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") +df_list <- randomSplit(training, c(7,3), 2) +gaussianDF <- df_list[[1]] +gaussianTestDF <- df_list[[2]] +gaussianGLM <- spark.glm(gaussianDF, label ~ features, family = "gaussian") # Model summary summary(gaussianGLM) @@ -39,14 +40,15 @@ gaussianPredictions <- predict(gaussianGLM, gaussianTestDF) head(gaussianPredictions) # Fit a generalized linear model with glm (R-compliant) -gaussianGLM2 <- glm(Sepal_Length ~ Sepal_Width + Species, gaussianDF, family = "gaussian") +gaussianGLM2 <- glm(label ~ features, gaussianDF, family = "gaussian") summary(gaussianGLM2) # Fit a generalized linear model of family "binomial" with spark.glm -# Note: Filter out "setosa" from label column (two labels left) to match "binomial" family. -binomialDF <- filter(irisDF, irisDF$Species != "setosa") -binomialTestDF <- binomialDF -binomialGLM <- spark.glm(binomialDF, Species ~ Sepal_Length + Sepal_Width, family = "binomial") +training2 <- read.df("data/mllib/sample_binary_classification_data.txt", source = "libsvm") +df_list2 <- randomSplit(training2, c(7,3), 2) +binomialDF <- df_list2[[1]] +binomialTestDF <- df_list2[[2]] +binomialGLM <- spark.glm(binomialDF, label ~ features, family = "binomial") # Model summary summary(binomialGLM) diff --git a/examples/src/main/r/ml/kmeans.R b/examples/src/main/r/ml/kmeans.R index 288e2f9724e00..824df20644fa1 100644 --- a/examples/src/main/r/ml/kmeans.R +++ b/examples/src/main/r/ml/kmeans.R @@ -26,10 +26,12 @@ sparkR.session(appName = "SparkR-ML-kmeans-example") # $example on$ # Fit a k-means model with spark.kmeans -irisDF <- suppressWarnings(createDataFrame(iris)) -kmeansDF <- irisDF -kmeansTestDF <- irisDF -kmeansModel <- spark.kmeans(kmeansDF, ~ Sepal_Length + Sepal_Width + Petal_Length + Petal_Width, +t <- as.data.frame(Titanic) +training <- createDataFrame(t) +df_list <- randomSplit(training, c(7,3), 2) +kmeansDF <- df_list[[1]] +kmeansTestDF <- df_list[[2]] +kmeansModel <- spark.kmeans(kmeansDF, ~ Class + Sex + Age + Freq, k = 3) # Model summary diff --git a/examples/src/main/r/ml/ml.R b/examples/src/main/r/ml/ml.R index b96819418bad3..41b7867f64e36 100644 --- a/examples/src/main/r/ml/ml.R +++ b/examples/src/main/r/ml/ml.R @@ -26,11 +26,12 @@ sparkR.session(appName = "SparkR-ML-example") ############################ model read/write ############################################## # $example on:read_write$ -irisDF <- suppressWarnings(createDataFrame(iris)) +training <- read.df("data/mllib/sample_multiclass_classification_data.txt", source = "libsvm") # Fit a generalized linear model of family "gaussian" with spark.glm -gaussianDF <- irisDF -gaussianTestDF <- irisDF -gaussianGLM <- spark.glm(gaussianDF, Sepal_Length ~ Sepal_Width + Species, family = "gaussian") +df_list <- randomSplit(training, c(7,3), 2) +gaussianDF <- df_list[[1]] +gaussianTestDF <- df_list[[2]] +gaussianGLM <- spark.glm(gaussianDF, label ~ features, family = "gaussian") # Save and then load a fitted MLlib model modelPath <- tempfile(pattern = "ml", fileext = ".tmp") diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala index bb5d163608494..868f49b16f218 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ALSExample.scala @@ -65,6 +65,8 @@ object ALSExample { val model = als.fit(training) // Evaluate the model by computing the RMSE on the test data + // Note we set cold start strategy to 'drop' to ensure we don't get NaN evaluation metrics + model.setColdStartStrategy("drop") val predictions = model.transform(test) val evaluator = new RegressionEvaluator() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala index 654535c264a35..16da4fa887aaf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/BucketedRandomProjectionLSHExample.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * An example demonstrating BucketedRandomProjectionLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.BucketedRandomProjectionLSHExample + * bin/run-example ml.BucketedRandomProjectionLSHExample */ object BucketedRandomProjectionLSHExample { def main(args: Array[String]): Unit = { diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala index 6c1e22268ad2c..b94ab9b8bedc1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MinHashLSHExample.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.SparkSession /** * An example demonstrating MinHashLSH. * Run with: - * bin/run-example org.apache.spark.examples.ml.MinHashLSHExample + * bin/run-example ml.MinHashLSHExample */ object MinHashLSHExample { def main(args: Array[String]): Unit = { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala new file mode 100644 index 0000000000000..08914d82fffdd --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.streaming.Sink + +private[kafka010] class KafkaSink( + sqlContext: SQLContext, + executorKafkaParams: ju.Map[String, Object], + topic: Option[String]) extends Sink with Logging { + @volatile private var latestBatchId = -1L + + override def toString(): String = "KafkaSink" + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + if (batchId <= latestBatchId) { + logInfo(s"Skipping already committed batch $batchId") + } else { + KafkaWriter.write(sqlContext.sparkSession, + data.queryExecution, executorKafkaParams, topic) + latestBatchId = batchId + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 6a7456719875f..febe3c217122a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -23,12 +23,14 @@ import java.util.UUID import scala.collection.JavaConverters._ import org.apache.kafka.clients.consumer.ConsumerConfig -import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** @@ -36,8 +38,12 @@ import org.apache.spark.sql.types.StructType * IllegalArgumentException when the Kafka Dataset is created, so that it can catch * missing options even before the query is started. */ -private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSourceProvider - with RelationProvider with Logging { +private[kafka010] class KafkaSourceProvider extends DataSourceRegister + with StreamSourceProvider + with StreamSinkProvider + with RelationProvider + with CreatableRelationProvider + with Logging { import KafkaSourceProvider._ override def shortName(): String = "kafka" @@ -152,6 +158,72 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre endingRelationOffsets) } + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { + val defaultTopic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) + val specifiedKafkaParams = kafkaParamsForProducer(parameters) + new KafkaSink(sqlContext, + new ju.HashMap[String, Object](specifiedKafkaParams.asJava), defaultTopic) + } + + override def createRelation( + outerSQLContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + mode match { + case SaveMode.Overwrite | SaveMode.Ignore => + throw new AnalysisException(s"Save mode $mode not allowed for Kafka. " + + s"Allowed save modes are ${SaveMode.Append} and " + + s"${SaveMode.ErrorIfExists} (default).") + case _ => // good + } + val topic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) + val specifiedKafkaParams = kafkaParamsForProducer(parameters) + KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution, + new ju.HashMap[String, Object](specifiedKafkaParams.asJava), topic) + + /* This method is suppose to return a relation that reads the data that was written. + * We cannot support this for Kafka. Therefore, in order to make things consistent, + * we return an empty base relation. + */ + new BaseRelation { + override def sqlContext: SQLContext = unsupportedException + override def schema: StructType = unsupportedException + override def needConversion: Boolean = unsupportedException + override def sizeInBytes: Long = unsupportedException + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = unsupportedException + private def unsupportedException = + throw new UnsupportedOperationException("BaseRelation from Kafka write " + + "operation is not usable.") + } + } + + private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase.startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } + private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) = ConfigUpdater("source", specifiedKafkaParams) .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) @@ -381,6 +453,7 @@ private[kafka010] object KafkaSourceProvider { private val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + val TOPIC_OPTION_KEY = "topic" private val deserClassName = classOf[ByteArrayDeserializer].getName } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala new file mode 100644 index 0000000000000..6e160cbe2db52 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.producer.{KafkaProducer, _} +import org.apache.kafka.common.serialization.ByteArraySerializer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.types.{BinaryType, StringType} + +/** + * A simple trait for writing out data in a single Spark task, without any concerns about how + * to commit or abort tasks. Exceptions thrown by the implementation of this class will + * automatically trigger task aborts. + */ +private[kafka010] class KafkaWriteTask( + producerConfiguration: ju.Map[String, Object], + inputSchema: Seq[Attribute], + topic: Option[String]) { + // used to synchronize with Kafka callbacks + @volatile private var failedWrite: Exception = null + private val projection = createProjection + private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ + + /** + * Writes key value data out to topics. + */ + def execute(iterator: Iterator[InternalRow]): Unit = { + producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration) + while (iterator.hasNext && failedWrite == null) { + val currentRow = iterator.next() + val projectedRow = projection(currentRow) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } + producer.send(record, callback) + } + } + + def close(): Unit = { + if (producer != null) { + checkForErrors + producer.close() + checkForErrors + producer = null + } + } + + private def createProjection: UnsafeProjection = { + val topicExpression = topic.map(Literal(_)).orElse { + inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) + }.getOrElse { + throw new IllegalStateException(s"topic option required when no " + + s"'${KafkaWriter.TOPIC_ATTRIBUTE_NAME}' attribute is present") + } + topicExpression.dataType match { + case StringType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + + s"must be a ${StringType}") + } + val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) + .getOrElse(Literal(null, BinaryType)) + keyExpression.dataType match { + case StringType | BinaryType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t") + } + val valueExpression = inputSchema + .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( + throw new IllegalStateException(s"Required attribute " + + s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found") + ) + valueExpression.dataType match { + case StringType | BinaryType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t") + } + UnsafeProjection.create( + Seq(topicExpression, Cast(keyExpression, BinaryType), + Cast(valueExpression, BinaryType)), inputSchema) + } + + private def checkForErrors: Unit = { + if (failedWrite != null) { + throw failedWrite + } + } +} + diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala new file mode 100644 index 0000000000000..a637d52c933a3 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} +import org.apache.spark.sql.types.{BinaryType, StringType} +import org.apache.spark.util.Utils + +/** + * The [[KafkaWriter]] class is used to write data from a batch query + * or structured streaming query, given by a [[QueryExecution]], to Kafka. + * The data is assumed to have a value column, and an optional topic and key + * columns. If the topic column is missing, then the topic must come from + * the 'topic' configuration option. If the key column is missing, then a + * null valued key field will be added to the + * [[org.apache.kafka.clients.producer.ProducerRecord]]. + */ +private[kafka010] object KafkaWriter extends Logging { + val TOPIC_ATTRIBUTE_NAME: String = "topic" + val KEY_ATTRIBUTE_NAME: String = "key" + val VALUE_ATTRIBUTE_NAME: String = "value" + + override def toString: String = "KafkaWriter" + + def validateQuery( + queryExecution: QueryExecution, + kafkaParameters: ju.Map[String, Object], + topic: Option[String] = None): Unit = { + val schema = queryExecution.logical.output + schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( + if (topic == None) { + throw new AnalysisException(s"topic option required when no " + + s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.") + } else { + Literal(topic.get, StringType) + } + ).dataType match { + case StringType => // good + case _ => + throw new AnalysisException(s"Topic type must be a String") + } + schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( + Literal(null, StringType) + ).dataType match { + case StringType | BinaryType => // good + case _ => + throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + + s"must be a String or BinaryType") + } + schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( + throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") + ).dataType match { + case StringType | BinaryType => // good + case _ => + throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + + s"must be a String or BinaryType") + } + } + + def write( + sparkSession: SparkSession, + queryExecution: QueryExecution, + kafkaParameters: ju.Map[String, Object], + topic: Option[String] = None): Unit = { + val schema = queryExecution.logical.output + validateQuery(queryExecution, kafkaParameters, topic) + SQLExecution.withNewExecutionId(sparkSession, queryExecution) { + queryExecution.toRdd.foreachPartition { iter => + val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) + Utils.tryWithSafeFinally(block = writeTask.execute(iter))( + finallyBlock = writeTask.close()) + } + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala new file mode 100644 index 0000000000000..490535623cb36 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{BinaryType, DataType} + +class KafkaSinkSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + protected var testUtils: KafkaTestUtils = _ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + super.afterAll() + } + } + + test("batch - write to kafka") { + val topic = newTopic() + testUtils.createTopic(topic) + val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value") + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", topic) + .save() + checkAnswer( + createKafkaReader(topic).selectExpr("CAST(value as STRING) value"), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil) + } + + test("batch - null topic field value, and no topic option") { + val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") + val ex = intercept[SparkException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .save() + } + assert(ex.getMessage.toLowerCase.contains( + "null topic present in the data")) + } + + test("batch - unsupported save modes") { + val topic = newTopic() + testUtils.createTopic(topic) + val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") + + // Test bad save mode Ignore + var ex = intercept[AnalysisException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .mode(SaveMode.Ignore) + .save() + } + assert(ex.getMessage.toLowerCase.contains( + s"save mode ignore not allowed for kafka")) + + // Test bad save mode Overwrite + ex = intercept[AnalysisException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .mode(SaveMode.Overwrite) + .save() + } + assert(ex.getMessage.toLowerCase.contains( + s"save mode overwrite not allowed for kafka")) + } + + test("streaming - write to kafka with topic field") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + input.addData("1", "2", "3", "4", "5") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + input.addData("6", "7", "8", "9", "10") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - write aggregation w/o topic field, with topic option") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF().groupBy("value").count(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))( + withSelectExpr = "CAST(value as STRING) key", "CAST(count as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + + try { + input.addData("1", "2", "2", "3", "3", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3)) + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (3, 4)) + } finally { + writer.stop() + } + } + + test("streaming - aggregation with topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF().groupBy("value").count(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))( + withSelectExpr = "'foo' as topic", + "CAST(value as STRING) key", "CAST(count as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + + try { + input.addData("1", "2", "2", "3", "3", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3)) + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (3, 4)) + } finally { + writer.stop() + } + } + + + test("streaming - write data with bad schema") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains("required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains("topic type must be a string")) + + try { + /* value field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains( + "value attribute type must be a string or binarytype")) + + try { + ex = intercept[StreamingQueryException] { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val input = MemoryStream[String] + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val input = MemoryStream[String] + var writer: StreamingQuery = null + var ex: Exception = null + ex = intercept[IllegalArgumentException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + } + assert(ex.getMessage.toLowerCase.contains( + "kafka option 'key.serializer' is not supported")) + + ex = intercept[IllegalArgumentException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + } + assert(ex.getMessage.toLowerCase.contains( + "kafka option 'value.serializer' is not supported")) + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, Object] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaWriteTask(options, inputSchema, Some(topic)) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + writeTask.execute(iter) + } finally { + writeTask.close() + } + } + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + withTempDir { checkpointDir => + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + } + stream.start() + } +} diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala index 23c4d99e50f51..0f1790bddcc3d 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDD.scala @@ -36,7 +36,11 @@ import org.apache.spark.util.NextIterator /** Class representing a range of Kinesis sequence numbers. Both sequence numbers are inclusive. */ private[kinesis] case class SequenceNumberRange( - streamName: String, shardId: String, fromSeqNumber: String, toSeqNumber: String) + streamName: String, + shardId: String, + fromSeqNumber: String, + toSeqNumber: String, + recordCount: Int) /** Class representing an array of Kinesis sequence number ranges */ private[kinesis] @@ -136,6 +140,8 @@ class KinesisSequenceRangeIterator( private val client = new AmazonKinesisClient(credentials) private val streamName = range.streamName private val shardId = range.shardId + // AWS limits to maximum of 10k records per get call + private val maxGetRecordsLimit = 10000 private var toSeqNumberReceived = false private var lastSeqNumber: String = null @@ -153,12 +159,14 @@ class KinesisSequenceRangeIterator( // If the internal iterator has not been initialized, // then fetch records from starting sequence number - internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber) + internalIterator = getRecords(ShardIteratorType.AT_SEQUENCE_NUMBER, range.fromSeqNumber, + range.recordCount) } else if (!internalIterator.hasNext) { // If the internal iterator does not have any more records, // then fetch more records after the last consumed sequence number - internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber) + internalIterator = getRecords(ShardIteratorType.AFTER_SEQUENCE_NUMBER, lastSeqNumber, + range.recordCount) } if (!internalIterator.hasNext) { @@ -191,9 +199,12 @@ class KinesisSequenceRangeIterator( /** * Get records starting from or after the given sequence number. */ - private def getRecords(iteratorType: ShardIteratorType, seqNum: String): Iterator[Record] = { + private def getRecords( + iteratorType: ShardIteratorType, + seqNum: String, + recordCount: Int): Iterator[Record] = { val shardIterator = getKinesisIterator(iteratorType, seqNum) - val result = getRecordsAndNextKinesisIterator(shardIterator) + val result = getRecordsAndNextKinesisIterator(shardIterator, recordCount) result._1 } @@ -202,10 +213,12 @@ class KinesisSequenceRangeIterator( * to get records from Kinesis), and get the next shard iterator for next consumption. */ private def getRecordsAndNextKinesisIterator( - shardIterator: String): (Iterator[Record], String) = { + shardIterator: String, + recordCount: Int): (Iterator[Record], String) = { val getRecordsRequest = new GetRecordsRequest getRecordsRequest.setRequestCredentials(credentials) getRecordsRequest.setShardIterator(shardIterator) + getRecordsRequest.setLimit(Math.min(recordCount, this.maxGetRecordsLimit)) val getRecordsResult = retryOrTimeout[GetRecordsResult]( s"getting records using shard iterator") { client.getRecords(getRecordsRequest) diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala index 13fc54e531dda..320728f4bb221 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisReceiver.scala @@ -210,7 +210,8 @@ private[kinesis] class KinesisReceiver[T]( if (records.size > 0) { val dataIterator = records.iterator().asScala.map(messageHandler) val metadata = SequenceNumberRange(streamName, shardId, - records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber()) + records.get(0).getSequenceNumber(), records.get(records.size() - 1).getSequenceNumber(), + records.size()) blockGenerator.addMultipleDataWithCallback(dataIterator, metadata) } } diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala index 18a5a1509a33a..2c7b9c58e6fa6 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisBackedBlockRDDSuite.scala @@ -51,7 +51,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) shardIdToSeqNumbers = shardIdToDataAndSeqNumbers.mapValues { _.map { _._2 }} shardIdToRange = shardIdToSeqNumbers.map { case (shardId, seqNumbers) => val seqNumRange = SequenceNumberRange( - testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last) + testUtils.streamName, shardId, seqNumbers.head, seqNumbers.last, seqNumbers.size) (shardId, seqNumRange) } allRanges = shardIdToRange.values.toSeq @@ -181,7 +181,7 @@ abstract class KinesisBackedBlockRDDTests(aggregateTestData: Boolean) // Create the necessary ranges to use in the RDD val fakeRanges = Array.fill(numPartitions - numPartitionsInKinesis)( - SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy"))) + SequenceNumberRanges(SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy", 1))) val realRanges = Array.tabulate(numPartitionsInKinesis) { i => val range = shardIdToRange(shardIds(i + (numPartitions - numPartitionsInKinesis))) SequenceNumberRanges(Array(range)) diff --git a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala index 387a96f26b305..afb55c84f81fe 100644 --- a/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala +++ b/external/kinesis-asl/src/test/scala/org/apache/spark/streaming/kinesis/KinesisStreamSuite.scala @@ -119,13 +119,13 @@ abstract class KinesisStreamTests(aggregateTestData: Boolean) extends KinesisFun // Generate block info data for testing val seqNumRanges1 = SequenceNumberRanges( - SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy")) + SequenceNumberRange("fakeStream", "fakeShardId", "xxx", "yyy", 67)) val blockId1 = StreamBlockId(kinesisStream.id, 123) val blockInfo1 = ReceivedBlockInfo( 0, None, Some(seqNumRanges1), new BlockManagerBasedStoreResult(blockId1, None)) val seqNumRanges2 = SequenceNumberRanges( - SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb")) + SequenceNumberRange("fakeStream", "fakeShardId", "aaa", "bbb", 89)) val blockId2 = StreamBlockId(kinesisStream.id, 345) val blockInfo2 = ReceivedBlockInfo( 0, None, Some(seqNumRanges2), new BlockManagerBasedStoreResult(blockId2, None)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala index bf6e76d7ac44e..f76b14eeeb542 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -440,19 +440,14 @@ private class LinearSVCAggregator( private val numFeatures: Int = bcFeaturesStd.value.length private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures - private val coefficients: Vector = bcCoefficients.value private var weightSum: Double = 0.0 private var lossSum: Double = 0.0 - require(numFeaturesPlusIntercept == coefficients.size, s"Dimension mismatch. Coefficients " + - s"length ${coefficients.size}, FeaturesStd length ${numFeatures}, fitIntercept: $fitIntercept") - - private val coefficientsArray = coefficients match { - case dv: DenseVector => dv.values - case _ => - throw new IllegalArgumentException( - s"coefficients only supports dense vector but got type ${coefficients.getClass}.") + @transient private lazy val coefficientsArray = bcCoefficients.value match { + case DenseVector(values) => values + case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector" + + s" but got type ${bcCoefficients.value.getClass}.") } - private val gradientSumArray = Array.fill[Double](coefficientsArray.length)(0) + private lazy val gradientSumArray = new Array[Double](numFeaturesPlusIntercept) /** * Add a new training instance to this LinearSVCAggregator, and update the loss and gradient @@ -463,6 +458,9 @@ private class LinearSVCAggregator( */ def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => + require(weight >= 0.0, s"instance weight, $weight has to be >= 0.0") + require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + + s" Expecting $numFeatures but got ${features.size}.") if (weight == 0.0) return this val localFeaturesStd = bcFeaturesStd.value val localCoefficients = coefficientsArray @@ -530,18 +528,15 @@ private class LinearSVCAggregator( this } - def loss: Double = { - if (weightSum != 0) { - lossSum / weightSum - } else 0.0 - } + def loss: Double = if (weightSum != 0) lossSum / weightSum else 0.0 def gradient: Vector = { if (weightSum != 0) { val result = Vectors.dense(gradientSumArray.clone()) scal(1.0 / weightSum, result) result - } else Vectors.dense(Array.fill[Double](coefficientsArray.length)(0)) + } else { + Vectors.dense(new Array[Double](numFeaturesPlusIntercept)) + } } - } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 738b35135f7ae..1a78187d4f8e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -1436,7 +1436,7 @@ private class LogisticAggregator( case _ => throw new IllegalArgumentException(s"coefficients only supports dense vector but " + s"got type ${bcCoefficients.value.getClass}.)") } - private val gradientSumArray = new Array[Double](coefficientSize) + private lazy val gradientSumArray = new Array[Double](coefficientSize) if (multinomial && numClasses <= 2) { logInfo(s"Multinomial logistic regression for binary classification yields separate " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index ea2dc6cfd8d31..a9c1a7ba0bc8a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -580,10 +580,10 @@ private class ExpectationAggregator( private val k: Int = bcWeights.value.length private var totalCnt: Long = 0L private var newLogLikelihood: Double = 0.0 - private val newWeights: Array[Double] = new Array[Double](k) - private val newMeans: Array[DenseVector] = Array.fill(k)( + private lazy val newWeights: Array[Double] = new Array[Double](k) + private lazy val newMeans: Array[DenseVector] = Array.fill(k)( new DenseVector(Array.fill[Double](numFeatures)(0.0))) - private val newCovs: Array[DenseVector] = Array.fill(k)( + private lazy val newCovs: Array[DenseVector] = Array.fill(k)( new DenseVector(Array.fill[Double](numFeatures * (numFeatures + 1) / 2)(0.0))) @transient private lazy val oldGaussians = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index a503411b63612..810b02febbe77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.feature +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -24,7 +26,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -34,8 +36,27 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol - with HasHandleInvalid { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { + + /** + * Param for how to handle unseen labels. Options are 'skip' (filter out rows with + * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional + * bucket, at index numLabels. + * Default: "error" + * @group param + */ + @Since("1.6.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + + "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + + "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " + + "at index numLabels).", + ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) + + setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) + + /** @group getParam */ + @Since("1.6.0") + def getHandleInvalid: String = $(handleInvalid) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -73,7 +94,6 @@ class StringIndexer @Since("1.4.0") ( /** @group setParam */ @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") /** @group setParam */ @Since("1.4.0") @@ -105,6 +125,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") object StringIndexer extends DefaultParamsReadable[StringIndexer] { + private[feature] val SKIP_UNSEEN_LABEL: String = "skip" + private[feature] val ERROR_UNSEEN_LABEL: String = "error" + private[feature] val KEEP_UNSEEN_LABEL: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -144,7 +169,6 @@ class StringIndexerModel ( /** @group setParam */ @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") /** @group setParam */ @Since("1.4.0") @@ -163,25 +187,34 @@ class StringIndexerModel ( } transformSchema(dataset.schema, logging = true) - val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else { - throw new SparkException(s"Unseen label: $label.") - } + val filteredLabels = getHandleInvalid match { + case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown" + case _ => labels } val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(labels).toMetadata() + .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. - val filteredDataset = getHandleInvalid match { - case "skip" => + val (filteredDataset, keepInvalid) = getHandleInvalid match { + case StringIndexer.SKIP_UNSEEN_LABEL => val filterer = udf { label: String => labelToIndex.contains(label) } - dataset.where(filterer(dataset($(inputCol)))) - case _ => dataset + (dataset.where(filterer(dataset($(inputCol)))), false) + case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL) } + + val indexer = udf { label: String => + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else if (keepInvalid) { + labels.length + } else { + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.") + } + } + filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index af007625d1827..60dd7367053e2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -40,7 +40,8 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel @@ -80,14 +81,24 @@ private[recommendation] trait ALSModelParams extends Params with HasPredictionCo /** * Attempts to safely cast a user/item id to an Int. Throws an exception if the value is - * out of integer range. + * out of integer range or contains a fractional part. */ - protected val checkedCast = udf { (n: Double) => - if (n > Int.MaxValue || n < Int.MinValue) { - throw new IllegalArgumentException(s"ALS only supports values in Integer range for columns " + - s"${$(userCol)} and ${$(itemCol)}. Value $n was out of Integer range.") - } else { - n.toInt + protected[recommendation] val checkedCast = udf { (n: Any) => + n match { + case v: Int => v // Avoid unnecessary casting + case v: Number => + val intV = v.intValue + // Checks if number within Int range and has no fractional part. + if (v.doubleValue == intV) { + intV + } else { + throw new IllegalArgumentException(s"ALS only supports values in Integer range " + + s"and without fractional part for columns ${$(userCol)} and ${$(itemCol)}. " + + s"Value $n was either out of Integer range or contained a fractional part that " + + s"could not be converted.") + } + case _ => throw new IllegalArgumentException(s"ALS only supports values in Integer range " + + s"for columns ${$(userCol)} and ${$(itemCol)}. Value $n was not numeric.") } } @@ -274,23 +285,25 @@ class ALSModel private[ml] ( @Since("2.2.0") def setColdStartStrategy(value: String): this.type = set(coldStartStrategy, value) + private val predict = udf { (featuresA: Seq[Float], featuresB: Seq[Float]) => + if (featuresA != null && featuresB != null) { + // TODO(SPARK-19759): try dot-producting on Seqs or another non-converted type for + // potential optimization. + blas.sdot(rank, featuresA.toArray, 1, featuresB.toArray, 1) + } else { + Float.NaN + } + } + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. - val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => - if (userFeatures != null && itemFeatures != null) { - blas.sdot(rank, userFeatures.toArray, 1, itemFeatures.toArray, 1) - } else { - Float.NaN - } - } val predictions = dataset .join(userFactors, - checkedCast(dataset($(userCol)).cast(DoubleType)) === userFactors("id"), "left") + checkedCast(dataset($(userCol))) === userFactors("id"), "left") .join(itemFactors, - checkedCast(dataset($(itemCol)).cast(DoubleType)) === itemFactors("id"), "left") + checkedCast(dataset($(itemCol))) === itemFactors("id"), "left") .select(dataset("*"), predict(userFactors("features"), itemFactors("features")).as($(predictionCol))) getColdStartStrategy match { @@ -317,6 +330,64 @@ class ALSModel private[ml] ( @Since("1.6.0") override def write: MLWriter = new ALSModel.ALSModelWriter(this) + + /** + * Returns top `numItems` items recommended for each user, for all users. + * @param numItems max number of recommendations for each user + * @return a DataFrame of (userCol: Int, recommendations), where recommendations are + * stored as an array of (itemCol: Int, rating: Float) Rows. + */ + @Since("2.2.0") + def recommendForAllUsers(numItems: Int): DataFrame = { + recommendForAll(userFactors, itemFactors, $(userCol), $(itemCol), numItems) + } + + /** + * Returns top `numUsers` users recommended for each item, for all items. + * @param numUsers max number of recommendations for each item + * @return a DataFrame of (itemCol: Int, recommendations), where recommendations are + * stored as an array of (userCol: Int, rating: Float) Rows. + */ + @Since("2.2.0") + def recommendForAllItems(numUsers: Int): DataFrame = { + recommendForAll(itemFactors, userFactors, $(itemCol), $(userCol), numUsers) + } + + /** + * Makes recommendations for all users (or items). + * @param srcFactors src factors for which to generate recommendations + * @param dstFactors dst factors used to make recommendations + * @param srcOutputColumn name of the column for the source ID in the output DataFrame + * @param dstOutputColumn name of the column for the destination ID in the output DataFrame + * @param num max number of recommendations for each record + * @return a DataFrame of (srcOutputColumn: Int, recommendations), where recommendations are + * stored as an array of (dstOutputColumn: Int, rating: Float) Rows. + */ + private def recommendForAll( + srcFactors: DataFrame, + dstFactors: DataFrame, + srcOutputColumn: String, + dstOutputColumn: String, + num: Int): DataFrame = { + import srcFactors.sparkSession.implicits._ + + val ratings = srcFactors.crossJoin(dstFactors) + .select( + srcFactors("id"), + dstFactors("id"), + predict(srcFactors("features"), dstFactors("features"))) + // We'll force the IDs to be Int. Unfortunately this converts IDs to Int in the output. + val topKAggregator = new TopByKeyAggregator[Int, Int, Float](num, Ordering.by(_._2)) + val recs = ratings.as[(Int, Int, Float)].groupByKey(_._1).agg(topKAggregator.toColumn) + .toDF("id", "recommendations") + + val arrayType = ArrayType( + new StructType() + .add(dstOutputColumn, IntegerType) + .add("rating", FloatType) + ) + recs.select($"id" as srcOutputColumn, $"recommendations" cast arrayType) + } } @Since("1.6.0") @@ -491,8 +562,7 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset - .select(checkedCast(col($(userCol)).cast(DoubleType)), - checkedCast(col($(itemCol)).cast(DoubleType)), r) + .select(checkedCast(col($(userCol))), checkedCast(col($(itemCol))), r) .rdd .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) @@ -711,7 +781,7 @@ object ALS extends DefaultParamsReadable[ALS] with Logging { numUserBlocks: Int = 10, numItemBlocks: Int = 10, maxIter: Int = 10, - regParam: Double = 1.0, + regParam: Double = 0.1, implicitPrefs: Boolean = false, alpha: Double = 1.0, nonnegative: Boolean = false, diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala new file mode 100644 index 0000000000000..517179c0eb9ae --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/TopByKeyAggregator.scala @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import scala.language.implicitConversions +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.{Encoder, Encoders} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.expressions.Aggregator +import org.apache.spark.util.BoundedPriorityQueue + + +/** + * Works on rows of the form (K1, K2, V) where K1 & K2 are IDs and V is the score value. Finds + * the top `num` K2 items based on the given Ordering. + */ +private[recommendation] class TopByKeyAggregator[K1: TypeTag, K2: TypeTag, V: TypeTag] + (num: Int, ord: Ordering[(K2, V)]) + extends Aggregator[(K1, K2, V), BoundedPriorityQueue[(K2, V)], Array[(K2, V)]] { + + override def zero: BoundedPriorityQueue[(K2, V)] = new BoundedPriorityQueue[(K2, V)](num)(ord) + + override def reduce( + q: BoundedPriorityQueue[(K2, V)], + a: (K1, K2, V)): BoundedPriorityQueue[(K2, V)] = { + q += {(a._2, a._3)} + } + + override def merge( + q1: BoundedPriorityQueue[(K2, V)], + q2: BoundedPriorityQueue[(K2, V)]): BoundedPriorityQueue[(K2, V)] = { + q1 ++= q2 + } + + override def finish(r: BoundedPriorityQueue[(K2, V)]): Array[(K2, V)] = { + r.toArray.sorted(ord.reverse) + } + + override def bufferEncoder: Encoder[BoundedPriorityQueue[(K2, V)]] = { + Encoders.kryo[BoundedPriorityQueue[(K2, V)]] + } + + override def outputEncoder: Encoder[Array[(K2, V)]] = ExpressionEncoder[Array[(K2, V)]]() +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 2f78dd30b3af7..094853b6f4802 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -106,7 +106,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params fitting: Boolean): StructType = { SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { - SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) + SchemaUtils.checkNumericType(schema, $(censorCol)) SchemaUtils.checkNumericType(schema, $(labelCol)) } if (hasQuantilesCol) { @@ -200,8 +200,8 @@ class AFTSurvivalRegression @Since("1.6.0") (@Since("1.6.0") override val uid: S * and put it in an RDD with strong types. */ protected[ml] def extractAFTPoints(dataset: Dataset[_]): RDD[AFTPoint] = { - dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), col($(censorCol))) - .rdd.map { + dataset.select(col($(featuresCol)), col($(labelCol)).cast(DoubleType), + col($(censorCol)).cast(DoubleType)).rdd.map { case Row(features: Vector, label: Double, censor: Double) => AFTPoint(features, label, censor) } @@ -526,7 +526,7 @@ private class AFTAggregator( private var totalCnt: Long = 0L private var lossSum = 0.0 // Here we optimize loss function over log(sigma), intercept and coefficients - private val gradientSumArray = Array.ofDim[Double](length) + private lazy val gradientSumArray = Array.ofDim[Double](length) def count: Long = totalCnt def loss: Double = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index a6c29433d7303..529f66eadbcff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -49,7 +49,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures */ final val isotonic: BooleanParam = new BooleanParam(this, "isotonic", - "whether the output sequence should be isotonic/increasing (true) or" + + "whether the output sequence should be isotonic/increasing (true) or " + "antitonic/decreasing (false)") /** @group getParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 2de7e81d8d41e..45df1d9be647d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -959,7 +959,7 @@ private class LeastSquaresAggregator( @transient private lazy val effectiveCoefficientsVector = effectiveCoefAndOffset._1 @transient private lazy val offset = effectiveCoefAndOffset._2 - private val gradientSumArray = Array.ofDim[Double](dim) + private lazy val gradientSumArray = Array.ofDim[Double](dim) /** * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index ee2aefee7a6db..fe47176a4aaa6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -23,13 +23,14 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ -import org.apache.spark.ml.feature.LabeledPoint -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.feature.{Instance, LabeledPoint} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.functions.udf class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -41,6 +42,9 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau @transient var smallValidationDataset: Dataset[_] = _ @transient var binaryDataset: Dataset[_] = _ + @transient var smallSparseBinaryDataset: Dataset[_] = _ + @transient var smallSparseValidationDataset: Dataset[_] = _ + override def beforeAll(): Unit = { super.beforeAll() @@ -51,6 +55,13 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF() smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF() binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF() + + // Dataset for testing SparseVector + val toSparse: Vector => SparseVector = _.asInstanceOf[DenseVector].toSparse + val sparse = udf(toSparse) + smallSparseBinaryDataset = smallBinaryDataset.withColumn("features", sparse('features)) + smallSparseValidationDataset = smallValidationDataset.withColumn("features", sparse('features)) + } /** @@ -68,6 +79,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val model = svm.fit(smallBinaryDataset) assert(model.transform(smallValidationDataset) .where("prediction=label").count() > nPoints * 0.8) + val sparseModel = svm.fit(smallSparseBinaryDataset) + checkModels(model, sparseModel) } test("Linear SVC binary classification with regularization") { @@ -75,6 +88,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val model = svm.setRegParam(0.1).fit(smallBinaryDataset) assert(model.transform(smallValidationDataset) .where("prediction=label").count() > nPoints * 0.8) + val sparseModel = svm.fit(smallSparseBinaryDataset) + checkModels(model, sparseModel) } test("params") { @@ -123,6 +138,21 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(model2.intercept !== 0.0) } + test("sparse coefficients in SVCAggregator") { + val bcCoefficients = spark.sparkContext.broadcast(Vectors.sparse(2, Array(0), Array(1.0))) + val bcFeaturesStd = spark.sparkContext.broadcast(Array(1.0)) + val agg = new LinearSVCAggregator(bcCoefficients, bcFeaturesStd, true) + val thrown = withClue("LinearSVCAggregator cannot handle sparse coefficients") { + intercept[IllegalArgumentException] { + agg.add(Instance(1.0, 1.0, Vectors.dense(1.0))) + } + } + assert(thrown.getMessage.contains("coefficients only supports dense")) + + bcCoefficients.destroy(blocking = false) + bcFeaturesStd.destroy(blocking = false) + } + test("linearSVC with sample weights") { def modelEquals(m1: LinearSVCModel, m2: LinearSVCModel): Unit = { assert(m1.coefficients ~== m2.coefficients absTol 0.05) @@ -220,7 +250,7 @@ object LinearSVCSuite { "aggregationDepth" -> 3 ) - // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) + // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) def generateSVMInput( intercept: Double, weights: Array[Double], @@ -237,5 +267,10 @@ object LinearSVCSuite { y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } + def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = { + assert(model1.intercept == model2.intercept) + assert(model1.coefficients.equals(model2.coefficients)) + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2d0e63c9d669c..188dffb3dd55f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -64,7 +64,7 @@ class StringIndexerSuite test("StringIndexerUnseen") { val data = Seq((0, "a"), (1, "b"), (4, "b")) - val data2 = Seq((0, "a"), (1, "b"), (2, "c")) + val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d")) val df = data.toDF("id", "label") val df2 = data2.toDF("id", "label") val indexer = new StringIndexer() @@ -75,22 +75,32 @@ class StringIndexerSuite intercept[SparkException] { indexer.transform(df2).collect() } - val indexerSkipInvalid = new StringIndexer() - .setInputCol("label") - .setOutputCol("labelIndex") - .setHandleInvalid("skip") - .fit(df) + + indexer.setHandleInvalid("skip") // Verify that we skip the c record - val transformed = indexerSkipInvalid.transform(df2) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + val transformedSkip = indexer.transform(df2) + val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("b", "a")) - val output = transformed.select("id", "labelIndex").rdd.map { r => + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0 - val expected = Set((0, 1.0), (1, 0.0)) - assert(output === expected) + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + // Verify that we keep the unseen records + val transformedKeep = indexer.transform(df2) + val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, c -> 2, d -> 3 + val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) + assert(outputKeep === expectedKeep) } test("StringIndexer with a numeric input column") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index c9e7b505b2bd2..e494ea89e63bd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -22,6 +22,7 @@ import java.util.Random import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.WrappedArray import scala.collection.JavaConverters._ import scala.language.existentials @@ -40,7 +41,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerStageCompleted} import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import org.apache.spark.sql.types.{FloatType, IntegerType} +import org.apache.spark.sql.functions.lit +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel import org.apache.spark.util.Utils @@ -205,6 +207,70 @@ class ALSSuite assert(decompressed.toSet === expected) } + test("CheckedCast") { + val checkedCast = new ALS().checkedCast + val df = spark.range(1) + + withClue("Valid Integer Ids") { + df.select(checkedCast(lit(123))).collect() + } + + withClue("Valid Long Ids") { + df.select(checkedCast(lit(1231L))).collect() + } + + withClue("Valid Decimal Ids") { + df.select(checkedCast(lit(123).cast(DecimalType(15, 2)))).collect() + } + + withClue("Valid Double Ids") { + df.select(checkedCast(lit(123.0))).collect() + } + + val msg = "either out of Integer range or contained a fractional part" + withClue("Invalid Long: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000L))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Decimal: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000.0).cast(DecimalType(15, 2)))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Decimal: fractional part") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(123.1).cast(DecimalType(15, 2)))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Double: out of range") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(1231000000000.0))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Double: fractional part") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit(123.1))).collect() + } + assert(e.getMessage.contains(msg)) + } + + withClue("Invalid Type") { + val e: SparkException = intercept[SparkException] { + df.select(checkedCast(lit("123.1"))).collect() + } + assert(e.getMessage.contains("was not numeric")) + } + } + /** * Generates an explicit feedback dataset for testing ALS. * @param numUsers number of users @@ -510,34 +576,35 @@ class ALSSuite (0, big, small, 0, big, small, 2.0), (1, 1L, 1d, 0, 0L, 0d, 5.0) ).toDF("user", "user_big", "user_small", "item", "item_big", "item_small", "rating") + val msg = "either out of Integer range or contained a fractional part" withClue("fit should fail when ids exceed integer range. ") { assert(intercept[SparkException] { als.fit(df.select(df("user_big").as("user"), df("item"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("user_small").as("user"), df("item"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("item_big").as("item"), df("user"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) assert(intercept[SparkException] { als.fit(df.select(df("item_small").as("item"), df("user"), df("rating"))) - }.getCause.getMessage.contains("was out of Integer range")) + }.getCause.getMessage.contains(msg)) } withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) assert(intercept[SparkException] { model.transform(df.select(df("user_big").as("user"), df("item"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("user_small").as("user"), df("item"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("item_big").as("item"), df("user"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) assert(intercept[SparkException] { model.transform(df.select(df("item_small").as("item"), df("user"))).first - }.getMessage.contains("was out of Integer range")) + }.getMessage.contains(msg)) } } @@ -594,6 +661,99 @@ class ALSSuite model.setColdStartStrategy(s).transform(data) } } + + private def getALSModel = { + val spark = this.spark + import spark.implicits._ + + val userFactors = Seq( + (0, Array(6.0f, 4.0f)), + (1, Array(3.0f, 4.0f)), + (2, Array(3.0f, 6.0f)) + ).toDF("id", "features") + val itemFactors = Seq( + (3, Array(5.0f, 6.0f)), + (4, Array(6.0f, 2.0f)), + (5, Array(3.0f, 6.0f)), + (6, Array(4.0f, 1.0f)) + ).toDF("id", "features") + val als = new ALS().setRank(2) + new ALSModel(als.uid, als.getRank, userFactors, itemFactors) + .setUserCol("user") + .setItemCol("item") + } + + test("recommendForAllUsers with k < num_items") { + val topItems = getALSModel.recommendForAllUsers(2) + assert(topItems.count() == 3) + assert(topItems.columns.contains("user")) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f), (5, 33f)), + 2 -> Array((3, 51f), (5, 45f)) + ) + checkRecommendations(topItems, expected, "item") + } + + test("recommendForAllUsers with k = num_items") { + val topItems = getALSModel.recommendForAllUsers(4) + assert(topItems.count() == 3) + assert(topItems.columns.contains("user")) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f), (5, 33f), (4, 26f), (6, 16f)), + 2 -> Array((3, 51f), (5, 45f), (4, 30f), (6, 18f)) + ) + checkRecommendations(topItems, expected, "item") + } + + test("recommendForAllItems with k < num_users") { + val topUsers = getALSModel.recommendForAllItems(2) + assert(topUsers.count() == 4) + assert(topUsers.columns.contains("item")) + + val expected = Map( + 3 -> Array((0, 54f), (2, 51f)), + 4 -> Array((0, 44f), (2, 30f)), + 5 -> Array((2, 45f), (0, 42f)), + 6 -> Array((0, 28f), (2, 18f)) + ) + checkRecommendations(topUsers, expected, "user") + } + + test("recommendForAllItems with k = num_users") { + val topUsers = getALSModel.recommendForAllItems(3) + assert(topUsers.count() == 4) + assert(topUsers.columns.contains("item")) + + val expected = Map( + 3 -> Array((0, 54f), (2, 51f), (1, 39f)), + 4 -> Array((0, 44f), (2, 30f), (1, 26f)), + 5 -> Array((2, 45f), (0, 42f), (1, 33f)), + 6 -> Array((0, 28f), (2, 18f), (1, 16f)) + ) + checkRecommendations(topUsers, expected, "user") + } + + private def checkRecommendations( + topK: DataFrame, + expected: Map[Int, Array[(Int, Float)]], + dstColName: String): Unit = { + val spark = this.spark + import spark.implicits._ + + assert(topK.columns.contains("recommendations")) + topK.as[(Int, Seq[(Int, Float)])].collect().foreach { case (id: Int, recs: Seq[(Int, Float)]) => + assert(recs === expected(id)) + } + topK.collect().foreach { row => + val recs = row.getAs[WrappedArray[Row]]("recommendations") + assert(recs(0).fieldIndex(dstColName) == 0) + assert(recs(0).fieldIndex("rating") == 1) + } + } } class ALSCleanerSuite extends SparkFunSuite { diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala new file mode 100644 index 0000000000000..5e763a8e908b8 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/TopByKeyAggregatorSuite.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.recommendation + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.Dataset + + +class TopByKeyAggregatorSuite extends SparkFunSuite with MLlibTestSparkContext { + + private def getTopK(k: Int): Dataset[(Int, Array[(Int, Float)])] = { + val sqlContext = spark.sqlContext + import sqlContext.implicits._ + + val topKAggregator = new TopByKeyAggregator[Int, Int, Float](k, Ordering.by(_._2)) + Seq( + (0, 3, 54f), + (0, 4, 44f), + (0, 5, 42f), + (0, 6, 28f), + (1, 3, 39f), + (2, 3, 51f), + (2, 5, 45f), + (2, 6, 18f) + ).toDS().groupByKey(_._1).agg(topKAggregator.toColumn) + } + + test("topByKey with k < #items") { + val topK = getTopK(2) + assert(topK.count() === 3) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f)), + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f)) + ) + checkTopK(topK, expected) + } + + test("topByKey with k > #items") { + val topK = getTopK(5) + assert(topK.count() === 3) + + val expected = Map( + 0 -> Array((3, 54f), (4, 44f), (5, 42f), (6, 28f)), + 1 -> Array((3, 39f)), + 2 -> Array((3, 51f), (5, 45f), (6, 18f)) + ) + checkTopK(topK, expected) + } + + private def checkTopK( + topK: Dataset[(Int, Array[(Int, Float)])], + expected: Map[Int, Array[(Int, Float)]]): Unit = { + topK.collect().foreach { case (id, recs) => assert(recs === expected(id)) } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 0fdfdf37cf38d..3cd4b0ac308ef 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -27,6 +27,8 @@ import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.random.{ExponentialGenerator, WeibullGenerator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.functions.{col, lit} +import org.apache.spark.sql.types._ class AFTSurvivalRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -352,7 +354,7 @@ class AFTSurvivalRegressionSuite } } - test("should support all NumericType labels") { + test("should support all NumericType labels, and not support other types") { val aft = new AFTSurvivalRegression().setMaxIter(1) MLTestingUtils.checkNumericTypes[AFTSurvivalRegressionModel, AFTSurvivalRegression]( aft, spark, isClassification = false) { (expected, actual) => @@ -361,6 +363,36 @@ class AFTSurvivalRegressionSuite } } + test("should support all NumericType censors, and not support other types") { + val df = spark.createDataFrame(Seq( + (0, Vectors.dense(0)), + (1, Vectors.dense(1)), + (2, Vectors.dense(2)), + (3, Vectors.dense(3)), + (4, Vectors.dense(4)) + )).toDF("label", "features") + .withColumn("censor", lit(0.0)) + val aft = new AFTSurvivalRegression().setMaxIter(1) + val expected = aft.fit(df) + + val types = Seq(ShortType, LongType, IntegerType, FloatType, ByteType, DecimalType(10, 0)) + types.foreach { t => + val actual = aft.fit(df.select(col("label"), col("features"), + col("censor").cast(t))) + assert(expected.intercept === actual.intercept) + assert(expected.coefficients === actual.coefficients) + } + + val dfWithStringCensors = spark.createDataFrame(Seq( + (0, Vectors.dense(0, 2, 3), "0") + )).toDF("label", "features", "censor") + val thrown = intercept[IllegalArgumentException] { + aft.fit(dfWithStringCensors) + } + assert(thrown.getMessage.contains( + "Column censor must be of type NumericType but was actually of type StringType")) + } + test("numerical stability of standardization") { val trainer = new AFTSurvivalRegression() val model1 = trainer.fit(datasetUnivariate) diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 511686fb4f37f..bd4528bd21264 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -55,6 +55,9 @@ object MimaExcludes { // [SPARK-14272][ML] Add logLikelihood in GaussianMixtureSummary ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.clustering.GaussianMixtureSummary.this"), + // [SPARK-19267] Fetch Failure handling robust to user error handling + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.setFetchFailed"), + // [SPARK-19069] [CORE] Expose task 'status' and 'duration' in spark history server REST API. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskData.$default$10"), @@ -911,6 +914,10 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") + ) ++ Seq( + // [SPARK-17498] StringIndexer enhancement for handling unseen labels + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel") ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 9331e74eede59..14c51a306e1c2 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -93,13 +93,15 @@ def keyword_only(func): """ A decorator that forces keyword arguments in the wrapped method and saves actual input keyword arguments in `_input_kwargs`. + + .. note:: Should only be used to wrap a method where first arg is `self` """ @wraps(func) - def wrapper(*args, **kwargs): - if len(args) > 1: + def wrapper(self, *args, **kwargs): + if len(args) > 0: raise TypeError("Method %s forces keyword arguments." % func.__name__) - wrapper._input_kwargs = kwargs - return func(*args, **kwargs) + self._input_kwargs = kwargs + return func(self, **kwargs) return wrapper diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index ac40fceaf8e96..b4fc357e42d71 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -124,7 +124,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred "org.apache.spark.ml.classification.LinearSVC", self.uid) self._setDefault(maxIter=100, regParam=0.0, tol=1e-6, fitIntercept=True, standardization=True, threshold=0.0, aggregationDepth=2) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -140,7 +140,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre aggregationDepth=2): Sets params for Linear SVM Classifier. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -266,7 +266,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.LogisticRegression", self.uid) self._setDefault(maxIter=100, regParam=0.0, tol=1E-6, threshold=0.5, family="auto") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) self._checkThresholdConsistency() @@ -286,7 +286,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre Sets params for logistic regression. If the threshold and thresholds Params are both set, they must be equivalent. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) self._checkThresholdConsistency() return self @@ -760,7 +760,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -778,7 +778,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre seed=None) Sets params for the DecisionTreeClassifier. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -890,7 +890,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -908,7 +908,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="gini", numTrees=20, featureSubsetStrategy="auto", subsamplingRate=1.0) Sets params for linear classification. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1031,7 +1031,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, lossType="logistic", maxIter=20, stepSize=0.1, subsamplingRate=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1047,7 +1047,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre lossType="logistic", maxIter=20, stepSize=0.1, seed=None, subsamplingRate=1.0) Sets params for Gradient Boosted Tree Classification. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1174,7 +1174,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.NaiveBayes", self.uid) self._setDefault(smoothing=1.0, modelType="multinomial") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1188,7 +1188,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre modelType="multinomial", thresholds=None, weightCol=None) Sets params for Naive Bayes. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1329,7 +1329,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) self._setDefault(maxIter=100, tol=1E-4, blockSize=128, stepSize=0.03, solver="l-bfgs") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1343,7 +1343,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre solver="l-bfgs", initialWeights=None) Sets params for MultilayerPerceptronClassifier. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1519,7 +1519,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred classifier=None) """ super(OneVsRest, self).__init__() - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @keyword_only @@ -1529,7 +1529,7 @@ def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classif setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): Sets params for OneVsRest. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _fit(self, dataset): diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index c6c1a0033190e..88ac7e275e386 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -224,7 +224,7 @@ def __init__(self, featuresCol="features", predictionCol="prediction", k=2, self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.GaussianMixture", self.uid) self._setDefault(k=2, tol=0.01, maxIter=100) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) def _create_model(self, java_model): @@ -240,7 +240,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, Sets params for GaussianMixture. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -414,7 +414,7 @@ def __init__(self, featuresCol="features", predictionCol="prediction", k=2, super(KMeans, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.KMeans", self.uid) self._setDefault(k=2, initMode="k-means||", initSteps=2, tol=1e-4, maxIter=20) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) def _create_model(self, java_model): @@ -430,7 +430,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", k=2, Sets params for KMeans. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -591,7 +591,7 @@ def __init__(self, featuresCol="features", predictionCol="prediction", maxIter=2 self._java_obj = self._new_java_obj("org.apache.spark.ml.clustering.BisectingKMeans", self.uid) self._setDefault(maxIter=20, k=4, minDivisibleClusterSize=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -603,7 +603,7 @@ def setParams(self, featuresCol="features", predictionCol="prediction", maxIter= seed=None, k=4, minDivisibleClusterSize=1.0) Sets params for BisectingKMeans. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -916,7 +916,7 @@ def __init__(self, featuresCol="features", maxIter=20, seed=None, checkpointInte k=10, optimizer="online", learningOffset=1024.0, learningDecay=0.51, subsamplingRate=0.05, optimizeDocConcentration=True, topicDistributionCol="topicDistribution", keepLastCheckpoint=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) def _create_model(self, java_model): @@ -941,7 +941,7 @@ def setParams(self, featuresCol="features", maxIter=20, seed=None, checkpointInt Sets params for LDA. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 7aa16fa5b90f2..7cb8d62f212cb 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -148,7 +148,7 @@ def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator", self.uid) self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("1.4.0") @@ -174,7 +174,7 @@ def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", metricName="areaUnderROC") Sets params for binary classification evaluator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -226,7 +226,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", "org.apache.spark.ml.evaluation.RegressionEvaluator", self.uid) self._setDefault(predictionCol="prediction", labelCol="label", metricName="rmse") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("1.4.0") @@ -252,7 +252,7 @@ def setParams(self, predictionCol="prediction", labelCol="label", metricName="rmse") Sets params for regression evaluator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -299,7 +299,7 @@ def __init__(self, predictionCol="prediction", labelCol="label", "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) self._setDefault(predictionCol="prediction", labelCol="label", metricName="f1") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("1.5.0") @@ -325,7 +325,7 @@ def setParams(self, predictionCol="prediction", labelCol="label", metricName="f1") Sets params for multiclass classification evaluator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) if __name__ == "__main__": diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index c2eafbefcdec1..92f8549e9cb9e 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -94,7 +94,7 @@ def __init__(self, threshold=0.0, inputCol=None, outputCol=None): super(Binarizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) self._setDefault(threshold=0.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -104,7 +104,7 @@ def setParams(self, threshold=0.0, inputCol=None, outputCol=None): setParams(self, threshold=0.0, inputCol=None, outputCol=None) Sets params for this Binarizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -258,14 +258,14 @@ class BucketedRandomProjectionLSH(JavaEstimator, LSHParams, HasInputCol, HasOutp def __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, bucketLength=None): """ - __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, + __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, \ bucketLength=None) """ super(BucketedRandomProjectionLSH, self).__init__() self._java_obj = \ self._new_java_obj("org.apache.spark.ml.feature.BucketedRandomProjectionLSH", self.uid) self._setDefault(numHashTables=1) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -277,7 +277,7 @@ def setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1, bucketLength=None) Sets params for this BucketedRandomProjectionLSH. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.2.0") @@ -370,7 +370,7 @@ def __init__(self, splits=None, inputCol=None, outputCol=None, handleInvalid="er super(Bucketizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) self._setDefault(handleInvalid="error") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -380,7 +380,7 @@ def setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="e setParams(self, splits=None, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this Bucketizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -484,7 +484,7 @@ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputC self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -496,7 +496,7 @@ def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, input outputCol=None) Set the params for the CountVectorizer """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -616,7 +616,7 @@ def __init__(self, inverse=False, inputCol=None, outputCol=None): super(DCT, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.DCT", self.uid) self._setDefault(inverse=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -626,7 +626,7 @@ def setParams(self, inverse=False, inputCol=None, outputCol=None): setParams(self, inverse=False, inputCol=None, outputCol=None) Sets params for this DCT. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -680,7 +680,7 @@ def __init__(self, scalingVec=None, inputCol=None, outputCol=None): super(ElementwiseProduct, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ElementwiseProduct", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -690,7 +690,7 @@ def setParams(self, scalingVec=None, inputCol=None, outputCol=None): setParams(self, scalingVec=None, inputCol=None, outputCol=None) Sets params for this ElementwiseProduct. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -750,7 +750,7 @@ def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=N super(HashingTF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) self._setDefault(numFeatures=1 << 18, binary=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -760,7 +760,7 @@ def setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol= setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None) Sets params for this HashingTF. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -823,7 +823,7 @@ def __init__(self, minDocFreq=0, inputCol=None, outputCol=None): super(IDF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) self._setDefault(minDocFreq=0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -833,7 +833,7 @@ def setParams(self, minDocFreq=0, inputCol=None, outputCol=None): setParams(self, minDocFreq=0, inputCol=None, outputCol=None) Sets params for this IDF. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -913,7 +913,7 @@ def __init__(self, inputCol=None, outputCol=None): super(MaxAbsScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MaxAbsScaler", self.uid) self._setDefault() - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -923,7 +923,7 @@ def setParams(self, inputCol=None, outputCol=None): setParams(self, inputCol=None, outputCol=None) Sets params for this MaxAbsScaler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1011,7 +1011,7 @@ def __init__(self, inputCol=None, outputCol=None, seed=None, numHashTables=1): super(MinHashLSH, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinHashLSH", self.uid) self._setDefault(numHashTables=1) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1021,7 +1021,7 @@ def setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1): setParams(self, inputCol=None, outputCol=None, seed=None, numHashTables=1) Sets params for this MinHashLSH. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1106,7 +1106,7 @@ def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None): super(MinMaxScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.MinMaxScaler", self.uid) self._setDefault(min=0.0, max=1.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1116,7 +1116,7 @@ def setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None): setParams(self, min=0.0, max=1.0, inputCol=None, outputCol=None) Sets params for this MinMaxScaler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -1224,7 +1224,7 @@ def __init__(self, n=2, inputCol=None, outputCol=None): super(NGram, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.NGram", self.uid) self._setDefault(n=2) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1234,7 +1234,7 @@ def setParams(self, n=2, inputCol=None, outputCol=None): setParams(self, n=2, inputCol=None, outputCol=None) Sets params for this NGram. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -1288,7 +1288,7 @@ def __init__(self, p=2.0, inputCol=None, outputCol=None): super(Normalizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) self._setDefault(p=2.0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1298,7 +1298,7 @@ def setParams(self, p=2.0, inputCol=None, outputCol=None): setParams(self, p=2.0, inputCol=None, outputCol=None) Sets params for this Normalizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1363,12 +1363,12 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, @keyword_only def __init__(self, dropLast=True, inputCol=None, outputCol=None): """ - __init__(self, includeFirst=True, inputCol=None, outputCol=None) + __init__(self, dropLast=True, inputCol=None, outputCol=None) """ super(OneHotEncoder, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) self._setDefault(dropLast=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1378,7 +1378,7 @@ def setParams(self, dropLast=True, inputCol=None, outputCol=None): setParams(self, dropLast=True, inputCol=None, outputCol=None) Sets params for this OneHotEncoder. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1434,7 +1434,7 @@ def __init__(self, degree=2, inputCol=None, outputCol=None): self._java_obj = self._new_java_obj( "org.apache.spark.ml.feature.PolynomialExpansion", self.uid) self._setDefault(degree=2) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1444,7 +1444,7 @@ def setParams(self, degree=2, inputCol=None, outputCol=None): setParams(self, degree=2, inputCol=None, outputCol=None) Sets params for this PolynomialExpansion. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1540,7 +1540,7 @@ def __init__(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0. self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.QuantileDiscretizer", self.uid) self._setDefault(numBuckets=2, relativeError=0.001, handleInvalid="error") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1552,7 +1552,7 @@ def setParams(self, numBuckets=2, inputCol=None, outputCol=None, relativeError=0 handleInvalid="error") Set the params for the QuantileDiscretizer """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") @@ -1665,7 +1665,7 @@ def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, super(RegexTokenizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) self._setDefault(minTokenLength=1, gaps=True, pattern="\\s+", toLowercase=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1677,7 +1677,7 @@ def setParams(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None, outputCol=None, toLowercase=True) Sets params for this RegexTokenizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1768,7 +1768,7 @@ def __init__(self, statement=None): """ super(SQLTransformer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.SQLTransformer", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1778,7 +1778,7 @@ def setParams(self, statement=None): setParams(self, statement=None) Sets params for this SQLTransformer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -1847,7 +1847,7 @@ def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None): super(StandardScaler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) self._setDefault(withMean=False, withStd=True) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1857,7 +1857,7 @@ def setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) setParams(self, withMean=False, withStd=True, inputCol=None, outputCol=None) Sets params for this StandardScaler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -1963,7 +1963,7 @@ def __init__(self, inputCol=None, outputCol=None, handleInvalid="error"): super(StringIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) self._setDefault(handleInvalid="error") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1973,7 +1973,7 @@ def setParams(self, inputCol=None, outputCol=None, handleInvalid="error"): setParams(self, inputCol=None, outputCol=None, handleInvalid="error") Sets params for this StringIndexer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -2021,7 +2021,7 @@ def __init__(self, inputCol=None, outputCol=None, labels=None): super(IndexToString, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IndexToString", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2031,7 +2031,7 @@ def setParams(self, inputCol=None, outputCol=None, labels=None): setParams(self, inputCol=None, outputCol=None, labels=None) Sets params for this IndexToString. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -2085,7 +2085,7 @@ def __init__(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive= self.uid) self._setDefault(stopWords=StopWordsRemover.loadDefaultStopWords("english"), caseSensitive=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2095,7 +2095,7 @@ def setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive setParams(self, inputCol=None, outputCol=None, stopWords=None, caseSensitive=false) Sets params for this StopWordRemover. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -2178,7 +2178,7 @@ def __init__(self, inputCol=None, outputCol=None): """ super(Tokenizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2188,7 +2188,7 @@ def setParams(self, inputCol=None, outputCol=None): setParams(self, inputCol=None, outputCol=None) Sets params for this Tokenizer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -2222,7 +2222,7 @@ def __init__(self, inputCols=None, outputCol=None): """ super(VectorAssembler, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2232,7 +2232,7 @@ def setParams(self, inputCols=None, outputCol=None): setParams(self, inputCols=None, outputCol=None) Sets params for this VectorAssembler. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -2320,7 +2320,7 @@ def __init__(self, maxCategories=20, inputCol=None, outputCol=None): super(VectorIndexer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) self._setDefault(maxCategories=20) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2330,7 +2330,7 @@ def setParams(self, maxCategories=20, inputCol=None, outputCol=None): setParams(self, maxCategories=20, inputCol=None, outputCol=None) Sets params for this VectorIndexer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -2435,7 +2435,7 @@ def __init__(self, inputCol=None, outputCol=None, indices=None, names=None): super(VectorSlicer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid) self._setDefault(indices=[], names=[]) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2445,7 +2445,7 @@ def setParams(self, inputCol=None, outputCol=None, indices=None, names=None): setParams(self, inputCol=None, outputCol=None, indices=None, names=None): Sets params for this VectorSlicer. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.6.0") @@ -2558,7 +2558,7 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, windowSize=5, maxSentenceLength=1000) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2570,7 +2570,7 @@ def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, inputCol=None, outputCol=None, windowSize=5, maxSentenceLength=1000) Sets params for this Word2Vec. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -2718,7 +2718,7 @@ def __init__(self, k=None, inputCol=None, outputCol=None): """ super(PCA, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.PCA", self.uid) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2728,7 +2728,7 @@ def setParams(self, k=None, inputCol=None, outputCol=None): setParams(self, k=None, inputCol=None, outputCol=None) Set params for this PCA. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -2858,7 +2858,7 @@ def __init__(self, formula=None, featuresCol="features", labelCol="label", super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) self._setDefault(forceIndexLabel=False) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -2870,7 +2870,7 @@ def setParams(self, formula=None, featuresCol="features", labelCol="label", forceIndexLabel=False) Sets params for RFormula. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.5.0") @@ -3017,7 +3017,7 @@ def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) self._setDefault(numTopFeatures=50, selectorType="numTopFeatures", percentile=0.1, fpr=0.05, fdr=0.05, fwe=0.05) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -3031,7 +3031,7 @@ def setParams(self, numTopFeatures=50, featuresCol="features", outputCol=None, fdr=0.05, fwe=0.05) Sets params for this ChiSqSelector. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.1.0") diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a78e3b49fbcfc..4aac6a4466b54 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -58,7 +58,7 @@ def __init__(self, stages=None): __init__(self, stages=None) """ super(Pipeline, self).__init__() - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @since("1.3.0") @@ -85,7 +85,7 @@ def setParams(self, stages=None): setParams(self, stages=None) Sets params for Pipeline. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _fit(self, dataset): diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 43f82daa9fcd3..8bc899a0788bb 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -152,7 +152,7 @@ def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemB ratingCol="rating", nonnegative=False, checkpointInterval=10, intermediateStorageLevel="MEMORY_AND_DISK", finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -170,7 +170,7 @@ def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItem finalStorageLevel="MEMORY_AND_DISK", coldStartStrategy="nan") Sets params for ALS. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index b42e807069802..b199bf282e4f2 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -108,7 +108,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.LinearRegression", self.uid) self._setDefault(maxIter=100, regParam=0.0, tol=1e-6) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -122,7 +122,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre standardization=True, solver="auto", weightCol=None, aggregationDepth=2) Sets params for linear regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -464,7 +464,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.IsotonicRegression", self.uid) self._setDefault(isotonic=True, featureIndex=0) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -475,7 +475,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre weightCol=None, isotonic=True, featureIndex=0): Set the params for IsotonicRegression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -704,7 +704,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -720,7 +720,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="variance", seed=None, varianceCol=None) Sets params for the DecisionTreeRegressor. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -895,7 +895,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance", subsamplingRate=1.0, numTrees=20, featureSubsetStrategy="auto") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -913,7 +913,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre featureSubsetStrategy="auto") Sets params for linear regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1022,7 +1022,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred maxMemoryInMB=256, cacheNodeIds=False, subsamplingRate=1.0, checkpointInterval=10, lossType="squared", maxIter=20, stepSize=0.1, impurity="variance") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1040,7 +1040,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre impurity="variance") Sets params for Gradient Boosted Tree Regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1171,7 +1171,7 @@ def __init__(self, featuresCol="features", labelCol="label", predictionCol="pred self._setDefault(censorCol="censor", quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], maxIter=100, tol=1E-6) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1186,7 +1186,7 @@ def setParams(self, featuresCol="features", labelCol="label", predictionCol="pre quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], \ quantilesCol=None, aggregationDepth=2): """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): @@ -1366,7 +1366,7 @@ def __init__(self, labelCol="label", featuresCol="features", predictionCol="pred self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -1380,7 +1380,7 @@ def setParams(self, labelCol="label", featuresCol="features", predictionCol="pre regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) Sets params for generalized linear regression. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) def _create_model(self, java_model): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 293c6c0b0f36a..352416055791e 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -250,7 +250,7 @@ class TestParams(HasMaxIter, HasInputCol, HasSeed): def __init__(self, seed=None): super(TestParams, self).__init__() self._setDefault(maxIter=10) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -259,7 +259,7 @@ def setParams(self, seed=None): setParams(self, seed=None) Sets params for this test. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @@ -271,7 +271,7 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed): def __init__(self, seed=None): super(OtherTestParams, self).__init__() self._setDefault(maxIter=10) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self.setParams(**kwargs) @keyword_only @@ -280,7 +280,7 @@ def setParams(self, seed=None): setParams(self, seed=None) Sets params for this test. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 2dcc99cef8aa2..ffeb4459e1aac 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -186,7 +186,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numF """ super(CrossValidator, self).__init__() self._setDefault(numFolds=3) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @keyword_only @@ -198,7 +198,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, num seed=None): Sets params for cross validator. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("1.4.0") @@ -346,7 +346,7 @@ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, trai """ super(TrainValidationSplit, self).__init__() self._setDefault(trainRatio=0.75) - kwargs = self.__init__._input_kwargs + kwargs = self._input_kwargs self._set(**kwargs) @since("2.0.0") @@ -358,7 +358,7 @@ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, tra seed=None): Sets params for the train validation split. """ - kwargs = self.setParams._input_kwargs + kwargs = self._input_kwargs return self._set(**kwargs) @since("2.0.0") diff --git a/python/pyspark/sql/column.py b/python/pyspark/sql/column.py index c10ab9638a21f..ec05c18d4f062 100644 --- a/python/pyspark/sql/column.py +++ b/python/pyspark/sql/column.py @@ -180,7 +180,9 @@ def __init__(self, jc): __ror__ = _bin_op("or") # container operators - __contains__ = _bin_op("contains") + def __contains__(self, item): + raise ValueError("Cannot apply 'in' operator against a column: please use 'contains' " + "in a string column or 'array_contains' function for an array column.") # bitwise operators bitwiseOR = _bin_op("bitwiseOR") diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 426a4a8c93a67..376b86ea69bd4 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1773,11 +1773,11 @@ def json_tuple(col, *fields): @since(2.1) def from_json(col, schema, options={}): """ - Parses a column containing a JSON string into a [[StructType]] with the - specified schema. Returns `null`, in the case of an unparseable string. + Parses a column containing a JSON string into a [[StructType]] or [[ArrayType]] + with the specified schema. Returns `null`, in the case of an unparseable string. :param col: string column in json format - :param schema: a StructType to use when parsing the json column + :param schema: a StructType or ArrayType to use when parsing the json column :param options: options to control parsing. accepts the same options as the json datasource >>> from pyspark.sql.types import * @@ -1786,6 +1786,11 @@ def from_json(col, schema, options={}): >>> df = spark.createDataFrame(data, ("key", "value")) >>> df.select(from_json(df.value, schema).alias("json")).collect() [Row(json=Row(a=1))] + >>> data = [(1, '''[{"a": 1}]''')] + >>> schema = ArrayType(StructType([StructField("a", IntegerType())])) + >>> df = spark.createDataFrame(data, ("key", "value")) + >>> df.select(from_json(df.value, schema).alias("json")).collect() + [Row(json=[Row(a=1)])] """ sc = SparkContext._active_spark_context diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index ec47618e73a6c..45fb9b7591529 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -163,8 +163,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads a JSON file and returns the results as a :class:`DataFrame`. - Both JSON (one record per file) and `JSON Lines `_ - (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter. + `JSON Lines `_(newline-delimited JSON) is supported by default. + For JSON (one record per file), set the `wholeFile` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index 7587875cb9849..625fb9ba385af 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -433,8 +433,8 @@ def json(self, path, schema=None, primitivesAsString=None, prefersDecimal=None, """ Loads a JSON file stream and returns the results as a :class:`DataFrame`. - Both JSON (one record per file) and `JSON Lines `_ - (newline-delimited JSON) are supported and can be selected with the `wholeFile` parameter. + `JSON Lines `_(newline-delimited JSON) is supported by default. + For JSON (one record per file), set the `wholeFile` parameter to ``true``. If the ``schema`` parameter is not specified, this function goes through the input once to determine the input schema. diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index e943f8da3db14..81f3d1d36a342 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -967,6 +967,9 @@ def test_column_operators(self): cs.startswith('a'), cs.endswith('a') self.assertTrue(all(isinstance(c, Column) for c in css)) self.assertTrue(isinstance(ci.cast(LongType()), Column)) + self.assertRaisesRegexp(ValueError, + "Cannot apply 'in' operator against a column", + lambda: 1 in cs) def test_column_getitem(self): from pyspark.sql.functions import col diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index a2aead7e6b9a6..c6c87a9ea5555 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -58,6 +58,7 @@ from StringIO import StringIO +from pyspark import keyword_only from pyspark.conf import SparkConf from pyspark.context import SparkContext from pyspark.rdd import RDD @@ -2161,6 +2162,44 @@ def test_memory_conf(self): sc.stop() +class KeywordOnlyTests(unittest.TestCase): + class Wrapped(object): + @keyword_only + def set(self, x=None, y=None): + if "x" in self._input_kwargs: + self._x = self._input_kwargs["x"] + if "y" in self._input_kwargs: + self._y = self._input_kwargs["y"] + return x, y + + def test_keywords(self): + w = self.Wrapped() + x, y = w.set(y=1) + self.assertEqual(y, 1) + self.assertEqual(y, w._y) + self.assertIsNone(x) + self.assertFalse(hasattr(w, "_x")) + + def test_non_keywords(self): + w = self.Wrapped() + self.assertRaises(TypeError, lambda: w.set(0, y=1)) + + def test_kwarg_ownership(self): + # test _input_kwargs is owned by each class instance and not a shared static variable + class Setter(object): + @keyword_only + def set(self, x=None, other=None, other_x=None): + if "other" in self._input_kwargs: + self._input_kwargs["other"].set(x=self._input_kwargs["other_x"]) + self._x = self._input_kwargs["x"] + + a = Setter() + b = Setter() + a.set(x=1, other=b, other_x=2) + self.assertEqual(a._x, 1) + self.assertEqual(b._x, 2) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 2760f31b12fa7..1bc6f71860c3f 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -152,6 +152,7 @@ private[spark] class MesosClusterScheduler( // is registered with Mesos master. @volatile protected var ready = false private var masterInfo: Option[MasterInfo] = None + private var schedulerDriver: SchedulerDriver = _ def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = { val c = new CreateSubmissionResponse @@ -168,9 +169,8 @@ private[spark] class MesosClusterScheduler( return c } c.submissionId = desc.submissionId - queuedDriversState.persist(desc.submissionId, desc) - queuedDrivers += desc c.success = true + addDriverToQueue(desc) } c } @@ -191,7 +191,7 @@ private[spark] class MesosClusterScheduler( // 4. Check if it has already completed. if (launchedDrivers.contains(submissionId)) { val task = launchedDrivers(submissionId) - mesosDriver.killTask(task.taskId) + schedulerDriver.killTask(task.taskId) k.success = true k.message = "Killing running driver" } else if (removeFromQueuedDrivers(submissionId)) { @@ -324,7 +324,7 @@ private[spark] class MesosClusterScheduler( ready = false metricsSystem.report() metricsSystem.stop() - mesosDriver.stop(true) + schedulerDriver.stop(true) } override def registered( @@ -340,6 +340,8 @@ private[spark] class MesosClusterScheduler( stateLock.synchronized { this.masterInfo = Some(masterInfo) + this.schedulerDriver = driver + if (!pendingRecover.isEmpty) { // Start task reconciliation if we need to recover. val statuses = pendingRecover.collect { @@ -506,11 +508,10 @@ private[spark] class MesosClusterScheduler( } private class ResourceOffer( - val offerId: OfferID, - val slaveId: SlaveID, - var resources: JList[Resource]) { + val offer: Offer, + var remainingResources: JList[Resource]) { override def toString(): String = { - s"Offer id: ${offerId}, resources: ${resources}" + s"Offer id: ${offer.getId}, resources: ${remainingResources}" } } @@ -518,16 +519,16 @@ private[spark] class MesosClusterScheduler( val taskId = TaskID.newBuilder().setValue(desc.submissionId).build() val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.resources, "cpus", desc.cores) + partitionResources(offer.remainingResources, "cpus", desc.cores) val (finalResources, memResourcesToUse) = partitionResources(remainingResources.asJava, "mem", desc.mem) - offer.resources = finalResources.asJava + offer.remainingResources = finalResources.asJava val appName = desc.conf.get("spark.app.name") val taskInfo = TaskInfo.newBuilder() .setTaskId(taskId) .setName(s"Driver for ${appName}") - .setSlaveId(offer.slaveId) + .setSlaveId(offer.offer.getSlaveId) .setCommand(buildDriverCommand(desc)) .addAllResources(cpuResourcesToUse.asJava) .addAllResources(memResourcesToUse.asJava) @@ -549,23 +550,29 @@ private[spark] class MesosClusterScheduler( val driverCpu = submission.cores val driverMem = submission.mem logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") - val offerOption = currentOffers.find { o => - getResource(o.resources, "cpus") >= driverCpu && - getResource(o.resources, "mem") >= driverMem + val offerOption = currentOffers.find { offer => + getResource(offer.remainingResources, "cpus") >= driverCpu && + getResource(offer.remainingResources, "mem") >= driverMem } if (offerOption.isEmpty) { logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + s"cpu: $driverCpu, mem: $driverMem") } else { val offer = offerOption.get - val queuedTasks = tasks.getOrElseUpdate(offer.offerId, new ArrayBuffer[TaskInfo]) + val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) try { val task = createTaskInfo(submission, offer) queuedTasks += task - logTrace(s"Using offer ${offer.offerId.getValue} to launch driver " + + logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + submission.submissionId) - val newState = new MesosClusterSubmissionState(submission, task.getTaskId, offer.slaveId, - None, new Date(), None, getDriverFrameworkID(submission)) + val newState = new MesosClusterSubmissionState( + submission, + task.getTaskId, + offer.offer.getSlaveId, + None, + new Date(), + None, + getDriverFrameworkID(submission)) launchedDrivers(submission.submissionId) = newState launchedDriversState.persist(submission.submissionId, newState) afterLaunchCallback(submission.submissionId) @@ -588,7 +595,7 @@ private[spark] class MesosClusterScheduler( val currentTime = new Date() val currentOffers = offers.asScala.map { - o => new ResourceOffer(o.getId, o.getSlaveId, o.getResourcesList) + offer => new ResourceOffer(offer, offer.getResourcesList) }.toList stateLock.synchronized { @@ -615,8 +622,8 @@ private[spark] class MesosClusterScheduler( driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) } - for (o <- currentOffers if !tasks.contains(o.offerId)) { - driver.declineOffer(o.offerId) + for (offer <- currentOffers if !tasks.contains(offer.offer.getId)) { + declineOffer(driver, offer.offer, None, Some(getRejectOfferDuration(conf))) } } @@ -662,6 +669,12 @@ private[spark] class MesosClusterScheduler( override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = { val taskId = status.getTaskId.getValue + + logInfo(s"Received status update: taskId=${taskId}" + + s" state=${status.getState}" + + s" message=${status.getMessage}" + + s" reason=${status.getReason}"); + stateLock.synchronized { if (launchedDrivers.contains(taskId)) { if (status.getReason == Reason.REASON_RECONCILIATION && @@ -682,8 +695,7 @@ private[spark] class MesosClusterScheduler( val newDriverDescription = state.driverDescription.copy( retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) - pendingRetryDrivers += newDriverDescription - pendingRetryDriversState.persist(taskId, newDriverDescription) + addDriverToPending(newDriverDescription, taskId); } else if (TaskState.isFinished(mesosToTaskState(status.getState))) { removeFromLaunchedDrivers(taskId) state.finishDate = Some(new Date()) @@ -746,4 +758,21 @@ private[spark] class MesosClusterScheduler( def getQueuedDriversSize: Int = queuedDrivers.size def getLaunchedDriversSize: Int = launchedDrivers.size def getPendingRetryDriversSize: Int = pendingRetryDrivers.size + + private def addDriverToQueue(desc: MesosDriverDescription): Unit = { + queuedDriversState.persist(desc.submissionId, desc) + queuedDrivers += desc + revive() + } + + private def addDriverToPending(desc: MesosDriverDescription, taskId: String) = { + pendingRetryDriversState.persist(taskId, desc) + pendingRetryDrivers += desc + revive() + } + + private def revive(): Unit = { + logInfo("Reviving Offers.") + schedulerDriver.reviveOffers() + } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index f69c223ab9b6d..85c2e9c76f4b0 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -26,6 +26,7 @@ import scala.collection.mutable import scala.concurrent.Future import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf @@ -119,11 +120,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Reject offers with mismatched constraints in seconds private val rejectOfferDurationForUnmetConstraints = - getRejectOfferDurationForUnmetConstraints(sc) + getRejectOfferDurationForUnmetConstraints(sc.conf) // Reject offers when we reached the maximum number of cores for this framework private val rejectOfferDurationForReachedMaxCores = - getRejectOfferDurationForReachedMaxCores(sc) + getRejectOfferDurationForReachedMaxCores(sc.conf) // A client for talking to the external shuffle service private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { @@ -146,6 +147,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( @volatile var appId: String = _ + private var schedulerDriver: SchedulerDriver = _ + def newMesosTaskId(): String = { val id = nextMesosTaskId nextMesosTaskId += 1 @@ -252,9 +255,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {} override def registered( - d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - appId = frameworkId.getValue - mesosExternalShuffleClient.foreach(_.init(appId)) + driver: org.apache.mesos.SchedulerDriver, + frameworkId: FrameworkID, + masterInfo: MasterInfo) { + this.appId = frameworkId.getValue + this.mesosExternalShuffleClient.foreach(_.init(appId)) + this.schedulerDriver = driver markRegistered() } @@ -293,46 +299,25 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def declineUnmatchedOffers( - d: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { + driver: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { offers.foreach { offer => - declineOffer(d, offer, Some("unmet constraints"), + declineOffer( + driver, + offer, + Some("unmet constraints"), Some(rejectOfferDurationForUnmetConstraints)) } } - private def declineOffer( - d: org.apache.mesos.SchedulerDriver, - offer: Offer, - reason: Option[String] = None, - refuseSeconds: Option[Long] = None): Unit = { - - val id = offer.getId.getValue - val offerAttributes = toAttributeMap(offer.getAttributesList) - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus") - val ports = getRangeResource(offer.getResourcesList, "ports") - - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem" + - s" cpu: $cpus port: $ports for $refuseSeconds seconds" + - reason.map(r => s" (reason: $r)").getOrElse("")) - - refuseSeconds match { - case Some(seconds) => - val filters = Filters.newBuilder().setRefuseSeconds(seconds).build() - d.declineOffer(offer.getId, filters) - case _ => d.declineOffer(offer.getId) - } - } - /** * Launches executors on accepted offers, and declines unused offers. Executors are launched * round-robin on offers. * - * @param d SchedulerDriver + * @param driver SchedulerDriver * @param offers Mesos offers that match attribute constraints */ private def handleMatchedOffers( - d: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { + driver: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { val tasks = buildMesosTasks(offers) for (offer <- offers) { val offerAttributes = toAttributeMap(offer.getAttributesList) @@ -358,15 +343,19 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( s" ports: $ports") } - d.launchTasks( + driver.launchTasks( Collections.singleton(offer.getId), offerTasks.asJava) } else if (totalCoresAcquired >= maxCores) { // Reject an offer for a configurable amount of time to avoid starving other frameworks - declineOffer(d, offer, Some("reached spark.cores.max"), + declineOffer(driver, + offer, + Some("reached spark.cores.max"), Some(rejectOfferDurationForReachedMaxCores)) } else { - declineOffer(d, offer) + declineOffer( + driver, + offer) } } } @@ -582,8 +571,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Close the mesos external shuffle client if used mesosExternalShuffleClient.foreach(_.close()) - if (mesosDriver != null) { - mesosDriver.stop() + if (schedulerDriver != null) { + schedulerDriver.stop() } } @@ -634,13 +623,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future.successful { - if (mesosDriver == null) { + if (schedulerDriver == null) { logWarning("Asked to kill executors before the Mesos driver was started.") false } else { for (executorId <- executorIds) { val taskId = TaskID.newBuilder().setValue(executorId).build() - mesosDriver.killTask(taskId) + schedulerDriver.killTask(taskId) } // no need to adjust `executorLimitOption` since the AllocationManager already communicated // the desired limit through a call to `doRequestTotalExecutors`. diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 7e561916a71e2..215271302ec51 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkContext, SparkException, TaskState} @@ -65,7 +66,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( // reject offers with mismatched constraints in seconds private val rejectOfferDurationForUnmetConstraints = - getRejectOfferDurationForUnmetConstraints(sc) + getRejectOfferDurationForUnmetConstraints(sc.conf) + + private var schedulerDriver: SchedulerDriver = _ @volatile var appId: String = _ @@ -89,6 +92,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( /** * Creates a MesosExecutorInfo that is used to launch a Mesos executor. + * * @param availableResources Available resources that is offered by Mesos * @param execId The executor id to assign to this new executor. * @return A tuple of the new mesos executor info and the remaining available resources. @@ -178,10 +182,13 @@ private[spark] class MesosFineGrainedSchedulerBackend( override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {} override def registered( - d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { + driver: org.apache.mesos.SchedulerDriver, + frameworkId: FrameworkID, + masterInfo: MasterInfo) { inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) + this.schedulerDriver = driver markRegistered() } } @@ -383,13 +390,13 @@ private[spark] class MesosFineGrainedSchedulerBackend( } override def stop() { - if (mesosDriver != null) { - mesosDriver.stop() + if (schedulerDriver != null) { + schedulerDriver.stop() } } override def reviveOffers() { - mesosDriver.reviveOffers() + schedulerDriver.reviveOffers() } override def frameworkMessage( @@ -426,7 +433,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { - mesosDriver.killTask( + schedulerDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() ) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 1d742fefbbacf..3f25535cb5ec2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -46,9 +46,6 @@ trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered private final val registerLatch = new CountDownLatch(1) - // Driver for talking to Mesos - protected var mesosDriver: SchedulerDriver = null - /** * Creates a new MesosSchedulerDriver that communicates to the Mesos master. * @@ -115,10 +112,6 @@ trait MesosSchedulerUtils extends Logging { */ def startScheduler(newDriver: SchedulerDriver): Unit = { synchronized { - if (mesosDriver != null) { - registerLatch.await() - return - } @volatile var error: Option[Exception] = None @@ -128,8 +121,7 @@ trait MesosSchedulerUtils extends Logging { setDaemon(true) override def run() { try { - mesosDriver = newDriver - val ret = mesosDriver.run() + val ret = newDriver.run() logInfo("driver.run() returned with code " + ret) if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { error = Some(new SparkException("Error starting driver, DRIVER_ABORTED")) @@ -379,12 +371,24 @@ trait MesosSchedulerUtils extends Logging { } } - protected def getRejectOfferDurationForUnmetConstraints(sc: SparkContext): Long = { - sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s") + private def getRejectOfferDurationStr(conf: SparkConf): String = { + conf.get("spark.mesos.rejectOfferDuration", "120s") + } + + protected def getRejectOfferDuration(conf: SparkConf): Long = { + Utils.timeStringAsSeconds(getRejectOfferDurationStr(conf)) + } + + protected def getRejectOfferDurationForUnmetConstraints(conf: SparkConf): Long = { + conf.getTimeAsSeconds( + "spark.mesos.rejectOfferDurationForUnmetConstraints", + getRejectOfferDurationStr(conf)) } - protected def getRejectOfferDurationForReachedMaxCores(sc: SparkContext): Long = { - sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForReachedMaxCores", "120s") + protected def getRejectOfferDurationForReachedMaxCores(conf: SparkConf): Long = { + conf.getTimeAsSeconds( + "spark.mesos.rejectOfferDurationForReachedMaxCores", + getRejectOfferDurationStr(conf)) } /** @@ -438,6 +442,7 @@ trait MesosSchedulerUtils extends Logging { /** * The values of the non-zero ports to be used by the executor process. + * * @param conf the spark config to use * @return the ono-zero values of the ports */ @@ -521,4 +526,33 @@ trait MesosSchedulerUtils extends Logging { case TaskState.KILLED => MesosTaskState.TASK_KILLED case TaskState.LOST => MesosTaskState.TASK_LOST } + + protected def declineOffer( + driver: org.apache.mesos.SchedulerDriver, + offer: Offer, + reason: Option[String] = None, + refuseSeconds: Option[Long] = None): Unit = { + + val id = offer.getId.getValue + val offerAttributes = toAttributeMap(offer.getAttributesList) + val mem = getResource(offer.getResourcesList, "mem") + val cpus = getResource(offer.getResourcesList, "cpus") + val ports = getRangeResource(offer.getResourcesList, "ports") + + logDebug(s"Declining offer: $id with " + + s"attributes: $offerAttributes " + + s"mem: $mem " + + s"cpu: $cpus " + + s"port: $ports " + + refuseSeconds.map(s => s"for ${s} seconds ").getOrElse("") + + reason.map(r => s" (reason: $r)").getOrElse("")) + + refuseSeconds match { + case Some(seconds) => + val filters = Filters.newBuilder().setRefuseSeconds(seconds).build() + driver.declineOffer(offer.getId, filters) + case _ => + driver.declineOffer(offer.getId) + } + } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index b9d098486b675..32967b04cd346 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -53,19 +53,32 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi override def start(): Unit = { ready = true } } scheduler.start() + scheduler.registered(driver, Utils.TEST_FRAMEWORK_ID, Utils.TEST_MASTER_INFO) + } + + private def testDriverDescription(submissionId: String): MesosDriverDescription = { + new MesosDriverDescription( + "d1", + "jar", + 1000, + 1, + true, + command, + Map[String, String](), + submissionId, + new Date()) } test("can queue drivers") { setScheduler() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) + val response = scheduler.submitDriver(testDriverDescription("s1")) assert(response.success) - val response2 = - scheduler.submitDriver(new MesosDriverDescription( - "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + verify(driver, times(1)).reviveOffers() + + val response2 = scheduler.submitDriver(testDriverDescription("s2")) assert(response2.success) + val state = scheduler.getSchedulerState() val queuedDrivers = state.queuedDrivers.toList assert(queuedDrivers(0).submissionId == response.submissionId) @@ -75,9 +88,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi test("can kill queued drivers") { setScheduler() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) + val response = scheduler.submitDriver(testDriverDescription("s1")) assert(response.success) val killResponse = scheduler.killDriver(response.submissionId) assert(killResponse.success) @@ -238,18 +249,10 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi } test("can kill supervised drivers") { - val driver = mock[SchedulerDriver] val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") conf.setAppName("spark mesos") - scheduler = new MesosClusterScheduler( - new BlackHoleMesosClusterPersistenceEngineFactory, conf) { - override def start(): Unit = { - ready = true - mesosDriver = driver - } - } - scheduler.start() + setScheduler(conf.getAll.toMap) val response = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, @@ -291,4 +294,16 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(state.launchedDrivers.isEmpty) assert(state.finishedDrivers.size == 1) } + + test("Declines offer with refuse seconds = 120.") { + setScheduler() + + val filter = Filters.newBuilder().setRefuseSeconds(120).build() + val offerId = OfferID.newBuilder().setValue("o1").build() + val offer = Utils.createOffer(offerId.getValue, "s1", 1000, 1) + + scheduler.resourceOffers(driver, Collections.singletonList(offer)) + + verify(driver, times(1)).declineOffer(offerId, filter) + } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 78346e9744957..98033bec6dd68 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -552,17 +552,14 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient // override to avoid race condition with the driver thread on `mesosDriver` - override def startScheduler(newDriver: SchedulerDriver): Unit = { - mesosDriver = newDriver - } + override def startScheduler(newDriver: SchedulerDriver): Unit = {} override def stopExecutors(): Unit = { stopCalled = true } - - markRegistered() } backend.start() + backend.registered(driver, Utils.TEST_FRAMEWORK_ID, Utils.TEST_MASTER_INFO) backend } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 7ebb294aa9080..2a67cbc913ffe 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -28,6 +28,17 @@ import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ object Utils { + + val TEST_FRAMEWORK_ID = FrameworkID.newBuilder() + .setValue("test-framework-id") + .build() + + val TEST_MASTER_INFO = MasterInfo.newBuilder() + .setId("test-master") + .setIp(0) + .setPort(0) + .build() + def createOffer( offerId: String, slaveId: String, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala index 2fdb70a73c754..41b7b5d60b038 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala @@ -60,7 +60,7 @@ private[spark] class CredentialUpdater( if (remainingTime <= 0) { credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) } else { - logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime millis.") + logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.") credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS) } } @@ -81,8 +81,8 @@ private[spark] class CredentialUpdater( UserGroupInformation.getCurrentUser.addCredentials(newCredentials) logInfo("Credentials updated from credentials file.") - val remainingTime = getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) - - System.currentTimeMillis() + val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) + - System.currentTimeMillis()) if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime } else { // If current credential file is older than expected, sleep 1 hour and check again. @@ -100,6 +100,7 @@ private[spark] class CredentialUpdater( TimeUnit.HOURS.toMillis(1) } + logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.") credentialUpdater.schedule( credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c477cb48d0b07..ffa5aed30e19f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -117,6 +117,8 @@ class Analyzer( Batch("Hints", fixedPoint, new ResolveHints.ResolveBroadcastHints(conf), ResolveHints.RemoveAllHints), + Batch("Simple Sanity Check", Once, + LookupFunctions), Batch("Substitution", fixedPoint, CTESubstitution, WindowsSubstitution, @@ -146,7 +148,7 @@ class Analyzer( GlobalAggregates :: ResolveAggregateFunctions :: TimeWindowing :: - ResolveInlineTables :: + ResolveInlineTables(conf) :: TypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Post-Hoc Resolution", Once, postHocResolutionRules: _*), @@ -604,7 +606,11 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => - i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) + lookupTableFromCatalog(u).canonicalized match { + case v: View => + u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.") + case other => i.copy(table = other) + } case u: UnresolvedRelation => resolveRelation(u) } @@ -1038,6 +1044,25 @@ class Analyzer( } } + /** + * Checks whether a function identifier referenced by an [[UnresolvedFunction]] is defined in the + * function registry. Note that this rule doesn't try to resolve the [[UnresolvedFunction]]. It + * only performs simple existence check according to the function identifier to quickly identify + * undefined functions without triggering relation resolution, which may incur potentially + * expensive partition/schema discovery process in some cases. + * + * @see [[ResolveFunctions]] + * @see https://issues.apache.org/jira/browse/SPARK-19737 + */ + object LookupFunctions extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case f: UnresolvedFunction if !catalog.functionExists(f.name) => + withPosition(f) { + throw new NoSuchFunctionException(f.name.database.getOrElse("default"), f.name.funcName) + } + } + } + /** * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 9c9465f6b8def..556fa9901701b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -421,6 +421,9 @@ object FunctionRegistry { expression[BitwiseOr]("|"), expression[BitwiseXor]("^"), + // json + expression[StructToJson]("to_json"), + // Cast aliases (SPARK-16730) castAlias("boolean", BooleanType), castAlias("tinyint", ByteType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala index 7323197b10f6e..d5b3ea8c37c66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTables.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.analysis import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.Cast +import org.apache.spark.sql.catalyst.{CatalystConf, InternalRow} +import org.apache.spark.sql.catalyst.expressions.{Cast, TimeZoneAwareExpression} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.types.{StructField, StructType} @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.{StructField, StructType} /** * An analyzer rule that replaces [[UnresolvedInlineTable]] with [[LocalRelation]]. */ -object ResolveInlineTables extends Rule[LogicalPlan] { +case class ResolveInlineTables(conf: CatalystConf) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case table: UnresolvedInlineTable if table.expressionsResolved => validateInputDimension(table) @@ -95,11 +95,15 @@ object ResolveInlineTables extends Rule[LogicalPlan] { InternalRow.fromSeq(row.zipWithIndex.map { case (e, ci) => val targetType = fields(ci).dataType try { - if (e.dataType.sameType(targetType)) { - e.eval() + val castedExpr = if (e.dataType.sameType(targetType)) { + e } else { - Cast(e, targetType).eval() + Cast(e, targetType) } + castedExpr.transform { + case e: TimeZoneAwareExpression if e.timeZoneId.isEmpty => + e.withTimeZone(conf.sessionLocalTimeZone) + }.eval() } catch { case NonFatal(ex) => table.failAnalysis(s"failed to evaluate expression ${e.sql}: ${ex.getMessage}") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala index a3a4ab37ea714..31eded4deba7d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalog.scala @@ -244,11 +244,13 @@ abstract class ExternalCatalog { * @param db database name * @param table table name * @param predicates partition-pruning predicates + * @param defaultTimeZoneId default timezone id to parse partition values of TimestampType */ def listPartitionsByFilter( db: String, table: String, - predicates: Seq[Expression]): Seq[CatalogTablePartition] + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] // -------------------------------------------------------------------------- // Functions diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala index 58ced549bafe9..a418edc302d9c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogUtils.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI + import org.apache.hadoop.fs.Path import org.apache.hadoop.util.Shell @@ -162,6 +164,30 @@ object CatalogUtils { BucketSpec(numBuckets, normalizedBucketCols, normalizedSortCols) } + /** + * Convert URI to String. + * Since URI.toString does not decode the uri, e.g. change '%25' to '%'. + * Here we create a hadoop Path with the given URI, and rely on Path.toString + * to decode the uri + * @param uri the URI of the path + * @return the String of the path + */ + def URIToString(uri: URI): String = { + new Path(uri).toString + } + + /** + * Convert String to URI. + * Since new URI(string) does not encode string, e.g. change '%' to '%25'. + * Here we create a hadoop Path with the given String, and rely on Path.toUri + * to encode the string + * @param str the String of the path + * @return the URI of the path + */ + def stringToURI(str: String): URI = { + new Path(str).toUri + } + private def normalizeColumnName( tableName: String, tableCols: Seq[String], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 6bb2b2d4ff72e..80aba4af9436c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -202,7 +202,7 @@ class InMemoryCatalog( tableDefinition.storage.locationUri.isEmpty val tableWithLocation = if (needDefaultTableLocation) { - val defaultTableLocation = new Path(catalog(db).db.locationUri, table) + val defaultTableLocation = new Path(new Path(catalog(db).db.locationUri), table) try { val fs = defaultTableLocation.getFileSystem(hadoopConfig) fs.mkdirs(defaultTableLocation) @@ -211,7 +211,7 @@ class InMemoryCatalog( throw new SparkException(s"Unable to create table $table as failed " + s"to create its directory $defaultTableLocation", e) } - tableDefinition.withNewStorage(locationUri = Some(defaultTableLocation.toUri.toString)) + tableDefinition.withNewStorage(locationUri = Some(defaultTableLocation.toUri)) } else { tableDefinition } @@ -274,7 +274,7 @@ class InMemoryCatalog( "Managed table should always have table location, as we will assign a default location " + "to it if it doesn't have one.") val oldDir = new Path(oldDesc.table.location) - val newDir = new Path(catalog(db).db.locationUri, newName) + val newDir = new Path(new Path(catalog(db).db.locationUri), newName) try { val fs = oldDir.getFileSystem(hadoopConfig) fs.rename(oldDir, newDir) @@ -283,7 +283,7 @@ class InMemoryCatalog( throw new SparkException(s"Unable to rename table $oldName to $newName as failed " + s"to rename its directory $oldDir", e) } - oldDesc.table = oldDesc.table.withNewStorage(locationUri = Some(newDir.toUri.toString)) + oldDesc.table = oldDesc.table.withNewStorage(locationUri = Some(newDir.toUri)) } catalog(db).tables.put(newName, oldDesc) @@ -389,7 +389,7 @@ class InMemoryCatalog( existingParts.put( p.spec, - p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toString)))) + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri)))) } } @@ -462,7 +462,7 @@ class InMemoryCatalog( } oldPartition.copy( spec = newSpec, - storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toString))) + storage = oldPartition.storage.copy(locationUri = Some(newPartPath.toUri))) } else { oldPartition.copy(spec = newSpec) } @@ -544,7 +544,8 @@ class InMemoryCatalog( override def listPartitionsByFilter( db: String, table: String, - predicates: Seq[Expression]): Seq[CatalogTablePartition] = { + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = { // TODO: Provide an implementation throw new UnsupportedOperationException( "listPartitionsByFilter is not implemented for InMemoryCatalog") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 06734891b6937..498bfbde9d7a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -131,10 +132,10 @@ class SessionCatalog( * does not contain a scheme, this path will not be changed after the default * FileSystem is changed. */ - private def makeQualifiedPath(path: String): Path = { + private def makeQualifiedPath(path: URI): URI = { val hadoopPath = new Path(path) val fs = hadoopPath.getFileSystem(hadoopConf) - fs.makeQualified(hadoopPath) + fs.makeQualified(hadoopPath).toUri } private def requireDbExists(db: String): Unit = { @@ -170,7 +171,7 @@ class SessionCatalog( "you cannot create a database with this name.") } validateName(dbName) - val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri).toString + val qualifiedPath = makeQualifiedPath(dbDefinition.locationUri) externalCatalog.createDatabase( dbDefinition.copy(name = dbName, locationUri = qualifiedPath), ignoreIfExists) @@ -228,9 +229,9 @@ class SessionCatalog( * Get the path for creating a non-default database when database location is not provided * by users. */ - def getDefaultDBPath(db: String): String = { + def getDefaultDBPath(db: String): URI = { val database = formatDatabaseName(db) - new Path(new Path(conf.warehousePath), database + ".db").toString + new Path(new Path(conf.warehousePath), database + ".db").toUri } // ---------------------------------------------------------------------------- @@ -351,11 +352,11 @@ class SessionCatalog( db, table, loadPath, spec, isOverwrite, inheritTableSpecs, isSrcLocal) } - def defaultTablePath(tableIdent: TableIdentifier): String = { + def defaultTablePath(tableIdent: TableIdentifier): URI = { val dbName = formatDatabaseName(tableIdent.database.getOrElse(getCurrentDatabase)) val dbLocation = getDatabaseMetadata(dbName).locationUri - new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toString + new Path(new Path(dbLocation), formatTableName(tableIdent.table)).toUri } // ---------------------------------------------- @@ -841,7 +842,7 @@ class SessionCatalog( val table = formatTableName(tableName.table) requireDbExists(db) requireTableExists(TableIdentifier(table, Option(db))) - externalCatalog.listPartitionsByFilter(db, table, predicates) + externalCatalog.listPartitionsByFilter(db, table, predicates, conf.sessionLocalTimeZone) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index cb939026f1bf1..4452c479875fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI import java.util.Date import com.google.common.base.Objects @@ -26,8 +27,8 @@ import org.apache.spark.sql.catalyst.{CatalystConf, FunctionIdentifier, Internal import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Cast, Literal} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.catalyst.util.quoteIdentifier -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.StructType @@ -48,10 +49,7 @@ case class CatalogFunction( * Storage format, used to describe how a partition or a table is stored. */ case class CatalogStorageFormat( - // TODO(ekl) consider storing this field as java.net.URI for type safety. Note that this must - // be converted to/from a hadoop Path object using new Path(new URI(locationUri)) and - // path.toUri respectively before use as a filesystem path due to URI char escaping. - locationUri: Option[String], + locationUri: Option[URI], inputFormat: Option[String], outputFormat: Option[String], serde: Option[String], @@ -105,7 +103,7 @@ case class CatalogTablePartition( } /** Return the partition location, assuming it is specified. */ - def location: String = storage.locationUri.getOrElse { + def location: URI = storage.locationUri.getOrElse { val specString = spec.map { case (k, v) => s"$k=$v" }.mkString(", ") throw new AnalysisException(s"Partition [$specString] did not specify locationUri") } @@ -113,11 +111,11 @@ case class CatalogTablePartition( /** * Given the partition schema, returns a row with that schema holding the partition values. */ - def toRow(partitionSchema: StructType): InternalRow = { + def toRow(partitionSchema: StructType, defaultTimeZondId: String): InternalRow = { + val caseInsensitiveProperties = CaseInsensitiveMap(storage.properties) + val timeZoneId = caseInsensitiveProperties.getOrElse("timeZone", defaultTimeZondId) InternalRow.fromSeq(partitionSchema.map { field => - // TODO: use correct timezone for partition values. - Cast(Literal(spec(field.name)), field.dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Cast(Literal(spec(field.name)), field.dataType, Option(timeZoneId)).eval() }) } } @@ -210,7 +208,7 @@ case class CatalogTable( } /** Return the table location, assuming it is specified. */ - def location: String = storage.locationUri.getOrElse { + def location: URI = storage.locationUri.getOrElse { throw new AnalysisException(s"table $identifier did not specify locationUri") } @@ -241,7 +239,7 @@ case class CatalogTable( /** Syntactic sugar to update a field in `storage`. */ def withNewStorage( - locationUri: Option[String] = storage.locationUri, + locationUri: Option[URI] = storage.locationUri, inputFormat: Option[String] = storage.inputFormat, outputFormat: Option[String] = storage.outputFormat, compressed: Boolean = false, @@ -337,7 +335,7 @@ object CatalogTableType { case class CatalogDatabase( name: String, description: String, - locationUri: String, + locationUri: URI, properties: Map[String, String]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index 2d9c2e42064b3..03101b4bfc5f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.math.{BigDecimal, RoundingMode} import java.security.{MessageDigest, NoSuchAlgorithmException} import java.util.zip.CRC32 @@ -580,7 +581,7 @@ object XxHash64Function extends InterpretedHashFunction { * We should use this hash function for both shuffle and bucket of Hive tables, so that * we can guarantee shuffle and bucketing have same data distribution * - * TODO: Support Decimal and date related types + * TODO: Support date related types */ @ExpressionDescription( usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.") @@ -635,6 +636,16 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def genHashBytes(b: String, result: String): String = s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);" + override protected def genHashDecimal( + ctx: CodegenContext, + d: DecimalType, + input: String, + result: String): String = { + s""" + $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal( + $input.toJavaBigDecimal()).hashCode();""" + } + override protected def genHashCalendarInterval(input: String, result: String): String = { s""" $result = (31 * $hasherClassName.hashInt($input.months)) + @@ -732,6 +743,44 @@ object HiveHashFunction extends InterpretedHashFunction { HiveHasher.hashUnsafeBytes(base, offset, len) } + private val HIVE_DECIMAL_MAX_PRECISION = 38 + private val HIVE_DECIMAL_MAX_SCALE = 38 + + // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize() + def normalizeDecimal(input: BigDecimal): BigDecimal = { + if (input == null) return null + + def trimDecimal(input: BigDecimal) = { + var result = input + if (result.compareTo(BigDecimal.ZERO) == 0) { + // Special case for 0, because java doesn't strip zeros correctly on that number. + result = BigDecimal.ZERO + } else { + result = result.stripTrailingZeros + if (result.scale < 0) { + // no negative scale decimals + result = result.setScale(0) + } + } + result + } + + var result = trimDecimal(input) + val intDigits = result.precision - result.scale + if (intDigits > HIVE_DECIMAL_MAX_PRECISION) { + return null + } + + val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE, + Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale)) + if (result.scale > maxScale) { + result = result.setScale(maxScale, RoundingMode.HALF_UP) + // Trimming is again necessary, because rounding may introduce new trailing 0's. + result = trimDecimal(result) + } + result + } + override def hash(value: Any, dataType: DataType, seed: Long): Long = { value match { case null => 0 @@ -785,7 +834,10 @@ object HiveHashFunction extends InterpretedHashFunction { } result - case _ => super.hash(value, dataType, 0) + case d: Decimal => + normalizeDecimal(d.toJavaBigDecimal).hashCode() + + case _ => super.hash(value, dataType, seed) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 1e690a446951e..18b5f2f7ed2e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -23,11 +23,12 @@ import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.ParseModes +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -330,7 +331,7 @@ case class GetJsonObject(json: Expression, path: Expression) // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Return a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.", + usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Returns a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.", extended = """ Examples: > SELECT _FUNC_('{"a":1, "b":2}', 'a', 'b'); @@ -480,23 +481,45 @@ case class JsonTuple(children: Seq[Expression]) } /** - * Converts an json input string to a [[StructType]] with the specified schema. + * Converts an json input string to a [[StructType]] or [[ArrayType]] with the specified schema. */ case class JsonToStruct( - schema: StructType, + schema: DataType, options: Map[String, String], child: Expression, timeZoneId: Option[String] = None) extends UnaryExpression with TimeZoneAwareExpression with CodegenFallback with ExpectsInputTypes { override def nullable: Boolean = true - def this(schema: StructType, options: Map[String, String], child: Expression) = + def this(schema: DataType, options: Map[String, String], child: Expression) = this(schema, options, child, None) + override def checkInputDataTypes(): TypeCheckResult = schema match { + case _: StructType | ArrayType(_: StructType, _) => + super.checkInputDataTypes() + case _ => TypeCheckResult.TypeCheckFailure( + s"Input schema ${schema.simpleString} must be a struct or an array of structs.") + } + + @transient + lazy val rowSchema = schema match { + case st: StructType => st + case ArrayType(st: StructType, _) => st + } + + // This converts parsed rows to the desired output by the given schema. + @transient + lazy val converter = schema match { + case _: StructType => + (rows: Seq[InternalRow]) => if (rows.length == 1) rows.head else null + case ArrayType(_: StructType, _) => + (rows: Seq[InternalRow]) => new GenericArrayData(rows) + } + @transient lazy val parser = new JacksonParser( - schema, + rowSchema, new JSONOptions(options + ("mode" -> ParseModes.FAIL_FAST_MODE), timeZoneId.get)) override def dataType: DataType = schema @@ -505,11 +528,32 @@ case class JsonToStruct( copy(timeZoneId = Option(timeZoneId)) override def nullSafeEval(json: Any): Any = { + // When input is, + // - `null`: `null`. + // - invalid json: `null`. + // - empty string: `null`. + // + // When the schema is array, + // - json array: `Array(Row(...), ...)` + // - json object: `Array(Row(...))` + // - empty json array: `Array()`. + // - empty json object: `Array(Row(null))`. + // + // When the schema is a struct, + // - json object/array with single element: `Row(...)` + // - json array with multiple elements: `null` + // - empty json array: `null`. + // - empty json object: `Row(null)`. + + // We need `null` if the input string is an empty string. `JacksonParser` can + // deal with this but produces `Nil`. + if (json.toString.trim.isEmpty) return null + try { - parser.parse( + converter(parser.parse( json.asInstanceOf[UTF8String], CreateJacksonParser.utf8String, - identity[UTF8String]).headOption.orNull + identity[UTF8String])) } catch { case _: SparkSQLJsonProcessingException => null } @@ -521,6 +565,17 @@ case class JsonToStruct( /** * Converts a [[StructType]] to a json output string. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr[, options]) - Returns a json string with a given struct value", + extended = """ + Examples: + > SELECT _FUNC_(named_struct('a', 1, 'b', 2)); + {"a":1,"b":2} + > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); + {"time":"26/08/2015"} + """) +// scalastyle:on line.size.limit case class StructToJson( options: Map[String, String], child: Expression, @@ -530,6 +585,14 @@ case class StructToJson( def this(options: Map[String, String], child: Expression) = this(options, child, None) + // Used in `FunctionRegistry` + def this(child: Expression) = this(Map.empty, child, None) + def this(child: Expression, options: Expression) = + this( + options = StructToJson.convertToMapData(options), + child = child, + timeZoneId = None) + @transient lazy val writer = new CharArrayWriter() @@ -570,3 +633,20 @@ case class StructToJson( override def inputTypes: Seq[AbstractDataType] = StructType :: Nil } + +object StructToJson { + + def convertToMapData(exp: Expression): Map[String, String] = exp match { + case m: CreateMap + if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => + val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] + ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => + key.toString -> value.toString + } + case m: CreateMap => + throw new AnalysisException( + s"A type of keys and values in map() must be string, but got ${m.dataType}") + case _ => + throw new AnalysisException("Must use a map() function for options") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index e66fb893394eb..eaeaf08c37b4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -32,11 +32,13 @@ import java.util.Objects import javax.xml.bind.DatatypeConverter import scala.math.{BigDecimal, BigInt} +import scala.reflect.runtime.universe.TypeTag +import scala.util.Try import org.json4s.JsonAST._ import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, ScalaReflection} import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -153,6 +155,14 @@ object Literal { Literal(CatalystTypeConverters.convertToCatalyst(v), dataType) } + def create[T : TypeTag](v: T): Literal = Try { + val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T] + val convert = CatalystTypeConverters.createToCatalystConverter(dataType) + Literal(convert(v), dataType) + }.getOrElse { + Literal(v) + } + /** * Create a literal with default value for given DataType */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 4f593c894acd2..21d1cd5932620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -457,7 +457,7 @@ object FoldablePropagation extends Rule[LogicalPlan] { // join is not always picked from its children, but can also be null. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) => + case j @ Join(_, _, Inner, _) if !stop => j.transformExpressions(replaceFoldable) // We can fold the projections an expand holds. However expand changes the output columns diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ccebae3cc2701..4d27ff2acdbad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -752,14 +752,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } override def computeStats(conf: CatalystConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = if (limit == 0) { - // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero - // (product of children). - 1 - } else { - (limit: Long) * output.map(a => a.dataType.defaultSize).sum - } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + val childStats = child.stats(conf) + val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) + // Don't propagate column stats, because we don't know the distribution after a limit operation + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats), + rowCount = Some(rowCount), + isBroadcastable = childStats.isBroadcastable) } } @@ -773,14 +772,21 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo } override def computeStats(conf: CatalystConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = if (limit == 0) { + val childStats = child.stats(conf) + if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). - 1 + Statistics( + sizeInBytes = 1, + rowCount = Some(0), + isBroadcastable = childStats.isBroadcastable) } else { - (limit: Long) * output.map(a => a.dataType.defaultSize).sum + // The output row count of LocalLimit should be the sum of row counts from each partition. + // However, since the number of partitions is not available here, we just use statistics of + // the child. Because the distribution after a limit operation is unknown, we do not propagate + // the column stats. + childStats.copy(attributeStats = AttributeMap(Nil)) } - child.stats(conf).copy(sizeInBytes = sizeInBytes) } } @@ -816,12 +822,14 @@ case class Sample( override def computeStats(conf: CatalystConf): Statistics = { val ratio = upperBound - lowerBound - // BigInt can't multiply with Double - var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100 + val childStats = child.stats(conf) + var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) if (sizeInBytes == 0) { sizeInBytes = 1 } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) + // Don't propagate column stats, because we don't know the distribution after a sample operation + Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable) } override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0c928832d7d22..b10785b05d6c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.collection.immutable.HashSet import scala.collection.mutable +import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -52,17 +53,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo def estimate: Option[Statistics] = { if (childStats.rowCount.isEmpty) return None - // save a mutable copy of colStats so that we can later change it recursively + // Save a mutable copy of colStats so that we can later change it recursively. colStatsMap.setInitValues(childStats.attributeStats) - // estimate selectivity of this filter predicate - val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { - case Some(percent) => percent - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => 1.0 - } + // Estimate selectivity of this filter predicate, and update column stats if needed. + // For not-supported condition, set filter selectivity to a conservative estimate 100% + val filterSelectivity: Double = calculateFilterSelectivity(plan.condition).getOrElse(1.0) - val newColStats = colStatsMap.toColumnStats + val newColStats = if (filterSelectivity == 0) { + // The output is empty, we don't need to keep column stats. + AttributeMap[ColumnStat](Nil) + } else { + colStatsMap.toColumnStats + } val filteredRowCount: BigInt = EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) @@ -74,12 +77,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } /** - * Returns a percentage of rows meeting a compound condition in Filter node. - * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. + * Returns a percentage of rows meeting a condition in Filter node. + * If it's a single condition, we calculate the percentage directly. + * If it's a compound condition, it is decomposed into multiple single conditions linked with + * AND, OR, NOT. * For logical AND conditions, we need to update stats after a condition estimation * so that the stats will be more accurate for subsequent estimation. This is needed for * range condition such as (c > 40 AND c <= 50) - * For logical OR conditions, we do not update stats after a condition estimation. + * For logical OR and NOT conditions, we do not update stats after a condition estimation. * * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column @@ -90,34 +95,29 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { condition match { case And(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update) - val percent2 = calculateFilterSelectivity(cond2, update) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(p1 * p2) - case (Some(p1), None) => Some(p1) - case (None, Some(p2)) => Some(p2) - case (None, None) => None - } + val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0) + Some(percent1 * percent2) case Or(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update = false) - val percent2 = calculateFilterSelectivity(cond2, update = false) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) - case (Some(p1), None) => Some(1.0) - case (None, Some(p2)) => Some(1.0) - case (None, None) => None - } + val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) + Some(percent1 + percent2 - (percent1 * percent2)) - case Not(cond) => calculateFilterSelectivity(cond, update = false) match { - case Some(percent) => Some(1.0 - percent) - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => None - } + case Not(And(cond1, cond2)) => + calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false) + + case Not(Or(cond1, cond2)) => + calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false) - case _ => calculateSingleCondition(condition, update) + case Not(cond) => + calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) + case None => None + } + + case _ => + calculateSingleCondition(condition, update) } } @@ -225,12 +225,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } val percent = if (isNull) { - nullPercent.toDouble + nullPercent } else { - 1.0 - nullPercent.toDouble + 1.0 - nullPercent } - Some(percent) + Some(percent.toDouble) } /** @@ -249,17 +249,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo attr: Attribute, literal: Literal, update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + attr.dataType match { - case _: NumericType | DateType | TimestampType => + case _: NumericType | DateType | TimestampType | BooleanType => evaluateBinaryForNumeric(op, attr, literal, update) case StringType | BinaryType => // TODO: It is difficult to support other binary comparisons for String/Binary // type without min/max and advanced statistics like histogram. logDebug("[CBO] No range comparison statistics for String/Binary type " + attr) None - case _ => - // TODO: support boolean type. - None } } @@ -291,6 +293,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. * + * For EqualNullSafe (<=>), if the literal is not null, result will be the same as EqualTo; + * if the literal is null, the condition will be changed to IsNull after optimization. + * So we don't need specific logic for EqualNullSafe here. + * * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -323,7 +329,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo colStatsMap(attr) = newStats } - Some(1.0 / ndv.toDouble) + Some((1.0 / BigDecimal(ndv)).toDouble) } else { Some(0.0) } @@ -394,12 +400,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) + Some(math.min(1.0, (BigDecimal(newNdv) / BigDecimal(ndv)).toDouble)) } /** * Returns a percentage of rows meeting a binary comparison expression. - * This method evaluate expression for Numeric columns only. + * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns. * * @param op a binary comparison operator uch as =, <, <=, >, >= * @param attr an Attribute (or a column) @@ -414,53 +420,66 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo literal: Literal, update: Boolean): Option[Double] = { - var percent = 1.0 val colStat = colStatsMap(attr) - val statsRange = - Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + val max = BigDecimal(statsRange.max) + val min = BigDecimal(statsRange.min) + val ndv = BigDecimal(colStat.distinctCount) // determine the overlapping degree between predicate range and column's range - val literalValueBD = BigDecimal(literal.value.toString) + val numericLiteral = if (literal.dataType == BooleanType) { + if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) + } else { + BigDecimal(literal.value.toString) + } val (noOverlap: Boolean, completeOverlap: Boolean) = op match { case _: LessThan => - (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + (numericLiteral <= min, numericLiteral > max) case _: LessThanOrEqual => - (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + (numericLiteral < min, numericLiteral >= max) case _: GreaterThan => - (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + (numericLiteral >= max, numericLiteral < min) case _: GreaterThanOrEqual => - (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + (numericLiteral > max, numericLiteral <= min) } + var percent = BigDecimal(1.0) if (noOverlap) { percent = 0.0 } else if (completeOverlap) { percent = 1.0 } else { - // this is partial overlap case - val literalDouble = literalValueBD.toDouble - val maxDouble = BigDecimal(statsRange.max).toDouble - val minDouble = BigDecimal(statsRange.min).toDouble - + // This is the partial overlap case: // Without advanced statistics like histogram, we assume uniform data distribution. // We just prorate the adjusted range over the initial range to compute filter selectivity. - // For ease of computation, we convert all relevant numeric values to Double. + assert(max > min) percent = op match { case _: LessThan => - (literalDouble - minDouble) / (maxDouble - minDouble) + if (numericLiteral == max) { + // If the literal value is right on the boundary, we can minus the part of the + // boundary value (1/ndv). + 1.0 - 1.0 / ndv + } else { + (numericLiteral - min) / (max - min) + } case _: LessThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.min)) { - 1.0 / colStat.distinctCount.toDouble + if (numericLiteral == min) { + // The boundary value is the only satisfying value. + 1.0 / ndv } else { - (literalDouble - minDouble) / (maxDouble - minDouble) + (numericLiteral - min) / (max - min) } case _: GreaterThan => - (maxDouble - literalDouble) / (maxDouble - minDouble) + if (numericLiteral == min) { + 1.0 - 1.0 / ndv + } else { + (max - numericLiteral) / (max - min) + } case _: GreaterThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.max)) { - 1.0 / colStat.distinctCount.toDouble + if (numericLiteral == max) { + 1.0 / ndv } else { - (maxDouble - literalDouble) / (maxDouble - minDouble) + (max - numericLiteral) / (max - min) } } @@ -469,22 +488,25 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val newValue = convertBoundValue(attr.dataType, literal.value) var newMax = colStat.max var newMin = colStat.min + var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdv < 1) newNdv = 1 + op match { - case _: GreaterThan => newMin = newValue - case _: GreaterThanOrEqual => newMin = newValue - case _: LessThan => newMax = newValue - case _: LessThanOrEqual => newMax = newValue + case _: GreaterThan | _: GreaterThanOrEqual => + // If new ndv is 1, then new max must be equal to new min. + newMin = if (newNdv == 1) newMax else newValue + case _: LessThan | _: LessThanOrEqual => + newMax = if (newNdv == 1) newMin else newValue } - val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1) - val newStats = colStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) + val newStats = + colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) colStatsMap(attr) = newStats } } - Some(percent) + Some(percent.toDouble) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala index 920c6ea50f4ba..f45a826869842 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ResolveInlineTablesSuite.scala @@ -20,68 +20,67 @@ package org.apache.spark.sql.catalyst.analysis import org.scalatest.BeforeAndAfter import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.expressions.{Literal, Rand} +import org.apache.spark.sql.catalyst.expressions.{Cast, Literal, Rand} import org.apache.spark.sql.catalyst.expressions.aggregate.Count -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.types.{LongType, NullType} +import org.apache.spark.sql.types.{LongType, NullType, TimestampType} /** * Unit tests for [[ResolveInlineTables]]. Note that there are also test cases defined in * end-to-end tests (in sql/core module) for verifying the correct error messages are shown * in negative cases. */ -class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { +class ResolveInlineTablesSuite extends AnalysisTest with BeforeAndAfter { private def lit(v: Any): Literal = Literal(v) test("validate inputs are foldable") { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1))))) // nondeterministic (rand) should not work intercept[AnalysisException] { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Rand(1))))) } // aggregate should not work intercept[AnalysisException] { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(Count(lit(1)))))) } // unresolved attribute should not work intercept[AnalysisException] { - ResolveInlineTables.validateInputEvaluable( + ResolveInlineTables(conf).validateInputEvaluable( UnresolvedInlineTable(Seq("c1"), Seq(Seq(UnresolvedAttribute("A"))))) } } test("validate input dimensions") { - ResolveInlineTables.validateInputDimension( + ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2))))) // num alias != data dimension intercept[AnalysisException] { - ResolveInlineTables.validateInputDimension( + ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(lit(1)), Seq(lit(2))))) } // num alias == data dimension, but data themselves are inconsistent intercept[AnalysisException] { - ResolveInlineTables.validateInputDimension( + ResolveInlineTables(conf).validateInputDimension( UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(21), lit(22))))) } } test("do not fire the rule if not all expressions are resolved") { val table = UnresolvedInlineTable(Seq("c1", "c2"), Seq(Seq(UnresolvedAttribute("A")))) - assert(ResolveInlineTables(table) == table) + assert(ResolveInlineTables(conf)(table) == table) } test("convert") { val table = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) - val converted = ResolveInlineTables.convert(table) + val converted = ResolveInlineTables(conf).convert(table) assert(converted.output.map(_.dataType) == Seq(LongType)) assert(converted.data.size == 2) @@ -89,13 +88,24 @@ class ResolveInlineTablesSuite extends PlanTest with BeforeAndAfter { assert(converted.data(1).getLong(0) == 2L) } + test("convert TimeZoneAwareExpression") { + val table = UnresolvedInlineTable(Seq("c1"), + Seq(Seq(Cast(lit("1991-12-06 00:00:00.0"), TimestampType)))) + val converted = ResolveInlineTables(conf).convert(table) + val correct = Cast(lit("1991-12-06 00:00:00.0"), TimestampType) + .withTimeZone(conf.sessionLocalTimeZone).eval().asInstanceOf[Long] + assert(converted.output.map(_.dataType) == Seq(TimestampType)) + assert(converted.data.size == 1) + assert(converted.data(0).getLong(0) == correct) + } + test("nullability inference in convert") { val table1 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(lit(2L)))) - val converted1 = ResolveInlineTables.convert(table1) + val converted1 = ResolveInlineTables(conf).convert(table1) assert(!converted1.schema.fields(0).nullable) val table2 = UnresolvedInlineTable(Seq("c1"), Seq(Seq(lit(1)), Seq(Literal(null, NullType)))) - val converted2 = ResolveInlineTables.convert(table2) + val converted2 = ResolveInlineTables(conf).convert(table2) assert(converted2.schema.fields(0).nullable) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala index a5d399a065589..07ccd68698e94 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/ExternalCatalogSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import java.net.URI + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach @@ -340,7 +342,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac "db1", "tbl", Map("partCol1" -> "1", "partCol2" -> "2")).location - val tableLocation = catalog.getTable("db1", "tbl").location + val tableLocation = new Path(catalog.getTable("db1", "tbl").location) val defaultPartitionLocation = new Path(new Path(tableLocation, "partCol1=1"), "partCol2=2") assert(new Path(partitionLocation) == defaultPartitionLocation) } @@ -508,7 +510,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac partitionColumnNames = Seq("partCol1", "partCol2")) catalog.createTable(table, ignoreIfExists = false) - val tableLocation = catalog.getTable("db1", "tbl").location + val tableLocation = new Path(catalog.getTable("db1", "tbl").location) val mixedCasePart1 = CatalogTablePartition( Map("partCol1" -> "1", "partCol2" -> "2"), storageFormat) @@ -699,7 +701,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac // File System operations // -------------------------------------------------------------------------- - private def exists(uri: String, children: String*): Boolean = { + private def exists(uri: URI, children: String*): Boolean = { val base = new Path(uri) val finalPath = children.foldLeft(base) { case (parent, child) => new Path(parent, child) @@ -742,7 +744,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac identifier = TableIdentifier("external_table", Some("db1")), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat( - Some(Utils.createTempDir().getAbsolutePath), + Some(Utils.createTempDir().toURI), None, None, None, false, Map.empty), schema = new StructType().add("a", "int").add("b", "string"), provider = Some(defaultProvider) @@ -790,7 +792,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val partWithExistingDir = CatalogTablePartition( Map("partCol1" -> "7", "partCol2" -> "8"), CatalogStorageFormat( - Some(tempPath.toURI.toString), + Some(tempPath.toURI), None, None, None, false, Map.empty)) catalog.createPartitions("db1", "tbl", Seq(partWithExistingDir), ignoreIfExists = false) @@ -799,7 +801,7 @@ abstract class ExternalCatalogSuite extends SparkFunSuite with BeforeAndAfterEac val partWithNonExistingDir = CatalogTablePartition( Map("partCol1" -> "9", "partCol2" -> "10"), CatalogStorageFormat( - Some(tempPath.toURI.toString), + Some(tempPath.toURI), None, None, None, false, Map.empty)) catalog.createPartitions("db1", "tbl", Seq(partWithNonExistingDir), ignoreIfExists = false) assert(tempPath.exists()) @@ -883,7 +885,7 @@ abstract class CatalogTestUtils { def newFunc(): CatalogFunction = newFunc("funcName") - def newUriForDatabase(): String = Utils.createTempDir().toURI.toString.stripSuffix("/") + def newUriForDatabase(): URI = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) def newDb(name: String): CatalogDatabase = { CatalogDatabase(name, name + " description", newUriForDatabase(), Map.empty) @@ -895,7 +897,7 @@ abstract class CatalogTestUtils { CatalogTable( identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL, - storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().getAbsolutePath)), + storage = storageFormat.copy(locationUri = Some(Utils.createTempDir().toURI)), schema = new StructType() .add("col1", "int") .add("col2", "string") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index a755231962be2..ffc272c6c0c39 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.{FunctionIdentifier, SimpleCatalystConf, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser @@ -1196,4 +1196,25 @@ class SessionCatalogSuite extends PlanTest { catalog.listFunctions("unknown_db", "func*") } } + + test("SPARK-19737: detect undefined functions without triggering relation resolution") { + import org.apache.spark.sql.catalyst.dsl.plans._ + + Seq(true, false) foreach { caseSensitive => + val conf = SimpleCatalystConf(caseSensitive) + val catalog = new SessionCatalog(newBasicCatalog(), new SimpleFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) + + // The analyzer should report the undefined function rather than the undefined table first. + val cause = intercept[AnalysisException] { + analyzer.execute( + UnresolvedRelation(TableIdentifier("undefined_table")).select( + UnresolvedFunction("undefined_fn", Nil, isDistinct = false) + ) + ) + } + + assert(cause.getMessage.contains("Undefined function: 'undefined_fn'")) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 0cb3a79eee67d..0c77dc2709dad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -75,7 +75,6 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } - def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { // Note : All expected hashes need to be computed using Hive 1.2.1 val actual = HiveHashFunction.hash(input, dataType, seed = 0) @@ -371,6 +370,51 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { new StructType().add("array", arrayOfString).add("map", mapOfString)) .add("structOfUDT", structOfUDT)) + test("hive-hash for decimal") { + def checkHiveHashForDecimal( + input: String, + precision: Int, + scale: Int, + expected: Long): Unit = { + val decimalType = DataTypes.createDecimalType(precision, scale) + val decimal = { + val value = Decimal.apply(new java.math.BigDecimal(input)) + if (value.changePrecision(precision, scale)) value else null + } + + checkHiveHash(decimal, decimalType, expected) + } + + checkHiveHashForDecimal("18", 38, 0, 558) + checkHiveHashForDecimal("-18", 38, 0, -558) + checkHiveHashForDecimal("-18", 38, 12, -558) + checkHiveHashForDecimal("18446744073709001000", 38, 19, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0) + checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057) + checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057) + checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656) + checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656) + checkHiveHashForDecimal("00000.00000000000", 38, 34, 0) + checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0) + checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252) + checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666) + checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608) + checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666) + } + test("SPARK-18207: Compute hash for a lot of expressions") { val N = 1000 val wideRow = new GenericInternalRow( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala index 0c46819cdb9cd..e3584909ddc4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/JsonExpressionsSuite.scala @@ -22,7 +22,7 @@ import java.util.Calendar import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.{DateTimeTestUtils, DateTimeUtils, ParseModes} -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -372,6 +372,62 @@ class JsonExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { ) } + test("from_json - input=array, schema=array, output=array") { + val input = """[{"a": 1}, {"a": 2}]""" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = InternalRow(1) :: InternalRow(2) :: Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=object, schema=array, output=array of single row") { + val input = """{"a": 1}""" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = InternalRow(1) :: Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty array, schema=array, output=empty array") { + val input = "[ ]" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty object, schema=array, output=array of single row with null") { + val input = "{ }" + val schema = ArrayType(StructType(StructField("a", IntegerType) :: Nil)) + val output = InternalRow(null) :: Nil + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=array of single object, schema=struct, output=single row") { + val input = """[{"a": 1}]""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = InternalRow(1) + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=array, schema=struct, output=null") { + val input = """[{"a": 1}, {"a": 2}]""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = null + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty array, schema=struct, output=null") { + val input = """[]""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = null + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + + test("from_json - input=empty object, schema=struct, output=single row with null") { + val input = """{ }""" + val schema = StructType(StructField("a", IntegerType) :: Nil) + val output = InternalRow(null) + checkEvaluation(JsonToStruct(schema, Map.empty, Literal(input), gmtId), output) + } + test("from_json null input column") { val schema = StructType(StructField("a", IntegerType) :: Nil) checkEvaluation( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala index 15e8e6c057baf..a9e0eb0e377a6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/LiteralExpressionSuite.scala @@ -19,9 +19,11 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets +import scala.reflect.runtime.universe.{typeTag, TypeTag} + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.CatalystTypeConverters +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection} import org.apache.spark.sql.catalyst.encoders.ExamplePointUDT import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -75,6 +77,9 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("boolean literals") { checkEvaluation(Literal(true), true) checkEvaluation(Literal(false), false) + + checkEvaluation(Literal.create(true), true) + checkEvaluation(Literal.create(false), false) } test("int literals") { @@ -83,36 +88,60 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Literal(d.toLong), d.toLong) checkEvaluation(Literal(d.toShort), d.toShort) checkEvaluation(Literal(d.toByte), d.toByte) + + checkEvaluation(Literal.create(d), d) + checkEvaluation(Literal.create(d.toLong), d.toLong) + checkEvaluation(Literal.create(d.toShort), d.toShort) + checkEvaluation(Literal.create(d.toByte), d.toByte) } checkEvaluation(Literal(Long.MinValue), Long.MinValue) checkEvaluation(Literal(Long.MaxValue), Long.MaxValue) + + checkEvaluation(Literal.create(Long.MinValue), Long.MinValue) + checkEvaluation(Literal.create(Long.MaxValue), Long.MaxValue) } test("double literals") { List(0.0, -0.0, Double.NegativeInfinity, Double.PositiveInfinity).foreach { d => checkEvaluation(Literal(d), d) checkEvaluation(Literal(d.toFloat), d.toFloat) + + checkEvaluation(Literal.create(d), d) + checkEvaluation(Literal.create(d.toFloat), d.toFloat) } checkEvaluation(Literal(Double.MinValue), Double.MinValue) checkEvaluation(Literal(Double.MaxValue), Double.MaxValue) checkEvaluation(Literal(Float.MinValue), Float.MinValue) checkEvaluation(Literal(Float.MaxValue), Float.MaxValue) + checkEvaluation(Literal.create(Double.MinValue), Double.MinValue) + checkEvaluation(Literal.create(Double.MaxValue), Double.MaxValue) + checkEvaluation(Literal.create(Float.MinValue), Float.MinValue) + checkEvaluation(Literal.create(Float.MaxValue), Float.MaxValue) + } test("string literals") { checkEvaluation(Literal(""), "") checkEvaluation(Literal("test"), "test") checkEvaluation(Literal("\u0000"), "\u0000") + + checkEvaluation(Literal.create(""), "") + checkEvaluation(Literal.create("test"), "test") + checkEvaluation(Literal.create("\u0000"), "\u0000") } test("sum two literals") { checkEvaluation(Add(Literal(1), Literal(1)), 2) + checkEvaluation(Add(Literal.create(1), Literal.create(1)), 2) } test("binary literals") { checkEvaluation(Literal.create(new Array[Byte](0), BinaryType), new Array[Byte](0)) checkEvaluation(Literal.create(new Array[Byte](2), BinaryType), new Array[Byte](2)) + + checkEvaluation(Literal.create(new Array[Byte](0)), new Array[Byte](0)) + checkEvaluation(Literal.create(new Array[Byte](2)), new Array[Byte](2)) } test("decimal") { @@ -124,24 +153,63 @@ class LiteralExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { Decimal((d * 1000L).toLong, 10, 3)) checkEvaluation(Literal(BigDecimal(d.toString)), Decimal(d)) checkEvaluation(Literal(new java.math.BigDecimal(d.toString)), Decimal(d)) + + checkEvaluation(Literal.create(Decimal(d)), Decimal(d)) + checkEvaluation(Literal.create(Decimal(d.toInt)), Decimal(d.toInt)) + checkEvaluation(Literal.create(Decimal(d.toLong)), Decimal(d.toLong)) + checkEvaluation(Literal.create(Decimal((d * 1000L).toLong, 10, 3)), + Decimal((d * 1000L).toLong, 10, 3)) + checkEvaluation(Literal.create(BigDecimal(d.toString)), Decimal(d)) + checkEvaluation(Literal.create(new java.math.BigDecimal(d.toString)), Decimal(d)) + } } + private def toCatalyst[T: TypeTag](value: T): Any = { + val ScalaReflection.Schema(dataType, _) = ScalaReflection.schemaFor[T] + CatalystTypeConverters.createToCatalystConverter(dataType)(value) + } + test("array") { - def checkArrayLiteral(a: Array[_], elementType: DataType): Unit = { - val toCatalyst = (a: Array[_], elementType: DataType) => { - CatalystTypeConverters.createToCatalystConverter(ArrayType(elementType))(a) - } - checkEvaluation(Literal(a), toCatalyst(a, elementType)) + def checkArrayLiteral[T: TypeTag](a: Array[T]): Unit = { + checkEvaluation(Literal(a), toCatalyst(a)) + checkEvaluation(Literal.create(a), toCatalyst(a)) + } + checkArrayLiteral(Array(1, 2, 3)) + checkArrayLiteral(Array("a", "b", "c")) + checkArrayLiteral(Array(1.0, 4.0)) + checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR)) + } + + test("seq") { + def checkSeqLiteral[T: TypeTag](a: Seq[T], elementType: DataType): Unit = { + checkEvaluation(Literal.create(a), toCatalyst(a)) } - checkArrayLiteral(Array(1, 2, 3), IntegerType) - checkArrayLiteral(Array("a", "b", "c"), StringType) - checkArrayLiteral(Array(1.0, 4.0), DoubleType) - checkArrayLiteral(Array(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR), + checkSeqLiteral(Seq(1, 2, 3), IntegerType) + checkSeqLiteral(Seq("a", "b", "c"), StringType) + checkSeqLiteral(Seq(1.0, 4.0), DoubleType) + checkSeqLiteral(Seq(CalendarInterval.MICROS_PER_DAY, CalendarInterval.MICROS_PER_HOUR), CalendarIntervalType) } - test("unsupported types (map and struct) in literals") { + test("map") { + def checkMapLiteral[T: TypeTag](m: T): Unit = { + checkEvaluation(Literal.create(m), toCatalyst(m)) + } + checkMapLiteral(Map("a" -> 1, "b" -> 2, "c" -> 3)) + checkMapLiteral(Map("1" -> 1.0, "2" -> 2.0, "3" -> 3.0)) + } + + test("struct") { + def checkStructLiteral[T: TypeTag](s: T): Unit = { + checkEvaluation(Literal.create(s), toCatalyst(s)) + } + checkStructLiteral((1, 3.0, "abcde")) + checkStructLiteral(("de", 1, 2.0f)) + checkStructLiteral((1, ("fgh", 3.0))) + } + + test("unsupported types (map and struct) in Literal.apply") { def checkUnsupportedTypeInLiteral(v: Any): Unit = { val errMsgMap = intercept[RuntimeException] { Literal(v) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index 82756f545a8c7..d128315b68869 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -130,6 +130,20 @@ class FoldablePropagationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Propagate in inner join") { + val ta = testRelation.select('a, Literal(1).as('tag)) + .union(testRelation.select('a, Literal(2).as('tag))) + .subquery('ta) + val tb = testRelation.select('a, Literal(1).as('tag)) + .union(testRelation.select('a, Literal(2).as('tag))) + .subquery('tb) + val query = ta.join(tb, Inner, + Some("ta.a".attr === "tb.a".attr && "ta.tag".attr === "tb.tag".attr)) + val optimized = Optimize.execute(query.analyze) + val correctAnswer = query.analyze + comparePlans(optimized, correctAnswer) + } + test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala new file mode 100644 index 0000000000000..e5dc811c8b7db --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType + + +class BasicStatsEstimationSuite extends StatsEstimationTestBase { + val attribute = attr("key") + val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + + val plan = StatsTestPlan( + outputList = Seq(attribute), + attributeStats = AttributeMap(Seq(attribute -> colStat)), + rowCount = 10, + // row count * (overhead + column size) + size = Some(10 * (8 + 4))) + + test("limit estimation: limit < child's rowCount") { + val localLimit = LocalLimit(Literal(2), plan) + val globalLimit = GlobalLimit(Literal(2), plan) + // LocalLimit's stats is just its child's stats except column stats + checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2))) + } + + test("limit estimation: limit > child's rowCount") { + val localLimit = LocalLimit(Literal(20), plan) + val globalLimit = GlobalLimit(Literal(20), plan) + checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats. + checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + } + + test("limit estimation: limit = 0") { + val localLimit = LocalLimit(Literal(0), plan) + val globalLimit = GlobalLimit(Literal(0), plan) + val stats = Statistics(sizeInBytes = 1, rowCount = Some(0)) + checkStats(localLimit, stats) + checkStats(globalLimit, stats) + } + + test("sample estimation") { + val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)() + checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) + + // Child doesn't have rowCount in stats + val childStats = Statistics(sizeInBytes = 120) + val childPlan = DummyLogicalPlan(childStats, childStats) + val sample2 = + Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)() + checkStats(sample2, Statistics(sizeInBytes = 14)) + } + + test("estimate statistics when the conf changes") { + val expectedDefaultStats = + Statistics( + sizeInBytes = 40, + rowCount = Some(10), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), + isBroadcastable = false) + val expectedCboStats = + Statistics( + sizeInBytes = 4, + rowCount = Some(1), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), + isBroadcastable = false) + + val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) + checkStats( + plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats) + } + + /** Check estimated stats when cbo is turned on/off. */ + private def checkStats( + plan: LogicalPlan, + expectedStatsCboOn: Statistics, + expectedStatsCboOff: Statistics): Unit = { + assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) + // Invalidate statistics + plan.invalidateStatsCache() + assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff) + } + + /** Check estimated stats when it's the same whether cbo is turned on or off. */ + private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit = + checkStats(plan, expectedStats, expectedStats) +} + +/** + * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes + * a simple statistics or a cbo estimated statistics based on the conf. + */ +private case class DummyLogicalPlan( + defaultStats: Statistics, + cboStats: Statistics) extends LogicalPlan { + override def output: Seq[Attribute] = Nil + override def children: Seq[LogicalPlan] = Nil + override def computeStats(conf: CatalystConf): Statistics = + if (conf.cboEnabled) cboStats else defaultStats +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 8be74ced7bb71..4691913c8c986 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.types._ @@ -33,219 +33,235 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // Suppose our test table has 10 rows and 6 columns. // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 - val arInt = AttributeReference("cint", IntegerType)() - val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + val attrInt = AttributeReference("cint", IntegerType)() + val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) // only 2 values - val arBool = AttributeReference("cbool", BooleanType)() - val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + val attrBool = AttributeReference("cbool", BooleanType)() + val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1) // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") - val arDate = AttributeReference("cdate", DateType)() - val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), + val attrDate = AttributeReference("cdate", DateType)() + val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("0.800000000000000000") - val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), + val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() + val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 - val arDouble = AttributeReference("cdouble", DoubleType)() - val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), + val attrDouble = AttributeReference("cdouble", DoubleType)() + val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), nullCount = 0, avgLen = 8, maxLen = 8) // Sixth column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" - val arString = AttributeReference("cstring", StringType)() - val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None, + val attrString = AttributeReference("cstring", StringType)() + val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2) + val attributeMap = AttributeMap(Seq( + attrInt -> colStatInt, + attrBool -> colStatBool, + attrDate -> colStatDate, + attrDecimal -> colStatDecimal, + attrDouble -> colStatDouble, + attrString -> colStatString)) + test("cint = 2") { validateEstimatedStats( - arInt, - Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( - arInt, - Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cint = 0") { // This is an out-of-range case since 0 is outside the range [min, max] validateEstimatedStats( - arInt, - Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(EqualTo(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint < 3") { validateEstimatedStats( - arInt, - Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint < 0") { // This is a corner case since literal 0 is smaller than min. validateEstimatedStats( - arInt, - Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(LessThan(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint <= 3") { validateEstimatedStats( - arInt, - Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( - arInt, - Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 5) + Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 5) } test("cint > 10") { // This is a corner case since max value is 10. validateEstimatedStats( - arInt, - Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(GreaterThan(attrInt, Literal(10)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint >= 6") { validateEstimatedStats( - arInt, - Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 5) + Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 5) } test("cint IS NULL") { validateEstimatedStats( - arInt, - Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(IsNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint IS NOT NULL") { validateEstimatedStats( - arInt, - Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 10) + Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) } test("cint > 3 AND cint <= 6") { - val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) + val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( - arInt, - Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - 4) + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) } test("cint = 3 OR cint = 6") { - val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) + val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 2) + } + + test("Not(cint > 3 AND cint <= 6)") { + val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 6) + } + + test("Not(cint <= 3 OR cint > 6)") { + val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 5) + } + + test("Not(cint = 3 AND cstring < 'A8')") { + val condition = Not(And(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), + Seq(attrInt -> colStatInt, attrString -> colStatString), + expectedRowCount = 10) + } + + test("Not(cint = 3 OR cstring < 'A8')") { + val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( - arInt, - Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 2) + Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), + Seq(attrInt -> colStatInt, attrString -> colStatString), + expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( - arInt, - Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( - arInt, - Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 7) + Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 7) } test("cbool = true") { validateEstimatedStats( - arBool, - Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)), - ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1), - 5) + Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cbool > false") { - // bool comparison is not supported yet, so stats remain same. validateEstimatedStats( - arBool, - Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)), - ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1), - 10) + Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cdate = cast('2017-01-02' AS DATE)") { val d20170102 = Date.valueOf("2017-01-02") validateEstimatedStats( - arDate, - Filter(EqualTo(arDate, Literal(d20170102)), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualTo(attrDate, Literal(d20170102)), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { val d20170103 = Date.valueOf("2017-01-03") validateEstimatedStats( - arDate, - Filter(LessThan(arDate, Literal(d20170103)), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThan(attrDate, Literal(d20170103)), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("""cdate IN ( cast('2017-01-03' AS DATE), @@ -254,133 +270,118 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170104 = Date.valueOf("2017-01-04") val d20170105 = Date.valueOf("2017-01-05") validateEstimatedStats( - arDate, - Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cdecimal = 0.400000000000000000") { val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") validateEstimatedStats( - arDecimal, - Filter(EqualTo(arDecimal, Literal(dec_0_40)), - childStatsTestPlan(Seq(arDecimal), 4L)), - ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8), - 1) + Filter(EqualTo(attrDecimal, Literal(dec_0_40)), + childStatsTestPlan(Seq(attrDecimal), 4L)), + Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 1) } test("cdecimal < 0.60 ") { val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") validateEstimatedStats( - arDecimal, - Filter(LessThan(arDecimal, Literal(dec_0_60)), - childStatsTestPlan(Seq(arDecimal), 4L)), - ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8), - 3) + Filter(LessThan(attrDecimal, Literal(dec_0_60)), + childStatsTestPlan(Seq(attrDecimal), 4L)), + Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( - arDouble, - Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), - ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8), - 3) + Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), + Seq(attrDouble -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( - arString, - Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), - ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2), - 1) + Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), + Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2)), + expectedRowCount = 1) } - // There is no min/max statistics for String type. We estimate 10 rows returned. - test("cstring < 'A2'") { + test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( - arString, - Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), - ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2), - 10) + Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), + Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2)), + expectedRowCount = 10) } - // This is a corner test case. We want to test if we can handle the case when the number of - // valid values in IN clause is greater than the number of distinct values for a given column. - // For example, column has only 2 distinct values 1 and 6. - // The predicate is: column IN (1, 2, 3, 4, 5). test("cint IN (1, 2, 3, 4, 5)") { + // This is a corner test case. We want to test if we can handle the case when the number of + // valid values in IN clause is greater than the number of distinct values for a given column. + // For example, column has only 2 distinct values 1 and 6. + // The predicate is: column IN (1, 2, 3, 4, 5). val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4) val cornerChildStatsTestplan = StatsTestPlan( - outputList = Seq(arInt), + outputList = Seq(attrInt), rowCount = 2L, - attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt)) + attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt)) ) validateEstimatedStats( - arInt, - Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - 2) + Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 2) } private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, rowCount = tableRowCount, - attributeStats = AttributeMap(Seq( - arInt -> childColStatInt, - arBool -> childColStatBool, - arDate -> childColStatDate, - arDecimal -> childColStatDecimal, - arDouble -> childColStatDouble, - arString -> childColStatString - )) - ) + attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a)))) } private def validateEstimatedStats( - ar: AttributeReference, filterNode: Filter, - expectedColStats: ColumnStat, - rowCount: Int): Unit = { - - val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) - val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats) - - val filteredStats = filterNode.stats(conf) - assert(filteredStats.sizeInBytes == expectedSizeInBytes) - assert(filteredStats.rowCount.get == rowCount) - assert(filteredStats.attributeStats(ar) == expectedColStats) - - // If the filter has a binary operator (including those nested inside - // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the - // operator, and then check again. - val rewrittenFilter = filterNode transformExpressionsDown { - case EqualTo(ar: AttributeReference, l: Literal) => - EqualTo(l, ar) - - case LessThan(ar: AttributeReference, l: Literal) => - GreaterThan(l, ar) - case LessThanOrEqual(ar: AttributeReference, l: Literal) => - GreaterThanOrEqual(l, ar) - - case GreaterThan(ar: AttributeReference, l: Literal) => - LessThan(l, ar) - case GreaterThanOrEqual(ar: AttributeReference, l: Literal) => - LessThanOrEqual(l, ar) + expectedColStats: Seq[(Attribute, ColumnStat)], + expectedRowCount: Int): Unit = { + + // If the filter has a binary operator (including those nested inside AND/OR/NOT), swap the + // sides of the attribute and the literal, reverse the operator, and then check again. + val swappedFilter = filterNode transformExpressionsDown { + case EqualTo(attr: Attribute, l: Literal) => + EqualTo(l, attr) + + case LessThan(attr: Attribute, l: Literal) => + GreaterThan(l, attr) + case LessThanOrEqual(attr: Attribute, l: Literal) => + GreaterThanOrEqual(l, attr) + + case GreaterThan(attr: Attribute, l: Literal) => + LessThan(l, attr) + case GreaterThanOrEqual(attr: Attribute, l: Literal) => + LessThanOrEqual(l, attr) + } + + val testFilters = if (swappedFilter != filterNode) { + Seq(swappedFilter, filterNode) + } else { + Seq(filterNode) } - if (rewrittenFilter != filterNode) { - validateEstimatedStats(ar, rewrittenFilter, expectedColStats, rowCount) + testFilters.foreach { filter => + val expectedAttributeMap = AttributeMap(expectedColStats) + val expectedStats = Statistics( + sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap), + rowCount = Some(expectedRowCount), + attributeStats = expectedAttributeMap) + assert(filter.stats(conf) == expectedStats) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala deleted file mode 100644 index 212d57a9bcf95..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.statsEstimation - -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.types.IntegerType - - -class StatsConfSuite extends StatsEstimationTestBase { - test("estimate statistics when the conf changes") { - val expectedDefaultStats = - Statistics( - sizeInBytes = 40, - rowCount = Some(10), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), - isBroadcastable = false) - val expectedCboStats = - Statistics( - sizeInBytes = 4, - rowCount = Some(1), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), - isBroadcastable = false) - - val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) - // Return the statistics estimated by cbo - assert(plan.stats(conf.copy(cboEnabled = true)) == expectedCboStats) - // Invalidate statistics - plan.invalidateStatsCache() - // Return the simple statistics - assert(plan.stats(conf.copy(cboEnabled = false)) == expectedDefaultStats) - } -} - -/** - * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes - * a simple statistics or a cbo estimated statistics based on the conf. - */ -private case class DummyLogicalPlan( - defaultStats: Statistics, - cboStats: Statistics) extends LogicalPlan { - override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: CatalystConf): Statistics = - if (conf.cboEnabled) cboStats else defaultStats -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 63be1e5302302..41470ae6aae19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -263,8 +263,8 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { /** * Loads a JSON file and returns the results as a `DataFrame`. * - * Both JSON (one record per file) and JSON Lines - * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. + * JSON Lines (newline-delimited JSON) is supported by + * default. For JSON (one record per file), set the `wholeFile` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index e56c33e4b512f..a4c5bf756cd5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -47,12 +47,14 @@ private[sql] object SQLUtils extends Logging { jsc: JavaSparkContext, sparkConfigMap: JMap[Object, Object], enableHiveSupport: Boolean): SparkSession = { - val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport) { + val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport + && jsc.sc.conf.get(CATALOG_IMPLEMENTATION.key, "hive").toLowerCase == "hive") { SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate() } else { if (enableHiveSupport) { logWarning("SparkR: enableHiveSupport is requested for SparkSession but " + - "Spark is not built with Hive; falling back to without Hive support.") + s"Spark is not built with Hive or ${CATALOG_IMPLEMENTATION.key} is not set to 'hive', " + + "falling back to without Hive support.") } SparkSession.builder().sparkContext(jsc.sc).getOrCreate() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 4ca1347008575..0ea806d6cb50b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging @@ -45,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) class CacheManager extends Logging { @transient - private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] + private val cachedData = new java.util.LinkedList[CachedData] @transient private val cacheLock = new ReentrantReadWriteLock @@ -70,7 +72,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) cachedData.clear() } @@ -88,46 +90,81 @@ class CacheManager extends Logging { query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { - val planToCache = query.queryExecution.analyzed + val planToCache = query.logicalPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { val sparkSession = query.sparkSession - cachedData += - CachedData( - planToCache, - InMemoryRelation( - sparkSession.sessionState.conf.useCompression, - sparkSession.sessionState.conf.columnBatchSize, - storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, - tableName)) + cachedData.add(CachedData( + planToCache, + InMemoryRelation( + sparkSession.sessionState.conf.useCompression, + sparkSession.sessionState.conf.columnBatchSize, + storageLevel, + sparkSession.sessionState.executePlan(planToCache).executedPlan, + tableName))) } } /** - * Tries to remove the data for the given [[Dataset]] from the cache. - * No operation, if it's already uncached. + * Un-cache all the cache entries that refer to the given plan. + */ + def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + } + + /** + * Un-cache all the cache entries that refer to the given plan. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { - val planToCache = query.queryExecution.analyzed - val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) - val found = dataIndex >= 0 - if (found) { - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) - cachedData.remove(dataIndex) + def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + val it = cachedData.iterator() + while (it.hasNext) { + val cd = it.next() + if (cd.plan.find(_.sameResult(plan)).isDefined) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + it.remove() + } } - found + } + + /** + * Tries to re-cache all the cache entries that refer to the given plan. + */ + def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) + } + + private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + val it = cachedData.iterator() + val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] + while (it.hasNext) { + val cd = it.next() + if (condition(cd.plan)) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist() + // Remove the cache entry before we create a new one, so that we can have a different + // physical plan. + it.remove() + val newCache = InMemoryRelation( + useCompression = cd.cachedRepresentation.useCompression, + batchSize = cd.cachedRepresentation.batchSize, + storageLevel = cd.cachedRepresentation.storageLevel, + child = spark.sessionState.executePlan(cd.plan).executedPlan, + tableName = cd.cachedRepresentation.tableName) + needToRecache += cd.copy(cachedRepresentation = newCache) + } + } + + needToRecache.foreach(cachedData.add) } /** Optionally returns cached data for the given [[Dataset]] */ def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { - lookupCachedData(query.queryExecution.analyzed) + lookupCachedData(query.logicalPlan) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { - cachedData.find(cd => plan.sameResult(cd.plan)) + cachedData.asScala.find(cd => plan.sameResult(cd.plan)) } /** Replaces segments of the given logical plan with cached versions where possible. */ @@ -145,39 +182,17 @@ class CacheManager extends Logging { } /** - * Invalidates the cache of any data that contains `plan`. Note that it is possible that this - * function will over invalidate. - */ - def invalidateCache(plan: LogicalPlan): Unit = writeLock { - cachedData.foreach { - case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => - data.cachedRepresentation.recache() - case _ => - } - } - - /** - * Invalidates the cache of any data that contains `resourcePath` in one or more + * Tries to re-cache all the cache entries that contain `resourcePath` in one or more * `HadoopFsRelation` node(s) as part of its logical plan. */ - def invalidateCachedPath( - sparkSession: SparkSession, resourcePath: String): Unit = writeLock { + def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock { val (fs, qualifiedPath) = { val path = new Path(resourcePath) - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + (fs, fs.makeQualified(path)) } - cachedData.foreach { - case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => - val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) - if (dataIndex >= 0) { - data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) - cachedData.remove(dataIndex) - } - sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) - case _ => // Do Nothing - } + recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala index b02edd4c74129..aa578f4d23133 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/OptimizeMetadataOnlyQuery.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} import org.apache.spark.sql.internal.SQLConf @@ -103,11 +103,13 @@ case class OptimizeMetadataOnlyQuery( case relation: CatalogRelation => val partAttrs = getPartitionAttrs(relation.tableMeta.partitionColumnNames, relation) + val caseInsensitiveProperties = + CaseInsensitiveMap(relation.tableMeta.storage.properties) + val timeZoneId = caseInsensitiveProperties.get("timeZone") + .getOrElse(conf.sessionLocalTimeZone) val partitionData = catalog.listPartitions(relation.tableMeta.identifier).map { p => InternalRow.fromSeq(partAttrs.map { attr => - // TODO: use correct timezone for partition values. - Cast(Literal(p.spec(attr.name)), attr.dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Cast(Literal(p.spec(attr.name)), attr.dataType, Option(timeZoneId)).eval() }) } LocalRelation(partAttrs, partitionData) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 65df688689397..00d1d6d2701f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -386,7 +386,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + "you can only specify one of them.", ctx) } - val customLocation = storage.locationUri.orElse(location) + val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI(_))) val tableType = if (customLocation.isDefined) { CatalogTableType.EXTERNAL @@ -1080,8 +1080,10 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) } + + val locUri = location.map(CatalogUtils.stringToURI(_)) val storage = CatalogStorageFormat( - locationUri = location, + locationUri = locUri, inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), @@ -1132,7 +1134,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. val newTableDesc = tableDesc.copy( - storage = CatalogStorageFormat.empty.copy(locationUri = location), + storage = CatalogStorageFormat.empty.copy(locationUri = locUri), provider = Some(conf.defaultDataSourceName)) CreateTable(newTableDesc, mode, Some(q)) } else { @@ -1329,6 +1331,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (ctx.identifierList != null) { operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { + // CREATE VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } + val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => icl.identifierComment.asScala.map { ic => ic.identifier.getText -> Option(ic.STRING).map(string) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 37bd95e737786..36037ac003728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -85,12 +85,6 @@ case class InMemoryRelation( buildBuffers() } - def recache(): Unit = { - _cachedColumnBuffers.unpersist() - _cachedColumnBuffers = null - buildBuffers() - } - private def buildBuffers(): Unit = { val output = child.output val cached = child.execute().mapPartitionsInternal { rowIterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala index 5abd579476504..3da66afceda9c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/createDataSourceTables.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.execution.command +import java.net.URI + +import org.apache.hadoop.fs.Path + import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -54,7 +58,7 @@ case class CreateDataSourceTableCommand(table: CatalogTable, ignoreIfExists: Boo // Create the relation to validate the arguments before writing the metadata to the metastore, // and infer the table schema and partition if users didn't specify schema in CREATE TABLE. - val pathOption = table.storage.locationUri.map("path" -> _) + val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) // Fill in some default table options from the session conf val tableWithDefaultOptions = table.copy( identifier = table.identifier.copy( @@ -141,7 +145,7 @@ case class CreateDataSourceTableAsSelectCommand( } saveDataIntoTable( - sparkSession, table, table.storage.locationUri, query, mode, tableExists = true) + sparkSession, table, table.storage.locationUri, query, SaveMode.Append, tableExists = true) } else { assert(table.schema.isEmpty) @@ -151,7 +155,7 @@ case class CreateDataSourceTableAsSelectCommand( table.storage.locationUri } val result = saveDataIntoTable( - sparkSession, table, tableLocation, query, mode, tableExists = false) + sparkSession, table, tableLocation, query, SaveMode.Overwrite, tableExists = false) val newTable = table.copy( storage = table.storage.copy(locationUri = tableLocation), // We will use the schema of resolved.relation as the schema of the table (instead of @@ -175,12 +179,12 @@ case class CreateDataSourceTableAsSelectCommand( private def saveDataIntoTable( session: SparkSession, table: CatalogTable, - tableLocation: Option[String], + tableLocation: Option[URI], data: LogicalPlan, mode: SaveMode, tableExists: Boolean): BaseRelation = { // Create the relation based on the input logical plan: `data`. - val pathOption = tableLocation.map("path" -> _) + val pathOption = tableLocation.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( session, className = table.provider.get, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 82cbb4aa47445..9d3c55060dfb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -66,7 +66,7 @@ case class CreateDatabaseCommand( CatalogDatabase( databaseName, comment.getOrElse(""), - path.getOrElse(catalog.getDefaultDBPath(databaseName)), + path.map(CatalogUtils.stringToURI(_)).getOrElse(catalog.getDefaultDBPath(databaseName)), props), ifNotExists) Seq.empty[Row] @@ -146,7 +146,7 @@ case class DescribeDatabaseCommand( val result = Row("Database Name", dbMetadata.name) :: Row("Description", dbMetadata.description) :: - Row("Location", dbMetadata.locationUri) :: Nil + Row("Location", CatalogUtils.URIToString(dbMetadata.locationUri)) :: Nil if (extended) { val properties = @@ -199,8 +199,7 @@ case class DropTableCommand( } } try { - sparkSession.sharedState.cacheManager.uncacheQuery( - sparkSession.table(tableName.quotedString)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } catch { case _: NoSuchTableException if ifExists => case NonFatal(e) => log.warn(e.toString, e) @@ -426,7 +425,8 @@ case class AlterTableAddPartitionCommand( table.identifier.quotedString, sparkSession.sessionState.conf.resolver) // inherit table storage format (possibly except for location) - CatalogTablePartition(normalizedSpec, table.storage.copy(locationUri = location)) + CatalogTablePartition(normalizedSpec, table.storage.copy( + locationUri = location.map(CatalogUtils.stringToURI(_)))) } catalog.createPartitions(table.identifier, parts, ignoreIfExists = ifNotExists) Seq.empty[Row] @@ -710,7 +710,7 @@ case class AlterTableRecoverPartitionsCommand( // inherit table storage format (possibly except for location) CatalogTablePartition( spec, - table.storage.copy(locationUri = Some(location.toUri.toString)), + table.storage.copy(locationUri = Some(location.toUri)), params) } spark.sessionState.catalog.createPartitions(tableName, parts, ignoreIfExists = true) @@ -741,6 +741,7 @@ case class AlterTableSetLocationCommand( override def run(sparkSession: SparkSession): Seq[Row] = { val catalog = sparkSession.sessionState.catalog val table = catalog.getTableMetadata(tableName) + val locUri = CatalogUtils.stringToURI(location) DDLUtils.verifyAlterTableType(catalog, table, isView = false) partitionSpec match { case Some(spec) => @@ -748,11 +749,11 @@ case class AlterTableSetLocationCommand( sparkSession, table, "ALTER TABLE ... SET LOCATION") // Partition spec is specified, so we set the location only for this partition val part = catalog.getPartition(table.identifier, spec) - val newPart = part.copy(storage = part.storage.copy(locationUri = Some(location))) + val newPart = part.copy(storage = part.storage.copy(locationUri = Some(locUri))) catalog.alterPartitions(table.identifier, Seq(newPart)) case None => // No partition spec is specified, so we set the location for the table itself - catalog.alterTable(table.withNewStorage(locationUri = Some(location))) + catalog.alterTable(table.withNewStorage(locationUri = Some(locUri))) } Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 3e80916104bd9..86394ff23e379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -79,7 +79,8 @@ case class CreateTableLikeCommand( CatalogTable( identifier = targetTable, tableType = tblType, - storage = sourceTableDesc.storage.copy(locationUri = location), + storage = sourceTableDesc.storage.copy( + locationUri = location.map(CatalogUtils.stringToURI(_))), schema = sourceTableDesc.schema, provider = newProvider, partitionColumnNames = sourceTableDesc.partitionColumnNames, @@ -495,7 +496,8 @@ case class DescribeTableCommand( append(buffer, "Owner:", table.owner, "") append(buffer, "Create Time:", new Date(table.createTime).toString, "") append(buffer, "Last Access Time:", new Date(table.lastAccessTime).toString, "") - append(buffer, "Location:", table.storage.locationUri.getOrElse(""), "") + append(buffer, "Location:", table.storage.locationUri.map(CatalogUtils.URIToString(_)) + .getOrElse(""), "") append(buffer, "Table Type:", table.tableType.name, "") table.stats.foreach(s => append(buffer, "Statistics:", s.simpleString, "")) @@ -587,7 +589,8 @@ case class DescribeTableCommand( append(buffer, "Partition Value:", s"[${partition.spec.values.mkString(", ")}]", "") append(buffer, "Database:", table.database, "") append(buffer, "Table:", tableIdentifier.table, "") - append(buffer, "Location:", partition.storage.locationUri.getOrElse(""), "") + append(buffer, "Location:", partition.storage.locationUri.map(CatalogUtils.URIToString(_)) + .getOrElse(""), "") append(buffer, "Partition Parameters:", "", "") partition.parameters.foreach { case (key, value) => append(buffer, s" $key", value, "") @@ -953,7 +956,7 @@ case class ShowCreateTableCommand(table: TableIdentifier) extends RunnableComman // when the table creation DDL contains the PATH option. None } else { - Some(s"path '${escapeSingleQuotedString(location)}'") + Some(s"path '${escapeSingleQuotedString(CatalogUtils.URIToString(location))}'") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala index 1235a4b12f1d0..d6c4b97ebd080 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/CatalogFileIndex.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import java.net.URI + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -46,7 +48,7 @@ class CatalogFileIndex( assert(table.identifier.database.isDefined, "The table identifier must be qualified in CatalogFileIndex") - private val baseLocation: Option[String] = table.storage.locationUri + private val baseLocation: Option[URI] = table.storage.locationUri override def partitionSchema: StructType = table.partitionSchema @@ -72,7 +74,8 @@ class CatalogFileIndex( val path = new Path(p.location) val fs = path.getFileSystem(hadoopConf) PartitionPath( - p.toRow(partitionSchema), path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + p.toRow(partitionSchema, sparkSession.sessionState.conf.sessionLocalTimeZone), + path.makeQualified(fs.getUri, fs.getWorkingDirectory)) } val partitionSpec = PartitionSpec(partitionSchema, partitions) new PrunedInMemoryFileIndex( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index d510581f90e69..c9384e44255b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -29,7 +29,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging import org.apache.spark.sql._ -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogUtils} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -106,10 +106,13 @@ case class DataSource( * be any further inference in any triggers. * * @param format the file format object for this DataSource + * @param fileStatusCache the shared cache for file statuses to speed up listing * @return A pair of the data schema (excluding partition columns) and the schema of the partition * columns. */ - private def getOrInferFileFormatSchema(format: FileFormat): (StructType, StructType) = { + private def getOrInferFileFormatSchema( + format: FileFormat, + fileStatusCache: FileStatusCache = NoopCache): (StructType, StructType) = { // the operations below are expensive therefore try not to do them if we don't need to, e.g., // in streaming mode, we have already inferred and registered partition columns, we will // never have to materialize the lazy val below @@ -122,7 +125,7 @@ case class DataSource( val qualified = hdfsPath.makeQualified(fs.getUri, fs.getWorkingDirectory) SparkHadoopUtil.get.globPathIfNecessary(qualified) }.toArray - new InMemoryFileIndex(sparkSession, globbedPaths, options, None) + new InMemoryFileIndex(sparkSession, globbedPaths, options, None, fileStatusCache) } val partitionSchema = if (partitionColumns.isEmpty) { // Try to infer partitioning, because no DataSource in the read path provides the partitioning @@ -279,28 +282,6 @@ case class DataSource( } } - /** - * Returns true if there is a single path that has a metadata log indicating which files should - * be read. - */ - def hasMetadata(path: Seq[String]): Boolean = { - path match { - case Seq(singlePath) => - try { - val hdfsPath = new Path(singlePath) - val fs = hdfsPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) - val metadataPath = new Path(hdfsPath, FileStreamSink.metadataDir) - val res = fs.exists(metadataPath) - res - } catch { - case NonFatal(e) => - logWarning(s"Error while looking for metadata directory.") - false - } - case _ => false - } - } - /** * Create a resolved [[BaseRelation]] that can be used to read data from or write data into this * [[DataSource]] @@ -331,7 +312,9 @@ case class DataSource( // We are reading from the results of a streaming query. Load files from the metadata log // instead of listing them using HDFS APIs. case (format: FileFormat, _) - if hasMetadata(caseInsensitiveOptions.get("path").toSeq ++ paths) => + if FileStreamSink.hasMetadata( + caseInsensitiveOptions.get("path").toSeq ++ paths, + sparkSession.sessionState.newHadoopConf()) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath) val dataSchema = userSpecifiedSchema.orElse { @@ -374,7 +357,8 @@ case class DataSource( globPath }.toArray - val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format) + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + val (dataSchema, partitionSchema) = getOrInferFileFormatSchema(format, fileStatusCache) val fileCatalog = if (sparkSession.sqlContext.conf.manageFilesourcePartitions && catalogTable.isDefined && catalogTable.get.tracksPartitionsInCatalog) { @@ -384,7 +368,8 @@ case class DataSource( catalogTable.get, catalogTable.get.stats.map(_.sizeInBytes.toLong).getOrElse(defaultTableSize)) } else { - new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(partitionSchema)) + new InMemoryFileIndex( + sparkSession, globbedPaths, options, Some(partitionSchema), fileStatusCache) } HadoopFsRelation( @@ -612,6 +597,7 @@ object DataSource { def buildStorageFormatFromOptions(options: Map[String, String]): CatalogStorageFormat = { val path = CaseInsensitiveMap(options).get("path") val optionsWithoutPath = options.filterKeys(_.toLowerCase != "path") - CatalogStorageFormat.empty.copy(locationUri = path, properties = optionsWithoutPath) + CatalogStorageFormat.empty.copy( + locationUri = path.map(CatalogUtils.stringToURI(_)), properties = optionsWithoutPath) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index f694a0d6d724b..bddf5af23e060 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -21,13 +21,15 @@ import java.util.concurrent.Callable import scala.collection.mutable.ArrayBuffer +import org.apache.hadoop.fs.Path + import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{CatalystConf, CatalystTypeConverters, InternalRow, QualifiedTableName, TableIdentifier} import org.apache.spark.sql.catalyst.CatalystTypeConverters.convertToScala import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.CatalogRelation +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogUtils} import org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning.PhysicalOperation @@ -220,7 +222,7 @@ class FindDataSourceTable(sparkSession: SparkSession) extends Rule[LogicalPlan] val plan = cache.get(qualifiedTableName, new Callable[LogicalPlan]() { override def call(): LogicalPlan = { - val pathOption = table.storage.locationUri.map("path" -> _) + val pathOption = table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) val dataSource = DataSource( sparkSession, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index c17796811cdfd..30a09a9ad3370 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.{QueryExecution, SortExec, SQLExecution} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -68,7 +68,8 @@ object FileFormatWriter extends Logging { val bucketIdExpression: Option[Expression], val path: String, val customPartitionLocations: Map[TablePartitionSpec, String], - val maxRecordsPerFile: Long) + val maxRecordsPerFile: Long, + val timeZoneId: String) extends Serializable { assert(AttributeSet(allColumns) == AttributeSet(partitionColumns ++ dataColumns), @@ -122,9 +123,11 @@ object FileFormatWriter extends Logging { spec => spec.sortColumnNames.map(c => dataColumns.find(_.name == c).get) } + val caseInsensitiveOptions = CaseInsensitiveMap(options) + // Note: prepareWrite has side effect. It sets "job". val outputWriterFactory = - fileFormat.prepareWrite(sparkSession, job, options, dataColumns.toStructType) + fileFormat.prepareWrite(sparkSession, job, caseInsensitiveOptions, dataColumns.toStructType) val description = new WriteJobDescription( uuid = UUID.randomUUID().toString, @@ -136,8 +139,10 @@ object FileFormatWriter extends Logging { bucketIdExpression = bucketIdExpression, path = outputSpec.outputPath, customPartitionLocations = outputSpec.customPartitionLocations, - maxRecordsPerFile = options.get("maxRecordsPerFile").map(_.toLong) - .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile) + maxRecordsPerFile = caseInsensitiveOptions.get("maxRecordsPerFile").map(_.toLong) + .getOrElse(sparkSession.sessionState.conf.maxRecordsPerFile), + timeZoneId = caseInsensitiveOptions.get("timeZone") + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) ) // We should first sort by partition columns, then bucket id, and finally sorting columns. @@ -330,14 +335,13 @@ object FileFormatWriter extends Logging { /** Expressions that given partition columns build a path string like: col1=val/col2=val/... */ private def partitionPathExpression: Seq[Expression] = { desc.partitionColumns.zipWithIndex.flatMap { case (c, i) => - // TODO: use correct timezone for partition values. val escaped = ScalaUDF( ExternalCatalogUtils.escapePathName _, StringType, - Seq(Cast(c, StringType, Option(DateTimeUtils.defaultTimeZone().getID))), + Seq(Cast(c, StringType, Option(desc.timeZoneId))), Seq(StringType)) val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) - val partitionName = Literal(c.name + "=") :: str :: Nil + val partitionName = Literal(ExternalCatalogUtils.escapePathName(c.name) + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index b2ff68a833fea..a813829d50cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand( val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite) - // Invalidate the cache. - sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this + // data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala index 75f87a5503b8c..c8097a7fabc2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningAwareFileIndex.scala @@ -30,7 +30,7 @@ import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.types.{StringType, StructType} import org.apache.spark.util.SerializableConfiguration @@ -125,22 +125,27 @@ abstract class PartitioningAwareFileIndex( val leafDirs = leafDirToChildrenFiles.filter { case (_, files) => files.exists(f => isDataPath(f.getPath)) }.keys.toSeq + + val caseInsensitiveOptions = CaseInsensitiveMap(parameters) + val timeZoneId = caseInsensitiveOptions.get("timeZone") + .getOrElse(sparkSession.sessionState.conf.sessionLocalTimeZone) + userPartitionSchema match { case Some(userProvidedSchema) if userProvidedSchema.nonEmpty => val spec = PartitioningUtils.parsePartitions( leafDirs, typeInference = false, - basePaths = basePaths) + basePaths = basePaths, + timeZoneId = timeZoneId) // Without auto inference, all of value in the `row` should be null or in StringType, // we need to cast into the data type that user specified. def castPartitionValuesToUserSchema(row: InternalRow) = { InternalRow((0 until row.numFields).map { i => - // TODO: use correct timezone for partition values. Cast( Literal.create(row.getUTF8String(i), StringType), userProvidedSchema.fields(i).dataType, - Option(DateTimeUtils.defaultTimeZone().getID)).eval() + Option(timeZoneId)).eval() }: _*) } @@ -151,7 +156,8 @@ abstract class PartitioningAwareFileIndex( PartitioningUtils.parsePartitions( leafDirs, typeInference = sparkSession.sessionState.conf.partitionColumnTypeInferenceEnabled, - basePaths = basePaths) + basePaths = basePaths, + timeZoneId = timeZoneId) } } @@ -300,7 +306,7 @@ object PartitioningAwareFileIndex extends Logging { sparkSession: SparkSession): Seq[(Path, Seq[FileStatus])] = { // Short-circuits parallel listing when serial listing is likely to be faster. - if (paths.size < sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { + if (paths.size <= sparkSession.sessionState.conf.parallelPartitionDiscoveryThreshold) { return paths.map { path => (path, listLeafFiles(path, hadoopConf, filter, Some(sparkSession))) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala index bad59961ace12..09876bbc2f85d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/PartitioningUtils.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.datasources import java.lang.{Double => JDouble, Long => JLong} import java.math.{BigDecimal => JBigDecimal} -import java.sql.{Date => JDate, Timestamp => JTimestamp} +import java.util.TimeZone import scala.collection.mutable.ArrayBuffer import scala.util.Try @@ -31,7 +31,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Cast, Literal} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String // TODO: We should tighten up visibility of the classes here once we clean up Hive coupling. @@ -91,10 +93,19 @@ object PartitioningUtils { private[datasources] def parsePartitions( paths: Seq[Path], typeInference: Boolean, - basePaths: Set[Path]): PartitionSpec = { + basePaths: Set[Path], + timeZoneId: String): PartitionSpec = { + parsePartitions(paths, typeInference, basePaths, TimeZone.getTimeZone(timeZoneId)) + } + + private[datasources] def parsePartitions( + paths: Seq[Path], + typeInference: Boolean, + basePaths: Set[Path], + timeZone: TimeZone): PartitionSpec = { // First, we need to parse every partition's path and see if we can find partition values. val (partitionValues, optDiscoveredBasePaths) = paths.map { path => - parsePartition(path, typeInference, basePaths) + parsePartition(path, typeInference, basePaths, timeZone) }.unzip // We create pairs of (path -> path's partition value) here @@ -173,7 +184,8 @@ object PartitioningUtils { private[datasources] def parsePartition( path: Path, typeInference: Boolean, - basePaths: Set[Path]): (Option[PartitionValues], Option[Path]) = { + basePaths: Set[Path], + timeZone: TimeZone): (Option[PartitionValues], Option[Path]) = { val columns = ArrayBuffer.empty[(String, Literal)] // Old Hadoop versions don't have `Path.isRoot` var finished = path.getParent == null @@ -194,7 +206,7 @@ object PartitioningUtils { // Let's say currentPath is a path of "/table/a=1/", currentPath.getName will give us a=1. // Once we get the string, we try to parse it and find the partition column and value. val maybeColumn = - parsePartitionColumn(currentPath.getName, typeInference) + parsePartitionColumn(currentPath.getName, typeInference, timeZone) maybeColumn.foreach(columns += _) // Now, we determine if we should stop. @@ -226,7 +238,8 @@ object PartitioningUtils { private def parsePartitionColumn( columnSpec: String, - typeInference: Boolean): Option[(String, Literal)] = { + typeInference: Boolean, + timeZone: TimeZone): Option[(String, Literal)] = { val equalSignIndex = columnSpec.indexOf('=') if (equalSignIndex == -1) { None @@ -237,7 +250,7 @@ object PartitioningUtils { val rawColumnValue = columnSpec.drop(equalSignIndex + 1) assert(rawColumnValue.nonEmpty, s"Empty partition column value in '$columnSpec'") - val literal = inferPartitionColumnValue(rawColumnValue, typeInference) + val literal = inferPartitionColumnValue(rawColumnValue, typeInference, timeZone) Some(columnName -> literal) } } @@ -370,7 +383,8 @@ object PartitioningUtils { */ private[datasources] def inferPartitionColumnValue( raw: String, - typeInference: Boolean): Literal = { + typeInference: Boolean, + timeZone: TimeZone): Literal = { val decimalTry = Try { // `BigDecimal` conversion can fail when the `field` is not a form of number. val bigDecimal = new JBigDecimal(raw) @@ -390,8 +404,16 @@ object PartitioningUtils { // Then falls back to fractional types .orElse(Try(Literal.create(JDouble.parseDouble(raw), DoubleType))) // Then falls back to date/timestamp types - .orElse(Try(Literal(JDate.valueOf(raw)))) - .orElse(Try(Literal(JTimestamp.valueOf(unescapePathName(raw))))) + .orElse(Try( + Literal.create( + DateTimeUtils.getThreadLocalTimestampFormat(timeZone) + .parse(unescapePathName(raw)).getTime * 1000L, + TimestampType))) + .orElse(Try( + Literal.create( + DateTimeUtils.millisToDays( + DateTimeUtils.getThreadLocalDateFormat.parse(raw).getTime), + DateType))) // Then falls back to string .getOrElse { if (raw == DEFAULT_PARTITION_NAME) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 73e6abc6dad37..47567032b0195 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -133,20 +133,24 @@ object TextInputCSVDataSource extends CSVDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { - val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first() - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) + val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match { + case Some(firstLine) => + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + case None => + // If the first line could not be read, just return the empty schema. + Some(StructType(Nil)) } - - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) } private def createBaseDataset( @@ -190,28 +194,28 @@ object WholeFileCSVDataSource extends CSVDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { - val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) - val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines => + val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) + csv.flatMap { lines => UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), - false, + shouldDropHeader = false, new CsvParser(parsedOptions.asParserSettings)) - }.take(1).headOption - - if (maybeFirstRow.isDefined) { - val firstRow = maybeFirstRow.get - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.flatMap { lines => - UnivocityParser.tokenizeStream( - CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), - parsedOptions.headerFlag, - new CsvParser(parsedOptions.asParserSettings)) - } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) - } else { - // If the first row could not be read, just return the empty schema. - Some(StructType(Nil)) + }.take(1).headOption match { + case Some(firstRow) => + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource( + lines.getConfiguration, + lines.getPath()), + parsedOptions.headerFlag, + new CsvParser(parsedOptions.asParserSettings)) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + case None => + // If the first row could not be read, just return the empty schema. + Some(StructType(Nil)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala index 804031a5bb5f8..3b3b87e4354d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/UnivocityParser.scala @@ -54,39 +54,77 @@ private[csv] class UnivocityParser( private val dataSchema = StructType(schema.filter(_.name != options.columnNameOfCorruptRecord)) - private val valueConverters = - dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray - private val tokenizer = new CsvParser(options.asParserSettings) private var numMalformedRecords = 0 private val row = new GenericInternalRow(requiredSchema.length) - // This gets the raw input that is parsed lately. + // In `PERMISSIVE` parse mode, we should be able to put the raw malformed row into the field + // specified in `columnNameOfCorruptRecord`. The raw input is retrieved by this method. private def getCurrentInput(): String = tokenizer.getContext.currentParsedContent().stripLineEnd - // This parser loads an `indexArr._1`-th position value in input tokens, - // then put the value in `row(indexArr._2)`. - private val indexArr: Array[(Int, Int)] = { - val fields = if (options.dropMalformed) { - // If `dropMalformed` is enabled, then it needs to parse all the values - // so that we can decide which row is malformed. - requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) - } else { - requiredSchema - } - // TODO: Revisit this; we need to clean up code here for readability. - // See an URL below for related discussions: - // https://github.com/apache/spark/pull/16928#discussion_r102636720 - val fieldsWithIndexes = fields.zipWithIndex - corruptFieldIndex.map { case corrFieldIndex => - fieldsWithIndexes.filter { case (_, i) => i != corrFieldIndex } - }.getOrElse { - fieldsWithIndexes - }.map { case (f, i) => - (dataSchema.indexOf(f), i) - }.toArray + // This parser loads an `tokenIndexArr`-th position value in input tokens, + // then put the value in `row(rowIndexArr)`. + // + // For example, let's say there is CSV data as below: + // + // a,b,c + // 1,2,A + // + // Also, let's say `columnNameOfCorruptRecord` is set to "_unparsed", `header` is `true` + // by user and the user selects "c", "b", "_unparsed" and "a" fields. In this case, we need + // to map those values below: + // + // required schema - ["c", "b", "_unparsed", "a"] + // CSV data schema - ["a", "b", "c"] + // required CSV data schema - ["c", "b", "a"] + // + // with the input tokens, + // + // input tokens - [1, 2, "A"] + // + // Each input token is placed in each output row's position by mapping these. In this case, + // + // output row - ["A", 2, null, 1] + // + // In more details, + // - `valueConverters`, input tokens - CSV data schema + // `valueConverters` keeps the positions of input token indices (by its index) to each + // value's converter (by its value) in an order of CSV data schema. In this case, + // [string->int, string->int, string->string]. + // + // - `tokenIndexArr`, input tokens - required CSV data schema + // `tokenIndexArr` keeps the positions of input token indices (by its index) to reordered + // fields given the required CSV data schema (by its value). In this case, [2, 1, 0]. + // + // - `rowIndexArr`, input tokens - required schema + // `rowIndexArr` keeps the positions of input token indices (by its index) to reordered + // field indices given the required schema (by its value). In this case, [0, 1, 3]. + private val valueConverters: Array[ValueConverter] = + dataSchema.map(f => makeConverter(f.name, f.dataType, f.nullable, options)).toArray + + // Only used to create both `tokenIndexArr` and `rowIndexArr`. This variable means + // the fields that we should try to convert. + private val reorderedFields = if (options.dropMalformed) { + // If `dropMalformed` is enabled, then it needs to parse all the values + // so that we can decide which row is malformed. + requiredSchema ++ schema.filterNot(requiredSchema.contains(_)) + } else { + requiredSchema + } + + private val tokenIndexArr: Array[Int] = { + reorderedFields + .filter(_.name != options.columnNameOfCorruptRecord) + .map(f => dataSchema.indexOf(f)).toArray + } + + private val rowIndexArr: Array[Int] = if (corruptFieldIndex.isDefined) { + val corrFieldIndex = corruptFieldIndex.get + reorderedFields.indices.filter(_ != corrFieldIndex).toArray + } else { + reorderedFields.indices.toArray } /** @@ -200,14 +238,15 @@ private[csv] class UnivocityParser( private def convert(tokens: Array[String]): Option[InternalRow] = { convertWithParseMode(tokens) { tokens => var i: Int = 0 - while (i < indexArr.length) { - val (pos, rowIdx) = indexArr(i) + while (i < tokenIndexArr.length) { // It anyway needs to try to parse since it decides if this row is malformed // or not after trying to cast in `DROPMALFORMED` mode even if the casted // value is not stored in the row. - val value = valueConverters(pos).apply(tokens(pos)) + val from = tokenIndexArr(i) + val to = rowIndexArr(i) + val value = valueConverters(from).apply(tokens(from)) if (i < requiredSchema.length) { - row(rowIdx) = value + row(to) = value } i += 1 } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index 0dbe2a71ed3bc..07ec4e9429e42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming +import scala.util.control.NonFatal + +import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging @@ -25,9 +28,31 @@ import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.{FileFormat, FileFormatWriter} -object FileStreamSink { +object FileStreamSink extends Logging { // The name of the subdirectory that is used to store metadata about which files are valid. val metadataDir = "_spark_metadata" + + /** + * Returns true if there is a single path that has a metadata log indicating which files should + * be read. + */ + def hasMetadata(path: Seq[String], hadoopConf: Configuration): Boolean = { + path match { + case Seq(singlePath) => + try { + val hdfsPath = new Path(singlePath) + val fs = hdfsPath.getFileSystem(hadoopConf) + val metadataPath = new Path(hdfsPath, metadataDir) + val res = fs.exists(metadataPath) + res + } catch { + case NonFatal(e) => + logWarning(s"Error while looking for metadata directory.") + false + } + case _ => false + } + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 39c0b4979687b..6a7263ca45d85 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import scala.collection.JavaConverters._ -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.Logging @@ -43,8 +43,10 @@ class FileStreamSource( private val sourceOptions = new FileStreamOptions(options) + private val hadoopConf = sparkSession.sessionState.newHadoopConf() + private val qualifiedBasePath: Path = { - val fs = new Path(path).getFileSystem(sparkSession.sessionState.newHadoopConf()) + val fs = new Path(path).getFileSystem(hadoopConf) fs.makeQualified(new Path(path)) // can contains glob patterns } @@ -157,14 +159,65 @@ class FileStreamSource( checkFilesExist = false))) } + /** + * If the source has a metadata log indicating which files should be read, then we should use it. + * Only when user gives a non-glob path that will we figure out whether the source has some + * metadata log + * + * None means we don't know at the moment + * Some(true) means we know for sure the source DOES have metadata + * Some(false) means we know for sure the source DOSE NOT have metadata + */ + @volatile private[sql] var sourceHasMetadata: Option[Boolean] = + if (SparkHadoopUtil.get.isGlobPath(new Path(path))) Some(false) else None + + private def allFilesUsingInMemoryFileIndex() = { + val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) + val fileIndex = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) + fileIndex.allFiles() + } + + private def allFilesUsingMetadataLogFileIndex() = { + // Note if `sourceHasMetadata` holds, then `qualifiedBasePath` is guaranteed to be a + // non-glob path + new MetadataLogFileIndex(sparkSession, qualifiedBasePath).allFiles() + } + /** * Returns a list of files found, sorted by their timestamp. */ private def fetchAllFiles(): Seq[(String, Long)] = { val startTime = System.nanoTime - val globbedPaths = SparkHadoopUtil.get.globPathIfNecessary(qualifiedBasePath) - val catalog = new InMemoryFileIndex(sparkSession, globbedPaths, options, Some(new StructType)) - val files = catalog.allFiles().sortBy(_.getModificationTime)(fileSortOrder).map { status => + + var allFiles: Seq[FileStatus] = null + sourceHasMetadata match { + case None => + if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) { + sourceHasMetadata = Some(true) + allFiles = allFilesUsingMetadataLogFileIndex() + } else { + allFiles = allFilesUsingInMemoryFileIndex() + if (allFiles.isEmpty) { + // we still cannot decide + } else { + // decide what to use for future rounds + // double check whether source has metadata, preventing the extreme corner case that + // metadata log and data files are only generated after the previous + // `FileStreamSink.hasMetadata` check + if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) { + sourceHasMetadata = Some(true) + allFiles = allFilesUsingMetadataLogFileIndex() + } else { + sourceHasMetadata = Some(false) + // `allFiles` have already been fetched using InMemoryFileIndex in this round + } + } + } + case Some(true) => allFiles = allFilesUsingMetadataLogFileIndex() + case Some(false) => allFiles = allFilesUsingInMemoryFileIndex() + } + + val files = allFiles.sortBy(_.getModificationTime)(fileSortOrder).map { status => (status.getPath.toUri.toString, status.getModificationTime) } val endTime = System.nanoTime diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 4bd6431cbe110..70912d13ae458 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.streaming +import java.io.{InterruptedIOException, IOException} import java.util.UUID import java.util.concurrent.{CountDownLatch, TimeUnit} import java.util.concurrent.atomic.AtomicReference @@ -37,6 +38,12 @@ import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} +/** States for [[StreamExecution]]'s lifecycle. */ +trait State +case object INITIALIZING extends State +case object ACTIVE extends State +case object TERMINATED extends State + /** * Manages the execution of a streaming Spark SQL query that is occurring in a separate thread. * Unlike a standard query, a streaming query executes repeatedly each time new data arrives at any @@ -298,7 +305,14 @@ class StreamExecution( // `stop()` is already called. Let `finally` finish the cleanup. } } catch { - case _: InterruptedException if state.get == TERMINATED => // interrupted by stop() + case _: InterruptedException | _: InterruptedIOException if state.get == TERMINATED => + // interrupted by stop() + updateStatusMessage("Stopped") + case e: IOException if e.getMessage != null + && e.getMessage.startsWith(classOf[InterruptedException].getName) + && state.get == TERMINATED => + // This is a workaround for HADOOP-12074: `Shell.runCommand` converts `InterruptedException` + // to `new IOException(ie.toString())` before Hadoop 2.8. updateStatusMessage("Stopped") case e: Throwable => streamDeathCause = new StreamingQueryException( @@ -321,6 +335,7 @@ class StreamExecution( initializationLatch.countDown() try { + stopSources() state.set(TERMINATED) currentStatus = status.copy(isTriggerActive = false, isDataAvailable = false) @@ -558,6 +573,18 @@ class StreamExecution( sparkSession.streams.postListenerEvent(event) } + /** Stops all streaming sources safely. */ + private def stopSources(): Unit = { + uniqueSources.foreach { source => + try { + source.stop() + } catch { + case NonFatal(e) => + logWarning(s"Failed to stop streaming source: $source. Resources may have leaked.", e) + } + } + } + /** * Signals to the thread executing micro-batches that it should stop running after the next * batch. This method blocks until the thread stops running. @@ -570,7 +597,6 @@ class StreamExecution( microBatchThread.interrupt() microBatchThread.join() } - uniqueSources.foreach(_.stop()) logInfo(s"Query $prettyIdString was stopped") } @@ -709,10 +735,6 @@ class StreamExecution( } } - trait State - case object INITIALIZING extends State - case object ACTIVE extends State - case object TERMINATED extends State } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 2d29940eb8dab..ab1204a750fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -283,7 +283,9 @@ private[state] class HDFSBackedStateStoreProvider( // semantically correct because Structured Streaming requires rerunning a batch should // generate the same output. (SPARK-19677) // scalastyle:on - if (!fs.exists(finalDeltaFile) && !fs.rename(tempDeltaFile, finalDeltaFile)) { + if (fs.exists(finalDeltaFile)) { + fs.delete(tempDeltaFile, true) + } else if (!fs.rename(tempDeltaFile, finalDeltaFile)) { throw new IOException(s"Failed to rename $tempDeltaFile to $finalDeltaFile") } loadedMaps.put(newVersion, map) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 24ed906d33683..201f726db3fad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -91,15 +91,24 @@ object functions { * @group normal_funcs * @since 1.3.0 */ - def lit(literal: Any): Column = { - literal match { - case c: Column => return c - case s: Symbol => return new ColumnName(literal.asInstanceOf[Symbol].name) - case _ => // continue - } + def lit(literal: Any): Column = typedLit(literal) - val literalExpr = Literal(literal) - Column(literalExpr) + /** + * Creates a [[Column]] of literal value. + * + * The passed in object is returned directly if it is already a [[Column]]. + * If the object is a Scala Symbol, it is converted into a [[Column]] also. + * Otherwise, a new [[Column]] is created to represent the literal value. + * The difference between this function and [[lit]] is that this function + * can handle parameterized scala types e.g.: List, Seq and Map. + * + * @group normal_funcs + * @since 2.2.0 + */ + def typedLit[T : TypeTag](literal: T): Column = literal match { + case c: Column => c + case s: Symbol => new ColumnName(s.name) + case _ => Column(Literal.create(literal)) } ////////////////////////////////////////////////////////////////////////////////////////////// @@ -2964,7 +2973,22 @@ object functions { * @group collection_funcs * @since 2.1.0 */ - def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = withExpr { + def from_json(e: Column, schema: StructType, options: Map[String, String]): Column = + from_json(e, schema.asInstanceOf[DataType], options) + + /** + * (Scala-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.2.0 + */ + def from_json(e: Column, schema: DataType, options: Map[String, String]): Column = withExpr { JsonToStruct(schema, options, e.expr) } @@ -2983,6 +3007,21 @@ object functions { def from_json(e: Column, schema: StructType, options: java.util.Map[String, String]): Column = from_json(e, schema, options.asScala.toMap) + /** + * (Java-specific) Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * @param options options to control how the json is parsed. accepts the same options and the + * json data source. + * + * @group collection_funcs + * @since 2.2.0 + */ + def from_json(e: Column, schema: DataType, options: java.util.Map[String, String]): Column = + from_json(e, schema, options.asScala.toMap) + /** * Parses a column containing a JSON string into a `StructType` with the specified schema. * Returns `null`, in the case of an unparseable string. @@ -2997,8 +3036,21 @@ object functions { from_json(e, schema, Map.empty[String, String]) /** - * Parses a column containing a JSON string into a `StructType` with the specified schema. - * Returns `null`, in the case of an unparseable string. + * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. + * + * @param e a string column containing JSON data. + * @param schema the schema to use when parsing the json string + * + * @group collection_funcs + * @since 2.2.0 + */ + def from_json(e: Column, schema: DataType): Column = + from_json(e, schema, Map.empty[String, String]) + + /** + * Parses a column containing a JSON string into a `StructType` or `ArrayType` + * with the specified schema. Returns `null`, in the case of an unparseable string. * * @param e a string column containing JSON data. * @param schema the schema to use when parsing the json string as a json string @@ -3007,8 +3059,7 @@ object functions { * @since 2.1.0 */ def from_json(e: Column, schema: String, options: java.util.Map[String, String]): Column = - from_json(e, DataType.fromJson(schema).asInstanceOf[StructType], options) - + from_json(e, DataType.fromJson(schema), options) /** * (Scala-specific) Converts a column containing a `StructType` into a JSON string with the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index 3d9f41832bc73..53374859f13f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.internal import scala.reflect.runtime.universe.TypeTag +import org.apache.hadoop.fs.Path + import org.apache.spark.annotation.Experimental import org.apache.spark.sql._ import org.apache.spark.sql.catalog.{Catalog, Column, Database, Function, Table} @@ -77,7 +79,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { new Database( name = metadata.name, description = metadata.description, - locationUri = metadata.locationUri) + locationUri = CatalogUtils.URIToString(metadata.locationUri)) } /** @@ -341,8 +343,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def dropTempView(viewName: String): Boolean = { - sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) + sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -357,7 +359,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -402,7 +404,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } /** @@ -440,17 +442,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // If this table is cached as an InMemoryRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sparkSession.table(tableIdent).queryExecution.analyzed - // Use lookupCachedData directly since RefreshTable also takes databaseName. - val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty - if (isCached) { - // Create a data frame to represent the table. - // TODO: Use uncacheTable once it supports database name. - val df = Dataset.ofRows(sparkSession, logicalPlan) + val table = sparkSession.table(tableIdent) + if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) // Cache it again. - sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table)) + sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } } @@ -462,7 +459,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def refreshByPath(resourcePath: String): Unit = { - sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath) + sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index dc0f130406932..461dfe3a66e1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -402,11 +402,13 @@ object SQLConf { val PARALLEL_PARTITION_DISCOVERY_THRESHOLD = buildConf("spark.sql.sources.parallelPartitionDiscovery.threshold") - .doc("The maximum number of files allowed for listing files at driver side. If the number " + - "of detected files exceeds this value during partition discovery, it tries to list the " + + .doc("The maximum number of paths allowed for listing files at driver side. If the number " + + "of detected paths exceeds this value during partition discovery, it tries to list the " + "files with another Spark distributed job. This applies to Parquet, ORC, CSV, JSON and " + "LibSVM data sources.") .intConf + .checkValue(parallel => parallel >= 0, "The maximum number of paths allowed for listing " + + "files at driver side must not be negative") .createWithDefault(32) val PARALLEL_PARTITION_DISCOVERY_PARALLELISM = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala index bce84de45c3d7..86129fa87feaa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SharedState.scala @@ -21,6 +21,7 @@ import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.Path import org.apache.spark.{SparkConf, SparkContext, SparkException} import org.apache.spark.internal.Logging @@ -95,7 +96,10 @@ private[sql] class SharedState(val sparkContext: SparkContext) extends Logging { // Create the default database if it doesn't exist. { val defaultDbDefinition = CatalogDatabase( - SessionCatalog.DEFAULT_DATABASE, "default database", warehousePath, Map()) + SessionCatalog.DEFAULT_DATABASE, + "default database", + CatalogUtils.stringToURI(warehousePath), + Map()) // Initialize default database if it doesn't exist if (!externalCatalog.databaseExists(SessionCatalog.DEFAULT_DATABASE)) { // There may be another Spark application creating default database at the same time, here we diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 6a275281d8697..aed8074a64d5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -143,8 +143,8 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo /** * Loads a JSON file stream and returns the results as a `DataFrame`. * - * Both JSON (one record per file) and JSON Lines - * (newline-delimited JSON) are supported and can be selected with the `wholeFile` option. + * JSON Lines (newline-delimited JSON) is supported by + * default. For JSON (one record per file), set the `wholeFile` option to true. * * This function goes through the input once to determine the input schema. If you know the * schema in advance, use the version that specifies the schema to avoid the extra scan. diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-negative.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-negative.sql new file mode 100644 index 0000000000000..1caa45c66749d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-negative.sql @@ -0,0 +1,36 @@ +-- Negative testcases for column resolution +CREATE DATABASE mydb1; +USE mydb1; +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1; + +CREATE DATABASE mydb2; +USE mydb2; +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1; + +-- Negative tests: column resolution scenarios with ambiguous cases in join queries +SET spark.sql.crossJoin.enabled = true; +USE mydb1; +SELECT i1 FROM t1, mydb1.t1; +SELECT t1.i1 FROM t1, mydb1.t1; +SELECT mydb1.t1.i1 FROM t1, mydb1.t1; +SELECT i1 FROM t1, mydb2.t1; +SELECT t1.i1 FROM t1, mydb2.t1; +USE mydb2; +SELECT i1 FROM t1, mydb1.t1; +SELECT t1.i1 FROM t1, mydb1.t1; +SELECT i1 FROM t1, mydb2.t1; +SELECT t1.i1 FROM t1, mydb2.t1; +SELECT db1.t1.i1 FROM t1, mydb2.t1; +SET spark.sql.crossJoin.enabled = false; + +-- Negative tests +USE mydb1; +SELECT mydb1.t1 FROM t1; +SELECT t1.x.y.* FROM t1; +SELECT t1 FROM mydb1.t1; +USE mydb2; +SELECT mydb1.t1.i1 FROM t1; + +-- reset +DROP DATABASE mydb1 CASCADE; +DROP DATABASE mydb2 CASCADE; diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql new file mode 100644 index 0000000000000..d3f928751757c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution-views.sql @@ -0,0 +1,25 @@ +-- Tests for qualified column names for the view code-path +-- Test scenario with Temporary view +CREATE OR REPLACE TEMPORARY VIEW view1 AS SELECT 2 AS i1; +SELECT view1.* FROM view1; +SELECT * FROM view1; +SELECT view1.i1 FROM view1; +SELECT i1 FROM view1; +SELECT a.i1 FROM view1 AS a; +SELECT i1 FROM view1 AS a; +-- cleanup +DROP VIEW view1; + +-- Test scenario with Global Temp view +CREATE OR REPLACE GLOBAL TEMPORARY VIEW view1 as SELECT 1 as i1; +SELECT * FROM global_temp.view1; +-- TODO: Support this scenario +SELECT global_temp.view1.* FROM global_temp.view1; +SELECT i1 FROM global_temp.view1; +-- TODO: Support this scenario +SELECT global_temp.view1.i1 FROM global_temp.view1; +SELECT view1.i1 FROM global_temp.view1; +SELECT a.i1 FROM global_temp.view1 AS a; +SELECT i1 FROM global_temp.view1 AS a; +-- cleanup +DROP VIEW global_temp.view1; diff --git a/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql new file mode 100644 index 0000000000000..79e90ad3de91d --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/columnresolution.sql @@ -0,0 +1,88 @@ +-- Tests covering different scenarios with qualified column names +-- Scenario: column resolution scenarios with datasource table +CREATE DATABASE mydb1; +USE mydb1; +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1; + +CREATE DATABASE mydb2; +USE mydb2; +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1; + +USE mydb1; +SELECT i1 FROM t1; +SELECT i1 FROM mydb1.t1; +SELECT t1.i1 FROM t1; +SELECT t1.i1 FROM mydb1.t1; + +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM t1; +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM mydb1.t1; + +USE mydb2; +SELECT i1 FROM t1; +SELECT i1 FROM mydb1.t1; +SELECT t1.i1 FROM t1; +SELECT t1.i1 FROM mydb1.t1; +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM mydb1.t1; + +-- Scenario: resolve fully qualified table name in star expansion +USE mydb1; +SELECT t1.* FROM t1; +SELECT mydb1.t1.* FROM mydb1.t1; +SELECT t1.* FROM mydb1.t1; +USE mydb2; +SELECT t1.* FROM t1; +-- TODO: Support this scenario +SELECT mydb1.t1.* FROM mydb1.t1; +SELECT t1.* FROM mydb1.t1; +SELECT a.* FROM mydb1.t1 AS a; + +-- Scenario: resolve in case of subquery + +USE mydb1; +CREATE TABLE t3 USING parquet AS SELECT * FROM VALUES (4,1), (3,1) AS t3(c1, c2); +CREATE TABLE t4 USING parquet AS SELECT * FROM VALUES (4,1), (2,1) AS t4(c2, c3); + +SELECT * FROM t3 WHERE c1 IN (SELECT c2 FROM t4 WHERE t4.c3 = t3.c2); + +-- TODO: Support this scenario +SELECT * FROM mydb1.t3 WHERE c1 IN + (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2); + +-- Scenario: column resolution scenarios in join queries +SET spark.sql.crossJoin.enabled = true; + +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM t1, mydb2.t1; + +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1; + +USE mydb2; +-- TODO: Support this scenario +SELECT mydb1.t1.i1 FROM t1, mydb1.t1; +SET spark.sql.crossJoin.enabled = false; + +-- Scenario: Table with struct column +USE mydb1; +CREATE TABLE t5(i1 INT, t5 STRUCT) USING parquet; +INSERT INTO t5 VALUES(1, (2, 3)); +SELECT t5.i1 FROM t5; +SELECT t5.t5.i1 FROM t5; +SELECT t5.t5.i1 FROM mydb1.t5; +SELECT t5.i1 FROM mydb1.t5; +SELECT t5.* FROM mydb1.t5; +SELECT t5.t5.* FROM mydb1.t5; +-- TODO: Support this scenario +SELECT mydb1.t5.t5.i1 FROM mydb1.t5; +-- TODO: Support this scenario +SELECT mydb1.t5.t5.i2 FROM mydb1.t5; +-- TODO: Support this scenario +SELECT mydb1.t5.* FROM mydb1.t5; + +-- Cleanup and Reset +USE default; +DROP DATABASE mydb1 CASCADE; +DROP DATABASE mydb2 CASCADE; diff --git a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql index 5107fa4d55537..b3ec956cd178e 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/inline-table.sql @@ -46,3 +46,6 @@ select * from values ("one", random_not_exist_func(1)), ("two", 2) as data(a, b) -- error reporting: aggregate expression select * from values ("one", count(1)), ("two", 2) as data(a, b); + +-- string to timestamp +select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b); diff --git a/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql new file mode 100644 index 0000000000000..38739cb950582 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/inner-join.sql @@ -0,0 +1,17 @@ +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a); +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a); + +CREATE TEMPORARY VIEW ta AS +SELECT a, 'a' AS tag FROM t1 +UNION ALL +SELECT a, 'b' AS tag FROM t2; + +CREATE TEMPORARY VIEW tb AS +SELECT a, 'a' AS tag FROM t3 +UNION ALL +SELECT a, 'b' AS tag FROM t4; + +-- SPARK-19766 Constant alias columns in INNER JOIN should not be folded by FoldablePropagation rule +SELECT tb.* FROM ta INNER JOIN tb ON ta.a = tb.a AND ta.tag = tb.tag; diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql new file mode 100644 index 0000000000000..9308560451bf5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -0,0 +1,8 @@ +-- to_json +describe function to_json; +describe function extended to_json; +select to_json(named_struct('a', 1, 'b', 2)); +select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); +-- Check if errors handled +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')); +select to_json(); diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql index b10c41929cdaf..880175fd7add0 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-joins.sql @@ -79,7 +79,7 @@ GROUP BY t1a, t3a, t3b, t3c -ORDER BY t1a DESC; +ORDER BY t1a DESC, t3b DESC; -- TC 01.03 SELECT Count(DISTINCT(t1a)) diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql index 6b9e8bf2f362d..5c371d2305ac8 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/in-set-operations.sql @@ -287,7 +287,7 @@ WHERE t1a IN (SELECT t3a WHERE t1b > 6) AS t5) GROUP BY t1a, t1b, t1c, t1d HAVING t1c IS NOT NULL AND t1b IS NOT NULL -ORDER BY t1c DESC; +ORDER BY t1c DESC, t1a DESC; -- TC 01.08 SELECT t1a, @@ -351,7 +351,7 @@ WHERE t1b IN FROM t1 WHERE t1b > 6) AS t4 WHERE t2b = t1b) -ORDER BY t1c DESC NULLS last; +ORDER BY t1c DESC NULLS last, t1a DESC; -- TC 01.11 SELECT * @@ -468,5 +468,5 @@ HAVING t1b NOT IN EXCEPT SELECT t3b FROM t3) -ORDER BY t1c DESC NULLS LAST; +ORDER BY t1c DESC NULLS LAST, t1i; diff --git a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql index 505366b7acd43..e09b91f18de0a 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/subquery/in-subquery/not-in-joins.sql @@ -85,7 +85,7 @@ AND t1b != t3b AND t1d = t2d GROUP BY t1a, t1b, t1c, t3a, t3b, t3c HAVING count(distinct(t3a)) >= 1 -ORDER BY t1a; +ORDER BY t1a, t3b; -- TC 01.03 SELECT t1a, diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out new file mode 100644 index 0000000000000..60bd8e9cc99db --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-negative.sql.out @@ -0,0 +1,240 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 28 + + +-- !query 0 +CREATE DATABASE mydb1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +USE mydb1 +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1 +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE DATABASE mydb2 +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +USE mydb2 +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +SET spark.sql.crossJoin.enabled = true +-- !query 6 schema +struct +-- !query 6 output +spark.sql.crossJoin.enabled true + + +-- !query 7 +USE mydb1 +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +SELECT i1 FROM t1, mydb1.t1 +-- !query 8 schema +struct<> +-- !query 8 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 9 +SELECT t1.i1 FROM t1, mydb1.t1 +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 10 +SELECT mydb1.t1.i1 FROM t1, mydb1.t1 +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 11 +SELECT i1 FROM t1, mydb2.t1 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 12 +SELECT t1.i1 FROM t1, mydb2.t1 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 13 +USE mydb2 +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +SELECT i1 FROM t1, mydb1.t1 +-- !query 14 schema +struct<> +-- !query 14 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 15 +SELECT t1.i1 FROM t1, mydb1.t1 +-- !query 15 schema +struct<> +-- !query 15 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 16 +SELECT i1 FROM t1, mydb2.t1 +-- !query 16 schema +struct<> +-- !query 16 output +org.apache.spark.sql.AnalysisException +Reference 'i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 17 +SELECT t1.i1 FROM t1, mydb2.t1 +-- !query 17 schema +struct<> +-- !query 17 output +org.apache.spark.sql.AnalysisException +Reference 't1.i1' is ambiguous, could be: i1#x, i1#x.; line 1 pos 7 + + +-- !query 18 +SELECT db1.t1.i1 FROM t1, mydb2.t1 +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve '`db1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 19 +SET spark.sql.crossJoin.enabled = false +-- !query 19 schema +struct +-- !query 19 output +spark.sql.crossJoin.enabled false + + +-- !query 20 +USE mydb1 +-- !query 20 schema +struct<> +-- !query 20 output + + + +-- !query 21 +SELECT mydb1.t1 FROM t1 +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 22 +SELECT t1.x.y.* FROM t1 +-- !query 22 schema +struct<> +-- !query 22 output +org.apache.spark.sql.AnalysisException +cannot resolve 't1.x.y.*' give input columns 'i1'; + + +-- !query 23 +SELECT t1 FROM mydb1.t1 +-- !query 23 schema +struct<> +-- !query 23 output +org.apache.spark.sql.AnalysisException +cannot resolve '`t1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 24 +USE mydb2 +-- !query 24 schema +struct<> +-- !query 24 output + + + +-- !query 25 +SELECT mydb1.t1.i1 FROM t1 +-- !query 25 schema +struct<> +-- !query 25 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 26 +DROP DATABASE mydb1 CASCADE +-- !query 26 schema +struct<> +-- !query 26 output + + + +-- !query 27 +DROP DATABASE mydb2 CASCADE +-- !query 27 schema +struct<> +-- !query 27 output + diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out new file mode 100644 index 0000000000000..616421d6f2b28 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution-views.sql.out @@ -0,0 +1,140 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 17 + + +-- !query 0 +CREATE OR REPLACE TEMPORARY VIEW view1 AS SELECT 2 AS i1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +SELECT view1.* FROM view1 +-- !query 1 schema +struct +-- !query 1 output +2 + + +-- !query 2 +SELECT * FROM view1 +-- !query 2 schema +struct +-- !query 2 output +2 + + +-- !query 3 +SELECT view1.i1 FROM view1 +-- !query 3 schema +struct +-- !query 3 output +2 + + +-- !query 4 +SELECT i1 FROM view1 +-- !query 4 schema +struct +-- !query 4 output +2 + + +-- !query 5 +SELECT a.i1 FROM view1 AS a +-- !query 5 schema +struct +-- !query 5 output +2 + + +-- !query 6 +SELECT i1 FROM view1 AS a +-- !query 6 schema +struct +-- !query 6 output +2 + + +-- !query 7 +DROP VIEW view1 +-- !query 7 schema +struct<> +-- !query 7 output + + + +-- !query 8 +CREATE OR REPLACE GLOBAL TEMPORARY VIEW view1 as SELECT 1 as i1 +-- !query 8 schema +struct<> +-- !query 8 output + + + +-- !query 9 +SELECT * FROM global_temp.view1 +-- !query 9 schema +struct +-- !query 9 output +1 + + +-- !query 10 +SELECT global_temp.view1.* FROM global_temp.view1 +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.AnalysisException +cannot resolve 'global_temp.view1.*' give input columns 'i1'; + + +-- !query 11 +SELECT i1 FROM global_temp.view1 +-- !query 11 schema +struct +-- !query 11 output +1 + + +-- !query 12 +SELECT global_temp.view1.i1 FROM global_temp.view1 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`global_temp.view1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 13 +SELECT view1.i1 FROM global_temp.view1 +-- !query 13 schema +struct +-- !query 13 output +1 + + +-- !query 14 +SELECT a.i1 FROM global_temp.view1 AS a +-- !query 14 schema +struct +-- !query 14 output +1 + + +-- !query 15 +SELECT i1 FROM global_temp.view1 AS a +-- !query 15 schema +struct +-- !query 15 output +1 + + +-- !query 16 +DROP VIEW global_temp.view1 +-- !query 16 schema +struct<> +-- !query 16 output + diff --git a/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out new file mode 100644 index 0000000000000..764cad0e3943c --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/columnresolution.sql.out @@ -0,0 +1,447 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 54 + + +-- !query 0 +CREATE DATABASE mydb1 +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +USE mydb1 +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TABLE t1 USING parquet AS SELECT 1 AS i1 +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE DATABASE mydb2 +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +USE mydb2 +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TABLE t1 USING parquet AS SELECT 20 AS i1 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +USE mydb1 +-- !query 6 schema +struct<> +-- !query 6 output + + + +-- !query 7 +SELECT i1 FROM t1 +-- !query 7 schema +struct +-- !query 7 output +1 + + +-- !query 8 +SELECT i1 FROM mydb1.t1 +-- !query 8 schema +struct +-- !query 8 output +1 + + +-- !query 9 +SELECT t1.i1 FROM t1 +-- !query 9 schema +struct +-- !query 9 output +1 + + +-- !query 10 +SELECT t1.i1 FROM mydb1.t1 +-- !query 10 schema +struct +-- !query 10 output +1 + + +-- !query 11 +SELECT mydb1.t1.i1 FROM t1 +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 12 +SELECT mydb1.t1.i1 FROM mydb1.t1 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 13 +USE mydb2 +-- !query 13 schema +struct<> +-- !query 13 output + + + +-- !query 14 +SELECT i1 FROM t1 +-- !query 14 schema +struct +-- !query 14 output +20 + + +-- !query 15 +SELECT i1 FROM mydb1.t1 +-- !query 15 schema +struct +-- !query 15 output +1 + + +-- !query 16 +SELECT t1.i1 FROM t1 +-- !query 16 schema +struct +-- !query 16 output +20 + + +-- !query 17 +SELECT t1.i1 FROM mydb1.t1 +-- !query 17 schema +struct +-- !query 17 output +1 + + +-- !query 18 +SELECT mydb1.t1.i1 FROM mydb1.t1 +-- !query 18 schema +struct<> +-- !query 18 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1]; line 1 pos 7 + + +-- !query 19 +USE mydb1 +-- !query 19 schema +struct<> +-- !query 19 output + + + +-- !query 20 +SELECT t1.* FROM t1 +-- !query 20 schema +struct +-- !query 20 output +1 + + +-- !query 21 +SELECT mydb1.t1.* FROM mydb1.t1 +-- !query 21 schema +struct<> +-- !query 21 output +org.apache.spark.sql.AnalysisException +cannot resolve 'mydb1.t1.*' give input columns 'i1'; + + +-- !query 22 +SELECT t1.* FROM mydb1.t1 +-- !query 22 schema +struct +-- !query 22 output +1 + + +-- !query 23 +USE mydb2 +-- !query 23 schema +struct<> +-- !query 23 output + + + +-- !query 24 +SELECT t1.* FROM t1 +-- !query 24 schema +struct +-- !query 24 output +20 + + +-- !query 25 +SELECT mydb1.t1.* FROM mydb1.t1 +-- !query 25 schema +struct<> +-- !query 25 output +org.apache.spark.sql.AnalysisException +cannot resolve 'mydb1.t1.*' give input columns 'i1'; + + +-- !query 26 +SELECT t1.* FROM mydb1.t1 +-- !query 26 schema +struct +-- !query 26 output +1 + + +-- !query 27 +SELECT a.* FROM mydb1.t1 AS a +-- !query 27 schema +struct +-- !query 27 output +1 + + +-- !query 28 +USE mydb1 +-- !query 28 schema +struct<> +-- !query 28 output + + + +-- !query 29 +CREATE TABLE t3 USING parquet AS SELECT * FROM VALUES (4,1), (3,1) AS t3(c1, c2) +-- !query 29 schema +struct<> +-- !query 29 output + + + +-- !query 30 +CREATE TABLE t4 USING parquet AS SELECT * FROM VALUES (4,1), (2,1) AS t4(c2, c3) +-- !query 30 schema +struct<> +-- !query 30 output + + + +-- !query 31 +SELECT * FROM t3 WHERE c1 IN (SELECT c2 FROM t4 WHERE t4.c3 = t3.c2) +-- !query 31 schema +struct +-- !query 31 output +4 1 + + +-- !query 32 +SELECT * FROM mydb1.t3 WHERE c1 IN + (SELECT mydb1.t4.c2 FROM mydb1.t4 WHERE mydb1.t4.c3 = mydb1.t3.c2) +-- !query 32 schema +struct<> +-- !query 32 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t4.c3`' given input columns: [c2, c3]; line 2 pos 42 + + +-- !query 33 +SET spark.sql.crossJoin.enabled = true +-- !query 33 schema +struct +-- !query 33 output +spark.sql.crossJoin.enabled true + + +-- !query 34 +SELECT mydb1.t1.i1 FROM t1, mydb2.t1 +-- !query 34 schema +struct<> +-- !query 34 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 35 +SELECT mydb1.t1.i1 FROM mydb1.t1, mydb2.t1 +-- !query 35 schema +struct<> +-- !query 35 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 36 +USE mydb2 +-- !query 36 schema +struct<> +-- !query 36 output + + + +-- !query 37 +SELECT mydb1.t1.i1 FROM t1, mydb1.t1 +-- !query 37 schema +struct<> +-- !query 37 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t1.i1`' given input columns: [i1, i1]; line 1 pos 7 + + +-- !query 38 +SET spark.sql.crossJoin.enabled = false +-- !query 38 schema +struct +-- !query 38 output +spark.sql.crossJoin.enabled false + + +-- !query 39 +USE mydb1 +-- !query 39 schema +struct<> +-- !query 39 output + + + +-- !query 40 +CREATE TABLE t5(i1 INT, t5 STRUCT) USING parquet +-- !query 40 schema +struct<> +-- !query 40 output + + + +-- !query 41 +INSERT INTO t5 VALUES(1, (2, 3)) +-- !query 41 schema +struct<> +-- !query 41 output + + + +-- !query 42 +SELECT t5.i1 FROM t5 +-- !query 42 schema +struct +-- !query 42 output +1 + + +-- !query 43 +SELECT t5.t5.i1 FROM t5 +-- !query 43 schema +struct +-- !query 43 output +2 + + +-- !query 44 +SELECT t5.t5.i1 FROM mydb1.t5 +-- !query 44 schema +struct +-- !query 44 output +2 + + +-- !query 45 +SELECT t5.i1 FROM mydb1.t5 +-- !query 45 schema +struct +-- !query 45 output +1 + + +-- !query 46 +SELECT t5.* FROM mydb1.t5 +-- !query 46 schema +struct> +-- !query 46 output +1 {"i1":2,"i2":3} + + +-- !query 47 +SELECT t5.t5.* FROM mydb1.t5 +-- !query 47 schema +struct +-- !query 47 output +2 3 + + +-- !query 48 +SELECT mydb1.t5.t5.i1 FROM mydb1.t5 +-- !query 48 schema +struct<> +-- !query 48 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t5.t5.i1`' given input columns: [i1, t5]; line 1 pos 7 + + +-- !query 49 +SELECT mydb1.t5.t5.i2 FROM mydb1.t5 +-- !query 49 schema +struct<> +-- !query 49 output +org.apache.spark.sql.AnalysisException +cannot resolve '`mydb1.t5.t5.i2`' given input columns: [i1, t5]; line 1 pos 7 + + +-- !query 50 +SELECT mydb1.t5.* FROM mydb1.t5 +-- !query 50 schema +struct<> +-- !query 50 output +org.apache.spark.sql.AnalysisException +cannot resolve 'mydb1.t5.*' give input columns 'i1, t5'; + + +-- !query 51 +USE default +-- !query 51 schema +struct<> +-- !query 51 output + + + +-- !query 52 +DROP DATABASE mydb1 CASCADE +-- !query 52 schema +struct<> +-- !query 52 output + + + +-- !query 53 +DROP DATABASE mydb2 CASCADE +-- !query 53 schema +struct<> +-- !query 53 output + diff --git a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out index de6f01b8de772..4e80f0bda5513 100644 --- a/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/inline-table.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 16 +-- Number of queries: 17 -- !query 0 @@ -143,3 +143,11 @@ struct<> -- !query 15 output org.apache.spark.sql.AnalysisException cannot evaluate expression count(1) in inline table definition; line 1 pos 29 + + +-- !query 16 +select * from values (timestamp('1991-12-06 00:00:00.0'), array(timestamp('1991-12-06 01:00:00.0'), timestamp('1991-12-06 12:00:00.0'))) as data(a, b) +-- !query 16 schema +struct> +-- !query 16 output +1991-12-06 00:00:00 [1991-12-06 01:00:00.0,1991-12-06 12:00:00.0] diff --git a/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out b/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out new file mode 100644 index 0000000000000..8d56ebe9fd3b4 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/inner-join.sql.out @@ -0,0 +1,67 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 7 + + +-- !query 0 +CREATE TEMPORARY VIEW t1 AS SELECT * FROM VALUES (1) AS GROUPING(a) +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE TEMPORARY VIEW t2 AS SELECT * FROM VALUES (1) AS GROUPING(a) +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +CREATE TEMPORARY VIEW t3 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a) +-- !query 2 schema +struct<> +-- !query 2 output + + + +-- !query 3 +CREATE TEMPORARY VIEW t4 AS SELECT * FROM VALUES (1), (1) AS GROUPING(a) +-- !query 3 schema +struct<> +-- !query 3 output + + + +-- !query 4 +CREATE TEMPORARY VIEW ta AS +SELECT a, 'a' AS tag FROM t1 +UNION ALL +SELECT a, 'b' AS tag FROM t2 +-- !query 4 schema +struct<> +-- !query 4 output + + + +-- !query 5 +CREATE TEMPORARY VIEW tb AS +SELECT a, 'a' AS tag FROM t3 +UNION ALL +SELECT a, 'b' AS tag FROM t4 +-- !query 5 schema +struct<> +-- !query 5 output + + + +-- !query 6 +SELECT tb.* FROM ta INNER JOIN tb ON ta.a = tb.a AND ta.tag = tb.tag +-- !query 6 schema +struct +-- !query 6 output +1 a +1 a +1 b +1 b diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out new file mode 100644 index 0000000000000..d8aa4fb9fa788 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -0,0 +1,63 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +describe function to_json +-- !query 0 schema +struct +-- !query 0 output +Class: org.apache.spark.sql.catalyst.expressions.StructToJson +Function: to_json +Usage: to_json(expr[, options]) - Returns a json string with a given struct value + + +-- !query 1 +describe function extended to_json +-- !query 1 schema +struct +-- !query 1 output +Class: org.apache.spark.sql.catalyst.expressions.StructToJson +Extended Usage: + Examples: + > SELECT to_json(named_struct('a', 1, 'b', 2)); + {"a":1,"b":2} + > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); + {"time":"26/08/2015"} + +Function: to_json +Usage: to_json(expr[, options]) - Returns a json string with a given struct value + + +-- !query 2 +select to_json(named_struct('a', 1, 'b', 2)) +-- !query 2 schema +struct +-- !query 2 output +{"a":1,"b":2} + + +-- !query 3 +select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')) +-- !query 3 schema +struct +-- !query 3 output +{"time":"26/08/2015"} + + +-- !query 4 +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Must use a map() function for options;; line 1 pos 7 + + +-- !query 5 +select to_json() +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function to_json; line 1 pos 7 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out index 7258bcfc6ab72..ab6a11a2b7efa 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-joins.sql.out @@ -102,7 +102,7 @@ GROUP BY t1a, t3a, t3b, t3c -ORDER BY t1a DESC +ORDER BY t1a DESC, t3b DESC -- !query 4 schema struct -- !query 4 output diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out index 878bc755ef5fc..e06f9206d3401 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/in-set-operations.sql.out @@ -353,7 +353,7 @@ WHERE t1a IN (SELECT t3a WHERE t1b > 6) AS t5) GROUP BY t1a, t1b, t1c, t1d HAVING t1c IS NOT NULL AND t1b IS NOT NULL -ORDER BY t1c DESC +ORDER BY t1c DESC, t1a DESC -- !query 9 schema struct -- !query 9 output @@ -445,7 +445,7 @@ WHERE t1b IN FROM t1 WHERE t1b > 6) AS t4 WHERE t2b = t1b) -ORDER BY t1c DESC NULLS last +ORDER BY t1c DESC NULLS last, t1a DESC -- !query 12 schema struct -- !query 12 output @@ -580,16 +580,16 @@ HAVING t1b NOT IN EXCEPT SELECT t3b FROM t3) -ORDER BY t1c DESC NULLS LAST +ORDER BY t1c DESC NULLS LAST, t1i -- !query 15 schema struct -- !query 15 output -1 8 16 2014-05-05 1 8 16 2014-05-04 +1 8 16 2014-05-05 1 16 12 2014-06-04 1 16 12 2014-07-04 1 6 8 2014-04-04 +1 10 NULL 2014-05-04 1 10 NULL 2014-08-04 1 10 NULL 2014-09-04 1 10 NULL 2015-05-04 -1 10 NULL 2014-05-04 diff --git a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out index db01fa455735c..bae5d00cc8632 100644 --- a/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/subquery/in-subquery/not-in-joins.sql.out @@ -112,12 +112,12 @@ AND t1b != t3b AND t1d = t2d GROUP BY t1a, t1b, t1c, t3a, t3b, t3c HAVING count(distinct(t3a)) >= 1 -ORDER BY t1a +ORDER BY t1a, t3b -- !query 4 schema struct -- !query 4 output -val1c 8 16 1 10 12 val1c 8 16 1 6 12 +val1c 8 16 1 10 12 val1c 8 16 1 17 16 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 1af1a3652971c..7a7d52b21427a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -24,15 +24,15 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener -import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} -import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.{AccumulatorContext, Utils} private case class BigData(s: String) @@ -65,7 +65,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext maybeBlock.nonEmpty } - private def getNumInMemoryRelations(plan: LogicalPlan): Int = { + private def getNumInMemoryRelations(ds: Dataset[_]): Int = { + val plan = ds.queryExecution.withCachedData var sum = plan.collect { case _: InMemoryRelation => 1 }.sum plan.transformAllExpressions { case e: SubqueryExpression => @@ -187,7 +188,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assertCached(spark.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - getNumInMemoryRelations(spark.table("testData").queryExecution.withCachedData) + getNumInMemoryRelations(spark.table("testData")) } spark.catalog.cacheTable("testData") @@ -580,21 +581,21 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext localRelation.createOrReplaceTempView("localRelation") spark.catalog.cacheTable("localRelation") - assert(getNumInMemoryRelations(localRelation.queryExecution.withCachedData) == 1) + assert(getNumInMemoryRelations(localRelation) == 1) } test("SPARK-19093 Caching in side subquery") { withTempView("t1") { Seq(1).toDF("c1").createOrReplaceTempView("t1") spark.catalog.cacheTable("t1") - val cachedPlan = + val ds = sql( """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t1) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan) == 2) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) } } @@ -610,17 +611,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("t4") // Nested predicate subquery - val cachedPlan = + val ds = sql( """ |SELECT * FROM t1 |WHERE |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan) == 3) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 3) // Scalar subquery and predicate subquery - val cachedPlan2 = + val ds2 = sql( """ |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) @@ -630,8 +631,43 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext |EXISTS (SELECT c1 FROM t3) |OR |c1 IN (SELECT c1 FROM t4) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan2) == 4) + """.stripMargin) + assert(getNumInMemoryRelations(ds2) == 4) + } + } + + test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") { + withTable("t") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) + sql(s"CREATE TABLE t USING parquet LOCATION '$path'") + spark.catalog.cacheTable("t") + spark.table("t").select($"i").cache() + checkAnswer(spark.table("t").select($"i"), Row(1)) + assertCached(spark.table("t").select($"i")) + + Utils.deleteRecursively(path) + spark.sessionState.catalog.refreshTable(TableIdentifier("t")) + spark.catalog.uncacheTable("t") + assert(spark.table("t").select($"i").count() == 0) + assert(getNumInMemoryRelations(spark.table("t").select($"i")) == 0) + } + } + } + + test("refreshByPath should refresh all cached plans with the specified path") { + withTempDir { dir => + val path = dir.getCanonicalPath() + + spark.range(10).write.mode("overwrite").parquet(path) + spark.read.parquet(path).cache() + spark.read.parquet(path).filter($"id" > 4).cache() + assert(spark.read.parquet(path).filter($"id" > 4).count() == 5) + + spark.range(20).write.mode("overwrite").parquet(path) + spark.catalog.refreshByPath(path) + assert(spark.read.parquet(path).count() == 20) + assert(spark.read.parquet(path).filter($"id" > 4).count() == 15) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index ee280a313cc04..b0f398dab7455 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -712,4 +712,18 @@ class ColumnExpressionSuite extends QueryTest with SharedSQLContext { testData2.select($"a".bitwiseXOR($"b").bitwiseXOR(39)), testData2.collect().toSeq.map(r => Row(r.getInt(0) ^ r.getInt(1) ^ 39))) } + + test("typedLit") { + val df = Seq(Tuple1(0)).toDF("a") + // Only check the types `lit` cannot handle + checkAnswer( + df.select(typedLit(Seq(1, 2, 3))), + Row(Seq(1, 2, 3)) :: Nil) + checkAnswer( + df.select(typedLit(Map("a" -> 1, "b" -> 2))), + Row(Map("a" -> 1, "b" -> 2)) :: Nil) + checkAnswer( + df.select(typedLit(("a", 2, 1.0))), + Row(Row("a", 2, 1.0)) :: Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 9c39b3c7f09bf..cdea3b9a0f79f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions.{from_json, struct, to_json} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{CalendarIntervalType, IntegerType, StructType, TimestampType} +import org.apache.spark.sql.types._ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { import testImplicits._ @@ -133,6 +133,29 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { Row(null) :: Nil) } + test("from_json invalid schema") { + val df = Seq("""{"a" 1}""").toDS() + val schema = ArrayType(StringType) + val message = intercept[AnalysisException] { + df.select(from_json($"value", schema)) + }.getMessage + + assert(message.contains( + "Input schema array must be a struct or an array of structs.")) + } + + test("from_json array support") { + val df = Seq("""[{"a": 1, "b": "a"}, {"a": 2}, { }]""").toDS() + val schema = ArrayType( + StructType( + StructField("a", IntegerType) :: + StructField("b", StringType) :: Nil)) + + checkAnswer( + df.select(from_json($"value", schema)), + Row(Seq(Row(1, "a"), Row(2, null), Row(null, null)))) + } + test("to_json") { val df = Seq(Tuple1(Tuple1(1))).toDF("a") @@ -174,4 +197,27 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { .select(to_json($"struct").as("json")) checkAnswer(dfTwo, readBackTwo) } + + test("SPARK-19637 Support to_json in SQL") { + val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") + checkAnswer( + df1.selectExpr("to_json(a)"), + Row("""{"_1":1}""") :: Nil) + + val df2 = Seq(Tuple1(Tuple1(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))).toDF("a") + checkAnswer( + df2.selectExpr("to_json(a, map('timestampFormat', 'dd/MM/yyyy HH:mm'))"), + Row("""{"_1":"26/08/2015 18:00"}""") :: Nil) + + val errMsg1 = intercept[AnalysisException] { + df2.selectExpr("to_json(a, named_struct('a', 1))") + } + assert(errMsg1.getMessage.startsWith("Must use a map() function for options")) + + val errMsg2 = intercept[AnalysisException] { + df2.selectExpr("to_json(a, map('a', 1))") + } + assert(errMsg2.getMessage.startsWith( + "A type of keys and values in map() must be string, but got")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 34fa626e00e31..f9808834df4a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -312,13 +312,23 @@ object QueryTest { sparkAnswer: Seq[Row], isSorted: Boolean = false): Option[String] = { if (prepareAnswer(expectedAnswer, isSorted) != prepareAnswer(sparkAnswer, isSorted)) { + val getRowType: Option[Row] => String = row => + row.map(row => + if (row.schema == null) { + "struct<>" + } else { + s"${row.schema.catalogString}" + }).getOrElse("struct<>") + val errorMessage = s""" |== Results == |${sideBySide( s"== Correct Answer - ${expectedAnswer.size} ==" +: + getRowType(expectedAnswer.headOption) +: prepareAnswer(expectedAnswer, isSorted).map(_.toString()), s"== Spark Answer - ${sparkAnswer.size} ==" +: + getRowType(sparkAnswer.headOption) +: prepareAnswer(sparkAnswer, isSorted).map(_.toString())).mkString("\n")} """.stripMargin return Some(errorMessage) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 0b3da9aa8fbee..68ababcd11027 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -228,12 +228,12 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { if (isSorted(df.queryExecution.analyzed)) (schema, answer) else (schema, answer.sorted) } catch { - case a: AnalysisException if a.plan.nonEmpty => + case a: AnalysisException => // Do not output the logical plan tree which contains expression IDs. // Also implement a crude way of masking expression IDs in the error message // with a generic pattern "###". - (StructType(Seq.empty), - Seq(a.getClass.getName, a.getSimpleMessage.replaceAll("#\\d+", "#x"))) + val msg = if (a.plan.nonEmpty) a.getSimpleMessage else a.getMessage + (StructType(Seq.empty), Seq(a.getClass.getName, msg.replaceAll("#\\d+", "#x"))) case NonFatal(e) => // If there is an exception, put the exception class followed by the message. (StructType(Seq.empty), Seq(e.getClass.getName, e.getMessage)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index bbb31dbc8f3de..1f547c5a2a8ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -112,30 +112,6 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared spark.sessionState.conf.autoBroadcastJoinThreshold) } - test("estimates the size of limit") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => - val df = sql(s"""SELECT * FROM test limit $limit""") - - val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => - g.stats(conf).sizeInBytes - } - assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesGlobalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") - - val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => - l.stats(conf).sizeInBytes - } - assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesLocalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") - } - } - } - test("column stats round trip serialization") { // Make sure we serialize and then deserialize and we will get the result data val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 2d95cb6d64a87..0e5a1dc6ab629 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -172,7 +172,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { var e = intercept[AnalysisException] { sql(s"INSERT INTO TABLE $viewName SELECT 1") }.getMessage - assert(e.contains("Inserting into an RDD-based table is not allowed")) + assert(e.contains("Inserting into a view is not allowed. View: `default`.`testview`")) val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/employee.dat") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index bb6c486e880a0..d44a6e41cb347 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -210,6 +210,17 @@ class SparkSqlParserSuite extends PlanTest { "no viable alternative at input") } + test("create view as insert into table") { + // Single insert query + intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: CREATE VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("CREATE VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: CREATE VIEW ... AS FROM ... [INSERT INTO ...]+") + } + test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { assertEqual("describe table t", DescribeTableCommand( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 76bb9e5929a71..4b73b078da38e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.command +import java.net.URI + import scala.reflect.{classTag, ClassTag} import org.apache.spark.sql.catalyst.TableIdentifier @@ -317,7 +319,7 @@ class DDLCommandSuite extends PlanTest { val query = "CREATE EXTERNAL TABLE my_tab LOCATION '/something/anything'" val ct = parseAs[CreateTable](query) assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) - assert(ct.tableDesc.storage.locationUri == Some("/something/anything")) + assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } test("create hive table - property values must be set") { @@ -334,7 +336,7 @@ class DDLCommandSuite extends PlanTest { val query = "CREATE TABLE my_tab LOCATION '/something/anything'" val ct = parseAs[CreateTable](query) assert(ct.tableDesc.tableType == CatalogTableType.EXTERNAL) - assert(ct.tableDesc.storage.locationUri == Some("/something/anything")) + assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } test("create table - with partitioned by") { @@ -409,7 +411,7 @@ class DDLCommandSuite extends PlanTest { val expectedTableDesc = CatalogTable( identifier = TableIdentifier("my_tab"), tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy(locationUri = Some("/tmp/file")), + storage = CatalogStorageFormat.empty.copy(locationUri = Some(new URI("/tmp/file"))), schema = new StructType().add("a", IntegerType).add("b", StringType), provider = Some("parquet")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index b44f20e367f0a..b2199fdf90e5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -26,9 +26,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{DatabaseAlreadyExistsException, FunctionRegistry, NoSuchPartitionException, NoSuchTableException, TempTableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogDatabase, CatalogStorageFormat} -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.CATALOG_IMPLEMENTATION @@ -72,7 +70,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase( - CatalogDatabase(name, "", spark.sessionState.conf.warehousePath, Map()), + CatalogDatabase( + name, "", CatalogUtils.stringToURI(spark.sessionState.conf.warehousePath), Map()), ignoreIfExists = false) } @@ -133,11 +132,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - private def makeQualifiedPath(path: String): String = { + private def makeQualifiedPath(path: String): URI = { // copy-paste from SessionCatalog val hadoopPath = new Path(path) val fs = hadoopPath.getFileSystem(sparkContext.hadoopConfiguration) - fs.makeQualified(hadoopPath).toString + fs.makeQualified(hadoopPath).toUri } test("Create Database using Default Warehouse Path") { @@ -449,7 +448,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"DESCRIBE DATABASE EXTENDED $dbName"), Row("Database Name", dbNameWithoutBackTicks) :: Row("Description", "") :: - Row("Location", location) :: + Row("Location", CatalogUtils.URIToString(location)) :: Row("Properties", "") :: Nil) sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") @@ -458,7 +457,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"DESCRIBE DATABASE EXTENDED $dbName"), Row("Database Name", dbNameWithoutBackTicks) :: Row("Description", "") :: - Row("Location", location) :: + Row("Location", CatalogUtils.URIToString(location)) :: Row("Properties", "((a,a), (b,b), (c,c))") :: Nil) sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") @@ -467,7 +466,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql(s"DESCRIBE DATABASE EXTENDED $dbName"), Row("Database Name", dbNameWithoutBackTicks) :: Row("Description", "") :: - Row("Location", location) :: + Row("Location", CatalogUtils.URIToString(location)) :: Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) } finally { catalog.reset() @@ -1094,7 +1093,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isDefined) assert(catalog.getPartition(tableIdent, partSpec).storage.properties.isEmpty) // Verify that the location is set to the expected string - def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { + def verifyLocation(expected: URI, spec: Option[TablePartitionSpec] = None): Unit = { val storageFormat = spec .map { s => catalog.getPartition(tableIdent, s).storage } .getOrElse { catalog.getTableMetadata(tableIdent).storage } @@ -1111,17 +1110,17 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } // set table location sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") - verifyLocation("/path/to/your/lovely/heart") + verifyLocation(new URI("/path/to/your/lovely/heart")) // set table partition location sql("ALTER TABLE dbx.tab1 PARTITION (a='1', b='2') SET LOCATION '/path/to/part/ways'") - verifyLocation("/path/to/part/ways", Some(partSpec)) + verifyLocation(new URI("/path/to/part/ways"), Some(partSpec)) // set table location without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") - verifyLocation("/swanky/steak/place") + verifyLocation(new URI("/swanky/steak/place")) // set table partition location without explicitly specifying database sql("ALTER TABLE tab1 PARTITION (a='1', b='2') SET LOCATION 'vienna'") - verifyLocation("vienna", Some(partSpec)) + verifyLocation(new URI("vienna"), Some(partSpec)) // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE dbx.does_not_exist SET LOCATION '/mister/spark'") @@ -1255,7 +1254,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { "PARTITION (a='2', b='6') LOCATION 'paris' PARTITION (a='3', b='7')") assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isDefined) - assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option("paris")) + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Option(new URI("paris"))) assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isDefined) // add partitions without explicitly specifying database @@ -1819,7 +1818,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { // SET LOCATION won't move data from previous table path to new table path. assert(spark.table("tbl").count() == 0) // the previous table path should be still there. - assert(new File(new URI(defaultTablePath)).exists()) + assert(new File(defaultTablePath).exists()) sql("INSERT INTO tbl SELECT 2") checkAnswer(spark.table("tbl"), Row(2)) @@ -1836,36 +1835,34 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("insert data to a data source table which has a not existed location should succeed") { withTable("t") { withTempDir { dir => - val path = dir.toURI.toString.stripSuffix("/") spark.sql( s""" |CREATE TABLE t(a string, b int) |USING parquet - |OPTIONS(path "$path") + |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == path) + assert(table.location == new URI(dir.getAbsolutePath)) dir.delete - val tableLocFile = new File(table.location.stripPrefix("file:")) - assert(!tableLocFile.exists) + assert(!dir.exists) spark.sql("INSERT INTO TABLE t SELECT 'c', 1") - assert(tableLocFile.exists) + assert(dir.exists) checkAnswer(spark.table("t"), Row("c", 1) :: Nil) Utils.deleteRecursively(dir) - assert(!tableLocFile.exists) + assert(!dir.exists) spark.sql("INSERT OVERWRITE TABLE t SELECT 'c', 1") - assert(tableLocFile.exists) + assert(dir.exists) checkAnswer(spark.table("t"), Row("c", 1) :: Nil) val newDirFile = new File(dir, "x") - val newDir = newDirFile.toURI.toString + val newDir = newDirFile.getAbsolutePath spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") spark.sessionState.catalog.refreshTable(TableIdentifier("t")) val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table1.location == newDir) + assert(table1.location == new URI(newDir)) assert(!newDirFile.exists) spark.sql("INSERT INTO TABLE t SELECT 'c', 1") @@ -1878,16 +1875,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("insert into a data source table with no existed partition location should succeed") { withTable("t") { withTempDir { dir => - val path = dir.toURI.toString.stripSuffix("/") spark.sql( s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet |PARTITIONED BY(a, b) - |LOCATION "$path" + |LOCATION "$dir" """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == path) + assert(table.location == new URI(dir.getAbsolutePath)) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1906,21 +1902,20 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { test("read data from a data source table which has a not existed location should succeed") { withTable("t") { withTempDir { dir => - val path = dir.toURI.toString.stripSuffix("/") spark.sql( s""" |CREATE TABLE t(a string, b int) |USING parquet - |OPTIONS(path "$path") + |OPTIONS(path "$dir") """.stripMargin) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location == path) + assert(table.location == new URI(dir.getAbsolutePath)) dir.delete() checkAnswer(spark.table("t"), Nil) val newDirFile = new File(dir, "x") - val newDir = newDirFile.toURI.toString + val newDir = newDirFile.toURI spark.sql(s"ALTER TABLE t SET LOCATION '$newDir'") val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) @@ -1939,7 +1934,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { |CREATE TABLE t(a int, b int, c int, d int) |USING parquet |PARTITIONED BY(a, b) - |LOCATION "${dir.toURI}" + |LOCATION "$dir" """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) @@ -1952,4 +1947,154 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } } + + Seq(true, false).foreach { shouldDelete => + val tcName = if (shouldDelete) "non-existent" else "existed" + test(s"CTAS for external data source table with a $tcName location") { + withTable("t", "t1") { + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t + |USING parquet + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == new URI(dir.getAbsolutePath)) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t1 + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == new URI(dir.getAbsolutePath)) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + } + + Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => + test(s"data source table:partition column name containing $specialChars") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, `$specialChars` string) + |USING parquet + |PARTITIONED BY(`$specialChars`) + |LOCATION '$dir' + """.stripMargin) + + assert(dir.listFiles().isEmpty) + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") + val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" + val partFile = new File(dir, partEscaped) + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Nil) + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"location uri contains $specialChars for datasource table") { + withTable("t", "t1") { + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING parquet + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == new Path(loc.getAbsolutePath).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t SELECT 1") + assert(loc.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1") :: Nil) + } + + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING parquet + |PARTITIONED BY(b) + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == new Path(loc.getAbsolutePath).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + val partFile = new File(loc, "b=2") + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) + + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") + assert(!partFile1.exists()) + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"location uri contains $specialChars for database") { + try { + withTable("t") { + withTempDir { dir => + val loc = new File(dir, specialChars) + spark.sql(s"CREATE DATABASE tmpdb LOCATION '$loc'") + spark.sql("USE tmpdb") + + import testImplicits._ + Seq(1).toDF("a").write.saveAsTable("t") + val tblloc = new File(loc, "t") + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + val tblPath = new Path(tblloc.getAbsolutePath) + val fs = tblPath.getFileSystem(spark.sessionState.newHadoopConf()) + assert(table.location == fs.makeQualified(tblPath).toUri) + assert(tblloc.listFiles().nonEmpty) + } + } + } finally { + spark.sql("DROP DATABASE IF EXISTS tmpdb") + } + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala index efbfc2417dfef..7ea4064927576 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileIndexSuite.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.apache.spark.metrics.source.HiveCatalogMetrics import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext class FileIndexSuite extends SharedSQLContext { @@ -179,6 +180,21 @@ class FileIndexSuite extends SharedSQLContext { } } + test("InMemoryFileIndex with empty rootPaths when PARALLEL_PARTITION_DISCOVERY_THRESHOLD" + + "is a nonpositive number") { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "0") { + new InMemoryFileIndex(spark, Seq.empty, Map.empty, None) + } + + val e = intercept[IllegalArgumentException] { + withSQLConf(SQLConf.PARALLEL_PARTITION_DISCOVERY_THRESHOLD.key -> "-1") { + new InMemoryFileIndex(spark, Seq.empty, Map.empty, None) + } + }.getMessage + assert(e.contains("The maximum number of paths allowed for listing files at " + + "driver side must not be negative")) + } + test("refresh for InMemoryFileIndex with FileStatusCache") { withTempDir { dir => val fileStatusCache = FileStatusCache.getOrCreate(spark) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index d94eb66201112..eaedede349134 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -742,10 +742,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(iso8601timestampsPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val iso8601Timestamps = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(iso8601timestampsPath) val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd'T'HH:mm:ss.SSSZZ", Locale.US) @@ -775,10 +776,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(iso8601datesPath) // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val iso8601dates = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(iso8601datesPath) val iso8501 = FastDateFormat.getInstance("yyyy-MM-dd", Locale.US) @@ -833,10 +835,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(datesWithFormatPath) // This will load back the dates as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringDatesWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(datesWithFormatPath) val expectedStringDatesWithFormat = Seq( Row("2015/08/26"), @@ -864,10 +867,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(timestampsWithFormatPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringTimestampsWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(timestampsWithFormatPath) val expectedStringTimestampsWithFormat = Seq( Row("2015/08/26 18:00"), @@ -896,10 +900,11 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { .save(timestampsWithFormatPath) // This will load back the timestamps as string. + val stringSchema = StructType(StructField("date", StringType, true) :: Nil) val stringTimestampsWithFormat = spark.read .format("csv") + .schema(stringSchema) .option("header", "true") - .option("inferSchema", "false") .load(timestampsWithFormatPath) val expectedStringTimestampsWithFormat = Seq( Row("2015/08/27 01:00"), @@ -1072,14 +1077,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("Empty file produces empty dataframe with empty schema - wholeFile option") { - withTempPath { path => - path.createNewFile() - + test("Empty file produces empty dataframe with empty schema") { + Seq(false, true).foreach { wholeFile => val df = spark.read.format("csv") .option("header", true) - .option("wholeFile", true) - .load(path.getAbsolutePath) + .option("wholeFile", wholeFile) + .load(testFile(emptyFile)) assert(df.schema === spark.emptyDataFrame.schema) checkAnswer(df, spark.emptyDataFrame) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala index 420cff878fa0d..88cb8a0bad21e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.io.File import java.math.BigInteger import java.sql.{Date, Timestamp} +import java.util.{Calendar, TimeZone} import scala.collection.mutable.ArrayBuffer @@ -51,9 +52,12 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val defaultPartitionName = ExternalCatalogUtils.DEFAULT_PARTITION_NAME + val timeZone = TimeZone.getDefault() + val timeZoneId = timeZone.getID + test("column type inference") { - def check(raw: String, literal: Literal): Unit = { - assert(inferPartitionColumnValue(raw, true) === literal) + def check(raw: String, literal: Literal, timeZone: TimeZone = timeZone): Unit = { + assert(inferPartitionColumnValue(raw, true, timeZone) === literal) } check("10", Literal.create(10, IntegerType)) @@ -66,6 +70,14 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha check("1990-02-24", Literal.create(Date.valueOf("1990-02-24"), DateType)) check("1990-02-24 12:00:30", Literal.create(Timestamp.valueOf("1990-02-24 12:00:30"), TimestampType)) + + val c = Calendar.getInstance(TimeZone.getTimeZone("GMT")) + c.set(1990, 1, 24, 12, 0, 30) + c.set(Calendar.MILLISECOND, 0) + check("1990-02-24 12:00:30", + Literal.create(new Timestamp(c.getTimeInMillis), TimestampType), + TimeZone.getTimeZone("GMT")) + check(defaultPartitionName, Literal.create(null, NullType)) } @@ -77,7 +89,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha "hdfs://host:9000/path/a=10.5/b=hello") var exception = intercept[AssertionError] { - parsePartitions(paths.map(new Path(_)), true, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), true, Set.empty[Path], timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -90,7 +102,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/"))) + Set(new Path("hdfs://host:9000/path/")), + timeZoneId) // Valid paths = Seq( @@ -102,7 +115,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/something=true/table"))) + Set(new Path("hdfs://host:9000/path/something=true/table")), + timeZoneId) // Valid paths = Seq( @@ -114,7 +128,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/table=true"))) + Set(new Path("hdfs://host:9000/path/table=true")), + timeZoneId) // Invalid paths = Seq( @@ -126,7 +141,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/path/"))) + Set(new Path("hdfs://host:9000/path/")), + timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) @@ -145,20 +161,21 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - Set(new Path("hdfs://host:9000/tmp/tables/"))) + Set(new Path("hdfs://host:9000/tmp/tables/")), + timeZoneId) } assert(exception.getMessage().contains("Conflicting directory structures detected")) } test("parse partition") { def check(path: String, expected: Option[PartitionValues]): Unit = { - val actual = parsePartition(new Path(path), true, Set.empty[Path])._1 + val actual = parsePartition(new Path(path), true, Set.empty[Path], timeZone)._1 assert(expected === actual) } def checkThrows[T <: Throwable: Manifest](path: String, expected: String): Unit = { val message = intercept[T] { - parsePartition(new Path(path), true, Set.empty[Path]) + parsePartition(new Path(path), true, Set.empty[Path], timeZone) }.getMessage assert(message.contains(expected)) @@ -201,7 +218,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val partitionSpec1: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), typeInference = true, - basePaths = Set(new Path("file://path/a=10")))._1 + basePaths = Set(new Path("file://path/a=10")), + timeZone = timeZone)._1 assert(partitionSpec1.isEmpty) @@ -209,7 +227,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val partitionSpec2: Option[PartitionValues] = parsePartition( path = new Path("file://path/a=10"), typeInference = true, - basePaths = Set(new Path("file://path")))._1 + basePaths = Set(new Path("file://path")), + timeZone = timeZone)._1 assert(partitionSpec2 == Option(PartitionValues( @@ -226,7 +245,8 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha parsePartitions( paths.map(new Path(_)), true, - rootPaths) + rootPaths, + timeZoneId) assert(actualSpec === spec) } @@ -307,7 +327,7 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha test("parse partitions with type inference disabled") { def check(paths: Seq[String], spec: PartitionSpec): Unit = { val actualSpec = - parsePartitions(paths.map(new Path(_)), false, Set.empty[Path]) + parsePartitions(paths.map(new Path(_)), false, Set.empty[Path], timeZoneId) assert(actualSpec === spec) } @@ -686,6 +706,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val fields = schema.map(f => Column(f.name).cast(f.dataType)) checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } + + withTempPath { dir => + df.write.option("timeZone", "GMT") + .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name).cast(f.dataType)) + checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + } } test("Various inferred partition value types") { @@ -720,6 +747,13 @@ class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with Sha val fields = schema.map(f => Column(f.name)) checkAnswer(spark.read.load(dir.toString).select(fields: _*), row) } + + withTempPath { dir => + df.write.option("timeZone", "GMT") + .format("parquet").partitionBy(partitionColumns.map(_.name): _*).save(dir.toString) + val fields = schema.map(f => Column(f.name)) + checkAnswer(spark.read.option("timeZone", "GMT").load(dir.toString).select(fields: _*), row) + } } test("SPARK-8037: Ignores files whose name starts with dot") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index dc4e935f3d9ba..e848f74e3159f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.sql.execution.streaming.state import java.io.{File, IOException} import java.net.URI +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.Random +import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path, RawLocalFileSystem} import org.scalatest.{BeforeAndAfter, PrivateMethodTester} @@ -293,6 +295,11 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val provider = newStoreProvider(hadoopConf = conf) provider.getStore(0).commit() provider.getStore(0).commit() + + // Verify we don't leak temp files + val tempFiles = FileUtils.listFiles(new File(provider.id.checkpointLocation), + null, true).asScala.filter(_.getName.startsWith("temp-")) + assert(tempFiles.isEmpty) } test("corrupted file handling") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala index 75723d0abcfcc..989a7f2698171 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/CatalogSuite.scala @@ -459,7 +459,7 @@ class CatalogSuite options = Map("path" -> dir.getAbsolutePath)) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.tableType == CatalogTableType.EXTERNAL) - assert(table.storage.locationUri.get == dir.getAbsolutePath) + assert(table.storage.locationUri.get == new URI(dir.getAbsolutePath)) Seq((1)).toDF("i").write.insertInto("t") assert(dir.exists() && dir.listFiles().nonEmpty) @@ -481,7 +481,7 @@ class CatalogSuite options = Map.empty[String, String]) val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) assert(table.tableType == CatalogTableType.MANAGED) - val tablePath = new File(new URI(table.storage.locationUri.get)) + val tablePath = new File(table.storage.locationUri.get) assert(tablePath.exists() && tablePath.listFiles().isEmpty) Seq((1)).toDF("i").write.insertInto("t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 9082261af7b00..93f3efe2ccc4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -92,7 +92,7 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { def tableDir: File = { val identifier = spark.sessionState.sqlParser.parseTableIdentifier("bucketed_table") - new File(URI.create(spark.sessionState.catalog.defaultTablePath(identifier))) + new File(spark.sessionState.catalog.defaultTablePath(identifier)) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala index bf7fabe33266b..f251290583c5e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala @@ -18,11 +18,13 @@ package org.apache.spark.sql.sources import java.io.File +import java.sql.Timestamp import org.apache.hadoop.mapreduce.TaskAttemptContext import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils import org.apache.spark.sql.execution.datasources.SQLHadoopMapReduceCommitProtocol import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -124,6 +126,39 @@ class PartitionedWriteSuite extends QueryTest with SharedSQLContext { } } + test("timeZone setting in dynamic partition writes") { + def checkPartitionValues(file: File, expected: String): Unit = { + val dir = file.getParentFile() + val value = ExternalCatalogUtils.unescapePathName( + dir.getName.substring(dir.getName.indexOf("=") + 1)) + assert(value == expected) + } + val ts = Timestamp.valueOf("2016-12-01 00:00:00") + val df = Seq((1, ts)).toDF("i", "ts") + withTempPath { f => + df.write.partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + checkPartitionValues(files.head, "2016-12-01 00:00:00") + } + withTempPath { f => + df.write.option("timeZone", "GMT").partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + // use timeZone option "GMT" to format partition value. + checkPartitionValues(files.head, "2016-12-01 08:00:00") + } + withTempPath { f => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> "GMT") { + df.write.partitionBy("ts").parquet(f.getAbsolutePath) + val files = recursiveList(f).filter(_.getAbsolutePath.endsWith("parquet")) + assert(files.length == 1) + // if there isn't timeZone option, then use session local timezone. + checkPartitionValues(files.head, "2016-12-01 08:00:00") + } + } + } + /** Lists files recursively. */ private def recursiveList(f: File): Array[File] = { require(f.isDirectory) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala index faf9afc49a2d3..7ab339e005295 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PathOptionSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.sources +import java.net.URI + import org.apache.hadoop.fs.Path import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, Metadata, MetadataBuilder, StructType} @@ -78,7 +81,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { // should exist even path option is not specified when creating table withTable("src") { sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") - assert(getPathOption("src") == Some(defaultTablePath("src"))) + assert(getPathOption("src") == Some(CatalogUtils.URIToString(defaultTablePath("src")))) } } @@ -105,7 +108,8 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { |USING ${classOf[TestOptionsSource].getCanonicalName} |AS SELECT 1 """.stripMargin) - assert(spark.table("src").schema.head.metadata.getString("path") == defaultTablePath("src")) + assert(spark.table("src").schema.head.metadata.getString("path") == + CatalogUtils.URIToString(defaultTablePath("src"))) } } @@ -123,7 +127,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { withTable("src", "src2") { sql(s"CREATE TABLE src(i int) USING ${classOf[TestOptionsSource].getCanonicalName}") sql("ALTER TABLE src RENAME TO src2") - assert(getPathOption("src2") == Some(defaultTablePath("src2"))) + assert(getPathOption("src2") == Some(CatalogUtils.URIToString(defaultTablePath("src2")))) } } @@ -133,7 +137,7 @@ class PathOptionSuite extends DataSourceTest with SharedSQLContext { }.head } - private def defaultTablePath(tableName: String): String = { + private def defaultTablePath(tableName: String): URI = { spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 5110d89c85b15..1586850c77fca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -52,10 +52,7 @@ abstract class FileStreamSourceTest query.nonEmpty, "Cannot add data when there is no query for finding the active file stream source") - val sources = query.get.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] => - source.asInstanceOf[FileStreamSource] - } + val sources = getSourcesFromStreamingQuery(query.get) if (sources.isEmpty) { throw new Exception( "Could not find file source in the StreamExecution logical plan to add data to") @@ -134,6 +131,14 @@ abstract class FileStreamSourceTest }.head } + protected def getSourcesFromStreamingQuery(query: StreamExecution): Seq[FileStreamSource] = { + query.logicalPlan.collect { + case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] => + source.asInstanceOf[FileStreamSource] + } + } + + protected def withTempDirs(body: (File, File) => Unit) { val src = Utils.createTempDir(namePrefix = "streaming.src") val tmp = Utils.createTempDir(namePrefix = "streaming.tmp") @@ -388,9 +393,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { CheckAnswer("a", "b", "c", "d"), AssertOnQuery("seen files should contain only one entry") { streamExecution => - val source = streamExecution.logicalPlan.collect { case e: StreamingExecutionRelation => - e.source.asInstanceOf[FileStreamSource] - }.head + val source = getSourcesFromStreamingQuery(streamExecution).head assert(source.seenFiles.size == 1) true } @@ -662,6 +665,101 @@ class FileStreamSourceSuite extends FileStreamSourceTest { } } + test("read data from outputs of another streaming query") { + withSQLConf(SQLConf.FILE_SINK_LOG_COMPACT_INTERVAL.key -> "3") { + withTempDirs { case (outputDir, checkpointDir) => + // q1 is a streaming query that reads from memory and writes to text files + val q1Source = MemoryStream[String] + val q1 = + q1Source + .toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("text") + .start(outputDir.getCanonicalPath) + + // q2 is a streaming query that reads q1's text outputs + val q2 = + createFileStream("text", outputDir.getCanonicalPath).filter($"value" contains "keep") + + def q1AddData(data: String*): StreamAction = + Execute { _ => + q1Source.addData(data) + q1.processAllAvailable() + } + def q2ProcessAllAvailable(): StreamAction = Execute { q2 => q2.processAllAvailable() } + + testStream(q2)( + // batch 0 + q1AddData("drop1", "keep2"), + q2ProcessAllAvailable(), + CheckAnswer("keep2"), + + // batch 1 + Assert { + // create a text file that won't be on q1's sink log + // thus even if its content contains "keep", it should NOT appear in q2's answer + val shouldNotKeep = new File(outputDir, "should_not_keep.txt") + stringToFile(shouldNotKeep, "should_not_keep!!!") + shouldNotKeep.exists() + }, + q1AddData("keep3"), + q2ProcessAllAvailable(), + CheckAnswer("keep2", "keep3"), + + // batch 2: check that things work well when the sink log gets compacted + q1AddData("keep4"), + Assert { + // compact interval is 3, so file "2.compact" should exist + new File(outputDir, s"${FileStreamSink.metadataDir}/2.compact").exists() + }, + q2ProcessAllAvailable(), + CheckAnswer("keep2", "keep3", "keep4"), + + Execute { _ => q1.stop() } + ) + } + } + } + + test("start before another streaming query, and read its output") { + withTempDirs { case (outputDir, checkpointDir) => + // q1 is a streaming query that reads from memory and writes to text files + val q1Source = MemoryStream[String] + // define q1, but don't start it for now + val q1Write = + q1Source + .toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("text") + var q1: StreamingQuery = null + + val q2 = createFileStream("text", outputDir.getCanonicalPath).filter($"value" contains "keep") + + testStream(q2)( + AssertOnQuery { q2 => + val fileSource = getSourcesFromStreamingQuery(q2).head + // q1 has not started yet, verify that q2 doesn't know whether q1 has metadata + fileSource.sourceHasMetadata === None + }, + Execute { _ => + q1 = q1Write.start(outputDir.getCanonicalPath) + q1Source.addData("drop1", "keep2") + q1.processAllAvailable() + }, + AssertOnQuery { q2 => + q2.processAllAvailable() + val fileSource = getSourcesFromStreamingQuery(q2).head + // q1 has started, verify that q2 knows q1 has metadata by now + fileSource.sourceHasMetadata === Some(true) + }, + CheckAnswer("keep2"), + Execute { _ => q1.stop() } + ) + } + } + test("when schema inference is turned on, should read partition data") { def createFile(content: String, src: File, tmp: File): Unit = { val tempFile = Utils.tempFileWith(new File(tmp, "text")) @@ -755,10 +853,7 @@ class FileStreamSourceSuite extends FileStreamSourceTest { .streamingQuery q.processAllAvailable() val memorySink = q.sink.asInstanceOf[MemorySink] - val fileSource = q.logicalPlan.collect { - case StreamingExecutionRelation(source, _) if source.isInstanceOf[FileStreamSource] => - source.asInstanceOf[FileStreamSource] - }.head + val fileSource = getSourcesFromStreamingQuery(q).head /** Check the data read in the last batch */ def checkLastBatchData(data: Int*): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index f44cfada29e2b..6dfcd8baba20e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.streaming +import java.io.{InterruptedIOException, IOException} +import java.util.concurrent.{CountDownLatch, TimeoutException, TimeUnit} + import scala.reflect.ClassTag import scala.util.control.ControlThrowable @@ -350,13 +353,45 @@ class StreamSuite extends StreamTest { } } } -} -/** - * A fake StreamSourceProvider thats creates a fake Source that cannot be reused. - */ -class FakeDefaultSource extends StreamSourceProvider { + test("handle IOException when the streaming thread is interrupted (pre Hadoop 2.8)") { + // This test uses a fake source to throw the same IOException as pre Hadoop 2.8 when the + // streaming thread is interrupted. We should handle it properly by not failing the query. + ThrowingIOExceptionLikeHadoop12074.createSourceLatch = new CountDownLatch(1) + val query = spark + .readStream + .format(classOf[ThrowingIOExceptionLikeHadoop12074].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingIOExceptionLikeHadoop12074.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingIOExceptionLikeHadoop12074.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } + test("handle InterruptedIOException when the streaming thread is interrupted (Hadoop 2.8+)") { + // This test uses a fake source to throw the same InterruptedIOException as Hadoop 2.8+ when the + // streaming thread is interrupted. We should handle it properly by not failing the query. + ThrowingInterruptedIOException.createSourceLatch = new CountDownLatch(1) + val query = spark + .readStream + .format(classOf[ThrowingInterruptedIOException].getName) + .load() + .writeStream + .format("console") + .start() + assert(ThrowingInterruptedIOException.createSourceLatch + .await(streamingTimeout.toMillis, TimeUnit.MILLISECONDS), + "ThrowingInterruptedIOException.createSource wasn't called before timeout") + query.stop() + assert(query.exception.isEmpty) + } +} + +abstract class FakeSource extends StreamSourceProvider { private val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) override def sourceSchema( @@ -364,6 +399,10 @@ class FakeDefaultSource extends StreamSourceProvider { schema: Option[StructType], providerName: String, parameters: Map[String, String]): (String, StructType) = ("fakeSource", fakeSchema) +} + +/** A fake StreamSourceProvider that creates a fake Source that cannot be reused. */ +class FakeDefaultSource extends FakeSource { override def createSource( spark: SQLContext, @@ -395,3 +434,63 @@ class FakeDefaultSource extends StreamSourceProvider { } } } + +/** A fake source that throws the same IOException like pre Hadoop 2.8 when it's interrupted. */ +class ThrowingIOExceptionLikeHadoop12074 extends FakeSource { + import ThrowingIOExceptionLikeHadoop12074._ + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case ie: InterruptedException => + throw new IOException(ie.toString) + } + } +} + +object ThrowingIOExceptionLikeHadoop12074 { + /** + * A latch to allow the user to wait until [[ThrowingIOExceptionLikeHadoop12074.createSource]] is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null +} + +/** A fake source that throws InterruptedIOException like Hadoop 2.8+ when it's interrupted. */ +class ThrowingInterruptedIOException extends FakeSource { + import ThrowingInterruptedIOException._ + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + createSourceLatch.countDown() + try { + Thread.sleep(30000) + throw new TimeoutException("sleep was not interrupted in 30 seconds") + } catch { + case ie: InterruptedException => + val iie = new InterruptedIOException(ie.toString) + iie.initCause(ie) + throw iie + } + } +} + +object ThrowingInterruptedIOException { + /** + * A latch to allow the user to wait until [[ThrowingInterruptedIOException.createSource]] is + * called. + */ + @volatile var createSourceLatch: CountDownLatch = null +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index af2f31a34d8da..60e2375a9817d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -208,6 +208,12 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { } } + /** Execute arbitrary code */ + object Execute { + def apply(func: StreamExecution => Any): AssertOnQuery = + AssertOnQuery(query => { func(query); true }) + } + class StreamManualClock(time: Long = 0L) extends ManualClock(time) with Serializable { private var waitStartTime: Option[Long] = None @@ -472,7 +478,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { case a: AssertOnQuery => verify(currentStream != null || lastStream != null, - "cannot assert when not stream has been started") + "cannot assert when no stream has been started") val streamToAssert = Option(currentStream).getOrElse(lastStream) verify(a.condition(streamToAssert), s"Assert on query failed: ${a.message}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 1525ad5fd5178..a0a2b2b4c9b3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.streaming import java.util.concurrent.CountDownLatch import org.apache.commons.lang3.RandomStringUtils +import org.mockito.Mockito._ import org.scalactic.TolerantNumerics import org.scalatest.concurrent.Eventually._ import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout +import org.scalatest.mock.MockitoSugar import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset} @@ -32,11 +34,11 @@ import org.apache.spark.SparkException import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.streaming.util.BlockingSource +import org.apache.spark.sql.streaming.util.{BlockingSource, MockSourceProvider} import org.apache.spark.util.ManualClock -class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { +class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging with MockitoSugar { import AwaitTerminationTester._ import testImplicits._ @@ -481,6 +483,75 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging { } } + test("StreamExecution should call stop() on sources when a stream is stopped") { + var calledStop = false + val source = new Source { + override def stop(): Unit = { + calledStop = true + } + override def getOffset: Option[Offset] = None + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.emptyDataFrame + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + + MockSourceProvider.withMockSources(source) { + val df = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + + testStream(df)(StopStream) + + assert(calledStop, "Did not call stop on source for stopped stream") + } + } + + testQuietly("SPARK-19774: StreamExecution should call stop() on sources when a stream fails") { + var calledStop = false + val source1 = new Source { + override def stop(): Unit = { + throw new RuntimeException("Oh no!") + } + override def getOffset: Option[Offset] = Some(LongOffset(1)) + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.range(2).toDF(MockSourceProvider.fakeSchema.fieldNames: _*) + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + val source2 = new Source { + override def stop(): Unit = { + calledStop = true + } + override def getOffset: Option[Offset] = None + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + spark.emptyDataFrame + } + override def schema: StructType = MockSourceProvider.fakeSchema + } + + MockSourceProvider.withMockSources(source1, source2) { + val df1 = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + .as[Int] + + val df2 = spark.readStream + .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + .load() + .as[Int] + + testStream(df1.union(df2).map(i => i / 0))( + AssertOnQuery { sq => + intercept[StreamingQueryException](sq.processAllAvailable()) + sq.exception.isDefined && !sq.isActive + } + ) + + assert(calledStop, "Did not call stop on source for stopped stream") + } + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala new file mode 100644 index 0000000000000..0bf05381a7f36 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/util/MockSourceProvider.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming.util + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} + +/** + * A StreamSourceProvider that provides mocked Sources for unit testing. Example usage: + * + * {{{ + * MockSourceProvider.withMockSources(source1, source2) { + * val df1 = spark.readStream + * .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + * .load() + * + * val df2 = spark.readStream + * .format("org.apache.spark.sql.streaming.util.MockSourceProvider") + * .load() + * + * df1.union(df2) + * ... + * } + * }}} + */ +class MockSourceProvider extends StreamSourceProvider { + override def sourceSchema( + spark: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + ("dummySource", MockSourceProvider.fakeSchema) + } + + override def createSource( + spark: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + MockSourceProvider.sourceProviderFunction() + } +} + +object MockSourceProvider { + // Function to generate sources. May provide multiple sources if the user implements such a + // function. + private var sourceProviderFunction: () => Source = _ + + final val fakeSchema = StructType(StructField("a", IntegerType) :: Nil) + + def withMockSources(source: Source, otherSources: Source*)(f: => Unit): Unit = { + var i = 0 + val sources = source +: otherSources + sourceProviderFunction = () => { + val source = sources(i % sources.length) + i += 1 + source + } + try { + f + } finally { + sourceProviderFunction = null + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala index 9f27d06dcb366..7c9ea7d393630 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/DataFrameCallbackSuite.scala @@ -60,7 +60,7 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { spark.listenerManager.unregister(listener) } - test("execute callback functions when a DataFrame action failed") { + testQuietly("execute callback functions when a DataFrame action failed") { val metrics = ArrayBuffer.empty[(String, QueryExecution, Exception)] val listener = new QueryExecutionListener { override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = { @@ -75,8 +75,6 @@ class DataFrameCallbackSuite extends QueryTest with SharedSQLContext { val errorUdf = udf[Int, Int] { _ => throw new RuntimeException("udf error") } val df = sparkContext.makeRDD(Seq(1 -> "a")).toDF("i", "j") - // Ignore the log when we are expecting an exception. - sparkContext.setLogLevel("FATAL") val e = intercept[SparkException](df.select(errorUdf($"i")).collect()) assert(metrics.length == 1) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 50bb44f7d4e6e..9ab4624594924 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -38,7 +38,7 @@ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.ExternalCatalogUtils.escapePathName import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.ColumnStat -import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.PartitioningUtils import org.apache.spark.sql.hive.client.HiveClient @@ -210,7 +210,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat tableDefinition.storage.locationUri.isEmpty val tableLocation = if (needDefaultTableLocation) { - Some(defaultTablePath(tableDefinition.identifier)) + Some(CatalogUtils.stringToURI(defaultTablePath(tableDefinition.identifier))) } else { tableDefinition.storage.locationUri } @@ -260,7 +260,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // However, in older version of Spark we already store table location in storage properties // with key "path". Here we keep this behaviour for backward compatibility. val storagePropsWithLocation = table.storage.properties ++ - table.storage.locationUri.map("path" -> _) + table.storage.locationUri.map("path" -> CatalogUtils.URIToString(_)) // converts the table metadata to Spark SQL specific format, i.e. set data schema, names and // bucket specification to empty. Note that partition columns are retained, so that we can @@ -285,7 +285,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // compatible format, which means the data source is file-based and must have a `path`. require(table.storage.locationUri.isDefined, "External file-based data source table must have a `path` entry in storage properties.") - Some(new Path(table.location).toUri.toString) + Some(table.location) } else { None } @@ -432,13 +432,13 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // // Please refer to https://issues.apache.org/jira/browse/SPARK-15269 for more details. val tempPath = { - val dbLocation = getDatabase(tableDefinition.database).locationUri + val dbLocation = new Path(getDatabase(tableDefinition.database).locationUri) new Path(dbLocation, tableDefinition.identifier.table + "-__PLACEHOLDER__") } try { client.createTable( - tableDefinition.withNewStorage(locationUri = Some(tempPath.toString)), + tableDefinition.withNewStorage(locationUri = Some(tempPath.toUri)), ignoreIfExists) } finally { FileSystem.get(tempPath.toUri, hadoopConf).delete(tempPath, true) @@ -563,7 +563,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // want to alter the table location to a file path, we will fail. This should be fixed // in the future. - val newLocation = tableDefinition.storage.locationUri + val newLocation = tableDefinition.storage.locationUri.map(CatalogUtils.URIToString(_)) val storageWithPathOption = tableDefinition.storage.copy( properties = tableDefinition.storage.properties ++ newLocation.map("path" -> _)) @@ -704,7 +704,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val storageWithLocation = { val tableLocation = getLocationFromStorageProps(table) // We pass None as `newPath` here, to remove the path option in storage properties. - updateLocationInStorageProps(table, newPath = None).copy(locationUri = tableLocation) + updateLocationInStorageProps(table, newPath = None).copy( + locationUri = tableLocation.map(CatalogUtils.stringToURI(_))) } val partitionProvider = table.properties.get(TABLE_PARTITION_PROVIDER) @@ -848,10 +849,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat // However, Hive metastore is not case preserving and will generate wrong partition location // with lower cased partition column names. Here we set the default partition location // manually to avoid this problem. - val partitionPath = p.storage.locationUri.map(uri => new Path(new URI(uri))).getOrElse { + val partitionPath = p.storage.locationUri.map(uri => new Path(uri)).getOrElse { ExternalCatalogUtils.generatePartitionPath(p.spec, partitionColumnNames, tablePath) } - p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri.toString))) + p.copy(storage = p.storage.copy(locationUri = Some(partitionPath.toUri))) } val lowerCasedParts = partsWithLocation.map(p => p.copy(spec = lowerCasePartitionSpec(p.spec))) client.createPartitions(db, table, lowerCasedParts, ignoreIfExists) @@ -890,7 +891,7 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val newParts = newSpecs.map { spec => val rightPath = renamePartitionDirectory(fs, tablePath, partitionColumnNames, spec) val partition = client.getPartition(db, table, lowerCasePartitionSpec(spec)) - partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toString))) + partition.copy(storage = partition.storage.copy(locationUri = Some(rightPath.toUri))) } alterPartitions(db, table, newParts) } @@ -1008,7 +1009,8 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat override def listPartitionsByFilter( db: String, table: String, - predicates: Seq[Expression]): Seq[CatalogTablePartition] = withClient { + predicates: Seq[Expression], + defaultTimeZoneId: String): Seq[CatalogTablePartition] = withClient { val rawTable = getRawTable(db, table) val catalogTable = restoreTableMetadata(rawTable) val partitionColumnNames = catalogTable.partitionColumnNames.toSet @@ -1034,7 +1036,9 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat val index = partitionSchema.indexWhere(_.name == att.name) BoundReference(index, partitionSchema(index).dataType, nullable = true) }) - clientPrunedPartitions.filter { p => boundPredicate(p.toRow(partitionSchema)) } + clientPrunedPartitions.filter { p => + boundPredicate(p.toRow(partitionSchema, defaultTimeZoneId)) + } } else { client.getPartitions(catalogTable).map { part => part.copy(spec = restorePartitionSpec(part.spec, partColNameMap)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 151a69aebf1d8..4d3b6c3cec1c6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -128,7 +128,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log QualifiedTableName(relation.tableMeta.database, relation.tableMeta.identifier.table) val lazyPruningEnabled = sparkSession.sqlContext.conf.manageFilesourcePartitions - val tablePath = new Path(new URI(relation.tableMeta.location)) + val tablePath = new Path(relation.tableMeta.location) val result = if (relation.isPartitioned) { val partitionSchema = relation.tableMeta.partitionSchema val rootPaths: Seq[Path] = if (lazyPruningEnabled) { @@ -141,7 +141,7 @@ private[hive] class HiveMetastoreCatalog(sparkSession: SparkSession) extends Log // locations,_omitting_ the table's base path. val paths = sparkSession.sharedState.externalCatalog .listPartitions(tableIdentifier.database, tableIdentifier.name) - .map(p => new Path(new URI(p.storage.locationUri.get))) + .map(p => new Path(p.storage.locationUri.get)) if (paths.isEmpty) { Seq(tablePath) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index c9be1b9d100b0..f1ea86890c210 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -199,6 +199,11 @@ private[sql] class HiveSessionCatalog( } } + // TODO Removes this method after implementing Spark native "histogram_numeric". + override def functionExists(name: FunctionIdentifier): Boolean = { + super.functionExists(name) || hiveFunctions.contains(name.funcName) + } + /** List of functions we pass over to Hive. Note that over time this list should go to 0. */ // We have a list of Hive built-in functions that we do not support. So, we will check // Hive's function registry and lazily load needed functions into our own function registry. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala index 624cfa206eeb2..b5ce027d51e73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveStrategies.scala @@ -133,7 +133,7 @@ class DetermineTableStats(session: SparkSession) extends Rule[LogicalPlan] { } else if (session.sessionState.conf.fallBackToHdfsForStatsEnabled) { try { val hadoopConf = session.sessionState.newHadoopConf() - val tablePath = new Path(new URI(table.location)) + val tablePath = new Path(table.location) val fs: FileSystem = tablePath.getFileSystem(hadoopConf) fs.getContentSummary(tablePath).getLength } catch { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index c326ac4cc1a53..469c9d84de054 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -96,6 +96,7 @@ private[hive] class HiveClientImpl( case hive.v1_0 => new Shim_v1_0() case hive.v1_1 => new Shim_v1_1() case hive.v1_2 => new Shim_v1_2() + case hive.v2_0 => new Shim_v2_0() } // Create an internal session state for this HiveClientImpl. @@ -268,16 +269,21 @@ private[hive] class HiveClientImpl( */ def withHiveState[A](f: => A): A = retryLocked { val original = Thread.currentThread().getContextClassLoader - // Set the thread local metastore client to the client associated with this HiveClientImpl. - Hive.set(client) + val originalConfLoader = state.getConf.getClassLoader // The classloader in clientLoader could be changed after addJar, always use the latest - // classloader + // classloader. We explicitly set the context class loader since "conf.setClassLoader" does + // not do that, and the Hive client libraries may need to load classes defined by the client's + // class loader. + Thread.currentThread().setContextClassLoader(clientLoader.classLoader) state.getConf.setClassLoader(clientLoader.classLoader) + // Set the thread local metastore client to the client associated with this HiveClientImpl. + Hive.set(client) // setCurrentSessionState will use the classLoader associated // with the HiveConf in `state` to override the context class loader of the current // thread. shim.setCurrentSessionState(state) val ret = try f finally { + state.getConf.setClassLoader(originalConfLoader) Thread.currentThread().setContextClassLoader(original) HiveCatalogMetrics.incrementHiveClientCalls(1) } @@ -311,7 +317,7 @@ private[hive] class HiveClientImpl( new HiveDatabase( database.name, database.description, - database.locationUri, + CatalogUtils.URIToString(database.locationUri), Option(database.properties).map(_.asJava).orNull), ignoreIfExists) } @@ -329,7 +335,7 @@ private[hive] class HiveClientImpl( new HiveDatabase( database.name, database.description, - database.locationUri, + CatalogUtils.URIToString(database.locationUri), Option(database.properties).map(_.asJava).orNull)) } @@ -338,7 +344,7 @@ private[hive] class HiveClientImpl( CatalogDatabase( name = d.getName, description = d.getDescription, - locationUri = d.getLocationUri, + locationUri = CatalogUtils.stringToURI(d.getLocationUri), properties = Option(d.getParameters).map(_.asScala.toMap).orNull) }.getOrElse(throw new NoSuchDatabaseException(dbName)) } @@ -404,7 +410,7 @@ private[hive] class HiveClientImpl( createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, storage = CatalogStorageFormat( - locationUri = shim.getDataLocation(h), + locationUri = shim.getDataLocation(h).map(CatalogUtils.stringToURI(_)), // To avoid ClassNotFound exception, we try our best to not get the format class, but get // the class name directly. However, for non-native tables, there is no interface to get // the format class name, so we may still throw ClassNotFound in this case. @@ -845,7 +851,8 @@ private[hive] object HiveClientImpl { conf.foreach(c => hiveTable.setOwner(c.getUser)) hiveTable.setCreateTime((table.createTime / 1000).toInt) hiveTable.setLastAccessTime((table.lastAccessTime / 1000).toInt) - table.storage.locationUri.foreach { loc => hiveTable.getTTable.getSd.setLocation(loc)} + table.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach { loc => + hiveTable.getTTable.getSd.setLocation(loc)} table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass) table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass) hiveTable.setSerializationLib( @@ -879,7 +886,7 @@ private[hive] object HiveClientImpl { } val storageDesc = new StorageDescriptor val serdeInfo = new SerDeInfo - p.storage.locationUri.foreach(storageDesc.setLocation) + p.storage.locationUri.map(CatalogUtils.URIToString(_)).foreach(storageDesc.setLocation) p.storage.inputFormat.foreach(storageDesc.setInputFormat) p.storage.outputFormat.foreach(storageDesc.setOutputFormat) p.storage.serde.foreach(serdeInfo.setSerializationLib) @@ -900,7 +907,7 @@ private[hive] object HiveClientImpl { CatalogTablePartition( spec = Option(hp.getSpec).map(_.asScala.toMap).getOrElse(Map.empty), storage = CatalogStorageFormat( - locationUri = Option(apiPartition.getSd.getLocation), + locationUri = Option(CatalogUtils.stringToURI(apiPartition.getSd.getLocation)), inputFormat = Option(apiPartition.getSd.getInputFormat), outputFormat = Option(apiPartition.getSd.getOutputFormat), serde = Option(apiPartition.getSd.getSerdeInfo.getSerializationLib), diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala index 9fe1c76d3325d..c6188fc683e77 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveShim.scala @@ -24,10 +24,9 @@ import java.util.{ArrayList => JArrayList, List => JList, Map => JMap, Set => JS import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ -import scala.util.Try import scala.util.control.NonFatal -import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.api.{Function => HiveFunction, FunctionType, MetaException, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver @@ -41,7 +40,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.analysis.NoSuchPermanentFunctionException -import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, FunctionResource, FunctionResourceType} +import org.apache.spark.sql.catalyst.catalog.{CatalogFunction, CatalogTablePartition, CatalogUtils, FunctionResource, FunctionResourceType} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{IntegralType, StringType} @@ -268,7 +267,7 @@ private[client] class Shim_v0_12 extends Shim with Logging { val table = hive.getTable(database, tableName) parts.foreach { s => val location = s.storage.locationUri.map( - uri => new Path(table.getPath, new Path(new URI(uri)))).orNull + uri => new Path(table.getPath, new Path(uri))).orNull val params = if (s.parameters.nonEmpty) s.parameters.asJava else null val spec = s.spec.asJava if (hive.getPartition(table, spec, false) != null && ignoreIfExists) { @@ -463,7 +462,7 @@ private[client] class Shim_v0_13 extends Shim_v0_12 { val addPartitionDesc = new AddPartitionDesc(db, table, ignoreIfExists) parts.zipWithIndex.foreach { case (s, i) => addPartitionDesc.addPartition( - s.spec.asJava, s.storage.locationUri.map(u => new Path(new URI(u)).toString).orNull) + s.spec.asJava, s.storage.locationUri.map(CatalogUtils.URIToString(_)).orNull) if (s.parameters.nonEmpty) { addPartitionDesc.getPartition(i).setPartParams(s.parameters.asJava) } @@ -833,3 +832,77 @@ private[client] class Shim_v1_2 extends Shim_v1_1 { } } + +private[client] class Shim_v2_0 extends Shim_v1_2 { + private lazy val loadPartitionMethod = + findMethod( + classOf[Hive], + "loadPartition", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadTableMethod = + findMethod( + classOf[Hive], + "loadTable", + classOf[Path], + classOf[String], + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JBoolean.TYPE) + private lazy val loadDynamicPartitionsMethod = + findMethod( + classOf[Hive], + "loadDynamicPartitions", + classOf[Path], + classOf[String], + classOf[JMap[String, String]], + JBoolean.TYPE, + JInteger.TYPE, + JBoolean.TYPE, + JBoolean.TYPE, + JLong.TYPE) + + override def loadPartition( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + inheritTableSpecs: Boolean, + isSkewedStoreAsSubdir: Boolean, + isSrcLocal: Boolean): Unit = { + loadPartitionMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + inheritTableSpecs: JBoolean, isSkewedStoreAsSubdir: JBoolean, + isSrcLocal: JBoolean, JBoolean.FALSE) + } + + override def loadTable( + hive: Hive, + loadPath: Path, + tableName: String, + replace: Boolean, + isSrcLocal: Boolean): Unit = { + loadTableMethod.invoke(hive, loadPath, tableName, replace: JBoolean, isSrcLocal: JBoolean, + JBoolean.FALSE, JBoolean.FALSE) + } + + override def loadDynamicPartitions( + hive: Hive, + loadPath: Path, + tableName: String, + partSpec: JMap[String, String], + replace: Boolean, + numDP: Int, + listBucketingEnabled: Boolean): Unit = { + loadDynamicPartitionsMethod.invoke(hive, loadPath, tableName, partSpec, replace: JBoolean, + numDP: JInteger, listBucketingEnabled: JBoolean, JBoolean.FALSE, 0L: JLong) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala index d2487a2c034c0..6f69a4adf29d5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/IsolatedClientLoader.scala @@ -94,6 +94,7 @@ private[hive] object IsolatedClientLoader extends Logging { case "1.0" | "1.0.0" => hive.v1_0 case "1.1" | "1.1.0" => hive.v1_1 case "1.2" | "1.2.0" | "1.2.1" => hive.v1_2 + case "2.0" | "2.0.0" | "2.0.1" => hive.v2_0 } private def downloadVersion( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala index 4e2193b6abc3f..790ad74e6639e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/package.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.hive /** Support for interacting with different versions of the HiveMetastoreClient */ package object client { - private[hive] abstract class HiveVersion( + private[hive] sealed abstract class HiveVersion( val fullVersion: String, val extraDeps: Seq[String] = Nil, val exclusions: Seq[String] = Nil) @@ -62,6 +62,12 @@ package object client { "org.pentaho:pentaho-aggdesigner-algorithm", "net.hydromatic:linq4j", "net.hydromatic:quidem")) + + case object v2_0 extends HiveVersion("2.0.1", + exclusions = Seq("org.apache.curator:*", + "org.pentaho:pentaho-aggdesigner-algorithm")) + + val allSupportedHiveVersions = Set(v12, v13, v14, v1_0, v1_1, v1_2, v2_0) } // scalastyle:on diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 14b9565be02f1..28f074849c0f5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -163,7 +163,8 @@ case class HiveTableScanExec( sparkSession.sharedState.externalCatalog.listPartitionsByFilter( relation.tableMeta.database, relation.tableMeta.identifier.table, - normalizedFilters) + normalizedFilters, + sparkSession.sessionState.conf.sessionLocalTimeZone) } else { sparkSession.sharedState.externalCatalog.listPartitions( relation.tableMeta.database, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index f107149adaf10..b8536d0c1bd58 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -148,9 +148,16 @@ case class InsertIntoHiveTable( // We have to follow the Hive behavior here, to avoid troubles. For example, if we create // staging directory under the table director for Hive prior to 1.1, the staging directory will // be removed by Hive when Hive is trying to empty the table directory. - if (hiveVersion == v12 || hiveVersion == v13 || hiveVersion == v14 || hiveVersion == v1_0) { + val hiveVersionsUsingOldExternalTempPath: Set[HiveVersion] = Set(v12, v13, v14, v1_0) + val hiveVersionsUsingNewExternalTempPath: Set[HiveVersion] = Set(v1_1, v1_2, v2_0) + + // Ensure all the supported versions are considered here. + assert(hiveVersionsUsingNewExternalTempPath ++ hiveVersionsUsingOldExternalTempPath == + allSupportedHiveVersions) + + if (hiveVersionsUsingOldExternalTempPath.contains(hiveVersion)) { oldVersionExternalTempPath(path, hadoopConf, scratchDir) - } else if (hiveVersion == v1_1 || hiveVersion == v1_2) { + } else if (hiveVersionsUsingNewExternalTempPath.contains(hiveVersion)) { newVersionExternalTempPath(path, hadoopConf, stagingDir) } else { throw new IllegalStateException("Unsupported hive version: " + hiveVersion.fullVersion) @@ -386,8 +393,8 @@ case class InsertIntoHiveTable( logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e) } - // Invalidate the cache. - sparkSession.catalog.uncacheTable(table.qualifiedName) + // un-cache this table. + sparkSession.catalog.uncacheTable(table.identifier.quotedString) sparkSession.sessionState.catalog.refreshTable(table.identifier) // It would be nice to just return the childRdd unchanged so insert operations could be chained, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 8ccc2b7527f24..2b3f36064c1f8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -195,10 +195,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto tempPath.delete() table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) sql("DROP TABLE IF EXISTS refreshTable") - sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet") - checkAnswer( - table("refreshTable"), - table("src").collect()) + sparkSession.catalog.createTable("refreshTable", tempPath.toString, "parquet") + checkAnswer(table("refreshTable"), table("src")) // Cache the table. sql("CACHE TABLE refreshTable") assertCached(table("refreshTable")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 6d7a1c3937a96..490e02d0bd541 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.net.URI + import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute @@ -70,7 +72,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.identifier.database == Some("mydb")) assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some("/user/external/page_view")) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) assert(desc.schema.isEmpty) // will be populated later when the table is actually created assert(desc.comment == Some("This is the staging page view table")) // TODO will be SQLText @@ -102,7 +104,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.identifier.database == Some("mydb")) assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some("/user/external/page_view")) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) assert(desc.schema.isEmpty) // will be populated later when the table is actually created // TODO will be SQLText assert(desc.comment == Some("This is the staging page view table")) @@ -338,7 +340,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle val query = "CREATE EXTERNAL TABLE tab1 (id int, name string) LOCATION '/path/to/nowhere'" val (desc, _) = extractTableDesc(query) assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some("/path/to/nowhere")) + assert(desc.storage.locationUri == Some(new URI("/path/to/nowhere"))) } test("create table - if not exists") { @@ -469,7 +471,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle assert(desc.viewText.isEmpty) assert(desc.viewDefaultDatabase.isEmpty) assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.storage.locationUri == Some("/path/to/mercury")) + assert(desc.storage.locationUri == Some(new URI("/path/to/mercury"))) assert(desc.storage.inputFormat == Some("winput")) assert(desc.storage.outputFormat == Some("wowput")) assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) @@ -644,7 +646,7 @@ class HiveDDLCommandSuite extends PlanTest with SQLTestUtils with TestHiveSingle .add("id", "int") .add("name", "string", nullable = true, comment = "blabla")) assert(table.provider == Some(DDLUtils.HIVE_PROVIDER)) - assert(table.storage.locationUri == Some("/tmp/file")) + assert(table.storage.locationUri == Some(new URI("/tmp/file"))) assert(table.storage.properties == Map("my_prop" -> "1")) assert(table.comment == Some("BLABLA")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala index ee632d24b717e..705d43f1f3aba 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogBackwardCompatibilitySuite.scala @@ -40,7 +40,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest spark.sharedState.externalCatalog.asInstanceOf[HiveExternalCatalog].client val tempDir = Utils.createTempDir().getCanonicalFile - val tempDirUri = tempDir.toURI.toString.stripSuffix("/") + val tempDirUri = tempDir.toURI + val tempDirStr = tempDir.getAbsolutePath override def beforeEach(): Unit = { sql("CREATE DATABASE test_db") @@ -59,9 +60,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest } private def defaultTableURI(tableName: String): URI = { - val defaultPath = - spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db"))) - new Path(defaultPath).toUri + spark.sessionState.catalog.defaultTablePath(TableIdentifier(tableName, Some("test_db"))) } // Raw table metadata that are dumped from tables created by Spark 2.0. Note that, all spark @@ -170,8 +169,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest identifier = TableIdentifier("tbl7", Some("test_db")), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( - locationUri = Some(defaultTableURI("tbl7").toString + "-__PLACEHOLDER__"), - properties = Map("path" -> tempDirUri)), + locationUri = Some(new URI(defaultTableURI("tbl7") + "-__PLACEHOLDER__")), + properties = Map("path" -> tempDirStr)), schema = new StructType(), provider = Some("json"), properties = Map( @@ -184,7 +183,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( locationUri = Some(tempDirUri), - properties = Map("path" -> tempDirUri)), + properties = Map("path" -> tempDirStr)), schema = simpleSchema, properties = Map( "spark.sql.sources.provider" -> "parquet", @@ -195,8 +194,8 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest identifier = TableIdentifier("tbl9", Some("test_db")), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( - locationUri = Some(defaultTableURI("tbl9").toString + "-__PLACEHOLDER__"), - properties = Map("path" -> tempDirUri)), + locationUri = Some(new URI(defaultTableURI("tbl9") + "-__PLACEHOLDER__")), + properties = Map("path" -> tempDirStr)), schema = new StructType(), provider = Some("json"), properties = Map("spark.sql.sources.provider" -> "json")) @@ -220,7 +219,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest if (tbl.tableType == CatalogTableType.EXTERNAL) { // trim the URI prefix - val tableLocation = new URI(readBack.storage.locationUri.get).getPath + val tableLocation = readBack.storage.locationUri.get.getPath val expectedLocation = tempDir.toURI.getPath.stripSuffix("/") assert(tableLocation == expectedLocation) } @@ -236,7 +235,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest val readBack = getTableMetadata(tbl.identifier.table) // trim the URI prefix - val actualTableLocation = new URI(readBack.storage.locationUri.get).getPath + val actualTableLocation = readBack.storage.locationUri.get.getPath val expected = dir.toURI.getPath.stripSuffix("/") assert(actualTableLocation == expected) } @@ -252,7 +251,7 @@ class HiveExternalCatalogBackwardCompatibilitySuite extends QueryTest assert(readBack.schema.sameType(expectedSchema)) // trim the URI prefix - val actualTableLocation = new URI(readBack.storage.locationUri.get).getPath + val actualTableLocation = readBack.storage.locationUri.get.getPath val expectedLocation = if (tbl.tableType == CatalogTableType.EXTERNAL) { tempDir.toURI.getPath.stripSuffix("/") } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala index a60c210b04c8a..4349f1aa23be0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -52,7 +52,7 @@ class HiveExternalCatalogSuite extends ExternalCatalogSuite { test("list partitions by filter") { val catalog = newBasicCatalog() - val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1)) + val selectedPartitions = catalog.listPartitionsByFilter("db2", "tbl2", Seq('a.int === 1), "GMT") assert(selectedPartitions.length == 1) assert(selectedPartitions.head.spec == part1.spec) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 16cf4d7ec67f6..892a22ddfafc8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import java.net.URI + import org.apache.spark.sql.{QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -140,7 +142,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(hiveTable.storage.serde === Some(serde)) assert(hiveTable.tableType === CatalogTableType.EXTERNAL) - assert(hiveTable.storage.locationUri === Some(path.toString)) + assert(hiveTable.storage.locationUri === Some(new URI(path.getAbsolutePath))) val columns = hiveTable.schema assert(columns.map(_.name) === Seq("d1", "d2")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 8f0d5d886c9d5..5f15a705a2e99 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -485,7 +485,7 @@ object SetWarehouseLocationTest extends Logging { val tableMetadata = catalog.getTableMetadata(TableIdentifier("testLocation", Some("default"))) val expectedLocation = - "file:" + expectedWarehouseLocation.toString + "/testlocation" + CatalogUtils.stringToURI(s"file:${expectedWarehouseLocation.toString}/testlocation") val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( @@ -500,8 +500,8 @@ object SetWarehouseLocationTest extends Logging { sparkSession.sql("create table testLocation (a int)") val tableMetadata = catalog.getTableMetadata(TableIdentifier("testLocation", Some("testLocationDB"))) - val expectedLocation = - "file:" + expectedWarehouseLocation.toString + "/testlocationdb.db/testlocation" + val expectedLocation = CatalogUtils.stringToURI( + s"file:${expectedWarehouseLocation.toString}/testlocationdb.db/testlocation") val actualLocation = tableMetadata.location if (actualLocation != expectedLocation) { throw new Exception( @@ -868,14 +868,16 @@ object SPARK_18360 { val rawTable = hiveClient.getTable("default", "test_tbl") // Hive will use the value of `hive.metastore.warehouse.dir` to generate default table // location for tables in default database. - assert(rawTable.storage.locationUri.get.contains(newWarehousePath)) + assert(rawTable.storage.locationUri.map( + CatalogUtils.URIToString(_)).get.contains(newWarehousePath)) hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = false, purge = false) spark.sharedState.externalCatalog.createTable(tableMeta, ignoreIfExists = false) val readBack = spark.sharedState.externalCatalog.getTable("default", "test_tbl") // Spark SQL will use the location of default database to generate default table // location for tables in default database. - assert(readBack.storage.locationUri.get.contains(defaultDbLocation)) + assert(readBack.storage.locationUri.map(CatalogUtils.URIToString(_)) + .get.contains(defaultDbLocation)) } finally { hiveClient.dropTable("default", "test_tbl", ignoreIfNotExists = true, purge = false) hiveClient.runSqlHive(s"SET hive.metastore.warehouse.dir=$defaultDbLocation") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 71ce5a7c4a15a..d6999af84eac0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -284,19 +284,6 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef sql("DROP TABLE hiveTableWithStructValue") } - test("Reject partitioning that does not match table") { - withSQLConf(("hive.exec.dynamic.partition.mode", "nonstrict")) { - sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") - val data = (1 to 10).map(i => (i, s"data-$i", if ((i % 2) == 0) "even" else "odd")) - .toDF("id", "data", "part") - - intercept[AnalysisException] { - // cannot partition by 2 fields when there is only one in the table definition - data.write.partitionBy("part", "data").insertInto("partitioned") - } - } - } - test("Test partition mode = strict") { withSQLConf(("hive.exec.dynamic.partition.mode", "strict")) { sql("CREATE TABLE partitioned (id bigint, data string) PARTITIONED BY (part string)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 03ea0c8c77682..f02b7218d6eee 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -1011,7 +1011,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv identifier = TableIdentifier("not_skip_hive_metadata"), tableType = CatalogTableType.EXTERNAL, storage = CatalogStorageFormat.empty.copy( - locationUri = Some(tempPath.getCanonicalPath), + locationUri = Some(tempPath.toURI), properties = Map("skipHiveMetadata" -> "false") ), schema = schema, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala index 47ee4dd4d952c..4aea6d14efb0e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.hive +import java.net.URI + +import org.apache.hadoop.fs.Path + import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils @@ -26,8 +30,8 @@ class MultiDatabaseSuite extends QueryTest with SQLTestUtils with TestHiveSingle private def checkTablePath(dbName: String, tableName: String): Unit = { val metastoreTable = spark.sharedState.externalCatalog.getTable(dbName, tableName) - val expectedPath = - spark.sharedState.externalCatalog.getDatabase(dbName).locationUri + "/" + tableName + val expectedPath = new Path(new Path( + spark.sharedState.externalCatalog.getDatabase(dbName).locationUri), tableName).toUri assert(metastoreTable.location === expectedPath) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala index b792a168a4f9b..50506197b3138 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/PartitionedTablePerfStatsSuite.scala @@ -411,4 +411,15 @@ class PartitionedTablePerfStatsSuite } } } + + test("resolveRelation for a FileFormat DataSource without userSchema scan filesystem only once") { + withTempDir { dir => + import spark.implicits._ + Seq(1).toDF("a").write.mode("overwrite").save(dir.getAbsolutePath) + HiveCatalogMetrics.reset() + spark.read.parquet(dir.getAbsolutePath) + assert(HiveCatalogMetrics.METRIC_FILES_DISCOVERED.getCount() == 1) + assert(HiveCatalogMetrics.METRIC_FILE_CACHE_HITS.getCount() == 1) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala index 591a968c82847..e85ea5a59427d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientBuilder.scala @@ -35,22 +35,26 @@ private[client] class HiveClientBuilder { Some(new File(sys.props("java.io.tmpdir"), "hive-ivy-cache").getAbsolutePath)) } - private def buildConf() = { + private def buildConf(extraConf: Map[String, String]) = { lazy val warehousePath = Utils.createTempDir() lazy val metastorePath = Utils.createTempDir() metastorePath.delete() - Map( + extraConf ++ Map( "javax.jdo.option.ConnectionURL" -> s"jdbc:derby:;databaseName=$metastorePath;create=true", "hive.metastore.warehouse.dir" -> warehousePath.toString) } - def buildClient(version: String, hadoopConf: Configuration): HiveClient = { + // for testing only + def buildClient( + version: String, + hadoopConf: Configuration, + extraConf: Map[String, String] = Map.empty): HiveClient = { IsolatedClientLoader.forVersion( hiveMetastoreVersion = version, hadoopVersion = VersionInfo.getVersion, sparkConf = sparkConf, hadoopConf = hadoopConf, - config = buildConf(), + config = buildConf(extraConf), ivyPath = ivyPath).createClient() } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index 6feb277ca88e9..dd624eca6b7b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive.client import java.io.{ByteArrayOutputStream, File, PrintStream} +import java.net.URI import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -54,7 +55,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w test("success sanity check") { val badClient = buildClient(HiveUtils.hiveExecutionVersion, new Configuration()) - val db = new CatalogDatabase("default", "desc", "loc", Map()) + val db = new CatalogDatabase("default", "desc", new URI("loc"), Map()) badClient.createDatabase(db, ignoreIfExists = true) } @@ -88,7 +89,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w assert(getNestedMessages(e) contains "Unknown column 'A0.OWNER_NAME' in 'field list'") } - private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2") + private val versions = Seq("0.12", "0.13", "0.14", "1.0", "1.1", "1.2", "2.0") private var client: HiveClient = null @@ -98,7 +99,12 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w System.gc() // Hack to avoid SEGV on some JVM versions. val hadoopConf = new Configuration() hadoopConf.set("test", "success") - client = buildClient(version, hadoopConf) + // Hive changed the default of datanucleus.schema.autoCreateAll from true to false since 2.0 + // For details, see the JIRA HIVE-6113 + if (version == "2.0") { + hadoopConf.set("datanucleus.schema.autoCreateAll", "true") + } + client = buildClient(version, hadoopConf, HiveUtils.hiveClientConfigurations(hadoopConf)) } def table(database: String, tableName: String): CatalogTable = { @@ -120,10 +126,10 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w // Database related API /////////////////////////////////////////////////////////////////////////// - val tempDatabasePath = Utils.createTempDir().getCanonicalPath + val tempDatabasePath = Utils.createTempDir().toURI test(s"$version: createDatabase") { - val defaultDB = CatalogDatabase("default", "desc", "loc", Map()) + val defaultDB = CatalogDatabase("default", "desc", new URI("loc"), Map()) client.createDatabase(defaultDB, ignoreIfExists = true) val tempDB = CatalogDatabase( "temporary", description = "test create", tempDatabasePath, Map()) @@ -341,7 +347,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w test(s"$version: alterPartitions") { val spec = Map("key1" -> "1", "key2" -> "2") - val newLocation = Utils.createTempDir().getPath() + val newLocation = new URI(Utils.createTempDir().toURI.toString.stripSuffix("/")) val storage = storageFormat.copy( locationUri = Some(newLocation), // needed for 0.12 alter partitions @@ -655,7 +661,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w val expectedPath = s"file:${tPath.toUri.getPath.stripSuffix("/")}" val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) - assert(table.location.stripSuffix("/") == expectedPath) + assert(table.location == CatalogUtils.stringToURI(expectedPath)) assert(tPath.getFileSystem(spark.sessionState.newHadoopConf()).exists(tPath)) checkAnswer(spark.table("t"), Row("1") :: Nil) @@ -664,7 +670,7 @@ class VersionsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton w val table1 = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) val expectedPath1 = s"file:${t1Path.toUri.getPath.stripSuffix("/")}" - assert(table1.location.stripSuffix("/") == expectedPath1) + assert(table1.location == CatalogUtils.stringToURI(expectedPath1)) assert(t1Path.getFileSystem(spark.sessionState.newHadoopConf()).exists(t1Path)) checkAnswer(spark.table("t1"), Row(2) :: Nil) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 792ac1e259494..df2c1cee942b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.hive.execution import java.io.File +import java.lang.reflect.InvocationTargetException +import java.net.URI import org.apache.hadoop.fs.Path import org.scalatest.BeforeAndAfterEach @@ -25,7 +27,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.HiveExternalCatalog @@ -710,7 +712,7 @@ class HiveDDLSuite } sql(s"CREATE DATABASE $dbName Location '${tmpDir.toURI.getPath.stripSuffix("/")}'") val db1 = catalog.getDatabaseMetadata(dbName) - val dbPath = tmpDir.toURI.toString.stripSuffix("/") + val dbPath = new URI(tmpDir.toURI.toString.stripSuffix("/")) assert(db1 == CatalogDatabase(dbName, "", dbPath, Map.empty)) sql("USE db1") @@ -747,11 +749,12 @@ class HiveDDLSuite sql(s"CREATE DATABASE $dbName") val catalog = spark.sessionState.catalog val expectedDBLocation = s"file:${dbPath.toUri.getPath.stripSuffix("/")}/$dbName.db" + val expectedDBUri = CatalogUtils.stringToURI(expectedDBLocation) val db1 = catalog.getDatabaseMetadata(dbName) assert(db1 == CatalogDatabase( dbName, "", - expectedDBLocation, + expectedDBUri, Map.empty)) // the database directory was created assert(fs.exists(dbPath) && fs.isDirectory(dbPath)) @@ -1587,4 +1590,294 @@ class HiveDDLSuite } } } + + Seq(true, false).foreach { shouldDelete => + val tcName = if (shouldDelete) "non-existent" else "existed" + test(s"CTAS for external data source table with a $tcName location") { + withTable("t", "t1") { + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t + |USING parquet + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == new URI(dir.getAbsolutePath)) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t1 + |USING parquet + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == new URI(dir.getAbsolutePath)) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + + test(s"CTAS for external hive table with a $tcName location") { + withTable("t", "t1") { + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t + |USING hive + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val dirPath = new Path(dir.getAbsolutePath) + val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(new Path(table.location) == fs.makeQualified(dirPath)) + + checkAnswer(spark.table("t"), Row(3, 4, 1, 2)) + } + // partition table + withTempDir { + dir => + if (shouldDelete) { + dir.delete() + } + spark.sql( + s""" + |CREATE TABLE t1 + |USING hive + |PARTITIONED BY(a, b) + |LOCATION '$dir' + |AS SELECT 3 as a, 4 as b, 1 as c, 2 as d + """.stripMargin) + val dirPath = new Path(dir.getAbsolutePath) + val fs = dirPath.getFileSystem(spark.sessionState.newHadoopConf()) + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(new Path(table.location) == fs.makeQualified(dirPath)) + + val partDir = new File(dir, "a=3") + assert(partDir.exists()) + + checkAnswer(spark.table("t1"), Row(1, 2, 3, 4)) + } + } + } + } + } + + Seq("parquet", "hive").foreach { datasource => + Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => + test(s"partition column name of $datasource table containing $specialChars") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, `$specialChars` string) + |USING $datasource + |PARTITIONED BY(`$specialChars`) + |LOCATION '$dir' + """.stripMargin) + + assert(dir.listFiles().isEmpty) + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") + val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" + val partFile = new File(dir, partEscaped) + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Nil) + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`) SELECT 3, 4") + val partEscaped1 = s"${ExternalCatalogUtils.escapePathName(specialChars)}=4" + val partFile1 = new File(dir, partEscaped1) + assert(partFile1.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Row("3", "4") :: Nil) + } + } + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"datasource table: location uri contains $specialChars") { + withTable("t", "t1") { + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING parquet + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + assert(table.location == new Path(loc.getAbsolutePath).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t SELECT 1") + assert(loc.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1") :: Nil) + } + + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING parquet + |PARTITIONED BY(b) + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + assert(table.location == new Path(loc.getAbsolutePath).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + val partFile = new File(loc, "b=2") + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) + + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") + assert(!partFile1.exists()) + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"hive table: location uri contains $specialChars") { + withTable("t") { + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t(a string) + |USING hive + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + val path = new Path(loc.getAbsolutePath) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + assert(table.location == fs.makeQualified(path).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + if (specialChars != "a:b") { + spark.sql("INSERT INTO TABLE t SELECT 1") + assert(loc.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1") :: Nil) + } else { + val e = intercept[InvocationTargetException] { + spark.sql("INSERT INTO TABLE t SELECT 1") + }.getTargetException.getMessage + assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) + } + } + + withTempDir { dir => + val loc = new File(dir, specialChars) + loc.mkdir() + spark.sql( + s""" + |CREATE TABLE t1(a string, b string) + |USING hive + |PARTITIONED BY(b) + |LOCATION '$loc' + """.stripMargin) + + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t1")) + val path = new Path(loc.getAbsolutePath) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + assert(table.location == fs.makeQualified(path).toUri) + assert(new Path(table.location).toString.contains(specialChars)) + + assert(loc.listFiles().isEmpty) + if (specialChars != "a:b") { + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + val partFile = new File(loc, "b=2") + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t1"), Row("1", "2") :: Nil) + + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + val partFile1 = new File(loc, "b=2017-03-03 12:13%3A14") + assert(!partFile1.exists()) + val partFile2 = new File(loc, "b=2017-03-03 12%3A13%253A14") + assert(partFile2.listFiles().length >= 1) + checkAnswer(spark.table("t1"), + Row("1", "2") :: Row("1", "2017-03-03 12:13%3A14") :: Nil) + } else { + val e = intercept[InvocationTargetException] { + spark.sql("INSERT INTO TABLE t1 PARTITION(b=2) SELECT 1") + }.getTargetException.getMessage + assert(e.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) + + val e1 = intercept[InvocationTargetException] { + spark.sql("INSERT INTO TABLE t1 PARTITION(b='2017-03-03 12:13%3A14') SELECT 1") + }.getTargetException.getMessage + assert(e1.contains("java.net.URISyntaxException: Relative path in absolute URI: a:b")) + } + } + } + } + } + + Seq("a b", "a:b", "a%b").foreach { specialChars => + test(s"location uri contains $specialChars for database") { + try { + withTable("t") { + withTempDir { dir => + val loc = new File(dir, specialChars) + spark.sql(s"CREATE DATABASE tmpdb LOCATION '$loc'") + spark.sql("USE tmpdb") + + Seq(1).toDF("a").write.saveAsTable("t") + val tblloc = new File(loc, "t") + val table = spark.sessionState.catalog.getTableMetadata(TableIdentifier("t")) + val tblPath = new Path(tblloc.getAbsolutePath) + val fs = tblPath.getFileSystem(spark.sessionState.newHadoopConf()) + assert(table.location == fs.makeQualified(tblPath).toUri) + assert(tblloc.listFiles().nonEmpty) + } + } + } finally { + spark.sql("DROP DATABASE IF EXISTS tmpdb") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index ef2d451e6b2d6..be9a5fd71bd25 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.TestUtils import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry, NoSuchPartitionException} -import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} @@ -544,7 +544,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } userSpecifiedLocation match { case Some(location) => - assert(r.tableMeta.location === location) + assert(r.tableMeta.location === CatalogUtils.stringToURI(location)) case None => // OK. } // Also make sure that the format and serde are as desired. diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 3512c4a890313..81af24979d822 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -453,7 +453,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { // Converted test_parquet should be cached. sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => // OK + case LogicalRelation(_: HadoopFsRelation, _, _) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " + diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index 7fcf45e7dedc9..ee2fd45a7e851 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -152,11 +152,9 @@ trait DStreamCheckpointTester { self: SparkFunSuite => stopSparkContext: Boolean ): Seq[Seq[V]] = { try { - val batchDuration = ssc.graph.batchDuration val batchCounter = new BatchCounter(ssc) ssc.start() val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] - val currentTime = clock.getTimeMillis() logInfo("Manual clock before advancing = " + clock.getTimeMillis()) clock.setTime(targetBatchTime.milliseconds) @@ -171,7 +169,7 @@ trait DStreamCheckpointTester { self: SparkFunSuite => eventually(timeout(10 seconds)) { val checkpointFilesOfLatestTime = Checkpoint.getCheckpointFiles(checkpointDir).filter { - _.toString.contains(clock.getTimeMillis.toString) + _.getName.contains(clock.getTimeMillis.toString) } // Checkpoint files are written twice for every batch interval. So assert that both // are written to make sure that both of them have been written.